diff --git a/projects/rocshmem/include/rocshmem/rocshmem.hpp b/projects/rocshmem/include/rocshmem/rocshmem.hpp index b94ab11613..5113baa64a 100644 --- a/projects/rocshmem/include/rocshmem/rocshmem.hpp +++ b/projects/rocshmem/include/rocshmem/rocshmem.hpp @@ -490,6 +490,36 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team, int src_pe, rocshmem_team_t dest_team); +/** + * @brief perform a collective barrier between all PEs in the system. + * The caller is blocked until the barrier is resolved. + * + * This function must be invoked by a single thread within the PE. + * + * @param[in] handle GPU side handle. + * + * @return void + */ +__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all( + rocshmem_ctx_t ctx); + +__device__ ATTR_NO_INLINE void rocshmem_barrier_all(); + +/** + * @brief perform a collective barrier between all PEs in the system. + * The caller is blocked until the barrier is resolved. + * + * This function must be called as a wave-front collective. + * + * @param[in] handle GPU side handle. + * + * @return void + */ +__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_barrier_all( + rocshmem_ctx_t ctx); + +__device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all(); + /** * @brief perform a collective barrier between all PEs in the system. * The caller is blocked until the barrier is resolved. @@ -515,6 +545,42 @@ __device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all(); */ __device__ void rocshmem_barrier(rocshmem_team_t); +/** + * @brief registers the arrival of a PE at a barrier. + * The caller is blocked until the synchronization is resolved. + * + * In contrast with the shmem_barrier_all routine, shmem_sync_all only ensures + * completion and visibility of previously issued memory stores and does not + * ensure completion of remote memory updates issued via OpenSHMEM routines. + * + * This function must be invoked by a single thread within the PE. + * + * @param[in] handle GPU side handle. + * + * @return void + */ +__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx); + +__device__ ATTR_NO_INLINE void rocshmem_sync_all(); + +/** + * @brief registers the arrival of a PE at a barrier. + * The caller is blocked until the synchronization is resolved. + * + * In contrast with the shmem_barrier_all routine, shmem_sync_all only ensures + * completion and visibility of previously issued memory stores and does not + * ensure completion of remote memory updates issued via OpenSHMEM routines. + * + * This function must be called as a wave-front collective. + * + * @param[in] handle GPU side handle. + * + * @return void + */ +__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx); + +__device__ ATTR_NO_INLINE void rocshmem_wave_sync_all(); + /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index 684e130262..0c8994fb55 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -81,6 +81,16 @@ declare -A TEST_NUMBERS=( ["wgsignalfetch"]="56" ["wavesignalfetch"]="57" ["teambarrier"]="58" + ["defaultctxget"]="59" + ["defaultctxgetnbi"]="60" + ["defaultctxput"]="61" + ["defaultctxputnbi"]="62" + ["defaultctxp"]="63" + ["defaultctxg"]="64" + ["wavebarrierall"]="65" + ["wgbarrierall"]="66" + ["wavesyncall"]="67" + ["wgsyncall"]="68" ) ExecTest() { @@ -303,11 +313,38 @@ TestColl() { # | Name | Ranks | Workgroups | Threads | Max Message Size # ############################################################################## ExecTest "barrierall" 2 1 1 + ExecTest "barrierall" 2 16 64 + ExecTest "barrierall" 2 32 256 + ExecTest "barrierall" 2 64 1024 + + ExecTest "wavebarrierall" 2 1 1 + ExecTest "wavebarrierall" 2 16 64 + ExecTest "wavebarrierall" 2 32 256 + ExecTest "wavebarrierall" 2 64 1024 + + ExecTest "wgbarrierall" 2 1 1 + ExecTest "wgbarrierall" 2 16 64 + ExecTest "wgbarrierall" 2 32 256 + ExecTest "wgbarrierall" 2 64 1024 + ExecTest "teambarrier" 2 1 1 ExecTest "sync" 2 1 1 ExecTest "syncall" 2 1 1 + ExecTest "syncall" 2 16 64 + ExecTest "syncall" 2 32 256 + ExecTest "syncall" 2 64 1024 + + ExecTest "wavesyncall" 2 1 1 + ExecTest "wavesyncall" 2 16 64 + ExecTest "wavesyncall" 2 32 256 + ExecTest "wavesyncall" 2 64 1024 + + ExecTest "wgsyncall" 2 1 1 + ExecTest "wgsyncall" 2 16 64 + ExecTest "wgsyncall" 2 32 256 + ExecTest "wgsyncall" 2 64 1024 ExecTest "alltoall" 2 1 1 512 diff --git a/projects/rocshmem/src/backend_bc.cpp b/projects/rocshmem/src/backend_bc.cpp index d8c3ecaac1..72dc366288 100644 --- a/projects/rocshmem/src/backend_bc.cpp +++ b/projects/rocshmem/src/backend_bc.cpp @@ -130,6 +130,8 @@ void Backend::dump_stats() { printf("Quiets %llu\n", device_stats.getStat(NUM_QUIET)); printf("ToAll %llu\n", device_stats.getStat(NUM_TO_ALL)); printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL)); + printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE)); + printf("WG_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WG)); printf("Wait Until %llu\n", device_stats.getStat(NUM_WAIT_UNTIL)); printf("Wait Until Any %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ANY)); printf("Wait Until All %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ALL)); @@ -153,6 +155,8 @@ void Backend::dump_stats() { printf("Tests %llu\n", device_stats.getStat(NUM_TEST)); printf("SHMEM_PTR %llu\n", device_stats.getStat(NUM_SHMEM_PTR)); printf("SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL)); + printf("WAVE_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WAVE)); + printf("WG_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WG)); const auto& host_stats{globalHostStats}; printf("HOST STATS\n"); diff --git a/projects/rocshmem/src/context.hpp b/projects/rocshmem/src/context.hpp index 8d04c8687f..22d4cda4e8 100644 --- a/projects/rocshmem/src/context.hpp +++ b/projects/rocshmem/src/context.hpp @@ -137,11 +137,19 @@ class Context { __device__ void barrier_all(); + __device__ void barrier_all_wave(); + + __device__ void barrier_all_wg(); + __device__ void barrier(rocshmem_team_t team); __device__ void sync_all(); - __device__ void sync(rocshmem_team_t team); + __device__ void sync_all_wave(); + + __device__ void sync_all_wg(); + + __device__ void sync_wg(rocshmem_team_t team); template __device__ T amo_fetch(void* dst, T value, T cond, int pe, uint8_t atomic_op); diff --git a/projects/rocshmem/src/context_device.cpp b/projects/rocshmem/src/context_device.cpp index ea1ae4494f..c7c9d29583 100644 --- a/projects/rocshmem/src/context_device.cpp +++ b/projects/rocshmem/src/context_device.cpp @@ -148,6 +148,18 @@ __device__ void Context::barrier_all() { DISPATCH(barrier_all()); } +__device__ void Context::barrier_all_wave() { + ctxStats.incStat(NUM_BARRIER_ALL_WAVE); + + DISPATCH(barrier_all_wave()); +} + +__device__ void Context::barrier_all_wg() { + ctxStats.incStat(NUM_BARRIER_ALL_WG); + + DISPATCH(barrier_all_wg()); +} + __device__ void Context::barrier(rocshmem_team_t team) { ctxStats.incStat(NUM_BARRIER_ALL); @@ -160,10 +172,22 @@ __device__ void Context::sync_all() { DISPATCH(sync_all()); } -__device__ void Context::sync(rocshmem_team_t team) { - ctxStats.incStat(NUM_SYNC_ALL); +__device__ void Context::sync_all_wave() { + ctxStats.incStat(NUM_SYNC_ALL_WAVE); - DISPATCH(sync(team)); + DISPATCH(sync_all_wave()); +} + +__device__ void Context::sync_all_wg() { + ctxStats.incStat(NUM_SYNC_ALL_WG); + + DISPATCH(sync_all_wg()); +} + +__device__ void Context::sync_wg(rocshmem_team_t team) { + ctxStats.incStat(NUM_SYNC_ALL_WG); + + DISPATCH(sync_wg(team)); } __device__ void Context::putmem_wg(void* dest, const void* source, diff --git a/projects/rocshmem/src/ipc/backend_ipc.cpp b/projects/rocshmem/src/ipc/backend_ipc.cpp index 172472b474..387e3c27e2 100644 --- a/projects/rocshmem/src/ipc/backend_ipc.cpp +++ b/projects/rocshmem/src/ipc/backend_ipc.cpp @@ -114,8 +114,9 @@ IPCBackend::~IPCBackend() { void IPCBackend::setup_ctxs() { CHECK_HIP(hipMalloc(&ctx_array, sizeof(IPCContext) * maximum_num_contexts_)); + // 0th context is default context for (size_t i = 0; i < maximum_num_contexts_; i++) { - new (&ctx_array[i]) IPCContext(this); + new (&ctx_array[i]) IPCContext(this, i + 1); ctx_free_list.get()->push_back(ctx_array + i); } } @@ -278,9 +279,10 @@ void IPCBackend::init_wrk_sync_buffer() { auto max_num_teams{team_tracker.get_max_num_teams()}; /** - * size of barrier sync + * size of barrier sync for all the contexts */ - Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE; + Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE * + (maximum_num_contexts_ + 1); /** * Size of sync arrays for the teams @@ -378,15 +380,18 @@ void IPCBackend::rocshmem_collective_init() { /* * Allocate heap space for barrier_sync */ - size_t one_sync_size_bytes{sizeof(*barrier_sync)}; - size_t sync_size_bytes{one_sync_size_bytes * ROCSHMEM_BARRIER_SYNC_SIZE}; + size_t one_sync_size_bytes {sizeof(*barrier_sync)}; + size_t total_sync_elems { + ROCSHMEM_BARRIER_SYNC_SIZE * (maximum_num_contexts_ + 1)}; + size_t sync_size_bytes {one_sync_size_bytes * total_sync_elems}; + barrier_sync = reinterpret_cast(temp_Wrk_Sync_buff_ptr_); temp_Wrk_Sync_buff_ptr_ += sync_size_bytes; /* * Initialize the barrier synchronization array with default values. */ - for (int i = 0; i < num_pes; i++) { + for (int i = 0; i < total_sync_elems; i++) { barrier_sync[i] = ROCSHMEM_SYNC_VALUE; } diff --git a/projects/rocshmem/src/ipc/context_ipc_device.cpp b/projects/rocshmem/src/ipc/context_ipc_device.cpp index fac91c899b..7abc72c4d3 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.cpp @@ -36,15 +36,18 @@ namespace rocshmem { -__host__ IPCContext::IPCContext(Backend *b) +__host__ IPCContext::IPCContext(Backend *b, unsigned int ctx_id) : Context(b, false) { IPCBackend *backend{static_cast(b)}; ipcImpl_.ipc_bases = b->ipcImpl.ipc_bases; ipcImpl_.shm_size = b->ipcImpl.shm_size; - barrier_sync = backend->barrier_sync; + size_t barrier_sync_offset = ctx_id * ROCSHMEM_BARRIER_SYNC_SIZE; + + barrier_sync = backend->barrier_sync + barrier_sync_offset; fence_pool = backend->fence_pool; Wrk_Sync_buffer_bases_ = backend->get_wrk_sync_bases(); + ctx_id_ = ctx_id; orders_.store = detail::atomic::rocshmem_memory_order::memory_order_seq_cst; } diff --git a/projects/rocshmem/src/ipc/context_ipc_device.hpp b/projects/rocshmem/src/ipc/context_ipc_device.hpp index 92abb76c2d..ac7f7a8154 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.hpp @@ -31,9 +31,9 @@ namespace rocshmem { class IPCContext : public Context { public: - __host__ IPCContext(Backend *b); + __host__ IPCContext(Backend *b, unsigned int ctx_id); - __device__ IPCContext(Backend *b); + __device__ IPCContext(Backend *b, unsigned int ctx_id); __device__ void threadfence_system(); @@ -61,11 +61,19 @@ class IPCContext : public Context { __device__ void barrier_all(); + __device__ void barrier_all_wave(); + + __device__ void barrier_all_wg(); + __device__ void barrier(rocshmem_team_t team); __device__ void sync_all(); - __device__ void sync(rocshmem_team_t team); + __device__ void sync_all_wave(); + + __device__ void sync_all_wg(); + + __device__ void sync_wg(rocshmem_team_t team); template __device__ void p(T *dest, T value, int pe); @@ -240,6 +248,12 @@ class IPCContext : public Context { __device__ void internal_sync(int pe, int PE_start, int stride, int PE_size, int64_t *pSync); + __device__ void internal_sync_wave(int pe, int PE_start, int stride, int PE_size, + int64_t *pSync); + + __device__ void internal_sync_wg(int pe, int PE_start, int stride, int PE_size, + int64_t *pSync); + __device__ void internal_direct_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); @@ -289,6 +303,11 @@ class IPCContext : public Context { */ char **Wrk_Sync_buffer_bases_{nullptr}; + /** + * @brief Decive context Id + */ + unsigned int ctx_id_{}; + public: //TODO(Avinash): //Make tinfo private variable, it requires changes to the context diff --git a/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp b/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp index 8ba18a8ae8..6191ad0e8f 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_device_coll.cpp @@ -84,9 +84,29 @@ __device__ void IPCContext::internal_atomic_barrier(int pe, int PE_start, } } -// Uses PE values that are relative to world __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride, int PE_size, int64_t *pSync) { + if (PE_size < 64) { + internal_direct_barrier(pe, PE_start, stride, PE_size, pSync); + } else { + internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync); + } +} + +__device__ void IPCContext::internal_sync_wave(int pe, int PE_start, int stride, + int PE_size, int64_t *pSync) { + if (is_thread_zero_in_wave()) { + if (PE_size < 64) { + internal_direct_barrier(pe, PE_start, stride, PE_size, pSync); + } else { + internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync); + } + } +} + +// Uses PE values that are relative to world +__device__ void IPCContext::internal_sync_wg(int pe, int PE_start, int stride, + int PE_size, int64_t *pSync) { __syncthreads(); if (is_thread_zero_in_block()) { if (PE_size < 64) { @@ -98,7 +118,7 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride, __syncthreads(); } -__device__ void IPCContext::sync(rocshmem_team_t team) { +__device__ void IPCContext::sync_wg(rocshmem_team_t team) { IPCTeam *team_obj = reinterpret_cast(team); int pe = team_obj->my_pe_in_world; @@ -107,18 +127,38 @@ __device__ void IPCContext::sync(rocshmem_team_t team) { int pe_size = team_obj->num_pes; long *p_sync = team_obj->barrier_pSync; - internal_sync(pe, pe_start, pe_stride, pe_size, p_sync); + internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync); } __device__ void IPCContext::sync_all() { internal_sync(my_pe, 0, 1, num_pes, barrier_sync); } +__device__ void IPCContext::sync_all_wave() { + internal_sync_wave(my_pe, 0, 1, num_pes, barrier_sync); +} + +__device__ void IPCContext::sync_all_wg() { + internal_sync_wg(my_pe, 0, 1, num_pes, barrier_sync); +} + __device__ void IPCContext::barrier_all() { + quiet(); + sync_all(); +} + +__device__ void IPCContext::barrier_all_wave() { + if (is_thread_zero_in_wave()) { + quiet(); + } + sync_all_wave(); +} + +__device__ void IPCContext::barrier_all_wg() { if (is_thread_zero_in_block()) { quiet(); } - sync_all(); + sync_all_wg(); __syncthreads(); } @@ -134,7 +174,7 @@ __device__ void IPCContext::barrier(rocshmem_team_t team) { if (is_thread_zero_in_block()) { quiet(); } - internal_sync(pe, pe_start, pe_stride, pe_size, p_sync); + internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync); __syncthreads(); } diff --git a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp index 51fd8ab76d..0140e490e3 100644 --- a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp @@ -468,7 +468,7 @@ __device__ void IPCContext::internal_broadcast(T *dst, const T *src, int nelems, } // Synchronize on completion of broadcast - internal_sync(my_pe, pe_start, stride, pe_size, p_sync); + internal_sync_wg(my_pe, pe_start, stride, pe_size, p_sync); } template @@ -497,7 +497,7 @@ __device__ void IPCContext::alltoall_linear(rocshmem_team_t team, T *dst, quiet(); } // wait until everyone has obtained their designated data - internal_sync(my_pe, pe_start, stride, pe_size, pSync); + internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync); } template @@ -527,7 +527,7 @@ __device__ void IPCContext::fcollect_linear(rocshmem_team_t team, T *dst, quiet(); } // wait until everyone has obtained their designated data - internal_sync(my_pe, pe_start, stride, pe_size, pSync); + internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync); } // Block/wave functions diff --git a/projects/rocshmem/src/ipc/ipc_context_proxy.hpp b/projects/rocshmem/src/ipc/ipc_context_proxy.hpp index da8f587e8d..1222f799a8 100644 --- a/projects/rocshmem/src/ipc/ipc_context_proxy.hpp +++ b/projects/rocshmem/src/ipc/ipc_context_proxy.hpp @@ -45,7 +45,7 @@ class IPCDefaultContextProxy { size_t num_elems = 1) : constructed_{true}, proxy_{num_elems} { auto ctx{proxy_.get()}; - new (ctx) IPCContext(reinterpret_cast(backend)); + new (ctx) IPCContext(reinterpret_cast(backend), 0); ctx->tinfo = tinfo; rocshmem_ctx_t local{ctx, tinfo}; set_internal_ctx(&local); diff --git a/projects/rocshmem/src/reverse_offload/context_ro_device.cpp b/projects/rocshmem/src/reverse_offload/context_ro_device.cpp index bad1028358..451fb4209d 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_device.cpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_device.cpp @@ -170,6 +170,20 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) { } __device__ void ROContext::barrier_all() { + build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, + nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, + block_handle, true, get_status_flag(), is_default_ctx); +} + +__device__ void ROContext::barrier_all_wave() { + if (is_thread_zero_in_wave()) { + build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, + nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, + block_handle, true, get_status_flag(), is_default_ctx); + } +} + +__device__ void ROContext::barrier_all_wg() { if (is_thread_zero_in_block()) { build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, @@ -189,6 +203,20 @@ __device__ void ROContext::barrier(rocshmem_team_t team) { } __device__ void ROContext::sync_all() { + build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, + nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, + block_handle, true, get_status_flag(), is_default_ctx); +} + +__device__ void ROContext::sync_all_wave() { + if (is_thread_zero_in_wave()) { + build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, + nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, + block_handle, true, get_status_flag(), is_default_ctx); + } +} + +__device__ void ROContext::sync_all_wg() { if (is_thread_zero_in_block()) { build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id, @@ -197,7 +225,7 @@ __device__ void ROContext::sync_all() { __syncthreads(); } -__device__ void ROContext::sync(rocshmem_team_t team) { +__device__ void ROContext::sync_wg(rocshmem_team_t team) { ROTeam *team_obj = reinterpret_cast(team); if (is_thread_zero_in_block()) { build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr, diff --git a/projects/rocshmem/src/reverse_offload/context_ro_device.hpp b/projects/rocshmem/src/reverse_offload/context_ro_device.hpp index dfda274caa..aa3b675f2e 100644 --- a/projects/rocshmem/src/reverse_offload/context_ro_device.hpp +++ b/projects/rocshmem/src/reverse_offload/context_ro_device.hpp @@ -65,11 +65,19 @@ class ROContext : public Context { __device__ void barrier_all(); + __device__ void barrier_all_wave(); + + __device__ void barrier_all_wg(); + __device__ void barrier(rocshmem_team_t team); __device__ void sync_all(); - __device__ void sync(rocshmem_team_t team); + __device__ void sync_all_wave(); + + __device__ void sync_all_wg(); + + __device__ void sync_wg(rocshmem_team_t team); template __device__ void p(T *dest, T value, int pe); diff --git a/projects/rocshmem/src/rocshmem_gpu.cpp b/projects/rocshmem/src/rocshmem_gpu.cpp index 07f84c60d4..f3a52d668b 100644 --- a/projects/rocshmem/src/rocshmem_gpu.cpp +++ b/projects/rocshmem/src/rocshmem_gpu.cpp @@ -570,12 +570,24 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) { return ctx_internal->test(ivars, cmp, val); } -__device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) { +__device__ void rocshmem_ctx_barrier_all(rocshmem_ctx_t ctx) { GPU_DPRINTF("Function: rocshmem_ctx_barrier_all\n"); get_internal_ctx(ctx)->barrier_all(); } +__device__ void rocshmem_ctx_wave_barrier_all(rocshmem_ctx_t ctx) { + GPU_DPRINTF("Function: rocshmem_ctx_wave_barrier_all\n"); + + get_internal_ctx(ctx)->barrier_all_wave(); +} + +__device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) { + GPU_DPRINTF("Function: rocshmem_ctx_wg_barrier_all\n"); + + get_internal_ctx(ctx)->barrier_all_wg(); +} + __device__ void rocshmem_wg_barrier_all() { rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT); } @@ -586,12 +598,32 @@ __device__ void rocshmem_barrier(rocshmem_team_t team) { get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team); } -__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) { +__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) { GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n"); get_internal_ctx(ctx)->sync_all(); } +__device__ void rocshmem_sync_all() { + rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT); +} + +__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) { + GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n"); + + get_internal_ctx(ctx)->sync_all_wave(); +} + +__device__ void rocshmem_wave_sync_all() { + rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT); +} + +__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) { + GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n"); + + get_internal_ctx(ctx)->sync_all_wg(); +} + __device__ void rocshmem_wg_sync_all() { rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT); } @@ -600,7 +632,7 @@ __device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx, rocshmem_team_t team) { GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n"); - get_internal_ctx(ctx)->sync(team); + get_internal_ctx(ctx)->sync_wg(team); } __device__ void rocshmem_wg_team_sync(rocshmem_team_t team) { diff --git a/projects/rocshmem/src/stats.hpp b/projects/rocshmem/src/stats.hpp index a4fe1937ef..ab4d5eb785 100644 --- a/projects/rocshmem/src/stats.hpp +++ b/projects/rocshmem/src/stats.hpp @@ -43,6 +43,8 @@ enum rocshmem_stats { NUM_QUIET, NUM_TO_ALL, NUM_BARRIER_ALL, + NUM_BARRIER_ALL_WAVE, + NUM_BARRIER_ALL_WG, NUM_WAIT_UNTIL, NUM_WAIT_UNTIL_ANY, NUM_WAIT_UNTIL_ALL, @@ -70,6 +72,8 @@ enum rocshmem_stats { NUM_TEST, NUM_SHMEM_PTR, NUM_SYNC_ALL, + NUM_SYNC_ALL_WAVE, + NUM_SYNC_ALL_WG, NUM_BROADCAST, NUM_PUT_WG, NUM_PUT_NBI_WG, diff --git a/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp b/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp index 49f011a3d5..dfd5811726 100644 --- a/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/barrier_all_tester.cpp @@ -30,9 +30,12 @@ using namespace rocshmem; * DEVICE TEST KERNEL *****************************************************************************/ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time, - long long int *end_time) { + long long int *end_time, TestType type, + int wf_size) { __shared__ rocshmem_ctx_t ctx; + int t_id = get_flat_block_id(); int wg_id = get_flat_grid_id(); + int wf_id = t_id / wf_size; rocshmem_wg_init(); rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx); @@ -42,17 +45,25 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time, start_time[wg_id] = wall_clock64(); } - __syncthreads(); - - /** - * The function `rocshmem_ctx_wg_barrier_all` should be called from only - * one group within the grid to avoid unintended behavior. - */ - if (is_block_zero_in_grid()) { - rocshmem_ctx_wg_barrier_all(ctx); + switch (type) { + case BarrierAllTestType: + if(t_id == 0) { + rocshmem_ctx_barrier_all(ctx); + } + break; + case WAVEBarrierAllTestType: + if(wf_id == 0) { + rocshmem_ctx_wave_barrier_all(ctx); + } + break; + case WGBarrierAllTestType: + rocshmem_ctx_wg_barrier_all(ctx); + break; + default: + break; } + __syncthreads(); } - __syncthreads(); if (hipThreadIdx_x == 0) { end_time[wg_id] = wall_clock64(); @@ -74,10 +85,10 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, size_t shared_bytes = 0; hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream, - loop, args.skip, start_time, end_time); + loop, args.skip, start_time, end_time, _type, wf_size); - num_msgs = loop + args.skip; - num_timed_msgs = loop; + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; } void BarrierAllTester::resetBuffers(uint64_t size) {} diff --git a/projects/rocshmem/tests/functional_tests/sync_all_tester.cpp b/projects/rocshmem/tests/functional_tests/sync_all_tester.cpp new file mode 100644 index 0000000000..8a02db4d54 --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/sync_all_tester.cpp @@ -0,0 +1,96 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#include "sync_all_tester.hpp" + +#include + +using namespace rocshmem; + +/****************************************************************************** + * DEVICE TEST KERNEL + *****************************************************************************/ +__global__ void SyncAllTest(int loop, int skip, long long int *start_time, + long long int *end_time, TestType type, + int wf_size) { + __shared__ rocshmem_ctx_t ctx; + int t_id = get_flat_block_id(); + int wg_id = get_flat_grid_id(); + int wf_id = t_id / wf_size; + + rocshmem_wg_init(); + rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx); + + for (int i = 0; i < loop + skip; i++) { + if (hipThreadIdx_x == 0 && i == skip) { + start_time[wg_id] = wall_clock64(); + } + + switch (type) { + case SyncAllTestType: + if(t_id == 0) { + rocshmem_ctx_sync_all(ctx); + } + break; + case WAVESyncAllTestType: + if(wf_id == 0) { + rocshmem_ctx_wave_sync_all(ctx); + } + break; + case WGSyncAllTestType: + rocshmem_ctx_wg_sync_all(ctx); + break; + default: + break; + } + __syncthreads(); + } + + if (hipThreadIdx_x == 0) { + end_time[wg_id] = wall_clock64(); + } + + rocshmem_wg_ctx_destroy(&ctx); + rocshmem_wg_finalize(); +} + +/****************************************************************************** + * HOST TESTER CLASS METHODS + *****************************************************************************/ +SyncAllTester::SyncAllTester(TesterArguments args) : Tester(args) {} + +SyncAllTester::~SyncAllTester() {} + +void SyncAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, + uint64_t size) { + size_t shared_bytes = 0; + + hipLaunchKernelGGL(SyncAllTest, gridSize, blockSize, shared_bytes, stream, + loop, args.skip, start_time, end_time, _type, wf_size); + + num_msgs = (loop + args.skip) * gridSize.x; + num_timed_msgs = loop * gridSize.x; +} + +void SyncAllTester::resetBuffers(uint64_t size) {} + +void SyncAllTester::verifyResults(uint64_t size) {} diff --git a/projects/rocshmem/tests/functional_tests/sync_all_tester.hpp b/projects/rocshmem/tests/functional_tests/sync_all_tester.hpp new file mode 100644 index 0000000000..8062b37cda --- /dev/null +++ b/projects/rocshmem/tests/functional_tests/sync_all_tester.hpp @@ -0,0 +1,50 @@ +/****************************************************************************** + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + *****************************************************************************/ + +#ifndef _BARRIER_ALL_TESTER_HPP_ +#define _BARRIER_ALL_TESTER_HPP_ + +#include "tester.hpp" + +/****************************************************************************** + * DEVICE TEST KERNEL + *****************************************************************************/ +__global__ void SyncAllTest(TestType type); + +/****************************************************************************** + * HOST TESTER CLASS + *****************************************************************************/ +class SyncAllTester : public Tester { + public: + explicit SyncAllTester(TesterArguments args); + virtual ~SyncAllTester(); + + protected: + virtual void resetBuffers(uint64_t size) override; + + virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop, + uint64_t size) override; + + virtual void verifyResults(uint64_t size) override; +}; + +#endif diff --git a/projects/rocshmem/tests/functional_tests/sync_tester.cpp b/projects/rocshmem/tests/functional_tests/sync_tester.cpp index 1843cccc35..fd5c67fce9 100644 --- a/projects/rocshmem/tests/functional_tests/sync_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/sync_tester.cpp @@ -41,15 +41,6 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time, __syncthreads(); switch (type) { - case SyncAllTestType: - /** - * The function `rocshmem_ctx_wg_sync_all` should be called from only - * one group within the grid to avoid unintended behavior. - */ - if (is_block_zero_in_grid()) { - rocshmem_ctx_wg_sync_all(ctx); - } - break; case SyncTestType: rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]); break; diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index bd0f58baa2..d5b6dea7a6 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -43,10 +43,11 @@ #include "random_access_tester.hpp" #include "shmem_ptr_tester.hpp" #include "signaling_operations_tester.hpp" +#include "sync_all_tester.hpp" #include "sync_tester.hpp" #include "team_alltoall_tester.hpp" -#include "team_broadcast_tester.hpp" #include "team_barrier_tester.hpp" +#include "team_broadcast_tester.hpp" #include "team_ctx_infra_tester.hpp" #include "team_ctx_primitive_tester.hpp" #include "team_fcollect_tester.hpp" @@ -317,6 +318,14 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "Barrier_All ###" << std::endl; testers.push_back(new BarrierAllTester(args)); return testers; + case WAVEBarrierAllTestType: + if (rank == 0) std::cout << "WAVE Barrier_All ###" << std::endl; + testers.push_back(new BarrierAllTester(args)); + return testers; + case WGBarrierAllTestType: + if (rank == 0) std::cout << "WG Barrier_All ###" << std::endl; + testers.push_back(new BarrierAllTester(args)); + return testers; case TeamBarrierTestType: if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl; testers.push_back(new TeamBarrierTester(args)); @@ -325,6 +334,14 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "SyncAll ###" << std::endl; testers.push_back(new SyncTester(args)); return testers; + case WAVESyncAllTestType: + if (rank == 0) std::cout << "WAVE SyncAll ###" << std::endl; + testers.push_back(new SyncTester(args)); + return testers; + case WGSyncAllTestType: + if (rank == 0) std::cout << "WG SyncAll ###" << std::endl; + testers.push_back(new SyncTester(args)); + return testers; case SyncTestType: if (rank == 0) std::cout << "Sync ###" << std::endl; testers.push_back(new SyncTester(args)); @@ -510,7 +527,9 @@ bool Tester::peLaunchesKernel() { (_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) || (_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) || (_type == PingPongTestType) || (_type == BarrierAllTestType) || + (_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) || (_type == SyncTestType) || (_type == SyncAllTestType) || + (_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) || (_type == RandomAccessTestType) || (_type == PingAllTestType) || (_type == TeamBarrierTestType); diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index fe691bd1db..3a057f23ee 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -99,6 +99,10 @@ enum TestType { DefaultCTXPutNBITestType = 62, DefaultCTXPTestType = 63, DefaultCTXGTestType = 64, + WAVEBarrierAllTestType = 65, + WGBarrierAllTestType = 66, + WAVESyncAllTestType = 67, + WGSyncAllTestType = 68, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index c8d8dea516..1a3eaac6a7 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -85,8 +85,12 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case AMO_IncTestType: case AMO_FetchTestType: case BarrierAllTestType: + case WAVEBarrierAllTestType: + case WGBarrierAllTestType: case TeamBarrierTestType: case SyncAllTestType: + case WAVESyncAllTestType: + case WGSyncAllTestType: case SyncTestType: case ShmemPtrTestType: min_msg_size = 8; @@ -133,7 +137,9 @@ void TesterArguments::get_rocshmem_arguments() { myid = rocshmem_my_pe(); TestType type = (TestType)algorithm; - if ((type != BarrierAllTestType) && (type != SyncAllTestType) && + if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) && + (type != WGBarrierAllTestType) && (type != SyncAllTestType) && + (type != WAVESyncAllTestType) && (type != WGSyncAllTestType) && (type != SyncTestType) && (type != TeamAllToAllTestType) && (type != TeamFCollectTestType) && (type != TeamReductionTestType) && (type != TeamBroadcastTestType) && (type != PingAllTestType) &&