From 2f8a1c02a4c8ebc4be0225becf2bf434be323462 Mon Sep 17 00:00:00 2001 From: Yiltan Date: Fri, 31 Oct 2025 10:26:24 -0400 Subject: [PATCH] [GDA] Implement internal_direct_barrier_wg (#299) [ROCm/rocshmem commit: 5f87bb061be6dd9b5806b5826af9c536ab24f308] --- .../rocshmem/src/gda/context_gda_device.hpp | 3 + .../src/gda/context_gda_device_coll.cpp | 59 +++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/projects/rocshmem/src/gda/context_gda_device.hpp b/projects/rocshmem/src/gda/context_gda_device.hpp index f076cbab00..f00cb4b0a7 100644 --- a/projects/rocshmem/src/gda/context_gda_device.hpp +++ b/projects/rocshmem/src/gda/context_gda_device.hpp @@ -267,6 +267,9 @@ class GDAContext : public Context { __device__ void internal_direct_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); + __device__ void internal_direct_barrier_wg(int pe, int PE_start, int stride, + int n_pes, int64_t *pSync); + __device__ void internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); diff --git a/projects/rocshmem/src/gda/context_gda_device_coll.cpp b/projects/rocshmem/src/gda/context_gda_device_coll.cpp index be224fb302..7ae16341d7 100644 --- a/projects/rocshmem/src/gda/context_gda_device_coll.cpp +++ b/projects/rocshmem/src/gda/context_gda_device_coll.cpp @@ -70,6 +70,57 @@ __device__ void GDAContext::internal_direct_barrier(int pe, int PE_start, } } +__device__ void GDAContext::internal_direct_barrier_wg(int pe, int PE_start, + int stride, int n_pes, + int64_t *pSync) { + int64_t flag_val{1}; + + if (pe == PE_start) { + int wf_id = get_flat_block_id() / WF_SIZE; + int wf_count = (int) ceil((double)get_flat_block_size() / (double)WF_SIZE); + bool wf_leader = 0 == get_active_lane_num(); + + // Go through all PE offsets (except current offset = 0) + // and wait until they all reach + if (wf_leader) { + for (int j = wf_id + 1; j < n_pes; j+= wf_count) { + wait_until(&pSync[j], ROCSHMEM_CMP_EQ, flag_val); + pSync[j] = ROCSHMEM_SYNC_VALUE; + } + } + + __syncthreads(); + + // Announce to other PEs that all have reached + for (int i = wf_id + 1, j = PE_start + stride + wf_id; + i < n_pes; + i+= wf_count, j += (wf_count * stride)) { + put_nbi_wave(&pSync[0], &flag_val, 1, j); + } + + for (int i = wf_id + 1, j = PE_start + stride + wf_id; + i < n_pes; + i+= wf_count, j += (wf_count * stride)) { + pe_quiet(j); + } + + __syncthreads(); + + if (is_thread_zero_in_block()) { + pSync[0] = ROCSHMEM_SYNC_VALUE; + } + } else { + if (is_thread_zero_in_block()) { + // Mark current PE offset as reached + size_t pe_offset = (pe - PE_start) / stride; + put(&pSync[pe_offset], &flag_val, 1, PE_start); + wait_until(&pSync[0], ROCSHMEM_CMP_EQ, flag_val); + pSync[0] = ROCSHMEM_SYNC_VALUE; + __threadfence_system(); + } + } +} + __device__ void GDAContext::internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync) { @@ -116,10 +167,10 @@ __device__ void GDAContext::internal_sync_wave(int pe, int PE_start, int stride, __device__ void GDAContext::internal_sync_wg(int pe, int PE_start, int stride, int PE_size, int64_t *pSync) { __syncthreads(); - if (is_thread_zero_in_block()) { - if (PE_size < 64) { - internal_direct_barrier(pe, PE_start, stride, PE_size, pSync); - } else { + if (PE_size < 64) { + internal_direct_barrier_wg(pe, PE_start, stride, PE_size, pSync); + } else { + if (is_thread_zero_in_block()) { internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync); } }