From e8a737100789cf6817cba549ff2b205172251e53 Mon Sep 17 00:00:00 2001 From: Allen Hubbe Date: Wed, 5 Nov 2025 12:01:14 -0500 Subject: [PATCH] gda ionic: use all threads in wave operations (#295) Use all available threads for polling the cq to increase the maximum message rate. Even when posting a single wqe in the wave, use all available theads for polling the cq to reserve space in the sq. Changes were needed in the rocshmem abstraction to avoid disabling gpu threads, like taking turns or using only the first thread in a wave or wavefront. To avoid breaking other gda implementations, reimplement turn-based or single thread strategy in post_wqe_rma_turn and post_wqe_rma_single. Signed-off-by: Allen Hubbe [ROCm/rocshmem commit: 6de67d5d7c270d9867bafed3897f0d771fd08cd9] --- .../rocshmem/src/gda/context_gda_device.cpp | 94 +++++-------------- .../rocshmem/src/gda/context_gda_device.hpp | 2 + .../src/gda/context_gda_device_coll.cpp | 16 ++-- .../src/gda/ionic/queue_pair_ionic.cpp | 18 +++- projects/rocshmem/src/gda/queue_pair.cpp | 57 ++++++++--- projects/rocshmem/src/gda/queue_pair.hpp | 14 ++- projects/rocshmem/src/util.hpp | 7 ++ .../functional_tests/wavefront_primitives.cpp | 10 +- 8 files changed, 114 insertions(+), 104 deletions(-) diff --git a/projects/rocshmem/src/gda/context_gda_device.cpp b/projects/rocshmem/src/gda/context_gda_device.cpp index 869a0a1d9b..5cba9c9588 100644 --- a/projects/rocshmem/src/gda/context_gda_device.cpp +++ b/projects/rocshmem/src/gda/context_gda_device.cpp @@ -77,18 +77,8 @@ __device__ void GDAContext::putmem(void *dest, const void *source, size_t nelems return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - qps[pe].quiet(); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); + qps[pe].quiet(); } __device__ void GDAContext::getmem(void *dest, const void *source, size_t nelems, @@ -101,18 +91,8 @@ __device__ void GDAContext::getmem(void *dest, const void *source, size_t nelems return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - qps[pe].quiet(); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); + qps[pe].quiet(); } __device__ void GDAContext::putmem_nbi(void *dest, const void *source, @@ -124,17 +104,7 @@ __device__ void GDAContext::putmem_nbi(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); } __device__ void GDAContext::getmem_nbi(void *dest, const void *source, @@ -147,17 +117,7 @@ __device__ void GDAContext::getmem_nbi(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); } __device__ void GDAContext::fence() { //TODO: optimize @@ -177,6 +137,12 @@ __device__ void GDAContext::quiet() { } } +__device__ void GDAContext::quiet_wave() { + for (int i = 0; i < num_pes; i++) { + qps[i].quiet(QueuePair::WAVE); + } +} + __device__ void GDAContext::pe_quiet(size_t pe) { qps[pe].quiet(); } @@ -201,8 +167,8 @@ __device__ void GDAContext::putmem_wg(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); qps[pe].quiet(); } } @@ -217,8 +183,8 @@ __device__ void GDAContext::getmem_wg(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe, QueuePair::WAVE); qps[pe].quiet(); } } @@ -232,8 +198,8 @@ __device__ void GDAContext::putmem_nbi_wg(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); } } @@ -247,8 +213,8 @@ __device__ void GDAContext::getmem_nbi_wg(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].get_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); } } @@ -261,10 +227,8 @@ __device__ void GDAContext::putmem_wave(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - qps[pe].quiet(); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); + qps[pe].quiet(); } __device__ void GDAContext::getmem_wave(void *dest, const void *source, @@ -277,10 +241,8 @@ __device__ void GDAContext::getmem_wave(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - qps[pe].quiet(); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe, QueuePair::WAVE); + qps[pe].quiet(); } __device__ void GDAContext::putmem_nbi_wave(void *dest, const void *source, @@ -292,9 +254,7 @@ __device__ void GDAContext::putmem_nbi_wave(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); } __device__ void GDAContext::getmem_nbi_wave(void *dest, const void *source, @@ -307,9 +267,7 @@ __device__ void GDAContext::getmem_nbi_wave(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe, QueuePair::WAVE); } diff --git a/projects/rocshmem/src/gda/context_gda_device.hpp b/projects/rocshmem/src/gda/context_gda_device.hpp index f00cb4b0a7..08cf232f28 100644 --- a/projects/rocshmem/src/gda/context_gda_device.hpp +++ b/projects/rocshmem/src/gda/context_gda_device.hpp @@ -60,6 +60,8 @@ class GDAContext : public Context { __device__ void quiet(); + __device__ void quiet_wave(); + __device__ void pe_quiet(size_t pe); __device__ void *shmem_ptr(const void *dest, int pe); diff --git a/projects/rocshmem/src/gda/context_gda_device_coll.cpp b/projects/rocshmem/src/gda/context_gda_device_coll.cpp index 7ae16341d7..7610e4b2c9 100644 --- a/projects/rocshmem/src/gda/context_gda_device_coll.cpp +++ b/projects/rocshmem/src/gda/context_gda_device_coll.cpp @@ -232,15 +232,13 @@ __device__ void GDAContext::barrier_all() { } __device__ void GDAContext::barrier_all_wave() { - if (is_thread_zero_in_wave()) { - quiet(); - } + quiet_wave(); sync_all_wave(); } __device__ void GDAContext::barrier_all_wg() { - if (is_thread_zero_in_block()) { - quiet(); + if (is_wave_zero_in_block()) { + quiet_wave(); } sync_all_wg(); __syncthreads(); @@ -268,9 +266,7 @@ __device__ void GDAContext::barrier_wave(rocshmem_team_t team) { int pe_size = team_obj->num_pes; long *p_sync = team_obj->barrier_pSync; - if (is_thread_zero_in_wave()) { - quiet(); - } + quiet_wave(); internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync); } @@ -283,8 +279,8 @@ __device__ void GDAContext::barrier_wg(rocshmem_team_t team) { int pe_size = team_obj->num_pes; long *p_sync = team_obj->barrier_pSync; - if (is_thread_zero_in_block()) { - quiet(); + if (is_wave_zero_in_block()) { + quiet_wave(); } internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync); __syncthreads(); diff --git a/projects/rocshmem/src/gda/ionic/queue_pair_ionic.cpp b/projects/rocshmem/src/gda/ionic/queue_pair_ionic.cpp index 59872a3d2e..9533ab21be 100644 --- a/projects/rocshmem/src/gda/ionic/queue_pair_ionic.cpp +++ b/projects/rocshmem/src/gda/ionic/queue_pair_ionic.cpp @@ -171,7 +171,8 @@ __device__ void QueuePair::ionic_quiet_internal(uint64_t activemask, uint32_t co } __device__ void QueuePair::ionic_ring_doorbell(uint32_t pos) { - // TODO When threads write at once to the same address, not all writes reach the bus. + // When threads write at once to the same address, not all writes reach the bus. + // Take turns and insert a thread fence between writes to the same address. for (int i = 0; i < 64; ++i) { if (__lane_id() == i) { __threadfence(); @@ -185,11 +186,22 @@ __device__ void QueuePair::ionic_quiet() { ionic_quiet_internal(get_same_qp_lane_mask(), sq_prod); } -__device__ void QueuePair::ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { +__device__ void QueuePair::ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy) { uint64_t activemask = get_same_qp_lane_mask(); - uint32_t num_wqes = get_active_lane_count(activemask); uint32_t my_logical_lane_id = get_active_lane_num(activemask); + uint32_t num_wqes = 1; + if (cy == THREAD) { + num_wqes = get_active_lane_count(activemask); + } + uint32_t my_sq_prod = reserve_sq(activemask, num_wqes); + if (cy == WAVE) { + if (!is_first_active_lane(activemask)) { + return; + } + activemask &= activemask ^ (activemask - 1); + } + uint32_t my_sq_pos = my_sq_prod + my_logical_lane_id; struct ionic_v1_wqe *wqe = &ionic_sq_buf[my_sq_pos & sq_mask]; uint16_t wqe_flags = 0; diff --git a/projects/rocshmem/src/gda/queue_pair.cpp b/projects/rocshmem/src/gda/queue_pair.cpp index c268ae99e6..729db6490b 100644 --- a/projects/rocshmem/src/gda/queue_pair.cpp +++ b/projects/rocshmem/src/gda/queue_pair.cpp @@ -118,7 +118,39 @@ QueuePair::~QueuePair() { /****************************************************************************** ************************ PROVIDER-SPECIFIC HELPERS *************************** *****************************************************************************/ -__device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { +__device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy) { + switch (gda_provider_) { +#if defined(GDA_IONIC) + case GDAProvider::IONIC: + ionic_post_wqe_rma(pe, size, laddr, raddr, opcode, cy); + return; +#endif + default: + post_wqe_rma_turn(pe, size, laddr, raddr, opcode, cy); + } +} + +__device__ void QueuePair::post_wqe_rma_turn(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy) { + if (cy == THREAD) { + bool need_turn {true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + post_wqe_rma_single(pe, size, laddr, raddr, opcode); + need_turn = false; + } + turns = __ballot(need_turn); + } + } else { + if (is_thread_zero_in_wave()) { + post_wqe_rma_single(pe, size, laddr, raddr, opcode); + } + } +} + +__device__ void QueuePair::post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: @@ -129,11 +161,6 @@ __device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, case GDAProvider::BNXT: bnxt_post_wqe_rma(pe, size, laddr, raddr, opcode); return; -#endif -#if defined(GDA_IONIC) - case GDAProvider::IONIC: - ionic_post_wqe_rma(pe, size, laddr, raddr, opcode); - return; #endif default: assert(false /* invalid nic provider */); @@ -161,16 +188,20 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad } } -__device__ void QueuePair::quiet() { +__device__ void QueuePair::quiet(Collectivity cy) { switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: - mlx5_quiet(); + if (cy == THREAD || is_thread_zero_in_wave()) { + mlx5_quiet(); + } return; #endif #if defined(GDA_BNXT) case GDAProvider::BNXT: - bnxt_quiet(); + if (cy == THREAD || is_thread_zero_in_wave()) { + bnxt_quiet(); + } return; #endif #if defined(GDA_IONIC) @@ -186,16 +217,16 @@ __device__ void QueuePair::quiet() { /****************************************************************************** ****************************** SHMEM INTERFACE ******************************* *****************************************************************************/ -__device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems, int pe) { +__device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy) { uintptr_t *src = reinterpret_cast(const_cast(source)); uintptr_t *dst = reinterpret_cast(dest); - post_wqe_rma(pe, nelems, src, dst, gda_op_rdma_write); + post_wqe_rma(pe, nelems, src, dst, gda_op_rdma_write, cy); } -__device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe) { +__device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy) { uintptr_t *src = reinterpret_cast(const_cast(source)); uintptr_t *dst = reinterpret_cast(dest); - post_wqe_rma(pe, nelems, dst, src, gda_op_rdma_read); + post_wqe_rma(pe, nelems, dst, src, gda_op_rdma_read, cy); } __device__ int64_t QueuePair::atomic_cas(void *dest, int64_t atomic_data, int64_t atomic_cmp, int pe) { diff --git a/projects/rocshmem/src/gda/queue_pair.hpp b/projects/rocshmem/src/gda/queue_pair.hpp index 8d348bbcb2..cfb7cfc6ba 100644 --- a/projects/rocshmem/src/gda/queue_pair.hpp +++ b/projects/rocshmem/src/gda/queue_pair.hpp @@ -64,6 +64,8 @@ class QueuePair { */ virtual ~QueuePair(); + enum Collectivity { THREAD, WAVE }; + /** * @brief Create and enqueue a non-blocking put work queue entry (wqe). * @@ -72,7 +74,7 @@ class QueuePair { * @param[in] nelems Size in bytes of data transmission. * @param[in] pe Destination processing element of data transmission. */ - __device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe); + __device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy = THREAD); /** * @brief Create and enqueue a non-blocking get work queue entry (wqe). @@ -82,12 +84,12 @@ class QueuePair { * @param[in] nelems Size in bytes of data transmission. * @param[in] pe Destination processing element of data transmission. */ - __device__ void get_nbi(void *dest, const void *source, size_t nelems, int pe); + __device__ void get_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy = THREAD); /** * @brief Empty all completions from the completion queue. */ - __device__ void quiet(); + __device__ void quiet(Collectivity cy = THREAD); /** * @brief Create and enqueue an atomic fetch work queue entry (wqe). @@ -158,7 +160,9 @@ class QueuePair { * @param[in] raddr Remote address. * @param[in] opcode Operation to be performed. */ - __device__ __attribute__((noinline)) void post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + __device__ __attribute__((noinline)) void post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); + __device__ __attribute__((noinline)) void post_wqe_rma_turn(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); + __device__ __attribute__((noinline)) void post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); #if defined(GDA_MLX5) __device__ uint64_t mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); @@ -172,7 +176,7 @@ class QueuePair { #endif #if defined(GDA_IONIC) __device__ uint64_t ionic_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); - __device__ void ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + __device__ void ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); __device__ void ionic_quiet(); #endif diff --git a/projects/rocshmem/src/util.hpp b/projects/rocshmem/src/util.hpp index d76ab0f3ef..10c51e7623 100644 --- a/projects/rocshmem/src/util.hpp +++ b/projects/rocshmem/src/util.hpp @@ -241,6 +241,13 @@ __device__ __forceinline__ bool is_thread_zero_in_wave() { return (get_flat_block_id() % WF_SIZE) == 0; } +/* + * Returns true if the caller's thread flat_id is in the zero'th wave. + */ +__device__ __forceinline__ bool is_wave_zero_in_block() { + return (get_flat_block_id() / WF_SIZE) == 0; +} + __device__ __forceinline__ uint64_t get_active_lane_mask() { return __ballot(true); } diff --git a/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp b/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp index 14bbe572b0..bc221659bf 100644 --- a/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp +++ b/projects/rocshmem/tests/functional_tests/wavefront_primitives.cpp @@ -55,11 +55,11 @@ __global__ void WaveFrontPrimitiveTest(int loop, int skip, for (int i = 0; i < loop + skip; i++) { if (i == skip) { // Ensures all RMA calls from the skip loops are completed - if(is_thread_zero_in_wave()) { - rocshmem_ctx_quiet(ctx); - } + rocshmem_ctx_quiet(ctx); __syncthreads(); - start_time[idx] = wall_clock64(); + if (is_thread_zero_in_wave()) { + start_time[idx] = wall_clock64(); + } } switch (type) { case WAVEGetTestType: @@ -79,8 +79,8 @@ __global__ void WaveFrontPrimitiveTest(int loop, int skip, } } + rocshmem_ctx_quiet(ctx); if (is_thread_zero_in_wave()) { - rocshmem_ctx_quiet(ctx); end_time[idx] = wall_clock64(); }