From dc61bca0666737fcdcdc5d15c6661599f2efd33f Mon Sep 17 00:00:00 2001 From: Avinash Kethineedi Date: Tue, 8 Apr 2025 11:25:31 -0500 Subject: [PATCH] Update `Barrier` and `Sync` APIs (#73) * Add thread, wavefront, and workgroup-level `barrier` APIs in IPC and RO conduits; remove collectives on default context - Implemented `barrier` APIs for thread, wavefront, and workgroup scopes - Added support into both IPC and RO conduits - Added functional tests to cover all `barrier` APIs - Removed collective operations on default context * Add thread, wavefront, and workgroup-level `sync` APIs in IPC and RO conduits. - Implemented `sync` APIs for thread, wavefront, and workgroup scopes - Added support into both IPC and RO conduits - Added functional tests to cover all `sync` APIs * update naming convention for context-based `barrier` APIs --- include/rocshmem/rocshmem.hpp | 82 +++++++++++++++---- scripts/functional_tests/driver.sh | 32 +++++++- src/backend_bc.cpp | 6 ++ src/context.hpp | 8 ++ src/context_device.cpp | 28 ++++++- src/ipc/context_ipc_device.hpp | 8 ++ src/ipc/context_ipc_device_coll.cpp | 52 ++++++++++++ src/reverse_offload/context_ro_device.cpp | 32 ++++++++ src/reverse_offload/context_ro_device.hpp | 8 ++ src/rocshmem_gpu.cpp | 48 ++++++----- src/stats.hpp | 6 ++ tests/functional_tests/sync_tester.cpp | 29 ++++--- .../functional_tests/team_barrier_tester.cpp | 39 ++++++--- tests/functional_tests/tester.cpp | 22 ++++- tests/functional_tests/tester.hpp | 6 +- tests/functional_tests/tester_arguments.cpp | 8 +- 16 files changed, 347 insertions(+), 67 deletions(-) diff --git a/include/rocshmem/rocshmem.hpp b/include/rocshmem/rocshmem.hpp index 5113baa64a..a9a2251017 100644 --- a/include/rocshmem/rocshmem.hpp +++ b/include/rocshmem/rocshmem.hpp @@ -503,8 +503,6 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team, __device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all( rocshmem_ctx_t ctx); -__device__ ATTR_NO_INLINE void rocshmem_barrier_all(); - /** * @brief perform a collective barrier between all PEs in the system. * The caller is blocked until the barrier is resolved. @@ -518,8 +516,6 @@ __device__ ATTR_NO_INLINE void rocshmem_barrier_all(); __device__ ATTR_NO_INLINE void rocshmem_ctx_wave_barrier_all( rocshmem_ctx_t ctx); -__device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all(); - /** * @brief perform a collective barrier between all PEs in the system. * The caller is blocked until the barrier is resolved. @@ -533,17 +529,47 @@ __device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all(); __device__ ATTR_NO_INLINE void rocshmem_ctx_wg_barrier_all( rocshmem_ctx_t ctx); -__device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all(); +/** + * @brief perform a collective barrier between all PEs in the team. + * The caller is blocked until the barrier is resolved. + * + * This function must be invoked by a single thread within the PE. + * + * @param[in] handle GPU side handle. + * + * @param[in] team The team on which to perform barrier synchronization + * + * @return void + */ +__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team); /** * @brief perform a collective barrier between all PEs in the team. * The caller is blocked until the barrier is resolved. * - * @param[in] team The team on which to perform barrier synchronization + * This function must be called as a wave-front collective. + * + * @param[in] handle GPU side handle. + * + * @param[in] team The team on which to perform barrier synchronization * * @return void */ -__device__ void rocshmem_barrier(rocshmem_team_t); +__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team); + +/** + * @brief perform a collective barrier between all PEs in the team. + * The caller is blocked until the barrier is resolved. + * + * This function must be called as a work-group collective. + * + * @param[in] handle GPU side handle. + * + * @param[in] team The team on which to perform barrier synchronization + * + * @return void + */ +__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team); /** * @brief registers the arrival of a PE at a barrier. @@ -561,8 +587,6 @@ __device__ void rocshmem_barrier(rocshmem_team_t); */ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx); -__device__ ATTR_NO_INLINE void rocshmem_sync_all(); - /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. @@ -579,8 +603,6 @@ __device__ ATTR_NO_INLINE void rocshmem_sync_all(); */ __device__ ATTR_NO_INLINE void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx); -__device__ ATTR_NO_INLINE void rocshmem_wave_sync_all(); - /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. @@ -597,7 +619,41 @@ __device__ ATTR_NO_INLINE void rocshmem_wave_sync_all(); */ __device__ ATTR_NO_INLINE void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx); -__device__ ATTR_NO_INLINE void rocshmem_wg_sync_all(); +/** + * @brief registers the arrival of a PE at a barrier. + * The caller is blocked until the synchronization is resolved. + * + * In contrast with the shmem_barrier_all routine, shmem_team_sync only ensures + * completion and visibility of previously issued memory stores and does not + * ensure completion of remote memory updates issued via OpenSHMEM routines. + * + * This function must be invoked by a single thread within the PE. + * + * @param[in] handle GPU side handle. + * @param[in] team Handle of the team being synchronized + * + * @return void + */ +__device__ ATTR_NO_INLINE void rocshmem_ctx_team_sync( + rocshmem_ctx_t ctx, rocshmem_team_t team); + +/** + * @brief registers the arrival of a PE at a barrier. + * The caller is blocked until the synchronization is resolved. + * + * In contrast with the shmem_barrier_all routine, shmem_team_sync only ensures + * completion and visibility of previously issued memory stores and does not + * ensure completion of remote memory updates issued via OpenSHMEM routines. + * + * This function must be called as a wave-front collective. + * + * @param[in] handle GPU side handle. + * @param[in] team Handle of the team being synchronized + * + * @return void + */ +__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_team_sync( + rocshmem_ctx_t ctx, rocshmem_team_t team); /** * @brief registers the arrival of a PE at a barrier. @@ -617,8 +673,6 @@ __device__ ATTR_NO_INLINE void rocshmem_wg_sync_all(); __device__ ATTR_NO_INLINE void rocshmem_ctx_wg_team_sync( rocshmem_ctx_t ctx, rocshmem_team_t team); -__device__ ATTR_NO_INLINE void rocshmem_wg_team_sync(rocshmem_team_t team); - /** * @brief Query a local pointer to a symmetric data object on the * specified \pe . Returns an address that may be used to directly reference diff --git a/scripts/functional_tests/driver.sh b/scripts/functional_tests/driver.sh index 0c8994fb55..698437b312 100755 --- a/scripts/functional_tests/driver.sh +++ b/scripts/functional_tests/driver.sh @@ -80,7 +80,7 @@ declare -A TEST_NUMBERS=( ["signalfetch"]="55" ["wgsignalfetch"]="56" ["wavesignalfetch"]="57" - ["teambarrier"]="58" + ["teamwgbarrier"]="58" ["defaultctxget"]="59" ["defaultctxgetnbi"]="60" ["defaultctxput"]="61" @@ -91,6 +91,10 @@ declare -A TEST_NUMBERS=( ["wgbarrierall"]="66" ["wavesyncall"]="67" ["wgsyncall"]="68" + ["teambarrier"]="69" + ["teamwavebarrier"]="70" + ["wavesync"]="71" + ["wgsync"]="72" ) ExecTest() { @@ -328,8 +332,34 @@ TestColl() { ExecTest "wgbarrierall" 2 64 1024 ExecTest "teambarrier" 2 1 1 + ExecTest "teambarrier" 2 16 64 + ExecTest "teambarrier" 2 32 256 + ExecTest "teambarrier" 2 39 1024 + + ExecTest "teamwavebarrier" 2 1 1 + ExecTest "teamwavebarrier" 2 16 64 + ExecTest "teamwavebarrier" 2 32 256 + ExecTest "teamwavebarrier" 2 39 1024 + + ExecTest "teamwgbarrier" 2 1 1 + ExecTest "teamwgbarrier" 2 16 64 + ExecTest "teamwgbarrier" 2 32 256 + ExecTest "teamwgbarrier" 2 39 1024 ExecTest "sync" 2 1 1 + ExecTest "sync" 2 16 64 + ExecTest "sync" 2 32 256 + ExecTest "sync" 2 39 1024 + + ExecTest "wavesync" 2 1 1 + ExecTest "wavesync" 2 16 64 + ExecTest "wavesync" 2 32 256 + ExecTest "wavesync" 2 39 1024 + + ExecTest "wgsync" 2 1 1 + ExecTest "wgsync" 2 16 64 + ExecTest "wgsync" 2 32 256 + ExecTest "wgsync" 2 39 1024 ExecTest "syncall" 2 1 1 ExecTest "syncall" 2 16 64 diff --git a/src/backend_bc.cpp b/src/backend_bc.cpp index 72dc366288..7b7623154c 100644 --- a/src/backend_bc.cpp +++ b/src/backend_bc.cpp @@ -132,6 +132,9 @@ void Backend::dump_stats() { printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL)); printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE)); printf("WG_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WG)); + printf("Barrier %llu\n", device_stats.getStat(NUM_BARRIER)); + printf("WAVE_Barrier %llu\n", device_stats.getStat(NUM_BARRIER_WAVE)); + printf("WG_Barrier %llu\n", device_stats.getStat(NUM_BARRIER_WG)); printf("Wait Until %llu\n", device_stats.getStat(NUM_WAIT_UNTIL)); printf("Wait Until Any %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ANY)); printf("Wait Until All %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ALL)); @@ -157,6 +160,9 @@ void Backend::dump_stats() { printf("SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL)); printf("WAVE_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WAVE)); printf("WG_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WG)); + printf("Sync %llu\n", device_stats.getStat(NUM_SYNC)); + printf("WAVE_Sync %llu\n", device_stats.getStat(NUM_SYNC_WAVE)); + printf("WG_Sync %llu\n", device_stats.getStat(NUM_SYNC_WG)); const auto& host_stats{globalHostStats}; printf("HOST STATS\n"); diff --git a/src/context.hpp b/src/context.hpp index 22d4cda4e8..e74299fe0c 100644 --- a/src/context.hpp +++ b/src/context.hpp @@ -143,12 +143,20 @@ class Context { __device__ void barrier(rocshmem_team_t team); + __device__ void barrier_wave(rocshmem_team_t team); + + __device__ void barrier_wg(rocshmem_team_t team); + __device__ void sync_all(); __device__ void sync_all_wave(); __device__ void sync_all_wg(); + __device__ void sync(rocshmem_team_t team); + + __device__ void sync_wave(rocshmem_team_t team); + __device__ void sync_wg(rocshmem_team_t team); template diff --git a/src/context_device.cpp b/src/context_device.cpp index c7c9d29583..8682192bd2 100644 --- a/src/context_device.cpp +++ b/src/context_device.cpp @@ -161,11 +161,23 @@ __device__ void Context::barrier_all_wg() { } __device__ void Context::barrier(rocshmem_team_t team) { - ctxStats.incStat(NUM_BARRIER_ALL); + ctxStats.incStat(NUM_BARRIER); DISPATCH(barrier(team)); } +__device__ void Context::barrier_wave(rocshmem_team_t team) { + ctxStats.incStat(NUM_BARRIER_WAVE); + + DISPATCH(barrier_wave(team)); +} + +__device__ void Context::barrier_wg(rocshmem_team_t team) { + ctxStats.incStat(NUM_BARRIER_WG); + + DISPATCH(barrier_wg(team)); +} + __device__ void Context::sync_all() { ctxStats.incStat(NUM_SYNC_ALL); @@ -184,8 +196,20 @@ __device__ void Context::sync_all_wg() { DISPATCH(sync_all_wg()); } +__device__ void Context::sync(rocshmem_team_t team) { + ctxStats.incStat(NUM_SYNC); + + DISPATCH(sync(team)); +} + +__device__ void Context::sync_wave(rocshmem_team_t team) { + ctxStats.incStat(NUM_SYNC_WAVE); + + DISPATCH(sync_wave(team)); +} + __device__ void Context::sync_wg(rocshmem_team_t team) { - ctxStats.incStat(NUM_SYNC_ALL_WG); + ctxStats.incStat(NUM_SYNC_WG); DISPATCH(sync_wg(team)); } diff --git a/src/ipc/context_ipc_device.hpp b/src/ipc/context_ipc_device.hpp index ac7f7a8154..5e8acd5a53 100644 --- a/src/ipc/context_ipc_device.hpp +++ b/src/ipc/context_ipc_device.hpp @@ -67,12 +67,20 @@ class IPCContext : public Context { __device__ void barrier(rocshmem_team_t team); + __device__ void barrier_wave(rocshmem_team_t team); + + __device__ void barrier_wg(rocshmem_team_t team); + __device__ void sync_all(); __device__ void sync_all_wave(); __device__ void sync_all_wg(); + __device__ void sync(rocshmem_team_t team); + + __device__ void sync_wave(rocshmem_team_t team); + __device__ void sync_wg(rocshmem_team_t team); template diff --git a/src/ipc/context_ipc_device_coll.cpp b/src/ipc/context_ipc_device_coll.cpp index 6191ad0e8f..ab92f64007 100644 --- a/src/ipc/context_ipc_device_coll.cpp +++ b/src/ipc/context_ipc_device_coll.cpp @@ -118,6 +118,30 @@ __device__ void IPCContext::internal_sync_wg(int pe, int PE_start, int stride, __syncthreads(); } +__device__ void IPCContext::sync(rocshmem_team_t team) { + IPCTeam *team_obj = reinterpret_cast(team); + + int pe = team_obj->my_pe_in_world; + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_stride = team_obj->tinfo_wrt_world->stride; + int pe_size = team_obj->num_pes; + long *p_sync = team_obj->barrier_pSync; + + internal_sync(pe, pe_start, pe_stride, pe_size, p_sync); +} + +__device__ void IPCContext::sync_wave(rocshmem_team_t team) { + IPCTeam *team_obj = reinterpret_cast(team); + + int pe = team_obj->my_pe_in_world; + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_stride = team_obj->tinfo_wrt_world->stride; + int pe_size = team_obj->num_pes; + long *p_sync = team_obj->barrier_pSync; + + internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync); +} + __device__ void IPCContext::sync_wg(rocshmem_team_t team) { IPCTeam *team_obj = reinterpret_cast(team); @@ -171,6 +195,34 @@ __device__ void IPCContext::barrier(rocshmem_team_t team) { int pe_size = team_obj->num_pes; long *p_sync = team_obj->barrier_pSync; + quiet(); + internal_sync(pe, pe_start, pe_stride, pe_size, p_sync); +} + +__device__ void IPCContext::barrier_wave(rocshmem_team_t team) { + IPCTeam *team_obj = reinterpret_cast(team); + + int pe = team_obj->my_pe_in_world; + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_stride = team_obj->tinfo_wrt_world->stride; + int pe_size = team_obj->num_pes; + long *p_sync = team_obj->barrier_pSync; + + if (is_thread_zero_in_wave()) { + quiet(); + } + internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync); +} + +__device__ void IPCContext::barrier_wg(rocshmem_team_t team) { + IPCTeam *team_obj = reinterpret_cast(team); + + int pe = team_obj->my_pe_in_world; + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_stride = team_obj->tinfo_wrt_world->stride; + int pe_size = team_obj->num_pes; + long *p_sync = team_obj->barrier_pSync; + if (is_thread_zero_in_block()) { quiet(); } diff --git a/src/reverse_offload/context_ro_device.cpp b/src/reverse_offload/context_ro_device.cpp index 451fb4209d..36d389da7e 100644 --- a/src/reverse_offload/context_ro_device.cpp +++ b/src/reverse_offload/context_ro_device.cpp @@ -193,6 +193,22 @@ __device__ void ROContext::barrier_all_wg() { } __device__ void ROContext::barrier(rocshmem_team_t team) { + ROTeam *team_obj = reinterpret_cast(team); + build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, + nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, + true, get_status_flag(), is_default_ctx); +} + +__device__ void ROContext::barrier_wave(rocshmem_team_t team) { + ROTeam *team_obj = reinterpret_cast(team); + if (is_thread_zero_in_wave()) { + build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, + nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, + true, get_status_flag(), is_default_ctx); + } +} + +__device__ void ROContext::barrier_wg(rocshmem_team_t team) { ROTeam *team_obj = reinterpret_cast(team); if (is_thread_zero_in_block()) { build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, @@ -225,6 +241,22 @@ __device__ void ROContext::sync_all_wg() { __syncthreads(); } +__device__ void ROContext::sync(rocshmem_team_t team) { + ROTeam *team_obj = reinterpret_cast(team); + build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, + nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, + true, get_status_flag(), is_default_ctx); +} + +__device__ void ROContext::sync_wave(rocshmem_team_t team) { + ROTeam *team_obj = reinterpret_cast(team); + if (is_thread_zero_in_wave()) { + build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, + nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle, + true, get_status_flag(), is_default_ctx); + } +} + __device__ void ROContext::sync_wg(rocshmem_team_t team) { ROTeam *team_obj = reinterpret_cast(team); if (is_thread_zero_in_block()) { diff --git a/src/reverse_offload/context_ro_device.hpp b/src/reverse_offload/context_ro_device.hpp index aa3b675f2e..ba285662cd 100644 --- a/src/reverse_offload/context_ro_device.hpp +++ b/src/reverse_offload/context_ro_device.hpp @@ -71,12 +71,20 @@ class ROContext : public Context { __device__ void barrier(rocshmem_team_t team); + __device__ void barrier_wave(rocshmem_team_t team); + + __device__ void barrier_wg(rocshmem_team_t team); + __device__ void sync_all(); __device__ void sync_all_wave(); __device__ void sync_all_wg(); + __device__ void sync(rocshmem_team_t team); + + __device__ void sync_wave(rocshmem_team_t team); + __device__ void sync_wg(rocshmem_team_t team); template diff --git a/src/rocshmem_gpu.cpp b/src/rocshmem_gpu.cpp index f3a52d668b..5c9e5279e8 100644 --- a/src/rocshmem_gpu.cpp +++ b/src/rocshmem_gpu.cpp @@ -588,14 +588,22 @@ __device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) { get_internal_ctx(ctx)->barrier_all_wg(); } -__device__ void rocshmem_wg_barrier_all() { - rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT); -} - -__device__ void rocshmem_barrier(rocshmem_team_t team) { +__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) { GPU_DPRINTF("Function: rocshmem_barrier\n"); - get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team); + get_internal_ctx(ctx)->barrier(team); +} + +__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) { + GPU_DPRINTF("Function: rocshmem_wave_barrier\n"); + + get_internal_ctx(ctx)->barrier_wave(team); +} + +__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) { + GPU_DPRINTF("Function: rocshmem_wg_barrier\n"); + + get_internal_ctx(ctx)->barrier_wg(team); } __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) { @@ -604,39 +612,37 @@ __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) { get_internal_ctx(ctx)->sync_all(); } -__device__ void rocshmem_sync_all() { - rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT); -} - __device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) { GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n"); get_internal_ctx(ctx)->sync_all_wave(); } -__device__ void rocshmem_wave_sync_all() { - rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT); -} - __device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) { GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n"); get_internal_ctx(ctx)->sync_all_wg(); } -__device__ void rocshmem_wg_sync_all() { - rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT); -} - -__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx, +__device__ void rocshmem_ctx_team_sync(rocshmem_ctx_t ctx, rocshmem_team_t team) { GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n"); get_internal_ctx(ctx)->sync_wg(team); } -__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) { - rocshmem_ctx_wg_team_sync(ROCSHMEM_CTX_DEFAULT, team); +__device__ void rocshmem_ctx_wave_team_sync(rocshmem_ctx_t ctx, + rocshmem_team_t team) { +GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n"); + +get_internal_ctx(ctx)->sync_wg(team); +} + +__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx, + rocshmem_team_t team) { +GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n"); + +get_internal_ctx(ctx)->sync_wg(team); } __device__ int rocshmem_ctx_n_pes(rocshmem_ctx_t ctx) { diff --git a/src/stats.hpp b/src/stats.hpp index ab4d5eb785..62a4fbf933 100644 --- a/src/stats.hpp +++ b/src/stats.hpp @@ -45,6 +45,9 @@ enum rocshmem_stats { NUM_BARRIER_ALL, NUM_BARRIER_ALL_WAVE, NUM_BARRIER_ALL_WG, + NUM_BARRIER, + NUM_BARRIER_WAVE, + NUM_BARRIER_WG, NUM_WAIT_UNTIL, NUM_WAIT_UNTIL_ANY, NUM_WAIT_UNTIL_ALL, @@ -74,6 +77,9 @@ enum rocshmem_stats { NUM_SYNC_ALL, NUM_SYNC_ALL_WAVE, NUM_SYNC_ALL_WG, + NUM_SYNC, + NUM_SYNC_WAVE, + NUM_SYNC_WG, NUM_BROADCAST, NUM_PUT_WG, NUM_PUT_NBI_WG, diff --git a/tests/functional_tests/sync_tester.cpp b/tests/functional_tests/sync_tester.cpp index fd5c67fce9..b130e6fdb7 100644 --- a/tests/functional_tests/sync_tester.cpp +++ b/tests/functional_tests/sync_tester.cpp @@ -27,9 +27,12 @@ *****************************************************************************/ __global__ void SyncTest(int loop, int skip, long long int *start_time, long long int *end_time, TestType type, - ShmemContextType ctx_type, rocshmem_team_t *teams) { + ShmemContextType ctx_type, int wf_size, + rocshmem_team_t *teams) { __shared__ rocshmem_ctx_t ctx; + int t_id = get_flat_block_id(); int wg_id = get_flat_grid_id(); + int wf_id = t_id / wf_size; rocshmem_wg_init(); rocshmem_wg_ctx_create(ctx_type, &ctx); @@ -39,16 +42,25 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time, start_time[wg_id] = wall_clock64(); } - __syncthreads(); switch (type) { case SyncTestType: + if(t_id == 0) { + rocshmem_ctx_team_sync(ctx, teams[wg_id]); + } + break; + case WAVESyncTestType: + if(wf_id == 0) { + rocshmem_ctx_wave_team_sync(ctx, teams[wg_id]); + } + break; + case WGSyncTestType: rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]); break; default: break; } + __syncthreads(); } - __syncthreads(); if (hipThreadIdx_x == 0) { end_time[wg_id] = wall_clock64(); @@ -100,15 +112,10 @@ void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop, args.skip, start_time, end_time, _type, _shmem_context, - team_sync_world_dup); + wf_size, team_sync_world_dup); - num_msgs = loop + args.skip; - num_timed_msgs = loop; - - if(_type == SyncTestType) { - num_msgs *= gridSize.x; - num_timed_msgs *= gridSize.x; - } + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; } void SyncTester::postLaunchKernel() { diff --git a/tests/functional_tests/team_barrier_tester.cpp b/tests/functional_tests/team_barrier_tester.cpp index 263f2a2c62..ff81154aed 100644 --- a/tests/functional_tests/team_barrier_tester.cpp +++ b/tests/functional_tests/team_barrier_tester.cpp @@ -28,28 +28,41 @@ rocshmem_team_t team_barrier_world_dup; *****************************************************************************/ __global__ void TeamBarrierTest(int loop, int skip, long long int *start_time, long long int *end_time, - ShmemContextType ctx_type, - rocshmem_team_t *teams) { + ShmemContextType ctx_type, TestType type, + int wf_size, rocshmem_team_t *teams) { __shared__ rocshmem_ctx_t ctx; + int t_id = get_flat_block_id(); int wg_id = get_flat_grid_id(); + int wf_id = t_id / wf_size; rocshmem_wg_init(); rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx); - int n_pes = rocshmem_ctx_n_pes(ctx); - - __syncthreads(); - for (int i = 0; i < loop + skip; i++) { if (i == skip && hipThreadIdx_x == 0) { start_time[wg_id] = wall_clock64(); } - rocshmem_barrier(teams[wg_id]); + switch (type) { + case TeamBarrierTestType: + if(t_id == 0) { + rocshmem_ctx_barrier(ctx, teams[wg_id]); + } + break; + case TeamWAVEBarrierTestType: + if(wf_id == 0) { + rocshmem_ctx_wave_barrier(ctx, teams[wg_id]); + } + break; + case TeamWGBarrierTestType: + rocshmem_ctx_wg_barrier(ctx, teams[wg_id]); + break; + default: + break; + } + __syncthreads(); } - __syncthreads(); - if (hipThreadIdx_x == 0) { end_time[wg_id] = wall_clock64(); } @@ -95,10 +108,10 @@ void TeamBarrierTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, uint64_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize, - shared_bytes, stream, loop, args.skip, - start_time, end_time, _shmem_context, - team_barrier_world_dup); + hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize, shared_bytes, + stream, loop, args.skip, start_time, end_time, + _shmem_context, _type, wf_size, + team_barrier_world_dup); num_msgs = (loop + args.skip) * gridSize.x; num_timed_msgs = loop * gridSize.x; diff --git a/tests/functional_tests/tester.cpp b/tests/functional_tests/tester.cpp index d5b6dea7a6..a348b33a30 100644 --- a/tests/functional_tests/tester.cpp +++ b/tests/functional_tests/tester.cpp @@ -330,6 +330,14 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl; testers.push_back(new TeamBarrierTester(args)); return testers; + case TeamWAVEBarrierTestType: + if (rank == 0) std::cout << "Team WAVE Barrier Test ###" << std::endl; + testers.push_back(new TeamBarrierTester(args)); + return testers; + case TeamWGBarrierTestType: + if (rank == 0) std::cout << "Team WG Barrier Test ###" << std::endl; + testers.push_back(new TeamBarrierTester(args)); + return testers; case SyncAllTestType: if (rank == 0) std::cout << "SyncAll ###" << std::endl; testers.push_back(new SyncTester(args)); @@ -346,6 +354,14 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "Sync ###" << std::endl; testers.push_back(new SyncTester(args)); return testers; + case WAVESyncTestType: + if (rank == 0) std::cout << "WAVE Sync ###" << std::endl; + testers.push_back(new SyncTester(args)); + return testers; + case WGSyncTestType: + if (rank == 0) std::cout << "WG Sync ###" << std::endl; + testers.push_back(new SyncTester(args)); + return testers; case RandomAccessTestType: if (rank == 0) std::cout << "Random_Access ###" << std::endl; testers.push_back(new RandomAccessTester(args)); @@ -528,10 +544,12 @@ bool Tester::peLaunchesKernel() { (_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) || (_type == PingPongTestType) || (_type == BarrierAllTestType) || (_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) || - (_type == SyncTestType) || (_type == SyncAllTestType) || + (_type == SyncTestType) || (_type == WAVESyncTestType) || + (_type == WGSyncTestType) || (_type == SyncAllTestType) || (_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) || (_type == RandomAccessTestType) || (_type == PingAllTestType) || - (_type == TeamBarrierTestType); + (_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) || + (_type == TeamWGBarrierTestType); return is_launcher; } diff --git a/tests/functional_tests/tester.hpp b/tests/functional_tests/tester.hpp index 3a057f23ee..022e81004d 100644 --- a/tests/functional_tests/tester.hpp +++ b/tests/functional_tests/tester.hpp @@ -92,7 +92,7 @@ enum TestType { SignalFetchTestType = 55, WGSignalFetchTestType = 56, WAVESignalFetchTestType = 57, - TeamBarrierTestType = 58, + TeamWGBarrierTestType = 58, DefaultCTXGetTestType = 59, DefaultCTXGetNBITestType = 60, DefaultCTXPutTestType = 61, @@ -103,6 +103,10 @@ enum TestType { WGBarrierAllTestType = 66, WAVESyncAllTestType = 67, WGSyncAllTestType = 68, + TeamBarrierTestType = 69, + TeamWAVEBarrierTestType = 70, + WAVESyncTestType = 71, + WGSyncTestType = 72, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/tests/functional_tests/tester_arguments.cpp b/tests/functional_tests/tester_arguments.cpp index 1a3eaac6a7..533b9d4093 100644 --- a/tests/functional_tests/tester_arguments.cpp +++ b/tests/functional_tests/tester_arguments.cpp @@ -88,6 +88,8 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case WAVEBarrierAllTestType: case WGBarrierAllTestType: case TeamBarrierTestType: + case TeamWAVEBarrierTestType: + case TeamWGBarrierTestType: case SyncAllTestType: case WAVESyncAllTestType: case WGSyncAllTestType: @@ -140,10 +142,12 @@ void TesterArguments::get_rocshmem_arguments() { if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) && (type != WGBarrierAllTestType) && (type != SyncAllTestType) && (type != WAVESyncAllTestType) && (type != WGSyncAllTestType) && - (type != SyncTestType) && (type != TeamAllToAllTestType) && + (type != SyncTestType) && (type != WAVESyncTestType) && + (type != WGSyncTestType) && (type != TeamAllToAllTestType) && (type != TeamFCollectTestType) && (type != TeamReductionTestType) && (type != TeamBroadcastTestType) && (type != PingAllTestType) && - (type != TeamBarrierTestType)) { + (type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) && + (type != TeamWGBarrierTestType)) { if (numprocs != 2) { if (myid == 0) { std::cerr << "This test requires exactly two processes, we have "