diff --git a/projects/rocshmem/docs/api/coll.rst b/projects/rocshmem/docs/api/coll.rst index 809f6e53d2..43f4cb4c1a 100644 --- a/projects/rocshmem/docs/api/coll.rst +++ b/projects/rocshmem/docs/api/coll.rst @@ -22,6 +22,20 @@ This routine performs a collective barrier across all PEs in the system. The caller is blocked until the barrier is resolved and all updates local and remote are completed. These APIs should be called from only one thread/wavefront/workgroup within the grid to avoid undefined behavior. +ROCSHMEM_BARRIER_ALL_ON_STREAM +------------------------------- + +.. cpp:function:: __host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) + + :param stream: HIP stream on which to enqueue the operation. + :returns: None. + +**Description:** +This routine enqueues a collective barrier operation on a HIP stream. The barrier is performed +across all PEs in the system. The operation is enqueued on the specified stream and will execute +asynchronously. The caller must synchronize the stream (e.g., using ``hipStreamSynchronize``) +to ensure completion. + ROCSHMEM_BARRIER ---------------- @@ -109,7 +123,6 @@ execute asynchronously. The caller must synchronize the stream (e.g., using This function creates a separate context for each workgroup to avoid contention on the default context, allowing parallel execution across multiple streams. -If ``stream`` is ``nullptr``, the operation will use ``hipStreamDefault``. ROCSHMEM_BROADCAST ------------------ @@ -131,6 +144,28 @@ The caller is blocked until the broadcast completes. Valid ``TYPENAME`` and ``TYPE`` values are listed in :ref:`RMA_TYPES`. +ROCSHMEM_BROADCASTMEM_ON_STREAM +-------------------------------- + +.. cpp:function:: __host__ void rocshmem_broadcastmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t nelems, int pe_root, hipStream_t stream) + + :param team: The team participating in the collective. + :param dest: Destination address. Must be an address on the symmetric heap. + :param source: Source address. Must be an address on the symmetric heap. + :param nelems: Number of bytes to broadcast. + :param pe_root: Root PE (relative to team) from which to broadcast. + :param stream: HIP stream on which to enqueue the operation. + :returns: None. + +**Description:** +This routine enqueues a broadcast collective operation on a HIP stream. The function broadcasts +data from the root PE to all other PEs participating in the collective routine. The operation +is enqueued on the specified stream and will execute asynchronously. The caller must synchronize +the stream (e.g., using ``hipStreamSynchronize``) to ensure completion. + +This function creates a separate context for each workgroup to avoid contention on the +default context, allowing parallel execution across multiple streams. + ROCSHMEM_FCOLLECT ----------------- diff --git a/projects/rocshmem/docs/api/pt2pt_sync.rst b/projects/rocshmem/docs/api/pt2pt_sync.rst index 10fc3cf452..6985b61e02 100644 --- a/projects/rocshmem/docs/api/pt2pt_sync.rst +++ b/projects/rocshmem/docs/api/pt2pt_sync.rst @@ -96,6 +96,25 @@ ROCSHMEM_TEST **Description:** This routine tests if the condition ``(*ivars cmp val)`` is true. +ROCSHMEM_SIGNAL_WAIT_UNTIL_ON_STREAM +------------------------------------- + +.. cpp:function:: __host__ void rocshmem_signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, uint64_t cmp_value, hipStream_t stream) + + :param sig_addr: Address of the signal variable on the symmetric heap. + :param cmp: Comparison operator (e.g., ROCSHMEM_CMP_EQ, ROCSHMEM_CMP_GE, etc.). + :param cmp_value: Value to compare against. + :param stream: HIP stream on which to enqueue the operation. + :returns: None. + +**Description:** +This routine enqueues a wait operation on a HIP stream. The function blocks the calling thread +until the signal variable at ``sig_addr`` satisfies the comparison condition ``(*sig_addr cmp cmp_value)``. +The wait operation is executed asynchronously on the specified stream. The caller must synchronize +the stream (e.g., using ``hipStreamSynchronize``) to ensure the wait condition has been satisfied. + +Valid ``cmp`` values are listed in :ref:`CMP_VALUES`. + .. _CMP_VALUES: Supported comparisons diff --git a/projects/rocshmem/docs/api/rma.rst b/projects/rocshmem/docs/api/rma.rst index 2ae956fa40..d4bf4b09b5 100644 --- a/projects/rocshmem/docs/api/rma.rst +++ b/projects/rocshmem/docs/api/rma.rst @@ -67,6 +67,25 @@ ROCSHMEM_PUTMEM **Description:** This routine writes contiguous data of ``nelems`` bytes from source on the calling PE to ``dest`` at ``pe``. +ROCSHMEM_PUTMEM_ON_STREAM +-------------------------- + +.. cpp:function:: __host__ void rocshmem_putmem_on_stream(void *dest, const void *source, size_t nelems, int pe, hipStream_t stream) + + :param dest: Destination address. Must be an address on the symmetric heap. + :param source: Source address. Must be an address on the symmetric heap. + :param nelems: Size of the transfer in bytes. + :param pe: PE of the remote process. + :param stream: HIP stream on which to enqueue the operation. + + :returns: None. + +**Description:** +This routine enqueues a putmem RMA operation on a HIP stream. The function writes contiguous +data of ``nelems`` bytes from source on the calling PE to ``dest`` at ``pe``. The operation +is enqueued on the specified stream and will execute asynchronously. The caller must +synchronize the stream (e.g., using ``hipStreamSynchronize``) to ensure completion. + ROCSHMEM_P ---------- @@ -137,6 +156,25 @@ ROCSHMEM_GETMEM **Description:** This routine reads contiguous data of ``nelems`` bytes from source on ``pe`` to ``dest`` on the calling PE. +ROCSHMEM_GETMEM_ON_STREAM +-------------------------- + +.. cpp:function:: __host__ void rocshmem_getmem_on_stream(void *dest, const void *source, size_t nelems, int pe, hipStream_t stream) + + :param dest: Destination address. Must be an address on the symmetric heap. + :param source: Source address. Must be an address on the symmetric heap. + :param nelems: Size of the transfer in bytes. + :param pe: PE of the remote process. + :param stream: HIP stream on which to enqueue the operation. + + :returns: None. + +**Description:** +This routine enqueues a getmem RMA operation on a HIP stream. The function reads contiguous +data of ``nelems`` bytes from source on ``pe`` to ``dest`` on the calling PE. The operation +is enqueued on the specified stream and will execute asynchronously. The caller must +synchronize the stream (e.g., using ``hipStreamSynchronize``) to ensure completion. + ROCSHMEM_G ---------- .. cpp:function:: __device__ float rocshmem_ctx_float_g(rocshmem_ctx_t ctx, const float *source, int pe) diff --git a/projects/rocshmem/docs/api/sigops.rst b/projects/rocshmem/docs/api/sigops.rst index ac53ed21e3..578eb860bb 100644 --- a/projects/rocshmem/docs/api/sigops.rst +++ b/projects/rocshmem/docs/api/sigops.rst @@ -71,6 +71,30 @@ then applies ``sig_op`` at ``sig_addr`` with the signal value. Valid ``sig_op values`` are listed in SIGNAL_OPERATORS_. Valid ``TYPENAME`` and ``TYPE`` values are listed in :ref:`RMA_TYPES`. +ROCSHMEM_PUTMEM_SIGNAL_ON_STREAM +--------------------------------- + +.. cpp:function:: __host__ void rocshmem_putmem_signal_on_stream(void *dest, const void *source, size_t nelems, uint64_t *sig_addr, uint64_t signal, int sig_op, int pe, hipStream_t stream) + + :param dest: Destination address on the remote PE. Must be an address on the symmetric heap. + :param source: Source address on the local PE. Must be an address on the symmetric heap. + :param nelems: Size of the transfer in bytes. + :param sig_addr: Address of signal variable on the remote PE. Must be an address on the symmetric heap. + :param signal: Signal value to be written. + :param sig_op: Signal operation (ROCSHMEM_SIGNAL_SET or ROCSHMEM_SIGNAL_ADD). + :param pe: PE number of the remote PE. + :param stream: HIP stream on which to enqueue the operation. + :returns: None. + +**Description:** +This routine enqueues a put-with-signal operation on a HIP stream. The function writes contiguous +data of ``nelems`` bytes from source on the calling PE to ``dest`` at ``pe``, then applies ``sig_op`` +at ``sig_addr`` with the signal value. The operation is enqueued on the specified stream and will +execute asynchronously. The caller must synchronize the stream (e.g., using ``hipStreamSynchronize``) +to ensure completion. + +Valid ``sig_op`` values are listed in SIGNAL_OPERATORS_. + ROCSHMEM_SIGNAL_FETCH --------------------- diff --git a/projects/rocshmem/include/rocshmem/rocshmem.hpp b/projects/rocshmem/include/rocshmem/rocshmem.hpp index cbe0ca2d6a..ad4148651a 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem.hpp @@ -365,6 +365,102 @@ __host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream); +/** + * @brief enqueues a broadcast collective operation on given stream. + * + * @param[in] team The team participating in the collective. + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] nelems Number of bytes to broadcast. + * @param[in] pe_root Root PE (relative to team) from which to broadcast. + * @param[in] stream HIP stream on which to enqueue the operation. + * + * @return void + */ +__host__ void rocshmem_broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream); + +/** + * @brief enqueues a getmem RMA operation on given stream. + * + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] nelems Size of the transfer in bytes. + * @param[in] pe PE of the remote process. + * @param[in] stream HIP stream on which to enqueue the operation. + * + * @return void + */ +__host__ void rocshmem_getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream); + +/** + * @brief enqueues a putmem RMA operation on given stream. + * + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] nelems Size of the transfer in bytes. + * @param[in] pe PE of the remote process. + * @param[in] stream HIP stream on which to enqueue the operation. + * + * @return void + */ +__host__ void rocshmem_putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream); + +/** + * @brief Perform a put operation with signal on a HIP stream. + * + * This routine initiates a remote memory transfer on a specified HIP stream. + * The source data is copied from the local PE to the remote PE's destination + * address. After the put operation completes, a signal operation is performed + * on a remote symmetric signal variable. + * + * @param[in] dest Destination address on the remote PE + * @param[in] source Source address on the local PE + * @param[in] nelems Size of the transfer in bytes + * @param[in] sig_addr Address of signal variable on the remote PE + * @param[in] signal Signal value to be written + * @param[in] sig_op Signal operation (ROCSHMEM_SIGNAL_SET or + * ROCSHMEM_SIGNAL_ADD) + * @param[in] pe PE number of the remote PE + * @param[in] stream HIP stream on which to enqueue the operation + * + * @return void + */ +__host__ void rocshmem_putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, + uint64_t *sig_addr, + uint64_t signal, int sig_op, + int pe, hipStream_t stream); + +/** + * @brief Wait on a signal variable until it satisfies the specified condition, + * with the operation enqueued on a HIP stream. + * + * This function blocks the calling thread until the signal variable at + * \p sig_addr satisfies the comparison condition (* \p sig_addr \p cmp + * \p cmp_value). The wait operation is executed asynchronously on the + * specified HIP stream. + * + * @param[in] sig_addr Address of the signal variable on the symmetric heap + * @param[in] cmp Comparison operator (e.g., ROCSHMEM_CMP_EQ, + * ROCSHMEM_CMP_GE, ROCSHMEM_CMP_NE, etc.) + * @param[in] cmp_value Value to compare against + * @param[in] stream HIP stream on which to enqueue the operation + * + * @return void + */ +__host__ void rocshmem_signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream); + /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. diff --git a/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp b/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp index af105c9bfc..0d13f082bb 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_COLL.hpp @@ -624,6 +624,23 @@ __global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team, const void *source, size_t size); +/** + * @brief kernel for performing a broadcast collective operation. + * Caller enqueues the kernel on given stream + * + * @param[in] team The team participating in the collective. + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] nelems Number of bytes to broadcast. + * @param[in] pe_root Root PE (relative to team) from which to broadcast. + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_broadcastmem_kernel( + rocshmem_team_t team, void *dest, const void *source, size_t nelems, + int pe_root); + /** * @brief perform a collective barrier between all PEs in the system. * The caller is blocked until the barrier is resolved. diff --git a/projects/rocshmem/include/rocshmem/rocshmem_P2P_SYNC.hpp b/projects/rocshmem/include/rocshmem/rocshmem_P2P_SYNC.hpp index 6ec7277397..4cc22ab940 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_P2P_SYNC.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_P2P_SYNC.hpp @@ -576,6 +576,47 @@ __host__ size_t rocshmem_ulonglong_wait_until_some_vector( unsigned long long *ivars, size_t nelems, size_t* indices, const int* status, int cmp, unsigned long long val); +__device__ void rocshmem_uint64_wait_until( + uint64_t *ivars, int cmp, uint64_t val); +__device__ size_t rocshmem_uint64_wait_until_any( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__device__ void rocshmem_uint64_wait_until_all( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__device__ size_t rocshmem_uint64_wait_until_some( + uint64_t *ivars, size_t nelems, size_t* indices, const int* status, + int cmp, uint64_t val); +__device__ size_t rocshmem_uint64_wait_until_any_vector( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__device__ void rocshmem_uint64_wait_until_all_vector( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__device__ size_t rocshmem_uint64_wait_until_some_vector( + uint64_t *ivars, size_t nelems, size_t* indices, const int* status, + int cmp, uint64_t val); +__host__ void rocshmem_uint64_wait_until( + uint64_t *ivars, int cmp, uint64_t val); +__host__ size_t rocshmem_uint64_wait_until_any( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__host__ void rocshmem_uint64_wait_until_all( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__host__ size_t rocshmem_uint64_wait_until_some( + uint64_t *ivars, size_t nelems, size_t* indices, const int* status, + int cmp, uint64_t val); +__host__ size_t rocshmem_uint64_wait_until_any_vector( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__host__ void rocshmem_uint64_wait_until_all_vector( + uint64_t *ivars, size_t nelems, const int* status, + int cmp, uint64_t val); +__host__ size_t rocshmem_uint64_wait_until_some_vector( + uint64_t *ivars, size_t nelems, size_t* indices, const int* status, + int cmp, uint64_t val); + /** * @name SHMEM_TEST @@ -658,6 +699,11 @@ __device__ int rocshmem_ulonglong_test( __host__ int rocshmem_ulonglong_test( unsigned long long *ivars, int cmp, unsigned long long val); +__device__ int rocshmem_uint64_test( + uint64_t *ivars, int cmp, uint64_t val); +__host__ int rocshmem_uint64_test( + uint64_t *ivars, int cmp, uint64_t val); + } // namespace rocshmem diff --git a/projects/rocshmem/include/rocshmem/rocshmem_RMA.hpp b/projects/rocshmem/include/rocshmem/rocshmem_RMA.hpp index 8d627ad633..38af00ffad 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_RMA.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_RMA.hpp @@ -1209,6 +1209,37 @@ __host__ void rocshmem_ctx_getmem_nbi(rocshmem_ctx_t ctx, void *dest, __host__ void rocshmem_getmem_nbi(void *dest, const void *source, size_t nelems, int pe); +/** + * @brief kernel for performing a getmem RMA operation. + * Caller enqueues the kernel on given stream + * + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] nelems Size of the transfer in bytes. + * @param[in] pe PE of the remote process. + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_getmem_kernel(void *dest, + const void *source, + size_t nelems, int pe); + +/** + * @brief kernel for performing a putmem RMA operation. + * Caller enqueues the kernel on given stream + * + * @param[in] dest Destination address. Must be an address on the symmetric + * heap. + * @param[in] source Source address. Must be an address on the symmetric heap. + * @param[in] nelems Size of the transfer in bytes. + * @param[in] pe PE of the remote process. + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_putmem_kernel(void *dest, + const void *source, + size_t nelems, int pe); } // namespace rocshmem diff --git a/projects/rocshmem/include/rocshmem/rocshmem_SIG_OP.hpp b/projects/rocshmem/include/rocshmem/rocshmem_SIG_OP.hpp index 3182fd1e9e..cb0b9d788b 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem_SIG_OP.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem_SIG_OP.hpp @@ -619,6 +619,35 @@ __device__ ATTR_NO_INLINE uint64_t rocshmem_signal_fetch(const uint64_t *sig_add __device__ ATTR_NO_INLINE uint64_t rocshmem_signal_fetch_wg(const uint64_t *sig_addr); __device__ ATTR_NO_INLINE uint64_t rocshmem_signal_fetch_wave(const uint64_t *sig_addr); +/** + * @brief Kernel wrapper for putmem_signal operation on stream + * + * @param[in] dest Destination address on remote PE + * @param[in] source Source address on local PE + * @param[in] nelems Size of the transfer in bytes + * @param[in] sig_addr Address of signal variable on remote PE + * @param[in] signal Signal value to write + * @param[in] sig_op Signal operation (ROCSHMEM_SIGNAL_SET or + * ROCSHMEM_SIGNAL_ADD) + * @param[in] pe PE of the remote process + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_putmem_signal_kernel( + void *dest, const void *source, size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe); + +/** + * @brief Kernel wrapper for signal_wait_until operation on stream + * + * @param[in] sig_addr Address of signal variable on the symmetric heap + * @param[in] cmp Comparison operator + * @param[in] cmp_value Value to compare against + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_signal_wait_until_kernel( + uint64_t *sig_addr, int cmp, uint64_t cmp_value); } // namespace rocshmem diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index f5010ee624..0da1aa8b08 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -110,6 +110,12 @@ declare -A TEST_NUMBERS=( ["teamctxblockinfra"]="74" ["teamctxoddeveninfra"]="75" ["alltoallmem_on_stream"]="76" + ["barrier_all_on_stream"]="77" + ["broadcastmem_on_stream"]="78" + ["getmem_on_stream"]="79" + ["putmem_on_stream"]="80" + ["putmem_signal_on_stream"]="81" + ["signal_wait_until_on_stream"]="82" ) ExecTest() { @@ -150,6 +156,11 @@ ExecTest() { OPTIONS+=" -x UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384" OPTIONS+=" --map-by numa --timeout $TIMEOUT" + if [[ "" != "$ROCSHMEM_TEST_USE_DEFAULT_STREAM" ]] + then + OPTIONS+=" -x ROCSHMEM_TEST_USE_DEFAULT_STREAM=$ROCSHMEM_TEST_USE_DEFAULT_STREAM" + fi + if [[ "" != "$HOSTFILE" ]] then OPTIONS+=" --hostfile $HOSTFILE" @@ -222,6 +233,12 @@ TestRMAPut() { ExecTest "shmemptr" 2 8 1 8 ExecTest "shmemptr" 2 16 128 8 + ExecTest "putmem_on_stream" 2 1 1 1048576 + + export ROCSHMEM_TEST_USE_DEFAULT_STREAM=1 + ExecTest "putmem_on_stream" 2 1 1 1048576 + unset ROCSHMEM_TEST_USE_DEFAULT_STREAM + ################################ Non-Blocking ################################ ExecTest "putnbi" 2 1 1 1048576 @@ -274,6 +291,8 @@ TestRMAGet() { ExecTest "g" 2 8 1 32 ExecTest "g" 2 16 128 4 + ExecTest "getmem_on_stream" 2 1 1 1048576 + ################################ Non-Blocking ################################ ExecTest "getnbi" 2 1 1 1048576 @@ -373,6 +392,9 @@ TestSigOps() { ExecTest "wgsignalfetch" 2 2 32 ExecTest "wavesignalfetch" 2 1 32 ExecTest "wavesignalfetch" 2 1 64 + + ExecTest "putmem_signal_on_stream" 2 1 1 1048576 + ExecTest "signal_wait_until_on_stream" 2 1 1 } TestColl() { @@ -430,7 +452,9 @@ TestColl() { ExecTest "teamreduction" 2 1 1 32768 - ExecTest "alltoallmem_on_stream" 2 1 1 32768 + ExecTest "alltoallmem_on_stream" 2 1 1 1048576 + ExecTest "broadcastmem_on_stream" 2 1 1 1048576 + ExecTest "barrier_all_on_stream" 2 1 1 } TestOther() { diff --git a/projects/rocshmem/src/context.hpp b/projects/rocshmem/src/context.hpp index 5eb82ef007..0862383396 100644 --- a/projects/rocshmem/src/context.hpp +++ b/projects/rocshmem/src/context.hpp @@ -400,6 +400,25 @@ class Context { const void *source, size_t size, hipStream_t stream); + __host__ void broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream); + + __host__ void getmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, + hipStream_t stream); + + __host__ void signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/context_host.cpp b/projects/rocshmem/src/context_host.cpp index 653f6ddaf0..135ccc7cb1 100644 --- a/projects/rocshmem/src/context_host.cpp +++ b/projects/rocshmem/src/context_host.cpp @@ -129,4 +129,46 @@ __host__ void Context::alltoallmem_on_stream(rocshmem_team_t team, void *dest, HOST_DISPATCH(alltoallmem_on_stream(team, dest, source, size, stream)); } +__host__ void Context::broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream) { + ctxHostStats.incStat(NUM_HOST_BROADCAST); + + HOST_DISPATCH( + broadcastmem_on_stream(team, dest, source, nelems, pe_root, stream)); +} + +__host__ void Context::getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + ctxHostStats.incStat(NUM_HOST_GET); + + HOST_DISPATCH(getmem_on_stream(dest, source, nelems, pe, stream)); +} + +__host__ void Context::putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + ctxHostStats.incStat(NUM_HOST_PUT); + + HOST_DISPATCH(putmem_on_stream(dest, source, nelems, pe, stream)); +} + +__host__ void Context::putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, + uint64_t *sig_addr, + uint64_t signal, int sig_op, + int pe, hipStream_t stream) { + ctxHostStats.incStat(NUM_HOST_PUT); + + HOST_DISPATCH(putmem_signal_on_stream(dest, source, nelems, sig_addr, signal, + sig_op, pe, stream)); +} + +__host__ void Context::signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream) { + HOST_DISPATCH(signal_wait_until_on_stream(sig_addr, cmp, cmp_value, stream)); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/context_gda_host.cpp b/projects/rocshmem/src/gda/context_gda_host.cpp index 61c1e59cec..c480f15fda 100644 --- a/projects/rocshmem/src/gda/context_gda_host.cpp +++ b/projects/rocshmem/src/gda/context_gda_host.cpp @@ -113,6 +113,10 @@ __host__ void GDAHostContext::barrier_all() { host_interface->barrier_all(context_window_info); } +__host__ void GDAHostContext::barrier_all_on_stream(hipStream_t stream) { + host_interface->barrier_all_on_stream(stream); +} + __host__ void GDAHostContext::alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, @@ -121,4 +125,39 @@ __host__ void GDAHostContext::alltoallmem_on_stream(rocshmem_team_t team, host_interface->alltoallmem_on_stream(team, dest, source, size, stream); } +__host__ void GDAHostContext::broadcastmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t nelems, int pe_root, + hipStream_t stream) { + host_interface->broadcastmem_on_stream(team, dest, source, nelems, pe_root, + stream); +} + +__host__ void GDAHostContext::getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + host_interface->getmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void GDAHostContext::putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + host_interface->putmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void GDAHostContext::putmem_signal_on_stream( + void *dest, const void *source, size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, hipStream_t stream) { + host_interface->putmem_signal_on_stream(dest, source, nelems, sig_addr, + signal, sig_op, pe, stream); +} + +__host__ void GDAHostContext::signal_wait_until_on_stream(uint64_t *sig_addr, + int cmp, + uint64_t cmp_value, + hipStream_t stream) { + host_interface->signal_wait_until_on_stream(sig_addr, cmp, cmp_value, stream); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/context_gda_host.hpp b/projects/rocshmem/src/gda/context_gda_host.hpp index 3479b999f2..500473f333 100644 --- a/projects/rocshmem/src/gda/context_gda_host.hpp +++ b/projects/rocshmem/src/gda/context_gda_host.hpp @@ -82,10 +82,31 @@ class GDAHostContext : public Context { __host__ void barrier_all(); + __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream); + __host__ void broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream); + + __host__ void getmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, + hipStream_t stream); + + __host__ void signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/host/host.cpp b/projects/rocshmem/src/host/host.cpp index 7b6a782dcd..9c086b3d2c 100644 --- a/projects/rocshmem/src/host/host.cpp +++ b/projects/rocshmem/src/host/host.cpp @@ -25,6 +25,7 @@ #include "host.hpp" #include "rocshmem/rocshmem_config.h" // NOLINT(build/include_subdir) +#include "rocshmem/rocshmem_SIG_OP.hpp" #include "envvar.hpp" #include "host_helpers.hpp" #include "memory/window_info.hpp" @@ -325,12 +326,8 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) { } __host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) { - // launch kernel to do barrier with given stream, if non, use default stream - if (stream == nullptr) { - stream = hipStreamDefault; - } - - rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>(); + // Launch kernel to do barrier with given stream + rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>(); } __host__ void HostInterface::alltoallmem_on_stream(rocshmem_team_t team, @@ -338,11 +335,6 @@ __host__ void HostInterface::alltoallmem_on_stream(rocshmem_team_t team, const void *source, size_t size, hipStream_t stream) { - // launch kernel to do alltoall with given stream, if none, use default stream - if (stream == nullptr) { - stream = hipStreamDefault; - } - // Use dynamic block size determination: // - Query optimal block size using occupancy API // - Limit block size to size (number of bytes) to avoid over-subscription @@ -357,13 +349,117 @@ __host__ void HostInterface::alltoallmem_on_stream(rocshmem_team_t team, int num_threads_per_block = (optimal_block_size > static_cast(size)) ? static_cast(size) : optimal_block_size; - + + // Launch kernel to do alltoall with given stream dim3 gridSize(1); dim3 blockSize(num_threads_per_block); rocshmem_alltoallmem_kernel<<>>(team, dest, source, size); } +__host__ void HostInterface::broadcastmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t nelems, int pe_root, + hipStream_t stream) { + // Use dynamic block size determination: + // - Query optimal block size using occupancy API + // - Limit block size to nelems (number of bytes) to avoid over-subscription + // - Always use 1 block (single workgroup collective) + int optimal_block_size = 0; + int grid_size = 0; + CHECK_HIP(hipOccupancyMaxPotentialBlockSize(&grid_size, + &optimal_block_size, + rocshmem_broadcastmem_kernel, + 0, + 0)); + + // Limit block size to nelems (bytes) to avoid over-subscription + int num_threads_per_block = (optimal_block_size > static_cast(nelems)) + ? static_cast(nelems) + : optimal_block_size; + + // Launch kernel to do broadcast with given stream + dim3 gridSize(1); + dim3 blockSize(num_threads_per_block); + rocshmem_broadcastmem_kernel<<>>(team, + dest, + source, + nelems, + pe_root); +} + +__host__ void HostInterface::getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + int optimal_block_size = 0; + int grid_size = 0; + CHECK_HIP(hipOccupancyMaxPotentialBlockSize(&grid_size, &optimal_block_size, + rocshmem_getmem_kernel, 0, 0)); + + // Limit block size to nelems to avoid over-subscription + int num_threads_per_block = (optimal_block_size > static_cast(nelems)) + ? static_cast(nelems) + : optimal_block_size; + + // Launch kernel to do getmem with given stream + dim3 gridSize(1); + dim3 blockSize(num_threads_per_block); + rocshmem_getmem_kernel<<>>(dest, source, + nelems, pe); +} + +__host__ void HostInterface::putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + int optimal_block_size = 0; + int grid_size = 0; + CHECK_HIP(hipOccupancyMaxPotentialBlockSize(&grid_size, &optimal_block_size, + rocshmem_putmem_kernel, 0, 0)); + + // Limit block size to nelems to avoid over-subscription + int num_threads_per_block = (optimal_block_size > static_cast(nelems)) + ? static_cast(nelems) + : optimal_block_size; + + // Launch kernel to do putmem with given stream + dim3 gridSize(1); + dim3 blockSize(num_threads_per_block); + rocshmem_putmem_kernel<<>>(dest, source, + nelems, pe); +} + +__host__ void HostInterface::putmem_signal_on_stream( + void *dest, const void *source, size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, hipStream_t stream) { + int optimal_block_size = 0; + int grid_size = 0; + CHECK_HIP(hipOccupancyMaxPotentialBlockSize( + &grid_size, &optimal_block_size, rocshmem_putmem_signal_kernel, 0, 0)); + + // Limit block size to nelems to avoid over-subscription + int num_threads_per_block = (optimal_block_size > static_cast(nelems)) + ? static_cast(nelems) + : optimal_block_size; + + // Launch kernel to do putmem_signal with given stream + dim3 gridSize(1); + dim3 blockSize(num_threads_per_block); + rocshmem_putmem_signal_kernel<<>>( + dest, source, nelems, sig_addr, signal, sig_op, pe); +} + +__host__ void HostInterface::signal_wait_until_on_stream(uint64_t *sig_addr, + int cmp, + uint64_t cmp_value, + hipStream_t stream) { + // Use a single thread to wait on the signal + dim3 gridSize(1); + dim3 blockSize(1); + rocshmem_signal_wait_until_kernel<<>>( + sig_addr, cmp, cmp_value); +} + __host__ void HostInterface::barrier_for_sync() { if (host_comm_world_ != MPI_COMM_NULL) { mpilib_ftable_.Barrier(host_comm_world_); diff --git a/projects/rocshmem/src/host/host.hpp b/projects/rocshmem/src/host/host.hpp index 286458344d..5abe4371a4 100644 --- a/projects/rocshmem/src/host/host.hpp +++ b/projects/rocshmem/src/host/host.hpp @@ -200,6 +200,25 @@ class HostInterface { const void *source, size_t size, hipStream_t stream); + __host__ void broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream); + + __host__ void getmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, + hipStream_t stream); + + __host__ void signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream); + __host__ void barrier_for_sync(); __host__ void sync_all(WindowInfo* window_info); diff --git a/projects/rocshmem/src/ipc/context_ipc_host.cpp b/projects/rocshmem/src/ipc/context_ipc_host.cpp index 5714aa4dd8..e65c455383 100644 --- a/projects/rocshmem/src/ipc/context_ipc_host.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_host.cpp @@ -113,4 +113,39 @@ __host__ void IPCHostContext::alltoallmem_on_stream(rocshmem_team_t team, host_interface->alltoallmem_on_stream(team, dest, source, size, stream); } +__host__ void IPCHostContext::broadcastmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t nelems, int pe_root, + hipStream_t stream) { + host_interface->broadcastmem_on_stream(team, dest, source, nelems, pe_root, + stream); +} + +__host__ void IPCHostContext::getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + host_interface->getmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void IPCHostContext::putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + host_interface->putmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void IPCHostContext::putmem_signal_on_stream( + void *dest, const void *source, size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, hipStream_t stream) { + host_interface->putmem_signal_on_stream(dest, source, nelems, sig_addr, + signal, sig_op, pe, stream); +} + +__host__ void IPCHostContext::signal_wait_until_on_stream(uint64_t *sig_addr, + int cmp, + uint64_t cmp_value, + hipStream_t stream) { + host_interface->signal_wait_until_on_stream(sig_addr, cmp, cmp_value, stream); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/ipc/context_ipc_host.hpp b/projects/rocshmem/src/ipc/context_ipc_host.hpp index f5317a87c3..60448ec2b5 100644 --- a/projects/rocshmem/src/ipc/context_ipc_host.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_host.hpp @@ -88,6 +88,25 @@ class IPCHostContext : public Context { const void *source, size_t size, hipStream_t stream); + __host__ void broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream); + + __host__ void getmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, + hipStream_t stream); + + __host__ void signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/reverse_offload/context_ro_host.cpp b/projects/rocshmem/src/reverse_offload/context_ro_host.cpp index 87ef5cc768..c6b1f493e1 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_host.cpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_host.cpp @@ -133,6 +133,12 @@ __host__ void ROHostContext::barrier_all() { host_interface->barrier_for_sync(); } +__host__ void ROHostContext::barrier_all_on_stream(hipStream_t stream) { + DPRINTF("Function: ro_net_host_barrier_all_on_stream\n"); + + host_interface->barrier_all_on_stream(stream); +} + __host__ void ROHostContext::alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, @@ -143,4 +149,49 @@ __host__ void ROHostContext::alltoallmem_on_stream(rocshmem_team_t team, host_interface->alltoallmem_on_stream(team, dest, source, size, stream); } +__host__ void ROHostContext::broadcastmem_on_stream(rocshmem_team_t team, + void *dest, + const void *source, + size_t nelems, int pe_root, + hipStream_t stream) { + DPRINTF("Function: ro_net_host_broadcastmem_on_stream\n"); + + host_interface->broadcastmem_on_stream(team, dest, source, nelems, pe_root, + stream); +} + +__host__ void ROHostContext::getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + DPRINTF("Function: ro_net_host_getmem_on_stream\n"); + + host_interface->getmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void ROHostContext::putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + DPRINTF("Function: ro_net_host_putmem_on_stream\n"); + + host_interface->putmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void ROHostContext::putmem_signal_on_stream( + void *dest, const void *source, size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, hipStream_t stream) { + DPRINTF("Function: ro_net_host_putmem_signal_on_stream\n"); + + host_interface->putmem_signal_on_stream(dest, source, nelems, sig_addr, + signal, sig_op, pe, stream); +} + +__host__ void ROHostContext::signal_wait_until_on_stream(uint64_t *sig_addr, + int cmp, + uint64_t cmp_value, + hipStream_t stream) { + DPRINTF("Function: ro_net_host_signal_wait_until_on_stream\n"); + + host_interface->signal_wait_until_on_stream(sig_addr, cmp, cmp_value, stream); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/reverse_offload/context_ro_host.hpp b/projects/rocshmem/src/reverse_offload/context_ro_host.hpp index 049ce0be6d..50ba48d7e5 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_host.hpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_host.hpp @@ -131,10 +131,31 @@ class ROHostContext : public Context { __host__ void barrier_all(); + __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void alltoallmem_on_stream(rocshmem_team_t team, void *dest, const void *source, size_t size, hipStream_t stream); + __host__ void broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream); + + __host__ void getmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_on_stream(void *dest, const void *source, size_t nelems, + int pe, hipStream_t stream); + + __host__ void putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe, + hipStream_t stream); + + __host__ void signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream); + __host__ void sync_all(); template diff --git a/projects/rocshmem/src/rocshmem.cpp b/projects/rocshmem/src/rocshmem.cpp index 7bf191a0ad..9353f70136 100644 --- a/projects/rocshmem/src/rocshmem.cpp +++ b/projects/rocshmem/src/rocshmem.cpp @@ -1007,6 +1007,54 @@ __host__ void rocshmem_alltoallmem_on_stream(rocshmem_team_t team, void *dest, ->alltoallmem_on_stream(team, dest, source, size, stream); } +__host__ void rocshmem_broadcastmem_on_stream(rocshmem_team_t team, void *dest, + const void *source, size_t nelems, + int pe_root, hipStream_t stream) { + DPRINTF("Host function: rocshmem_broadcastmem_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT) + ->broadcastmem_on_stream(team, dest, source, nelems, pe_root, stream); +} + +__host__ void rocshmem_getmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + DPRINTF("Host function: rocshmem_getmem_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT) + ->getmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void rocshmem_putmem_on_stream(void *dest, const void *source, + size_t nelems, int pe, + hipStream_t stream) { + DPRINTF("Host function: rocshmem_putmem_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT) + ->putmem_on_stream(dest, source, nelems, pe, stream); +} + +__host__ void rocshmem_putmem_signal_on_stream(void *dest, const void *source, + size_t nelems, + uint64_t *sig_addr, + uint64_t signal, int sig_op, + int pe, hipStream_t stream) { + DPRINTF("Host function: rocshmem_putmem_signal_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT) + ->putmem_signal_on_stream(dest, source, nelems, sig_addr, signal, sig_op, + pe, stream); +} + +__host__ void rocshmem_signal_wait_until_on_stream(uint64_t *sig_addr, int cmp, + uint64_t cmp_value, + hipStream_t stream) { + DPRINTF("Host function: rocshmem_signal_wait_until_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT) + ->signal_wait_until_on_stream(sig_addr, cmp, cmp_value, stream); +} + __host__ void rocshmem_sync_all() { DPRINTF("Host function: rocshmem_sync_all\n"); @@ -1681,6 +1729,7 @@ WAIT_DEF_GEN(unsigned short, ushort) WAIT_DEF_GEN(unsigned int, uint) WAIT_DEF_GEN(unsigned long, ulong) WAIT_DEF_GEN(unsigned long long, ulonglong) +WAIT_DEF_GEN(uint64_t, uint64) // clang-format on } // namespace rocshmem diff --git a/projects/rocshmem/src/rocshmem_gpu.cpp b/projects/rocshmem/src/rocshmem_gpu.cpp index 1fe0a58614..721c164b31 100644 --- a/projects/rocshmem/src/rocshmem_gpu.cpp +++ b/projects/rocshmem/src/rocshmem_gpu.cpp @@ -674,6 +674,57 @@ __global__ ATTR_NO_INLINE void rocshmem_alltoallmem_kernel(rocshmem_team_t team, } } +__global__ ATTR_NO_INLINE void rocshmem_broadcastmem_kernel( + rocshmem_team_t team, void *dest, const void *source, size_t nelems, + int pe_root) { + __shared__ rocshmem_ctx_t ctx; + __shared__ int ctx_result; + + ctx_result = rocshmem_wg_team_create_ctx(team, 0, &ctx); + + // If context creation failed, fall back to default context + if (ctx_result != 0) { + ctx = ROCSHMEM_CTX_DEFAULT; + __syncthreads(); + } + + // Call device broadcast function with created context and provided team + // Using char type since nelems is in bytes (1 byte per element) + rocshmem_broadcast_wg(ctx, team, (char *) dest, (const char *) source, + (int) nelems, pe_root); + + if (ctx_result == 0) { + rocshmem_wg_ctx_destroy(&ctx); + } +} + +__global__ ATTR_NO_INLINE void rocshmem_getmem_kernel(void *dest, + const void *source, + size_t nelems, int pe) { + // Use work-group collective getmem with default context + rocshmem_getmem_wg(dest, source, nelems, pe); +} + +__global__ ATTR_NO_INLINE void rocshmem_putmem_kernel(void *dest, + const void *source, + size_t nelems, int pe) { + // Use work-group collective putmem with default context + rocshmem_putmem_wg(dest, source, nelems, pe); +} + +__global__ ATTR_NO_INLINE void rocshmem_putmem_signal_kernel( + void *dest, const void *source, size_t nelems, uint64_t *sig_addr, + uint64_t signal, int sig_op, int pe) { + // Use work-group collective putmem_signal with default context + rocshmem_putmem_signal_wg(dest, source, nelems, sig_addr, signal, sig_op, pe); +} + +__global__ ATTR_NO_INLINE void rocshmem_signal_wait_until_kernel( + uint64_t *sig_addr, int cmp, uint64_t cmp_value) { + // Use default context to wait on signal + rocshmem_uint64_wait_until(sig_addr, cmp, cmp_value); +} + __device__ void rocshmem_barrier_all() { GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n", get_internal_ctx(ROCSHMEM_CTX_DEFAULT)); @@ -1867,6 +1918,7 @@ WAIT_DEF_GEN(unsigned short, ushort) WAIT_DEF_GEN(unsigned int, uint) WAIT_DEF_GEN(unsigned long, ulong) WAIT_DEF_GEN(unsigned long long, ulonglong) +WAIT_DEF_GEN(uint64_t, uint64) // clang-format on } // namespace rocshmem diff --git a/projects/rocshmem/tests/functional_tests/barrier_all_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/barrier_all_on_stream_tester.cpp new file mode 100644 index 0000000000..a53b40bd6c --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/barrier_all_on_stream_tester.cpp @@ -0,0 +1,151 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "barrier_all_on_stream_tester.hpp" + +#include +#include +#include +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +BarrierAllOnStreamTester::BarrierAllOnStreamTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_my_pe(); + n_pes = rocshmem_n_pes(); + + char *value{nullptr}; + if ((value = getenv("ROCSHMEM_TEST_NUM_STREAMS"))) { + num_streams = atoi(value); + } else { + // Default to 1 stream + num_streams = 1; + } + + // Check if we should test with nullptr (default stream) + use_default_stream = false; + if ((value = getenv("ROCSHMEM_TEST_USE_DEFAULT_STREAM"))) { + use_default_stream = (atoi(value) != 0); + if (use_default_stream) { + num_streams = 1; // Only test with one nullptr stream + } + } + + streams.resize(num_streams); + start_events_timed.resize(num_streams); + stop_events_timed.resize(num_streams); + for (int i = 0; i < num_streams; i++) { + if (use_default_stream) { + streams[i] = nullptr; // Use default stream (0) + } else { + CHECK_HIP(hipStreamCreate(&streams[i])); + } + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +BarrierAllOnStreamTester::~BarrierAllOnStreamTester() { + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + // Don't destroy default stream (nullptr) + if (!use_default_stream) { + CHECK_HIP(hipStreamDestroy(streams[i])); + } + } +} + +void BarrierAllOnStreamTester::preLaunchKernel() { + // No specific setup needed for barrier +} + +void BarrierAllOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each stream from HIP events + for (int stream_id = 0; stream_id < num_streams && stream_id < num_timers; + stream_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, + start_events_timed[stream_id], + stop_events_timed[stream_id])); + + // Convert milliseconds to GPU cycles + // wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate + long long int elapsed_cycles = + static_cast(elapsed_time_ms * + static_cast(wall_clk_rate)); + + start_time[stream_id] = 0; + end_time[stream_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_streams + for (int i = num_streams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } +} + +void BarrierAllOnStreamTester::resetBuffers(size_t size) {} + +void BarrierAllOnStreamTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, size_t size) { + // Execute warmup iterations (skip) + for (int i = 0; i < args.skip; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + rocshmem_barrier_all_on_stream(streams[stream_id]); + } + } + + for (int i = 0; i < loop; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + // Record start event for this stream on first iteration + if (i == 0) { + CHECK_HIP(hipEventRecord(start_events_timed[stream_id], + streams[stream_id])); + } + + rocshmem_barrier_all_on_stream(streams[stream_id]); + + // Record stop event for this stream on last iteration + if (i == loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[stream_id], + streams[stream_id])); + } + } + } + + num_msgs = (loop + args.skip) * num_streams; + num_timed_msgs = loop * num_streams; +} + +void BarrierAllOnStreamTester::verifyResults(size_t size) {} diff --git a/projects/rocshmem/tests/functional_tests/barrier_all_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/barrier_all_on_stream_tester.hpp new file mode 100644 index 0000000000..36e28f5fa8 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/barrier_all_on_stream_tester.hpp @@ -0,0 +1,67 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _BARRIER_ALL_ON_STREAM_TESTER_HPP_ +#define _BARRIER_ALL_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class BarrierAllOnStreamTester : public Tester { + public: + explicit BarrierAllOnStreamTester(TesterArguments args); + virtual ~BarrierAllOnStreamTester(); + + protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + int my_pe; + int n_pes; + int num_streams = 1; + bool use_default_stream = false; + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "barrier_all_on_stream_tester.cpp" + +#endif + diff --git a/projects/rocshmem/tests/functional_tests/getmem_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/getmem_on_stream_tester.cpp new file mode 100644 index 0000000000..c4dfd5a43f --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/getmem_on_stream_tester.cpp @@ -0,0 +1,205 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "getmem_on_stream_tester.hpp" + +#include +#include +#include +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +GetmemOnStreamTester::GetmemOnStreamTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_my_pe(); + n_pes = rocshmem_n_pes(); + + char *value{nullptr}; + if ((value = getenv("ROCSHMEM_TEST_NUM_STREAMS"))) { + num_streams = atoi(value); + } else { + // Default to 1 stream + num_streams = 1; + } + + // Set target PE to get from (default: next PE in ring) + pe_target = (my_pe + 1) % n_pes; + if ((value = getenv("ROCSHMEM_TEST_GETMEM_TARGET"))) { + pe_target = atoi(value); + if (pe_target < 0 || pe_target >= n_pes) { + std::cerr << "Invalid ROCSHMEM_TEST_GETMEM_TARGET value. Using next PE." + << std::endl; + pe_target = (my_pe + 1) % n_pes; + } + } + + int num_bytes_stream = args.max_msg_size; + int total_bytes = num_bytes_stream * num_streams; + buf_size = total_bytes; + + source_buf = static_cast(rocshmem_malloc(buf_size)); + dest_buf = static_cast(rocshmem_malloc(buf_size)); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source_buf << ", dest: " << dest_buf + << std::endl; + rocshmem_global_exit(1); + } + + streams.resize(num_streams); + start_events_timed.resize(num_streams); + stop_events_timed.resize(num_streams); + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamCreate(&streams[i])); + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +GetmemOnStreamTester::~GetmemOnStreamTester() { + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + CHECK_HIP(hipStreamDestroy(streams[i])); + } + rocshmem_free(source_buf); + rocshmem_free(dest_buf); +} + +void GetmemOnStreamTester::preLaunchKernel() { + bw_factor = 1; // Point-to-point operation +} + +void GetmemOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each stream from HIP events + for (int stream_id = 0; stream_id < num_streams && stream_id < num_timers; + stream_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, + start_events_timed[stream_id], + stop_events_timed[stream_id])); + + // Convert milliseconds to GPU cycles + // wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate + long long int elapsed_cycles = + static_cast(elapsed_time_ms * + static_cast(wall_clk_rate)); + + start_time[stream_id] = 0; + end_time[stream_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_streams + for (int i = num_streams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } +} + +void GetmemOnStreamTester::resetBuffers(size_t size) { + // Initialize source buffer on all PEs + // Each stream has its own portion + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + int idx = stream_id * size; + // Each PE fills its source buffer with a unique value + int value = (my_pe + 1) * 100 + stream_id; + std::memset(source_buf + idx, value, size); + } + + // Clear destination buffer + std::memset(dest_buf, 0, buf_size); +} + +void GetmemOnStreamTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, size_t size) { + // Execute warmup iterations (skip) + for (int i = 0; i < args.skip; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + char *stream_dest = dest_buf + stream_id * size; + char *stream_source = source_buf + stream_id * size; + rocshmem_getmem_on_stream(stream_dest, stream_source, size, pe_target, + streams[stream_id]); + } + } + + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + for (int i = 0; i < loop; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + // Record start event for this stream on first iteration + if (i == 0) { + CHECK_HIP(hipEventRecord(start_events_timed[stream_id], + streams[stream_id])); + } + + char *stream_dest = dest_buf + stream_id * size; + char *stream_source = source_buf + stream_id * size; + rocshmem_getmem_on_stream(stream_dest, stream_source, size, pe_target, + streams[stream_id]); + + // Record stop event for this stream on last iteration + if (i == loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[stream_id], + streams[stream_id])); + } + } + } + + num_msgs = (loop + args.skip) * num_streams; + num_timed_msgs = loop * num_streams; +} + +void GetmemOnStreamTester::verifyResults(size_t size) { + // Verify correctness: after getmem, local dest buffer should have + // the data from target PE's source buffer + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + int idx = stream_id * size; + // Expected value is from pe_target + int expected_value = (pe_target + 1) * 100 + stream_id; + + for (size_t k = 0; k < size; k++) { + if (static_cast(dest_buf[idx + k]) != + static_cast(expected_value)) { + std::cerr << "PE " << my_pe << ": Verification failed for stream " + << stream_id << " at byte " << k << std::endl; + std::cerr << "Expected value: " << expected_value + << ", Got: " << static_cast(dest_buf[idx + k]) + << std::endl; + rocshmem_global_exit(1); + } + } + } +} + diff --git a/projects/rocshmem/tests/functional_tests/getmem_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/getmem_on_stream_tester.hpp new file mode 100644 index 0000000000..1a79e65987 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/getmem_on_stream_tester.hpp @@ -0,0 +1,70 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _GETMEM_ON_STREAM_TESTER_HPP_ +#define _GETMEM_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class GetmemOnStreamTester : public Tester { + public: + explicit GetmemOnStreamTester(TesterArguments args); + virtual ~GetmemOnStreamTester(); + + protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + char *source_buf; + char *dest_buf; + int my_pe; + int n_pes; + size_t buf_size; + int num_streams = 1; + int pe_target; // Target PE to get from + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "getmem_on_stream_tester.cpp" + +#endif + diff --git a/projects/rocshmem/tests/functional_tests/putmem_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/putmem_on_stream_tester.cpp new file mode 100644 index 0000000000..8b876511f0 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/putmem_on_stream_tester.cpp @@ -0,0 +1,225 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "putmem_on_stream_tester.hpp" + +#include +#include +#include +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +PutmemOnStreamTester::PutmemOnStreamTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_my_pe(); + n_pes = rocshmem_n_pes(); + + char *value{nullptr}; + if ((value = getenv("ROCSHMEM_TEST_NUM_STREAMS"))) { + num_streams = atoi(value); + } else { + // Default to 1 stream + num_streams = 1; + } + + // Check if we should test with nullptr (default stream) + use_default_stream = false; + if ((value = getenv("ROCSHMEM_TEST_USE_DEFAULT_STREAM"))) { + use_default_stream = (atoi(value) != 0); + if (use_default_stream) { + num_streams = 1; // Only test with one nullptr stream + } + } + + // Set target PE to put to (default: next PE in ring) + pe_target = (my_pe + 1) % n_pes; + if ((value = getenv("ROCSHMEM_TEST_PUTMEM_TARGET"))) { + pe_target = atoi(value); + if (pe_target < 0 || pe_target >= n_pes) { + std::cerr << "Invalid ROCSHMEM_TEST_PUTMEM_TARGET value. Using next PE." + << std::endl; + pe_target = (my_pe + 1) % n_pes; + } + } + + int num_bytes_stream = args.max_msg_size; + int total_bytes = num_bytes_stream * num_streams; + buf_size = total_bytes; + + source_buf = static_cast(rocshmem_malloc(buf_size)); + dest_buf = static_cast(rocshmem_malloc(buf_size)); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source_buf << ", dest: " << dest_buf + << std::endl; + rocshmem_global_exit(1); + } + + streams.resize(num_streams); + start_events_timed.resize(num_streams); + stop_events_timed.resize(num_streams); + for (int i = 0; i < num_streams; i++) { + if (use_default_stream) { + streams[i] = nullptr; // Use default stream (0) + } else { + CHECK_HIP(hipStreamCreate(&streams[i])); + } + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +PutmemOnStreamTester::~PutmemOnStreamTester() { + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + // Don't destroy default stream (nullptr) + if (!use_default_stream) { + CHECK_HIP(hipStreamDestroy(streams[i])); + } + } + rocshmem_free(source_buf); + rocshmem_free(dest_buf); +} + +void PutmemOnStreamTester::preLaunchKernel() { + bw_factor = 1; // Point-to-point operation +} + +void PutmemOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each stream from HIP events + for (int stream_id = 0; stream_id < num_streams && stream_id < num_timers; + stream_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, + start_events_timed[stream_id], + stop_events_timed[stream_id])); + + // Convert milliseconds to GPU cycles + // wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate + long long int elapsed_cycles = + static_cast(elapsed_time_ms * + static_cast(wall_clk_rate)); + + start_time[stream_id] = 0; + end_time[stream_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_streams + for (int i = num_streams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } +} + +void PutmemOnStreamTester::resetBuffers(size_t size) { + // Initialize source buffer on all PEs + // Each stream has its own portion + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + int idx = stream_id * size; + // Each PE fills its source buffer with a unique value + int value = (my_pe + 1) * 100 + stream_id; + std::memset(source_buf + idx, value, size); + } + + // Clear destination buffer (will receive data from other PEs) + std::memset(dest_buf, 0, buf_size); +} + +void PutmemOnStreamTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, size_t size) { + // Execute warmup iterations (skip) + for (int i = 0; i < args.skip; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + char *stream_source = source_buf + stream_id * size; + char *stream_dest = dest_buf + stream_id * size; + rocshmem_putmem_on_stream(stream_dest, stream_source, size, pe_target, + streams[stream_id]); + } + } + + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + for (int i = 0; i < loop; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + // Record start event for this stream on first iteration + if (i == 0) { + CHECK_HIP(hipEventRecord(start_events_timed[stream_id], + streams[stream_id])); + } + + char *stream_source = source_buf + stream_id * size; + char *stream_dest = dest_buf + stream_id * size; + rocshmem_putmem_on_stream(stream_dest, stream_source, size, pe_target, + streams[stream_id]); + + // Record stop event for this stream on last iteration + if (i == loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[stream_id], + streams[stream_id])); + } + } + } + + num_msgs = (loop + args.skip) * num_streams; + num_timed_msgs = loop * num_streams; +} + +void PutmemOnStreamTester::verifyResults(size_t size) { + // Verify correctness: after putmem, my dest buffer should have + // the data that was put from the PE that targets me + // We need to find which PE writes to me: pe_source where (pe_source + 1) % n_pes == my_pe + int pe_source = (my_pe - 1 + n_pes) % n_pes; + + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + int idx = stream_id * size; + // Expected value is from pe_source + int expected_value = (pe_source + 1) * 100 + stream_id; + + for (size_t k = 0; k < size; k++) { + if (static_cast(dest_buf[idx + k]) != + static_cast(expected_value)) { + std::cerr << "PE " << my_pe << ": Verification failed for stream " + << stream_id << " at byte " << k << std::endl; + std::cerr << "Expected value from PE " << pe_source << ": " + << expected_value + << ", Got: " << static_cast(dest_buf[idx + k]) + << std::endl; + rocshmem_global_exit(1); + } + } + } +} + diff --git a/projects/rocshmem/tests/functional_tests/putmem_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/putmem_on_stream_tester.hpp new file mode 100644 index 0000000000..65b8bf7745 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/putmem_on_stream_tester.hpp @@ -0,0 +1,71 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _PUTMEM_ON_STREAM_TESTER_HPP_ +#define _PUTMEM_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class PutmemOnStreamTester : public Tester { + public: + explicit PutmemOnStreamTester(TesterArguments args); + virtual ~PutmemOnStreamTester(); + + protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + char *source_buf; + char *dest_buf; + int my_pe; + int n_pes; + size_t buf_size; + int num_streams = 1; + bool use_default_stream = false; + int pe_target; // Target PE to put to + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "putmem_on_stream_tester.cpp" + +#endif + diff --git a/projects/rocshmem/tests/functional_tests/putmem_signal_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/putmem_signal_on_stream_tester.cpp new file mode 100644 index 0000000000..e87d360f3a --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/putmem_signal_on_stream_tester.cpp @@ -0,0 +1,236 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "putmem_signal_on_stream_tester.hpp" + +#include +#include +#include +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +PutmemSignalOnStreamTester::PutmemSignalOnStreamTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_my_pe(); + n_pes = rocshmem_n_pes(); + + char *value{nullptr}; + if ((value = getenv("ROCSHMEM_TEST_NUM_STREAMS"))) { + num_streams = atoi(value); + } else { + // Default to 1 stream + num_streams = 1; + } + + // Set target PE to put to (default: next PE in ring) + pe_target = (my_pe + 1) % n_pes; + if ((value = getenv("ROCSHMEM_TEST_PUTMEM_TARGET"))) { + pe_target = atoi(value); + if (pe_target < 0 || pe_target >= n_pes) { + std::cerr << "Invalid ROCSHMEM_TEST_PUTMEM_TARGET value. Using next PE." + << std::endl; + pe_target = (my_pe + 1) % n_pes; + } + } + + int num_bytes_stream = args.max_msg_size; + int total_bytes = num_bytes_stream * num_streams; + buf_size = total_bytes; + + source_buf = static_cast(rocshmem_malloc(buf_size)); + dest_buf = static_cast(rocshmem_malloc(buf_size)); + sig_addr = static_cast(rocshmem_malloc(num_streams * sizeof(uint64_t))); + + if (source_buf == nullptr || dest_buf == nullptr || sig_addr == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source_buf << ", dest: " << dest_buf + << ", sig_addr: " << sig_addr << std::endl; + rocshmem_global_exit(1); + } + + streams.resize(num_streams); + start_events_timed.resize(num_streams); + stop_events_timed.resize(num_streams); + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamCreate(&streams[i])); + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +PutmemSignalOnStreamTester::~PutmemSignalOnStreamTester() { + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + CHECK_HIP(hipStreamDestroy(streams[i])); + } + rocshmem_free(source_buf); + rocshmem_free(dest_buf); + rocshmem_free(sig_addr); +} + +void PutmemSignalOnStreamTester::preLaunchKernel() { + bw_factor = 1; // Point-to-point operation +} + +void PutmemSignalOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each stream from HIP events + for (int stream_id = 0; stream_id < num_streams && stream_id < num_timers; + stream_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, + start_events_timed[stream_id], + stop_events_timed[stream_id])); + + // Convert milliseconds to GPU cycles + // wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate + long long int elapsed_cycles = + static_cast(elapsed_time_ms * + static_cast(wall_clk_rate)); + + start_time[stream_id] = 0; + end_time[stream_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_streams + for (int i = num_streams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } +} + +void PutmemSignalOnStreamTester::resetBuffers(size_t size) { + // Initialize source buffer on all PEs + // Each stream has its own portion + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + int idx = stream_id * size; + // Each PE fills its source buffer with a unique value + int value = (my_pe + 1) * 100 + stream_id; + std::memset(source_buf + idx, value, size); + } + + // Clear destination buffer (will receive data from other PEs) + std::memset(dest_buf, 0, buf_size); + + // Clear signal addresses + std::memset(sig_addr, 0, num_streams * sizeof(uint64_t)); +} + +void PutmemSignalOnStreamTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, size_t size) { + uint64_t signal_value = 1; + + // Execute warmup iterations (skip) + for (int i = 0; i < args.skip; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + char *stream_source = source_buf + stream_id * size; + char *stream_dest = dest_buf + stream_id * size; + rocshmem_putmem_signal_on_stream(stream_dest, stream_source, size, + &sig_addr[stream_id], signal_value, + sig_op, pe_target, streams[stream_id]); + } + } + + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Reset signal addresses after warmup and synchronize across PEs + std::memset(sig_addr, 0, num_streams * sizeof(uint64_t)); + rocshmem_barrier_all(); + + for (int i = 0; i < loop; i++) { + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + // Record start event for this stream on first iteration + if (i == 0) { + CHECK_HIP(hipEventRecord(start_events_timed[stream_id], + streams[stream_id])); + } + + char *stream_source = source_buf + stream_id * size; + char *stream_dest = dest_buf + stream_id * size; + rocshmem_putmem_signal_on_stream(stream_dest, stream_source, size, + &sig_addr[stream_id], signal_value, + sig_op, pe_target, streams[stream_id]); + + // Record stop event for this stream on last iteration + if (i == loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[stream_id], + streams[stream_id])); + } + } + } + + num_msgs = (loop + args.skip) * num_streams; + num_timed_msgs = loop * num_streams; +} + +void PutmemSignalOnStreamTester::verifyResults(size_t size) { + // Synchronize to ensure all operations completed + rocshmem_barrier_all(); + + // Verify correctness: after putmem_signal, my dest buffer should have + // the data that was put from the PE that targets me + // We need to find which PE writes to me: pe_source where (pe_source + 1) % n_pes == my_pe + int pe_source = (my_pe - 1 + n_pes) % n_pes; + + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + int idx = stream_id * size; + // Expected value is from pe_source + int expected_value = (pe_source + 1) * 100 + stream_id; + + // Verify data + for (size_t k = 0; k < size; k++) { + if (static_cast(dest_buf[idx + k]) != + static_cast(expected_value)) { + std::cerr << "PE " << my_pe << ": Data verification failed for stream " + << stream_id << " at byte " << k << std::endl; + std::cerr << "Expected value from PE " << pe_source << ": " + << expected_value + << ", Got: " << static_cast(dest_buf[idx + k]) + << std::endl; + rocshmem_global_exit(1); + } + } + + // Verify signal + uint64_t expected_signal = 1; + if (sig_addr[stream_id] != expected_signal) { + std::cerr << "PE " << my_pe << ": Signal verification failed for stream " + << stream_id << std::endl; + std::cerr << "Expected signal: " << expected_signal + << ", Got: " << sig_addr[stream_id] << std::endl; + rocshmem_global_exit(1); + } + } +} + diff --git a/projects/rocshmem/tests/functional_tests/putmem_signal_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/putmem_signal_on_stream_tester.hpp new file mode 100644 index 0000000000..232c9e6cb6 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/putmem_signal_on_stream_tester.hpp @@ -0,0 +1,72 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _PUTMEM_SIGNAL_ON_STREAM_TESTER_HPP_ +#define _PUTMEM_SIGNAL_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class PutmemSignalOnStreamTester : public Tester { + public: + explicit PutmemSignalOnStreamTester(TesterArguments args); + virtual ~PutmemSignalOnStreamTester(); + + protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + char *source_buf; + char *dest_buf; + uint64_t *sig_addr; + int my_pe; + int n_pes; + size_t buf_size; + int num_streams = 1; + int pe_target; // Target PE to put to + int sig_op = ROCSHMEM_SIGNAL_SET; // Signal operation + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "putmem_signal_on_stream_tester.cpp" + +#endif + diff --git a/projects/rocshmem/tests/functional_tests/signal_wait_until_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/signal_wait_until_on_stream_tester.cpp new file mode 100644 index 0000000000..1a9b9f27cf --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/signal_wait_until_on_stream_tester.cpp @@ -0,0 +1,204 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include +#include + +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +SignalWaitUntilOnStreamTester::SignalWaitUntilOnStreamTester( + TesterArguments args) + : Tester(args) { + my_pe = rocshmem_my_pe(); + n_pes = rocshmem_n_pes(); + + char *value{nullptr}; + if ((value = getenv("ROCSHMEM_TEST_NUM_STREAMS"))) { + num_streams = atoi(value); + } else { + // Default to 1 stream + num_streams = 1; + } + + // Set target PE (next PE in ring) + pe_target = (my_pe + 1) % n_pes; + + // Allocate signal addresses on symmetric heap + sig_addr = + static_cast(rocshmem_malloc(num_streams * sizeof(uint64_t))); + source_buf = + static_cast(rocshmem_malloc(num_streams * sizeof(uint64_t))); + + if (sig_addr == nullptr || source_buf == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "sig_addr: " << sig_addr << ", source_buf: " << source_buf + << std::endl; + rocshmem_global_exit(1); + } + + streams.resize(num_streams); + start_events_timed.resize(num_streams); + stop_events_timed.resize(num_streams); + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamCreate(&streams[i])); + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +SignalWaitUntilOnStreamTester::~SignalWaitUntilOnStreamTester() { + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + CHECK_HIP(hipStreamDestroy(streams[i])); + } + rocshmem_free(sig_addr); + rocshmem_free(source_buf); +} + +void SignalWaitUntilOnStreamTester::preLaunchKernel() { + bw_factor = 1; // Point-to-point operation +} + +void SignalWaitUntilOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_streams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each stream from HIP events + for (int stream_id = 0; stream_id < num_streams && stream_id < num_timers; + stream_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, + start_events_timed[stream_id], + stop_events_timed[stream_id])); + + // Convert milliseconds to GPU cycles + long long int elapsed_cycles = + static_cast(elapsed_time_ms * + static_cast(wall_clk_rate)); + + start_time[stream_id] = 0; + end_time[stream_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_streams + for (int i = num_streams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } +} + +void SignalWaitUntilOnStreamTester::resetBuffers(size_t size) { + // Clear signal addresses + std::memset(sig_addr, 0, num_streams * sizeof(uint64_t)); +} + +void SignalWaitUntilOnStreamTester::launchKernel(dim3 gridSize, dim3 blockSize, + int loop, size_t size) { + // Execute warmup + timed iterations + for (int i = 0; i < args.skip + loop; i++) { + // Increment signal value for each iteration + uint64_t signal_value = i + 1; + + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + // Record start event after warmup on first timed iteration for all streams + if (i == args.skip) { + CHECK_HIP(hipEventRecord(start_events_timed[stream_id], + streams[stream_id])); + } + + // PE 0 starts the ring by signaling PE 1 + if (my_pe == 0) { + rocshmem_putmem_signal_on_stream(&sig_addr[stream_id], + &source_buf[stream_id], + sizeof(uint64_t), &sig_addr[stream_id], + signal_value, sig_op, pe_target, + streams[stream_id]); + } else { + // All other PEs wait for signal from previous PE + rocshmem_signal_wait_until_on_stream(&sig_addr[stream_id], + ROCSHMEM_CMP_GE, signal_value, + streams[stream_id]); + + // Forward the signal to next PE (unless we're the last PE) + if (my_pe != n_pes - 1) { + rocshmem_putmem_signal_on_stream(&sig_addr[stream_id], + &source_buf[stream_id], + sizeof(uint64_t), &sig_addr[stream_id], + signal_value, sig_op, pe_target, + streams[stream_id]); + } + } + + // Record stop event on last timed iteration for all streams + if (i == args.skip + loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[stream_id], + streams[stream_id])); + } + } + + // Wait for all streams to complete + for (int j = 0; j < num_streams; j++) { + CHECK_HIP(hipStreamSynchronize(streams[j])); + } + + // Barrier to ensure all RMA operations completed across all PEs + rocshmem_barrier_all(); + } + + num_msgs = (loop + args.skip) * num_streams; + num_timed_msgs = loop * num_streams; +} + +void SignalWaitUntilOnStreamTester::verifyResults(size_t size) { + // Synchronize to ensure all operations completed + rocshmem_barrier_all(); + + // Verify signal values + // All PEs except PE 0 should have received the final signal value + uint64_t expected_signal = args.skip + args.loop; + + for (int stream_id = 0; stream_id < num_streams; stream_id++) { + // PE 0 doesn't receive signals (it initiates), so skip verification + if (my_pe == 0) { + continue; + } + + // Verify signal + if (sig_addr[stream_id] != expected_signal) { + std::cerr << "PE " << my_pe << ": Signal verification failed for stream " + << stream_id << std::endl; + std::cerr << "Expected signal: " << expected_signal + << ", Got: " << sig_addr[stream_id] << std::endl; + rocshmem_global_exit(1); + } + } +} + diff --git a/projects/rocshmem/tests/functional_tests/signal_wait_until_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/signal_wait_until_on_stream_tester.hpp new file mode 100644 index 0000000000..7c305ca675 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/signal_wait_until_on_stream_tester.hpp @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _SIGNAL_WAIT_UNTIL_ON_STREAM_TESTER_HPP_ +#define _SIGNAL_WAIT_UNTIL_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class SignalWaitUntilOnStreamTester : public Tester { + public: + explicit SignalWaitUntilOnStreamTester(TesterArguments args); + virtual ~SignalWaitUntilOnStreamTester(); + +protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + uint64_t *sig_addr; + uint64_t *source_buf; // Source buffer in symmetric heap + int my_pe; + int n_pes; + int num_streams = 1; + int pe_target; // Target PE to signal next + int sig_op = ROCSHMEM_SIGNAL_SET; // Signal operation + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "signal_wait_until_on_stream_tester.cpp" + +#endif diff --git a/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp index 9d8900e603..746279f6ae 100644 --- a/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/team_alltoallmem_on_stream_tester.cpp @@ -39,7 +39,7 @@ TeamAlltoallmemOnStreamTester::TeamAlltoallmemOnStreamTester(TesterArguments arg n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); char* value{nullptr}; - if ((value = getenv("ROCSHMEM_MAX_NUM_TEAMS"))) { + if ((value = getenv("ROCSHMEM_TEST_MAX_NUM_TEAMS"))) { num_teams = atoi(value); } else { // Default to number of work groups @@ -162,6 +162,10 @@ void TeamAlltoallmemOnStreamTester::launchKernel(dim3 gridSize, } } + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + for (int i = 0; i < loop; i++) { for (int wg_id = 0; wg_id < num_teams; wg_id++) { // Record start event for this work group on first iteration diff --git a/projects/rocshmem/tests/functional_tests/team_broadcastmem_on_stream_tester.cpp b/projects/rocshmem/tests/functional_tests/team_broadcastmem_on_stream_tester.cpp new file mode 100644 index 0000000000..007eb76125 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/team_broadcastmem_on_stream_tester.cpp @@ -0,0 +1,240 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "team_broadcastmem_on_stream_tester.hpp" + +#include +#include +#include +#include +#include + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +TeamBroadcastmemOnStreamTester::TeamBroadcastmemOnStreamTester(TesterArguments args) + : Tester(args) { + my_pe = rocshmem_team_my_pe(ROCSHMEM_TEAM_WORLD); + n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + + char* value{nullptr}; + if ((value = getenv("ROCSHMEM_TEST_MAX_NUM_TEAMS"))) { + num_teams = atoi(value); + } else { + // Default to number of work groups + num_teams = args.num_wgs; + } + + // Set root PE to 0 by default, can be modified via environment variable + if ((value = getenv("ROCSHMEM_TEST_BROADCAST_ROOT"))) { + pe_root = atoi(value); + if (pe_root < 0 || pe_root >= n_pes) { + std::cerr << "Invalid ROCSHMEM_TEST_BROADCAST_ROOT value. Using PE 0." + << std::endl; + pe_root = 0; + } + } + + int num_bytes_wg = args.max_msg_size; + int total_bytes = num_bytes_wg * num_teams; + buf_size = total_bytes; + + source_buf = static_cast(rocshmem_malloc(buf_size)); + dest_buf = static_cast(rocshmem_malloc(buf_size)); + + if (source_buf == nullptr || dest_buf == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "source: " << source_buf << ", dest: " << dest_buf + << std::endl; + rocshmem_global_exit(1); + } + + team_world_dup.resize(num_teams); + + streams.resize(num_teams); + start_events_timed.resize(num_teams); + stop_events_timed.resize(num_teams); + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipStreamCreate(&streams[i])); + CHECK_HIP(hipEventCreate(&start_events_timed[i])); + CHECK_HIP(hipEventCreate(&stop_events_timed[i])); + } +} + +TeamBroadcastmemOnStreamTester::~TeamBroadcastmemOnStreamTester() { + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipEventDestroy(stop_events_timed[i])); + CHECK_HIP(hipEventDestroy(start_events_timed[i])); + CHECK_HIP(hipStreamDestroy(streams[i])); + } + rocshmem_free(source_buf); + rocshmem_free(dest_buf); +} + +void TeamBroadcastmemOnStreamTester::preLaunchKernel() { + bw_factor = 1; // Broadcast is one-to-all + + for (int team_i = 0; team_i < num_teams; team_i++) { + team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, + &team_world_dup[team_i]); + if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + std::cerr << "Team " << team_i << " is invalid!" << std::endl; + abort(); + } + } +} + +void TeamBroadcastmemOnStreamTester::postLaunchKernel() { + // Synchronize all streams to ensure events are recorded + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + // Get elapsed time for each work group from HIP events + for (int wg_id = 0; wg_id < num_teams && wg_id < num_timers; wg_id++) { + float elapsed_time_ms = 0.0f; + CHECK_HIP(hipEventElapsedTime(&elapsed_time_ms, start_events_timed[wg_id], + stop_events_timed[wg_id])); + + // Convert milliseconds to GPU cycles + // wall_clk_rate is in kHz, so: cycles = ms * wall_clk_rate + long long int elapsed_cycles = static_cast( + elapsed_time_ms * static_cast(wall_clk_rate)); + + start_time[wg_id] = 0; + end_time[wg_id] = elapsed_cycles; + } + + // Fill remaining timers with zero if num_timers > num_teams + for (int i = num_teams; i < num_timers; i++) { + start_time[i] = 0; + end_time[i] = 0; + } + + for (int team_i = 0; team_i < num_teams; team_i++) { + rocshmem_team_destroy(team_world_dup[team_i]); + } +} + +void TeamBroadcastmemOnStreamTester::resetBuffers(size_t size) { + // Initialize source buffer on all PEs + // Each work group has its own portion + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + int idx = wg_id * size; + if (my_pe == pe_root) { + // Root PE fills its source buffer with broadcast value + int value = (pe_root + 1) * 100 + wg_id; + std::memset(source_buf + idx, value, size); + } else { + // Non-root PEs source buffer (not used in broadcast) + std::memset(source_buf + idx, 0xFF, size); + } + } + + // Initialize destination buffer on all PEs + // Root PE keeps its initial dest value (broadcast doesn't copy to root's + // dest) Non-root PEs set to 0 (will receive broadcast data) + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + int idx = wg_id * size; + if (my_pe == pe_root) { + // Root PE's dest buffer stays with a different value + int root_dest_value = 0xAA; + std::memset(dest_buf + idx, root_dest_value, size); + } else { + std::memset(dest_buf + idx, 0, size); + } + } +} + +void TeamBroadcastmemOnStreamTester::launchKernel(dim3 gridSize, + dim3 blockSize, + int loop, + size_t size) { + // Execute warmup iterations (skip) + for (int i = 0; i < args.skip; i++) { + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + char *wg_source = source_buf + wg_id * size; + char *wg_dest = dest_buf + wg_id * size; + rocshmem_broadcastmem_on_stream(team_world_dup[wg_id], wg_dest, + wg_source, size, pe_root, streams[wg_id]); + } + } + + for (int i = 0; i < num_teams; i++) { + CHECK_HIP(hipStreamSynchronize(streams[i])); + } + + for (int i = 0; i < loop; i++) { + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + // Record start event for this work group on first iteration + if (i == 0) { + CHECK_HIP(hipEventRecord(start_events_timed[wg_id], streams[wg_id])); + } + + char *wg_source = source_buf + wg_id * size; + char *wg_dest = dest_buf + wg_id * size; + rocshmem_broadcastmem_on_stream(team_world_dup[wg_id], wg_dest, + wg_source, size, pe_root, streams[wg_id]); + + // Record stop event for this work group on last iteration + if (i == loop - 1) { + CHECK_HIP(hipEventRecord(stop_events_timed[wg_id], streams[wg_id])); + } + } + } + + num_msgs = (loop + args.skip) * num_teams; + num_timed_msgs = loop * num_teams; +} + +void TeamBroadcastmemOnStreamTester::verifyResults(size_t size) { + // Verify correctness: after broadcast, non-root PEs receive the broadcast + // data Root PE's dest buffer is NOT modified (per OpenSHMEM/rocSHMEM spec) + for (int wg_id = 0; wg_id < num_teams; wg_id++) { + int idx = wg_id * size; + int expected_value; + + if (my_pe == pe_root) { + // Root PE's dest buffer should remain unchanged (0xAA) + expected_value = 0xAA; + } else { + // Non-root PEs should have received the broadcast value + expected_value = (pe_root + 1) * 100 + wg_id; + } + + for (size_t k = 0; k < size; k++) { + if (static_cast(dest_buf[idx + k]) != + static_cast(expected_value)) { + std::cerr << "PE " << my_pe << ": Verification failed for WG " + << wg_id << " at byte " << k << std::endl; + std::cerr << "Expected value: " << expected_value + << ", Got: " << static_cast(dest_buf[idx + k]) + << std::endl; + rocshmem_global_exit(1); + } + } + } +} + diff --git a/projects/rocshmem/tests/functional_tests/team_broadcastmem_on_stream_tester.hpp b/projects/rocshmem/tests/functional_tests/team_broadcastmem_on_stream_tester.hpp new file mode 100644 index 0000000000..df5d5662e4 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/team_broadcastmem_on_stream_tester.hpp @@ -0,0 +1,71 @@ +/****************************************************************************** + * Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _TEAM_BROADCASTMEM_ON_STREAM_TESTER_HPP_ +#define _TEAM_BROADCASTMEM_ON_STREAM_TESTER_HPP_ + +#include "tester.hpp" +#include +#include + +using namespace rocshmem; + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class TeamBroadcastmemOnStreamTester : public Tester { + public: + explicit TeamBroadcastmemOnStreamTester(TesterArguments args); + virtual ~TeamBroadcastmemOnStreamTester(); + + protected: + virtual void resetBuffers(size_t size) override; + + virtual void preLaunchKernel() override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + size_t size) override; + + virtual void postLaunchKernel() override; + + virtual void verifyResults(size_t size) override; + + private: + char *source_buf; + char *dest_buf; + int my_pe; + int n_pes; + size_t buf_size; + int num_teams = 1; + int pe_root = 0; // Root PE for broadcast + std::vector team_world_dup; + std::vector streams; + std::vector start_events_timed; + std::vector stop_events_timed; +}; + +#include "team_broadcastmem_on_stream_tester.cpp" + +#endif + diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 955afb48dc..bf6ce6c23c 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -36,7 +36,12 @@ #include "amo_standard_tester.hpp" #include "default_ctx_primitive_tester.hpp" #include "barrier_all_tester.hpp" +#include "barrier_all_on_stream_tester.hpp" #include "empty_tester.hpp" +#include "getmem_on_stream_tester.hpp" +#include "putmem_on_stream_tester.hpp" +#include "putmem_signal_on_stream_tester.hpp" +#include "signal_wait_until_on_stream_tester.hpp" #include "ping_all_tester.hpp" #include "ping_pong_tester.hpp" #include "primitive_mr_tester.hpp" @@ -48,6 +53,7 @@ #include "team_sync_tester.hpp" #include "team_alltoall_tester.hpp" #include "team_alltoallmem_on_stream_tester.hpp" +#include "team_broadcastmem_on_stream_tester.hpp" #include "team_barrier_tester.hpp" #include "team_broadcast_tester.hpp" #include "team_ctx_infra_tester.hpp" @@ -233,6 +239,36 @@ std::vector Tester::create(TesterArguments args) { std::cout << "Alltoallmem_On_Stream ###" << std::endl; testers.push_back(new TeamAlltoallmemOnStreamTester(args)); return testers; + case BarrierAllOnStreamTestType: + if (rank == 0) + std::cout << "Barrier_All_On_Stream ###" << std::endl; + testers.push_back(new BarrierAllOnStreamTester(args)); + return testers; + case TeamBroadcastmemOnStreamTestType: + if (rank == 0) + std::cout << "Broadcastmem_On_Stream ###" << std::endl; + testers.push_back(new TeamBroadcastmemOnStreamTester(args)); + return testers; + case GetmemOnStreamTestType: + if (rank == 0) + std::cout << "Getmem_On_Stream ###" << std::endl; + testers.push_back(new GetmemOnStreamTester(args)); + return testers; + case PutmemOnStreamTestType: + if (rank == 0) + std::cout << "Putmem_On_Stream ###" << std::endl; + testers.push_back(new PutmemOnStreamTester(args)); + return testers; + case PutmemSignalOnStreamTestType: + if (rank == 0) + std::cout << "Putmem_Signal_On_Stream ###" << std::endl; + testers.push_back(new PutmemSignalOnStreamTester(args)); + return testers; + case SignalWaitUntilOnStreamTestType: + if (rank == 0) + std::cout << "Signal_Wait_Until_On_Stream ###" << std::endl; + testers.push_back(new SignalWaitUntilOnStreamTester(args)); + return testers; case TeamFCollectTestType: if (rank == 0) { std::cout << "Fcollect Test ###" << std::endl; @@ -569,30 +605,50 @@ void Tester::execute() { } bool Tester::peLaunchesKernel() { - bool is_launcher; - /** * The PE assigned 0 is always active in these tests. */ - is_launcher = args.myid == 0; + bool is_launcher = (args.myid == 0); /** * Some test types are active on both sides. */ - is_launcher = is_launcher || (_type == TeamReductionTestType) || - (_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) || - (_type == TeamCtxInfraTestSingleType) || (_type == TeamCtxInfraTestBlockType) || - (_type == TeamCtxInfraTestOddEvenType) || - (_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) || - (_type == PingPongTestType) || (_type == BarrierAllTestType) || - (_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) || - (_type == TeamSyncTestType) || (_type == TeamWAVESyncTestType) || - (_type == TeamWGSyncTestType) || (_type == SyncAllTestType) || - (_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) || - (_type == RandomAccessTestType) || (_type == PingAllTestType) || - (_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) || - (_type == TeamWGBarrierTestType) || - (_type == TeamAlltoallmemOnStreamTestType); + switch (_type) { + case TeamReductionTestType: + case TeamBroadcastTestType: + case TeamCtxInfraTestType: + case TeamCtxInfraTestSingleType: + case TeamCtxInfraTestBlockType: + case TeamCtxInfraTestOddEvenType: + case TeamAllToAllTestType: + case TeamFCollectTestType: + case PingPongTestType: + case BarrierAllTestType: + case WAVEBarrierAllTestType: + case WGBarrierAllTestType: + case TeamSyncTestType: + case TeamWAVESyncTestType: + case TeamWGSyncTestType: + case SyncAllTestType: + case WAVESyncAllTestType: + case WGSyncAllTestType: + case RandomAccessTestType: + case PingAllTestType: + case TeamBarrierTestType: + case TeamWAVEBarrierTestType: + case TeamWGBarrierTestType: + case TeamAlltoallmemOnStreamTestType: + case BarrierAllOnStreamTestType: + case TeamBroadcastmemOnStreamTestType: + case GetmemOnStreamTestType: + case PutmemOnStreamTestType: + case PutmemSignalOnStreamTestType: + case SignalWaitUntilOnStreamTestType: + is_launcher = true; + break; + default: + break; + } return is_launcher; } diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index 4d67df060a..a11d5ccba9 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -114,6 +114,12 @@ enum TestType { TeamCtxInfraTestBlockType = 74, TeamCtxInfraTestOddEvenType = 75, TeamAlltoallmemOnStreamTestType = 76, + BarrierAllOnStreamTestType = 77, + TeamBroadcastmemOnStreamTestType = 78, + GetmemOnStreamTestType = 79, + PutmemOnStreamTestType = 80, + PutmemSignalOnStreamTestType = 81, + SignalWaitUntilOnStreamTestType = 82, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index afbe541df8..f99cd9fc92 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -112,10 +112,12 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case TeamBarrierTestType: case TeamWAVEBarrierTestType: case TeamWGBarrierTestType: + case BarrierAllOnStreamTestType: case SyncAllTestType: case WAVESyncAllTestType: case WGSyncAllTestType: case TeamSyncTestType: + case SignalWaitUntilOnStreamTestType: min_msg_size = 8; max_msg_size = 8; break; @@ -125,6 +127,8 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { max_msg_size = 4; break; case RandomAccessTestType: + case TeamAlltoallmemOnStreamTestType: + case TeamBroadcastmemOnStreamTestType: min_msg_size = 4; break; case TeamFCollectTestType: @@ -173,23 +177,49 @@ void TesterArguments::get_arguments() { myid = rocshmem_my_pe(); TestType type = (TestType)algorithm; - if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) && - (type != WGBarrierAllTestType) && (type != SyncAllTestType) && - (type != WAVESyncAllTestType) && (type != WGSyncAllTestType) && - (type != TeamSyncTestType) && (type != TeamWAVESyncTestType) && - (type != TeamWGSyncTestType) && (type != TeamAllToAllTestType) && - (type != TeamFCollectTestType) && (type != TeamReductionTestType) && - (type != TeamBroadcastTestType) && (type != PingAllTestType) && - (type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) && - (type != TeamWGBarrierTestType) && (type != TeamCtxInfraTestBlockType) && - (type != TeamCtxInfraTestOddEvenType) && - (type != TeamAlltoallmemOnStreamTestType)) { - if (numprocs != 2) { - if (myid == 0) { - std::cerr << "This test requires exactly two processes, we have " - << numprocs << "\n"; - } - exit(-1); + // Check if test requires exactly 2 PEs + // Tests that support arbitrary number of PEs are excluded + bool requires_two_pes = true; + switch (type) { + // Collective/barrier tests - support any number of PEs + case BarrierAllTestType: + case WAVEBarrierAllTestType: + case WGBarrierAllTestType: + case SyncAllTestType: + case WAVESyncAllTestType: + case WGSyncAllTestType: + case TeamSyncTestType: + case TeamWAVESyncTestType: + case TeamWGSyncTestType: + case TeamAllToAllTestType: + case TeamFCollectTestType: + case TeamReductionTestType: + case TeamBroadcastTestType: + case PingAllTestType: + case TeamBarrierTestType: + case TeamWAVEBarrierTestType: + case TeamWGBarrierTestType: + case TeamCtxInfraTestBlockType: + case TeamCtxInfraTestOddEvenType: + // On-stream tests - support any number of PEs + case TeamAlltoallmemOnStreamTestType: + case BarrierAllOnStreamTestType: + case TeamBroadcastmemOnStreamTestType: + case GetmemOnStreamTestType: + case PutmemOnStreamTestType: + case PutmemSignalOnStreamTestType: + case SignalWaitUntilOnStreamTestType: + requires_two_pes = false; + break; + default: + break; + } + + if (requires_two_pes && numprocs != 2) { + if (myid == 0) { + std::cerr << "This test requires exactly two processes, we have " + << numprocs << "\n"; } + exit(-1); } } diff --git a/projects/rocshmem/utils/header_files_gen/P2P_SYNC.py b/projects/rocshmem/utils/header_files_gen/P2P_SYNC.py index 72aac15847..98d6b3d590 100644 --- a/projects/rocshmem/utils/header_files_gen/P2P_SYNC.py +++ b/projects/rocshmem/utils/header_files_gen/P2P_SYNC.py @@ -38,6 +38,7 @@ types = [ ("unsigned int", "uint"), ("unsigned long", "ulong"), ("unsigned long long", "ulonglong"), + ("uint64_t", "uint64"), ]