From 1347d5d62868cfbdac3bc2df6c2b4e9d762a68e3 Mon Sep 17 00:00:00 2001 From: Yiltan Date: Wed, 19 Nov 2025 14:25:29 -0500 Subject: [PATCH] [GDA] Alltoall optimization - single warp (#319) * Remove testing of data types As the collective is templated, we are just testing if sizeof(T) works * Added single threaded varients * Applied thread puts optimization to barrier * Apply single threaded optimization to alltoall * This optimization only works on bnxt, so place a switch to protect it * Handle the edge case where the thread count is smaller than the number of PEs --- src/gda/backend_gda.cpp | 4 +- src/gda/bnxt/queue_pair_bnxt.cpp | 88 +++++++++++++++++++++++++++++ src/gda/context_gda_device.cpp | 7 ++- src/gda/context_gda_device.hpp | 13 ++++- src/gda/context_gda_device_coll.cpp | 49 ++++++++++++++++ src/gda/context_gda_tmpl_device.hpp | 38 ++++++++++++- src/gda/gda_context_proxy.hpp | 3 +- src/gda/queue_pair.cpp | 39 ++++++++++++- src/gda/queue_pair.hpp | 5 ++ tests/functional_tests/tester.cpp | 6 -- 10 files changed, 236 insertions(+), 16 deletions(-) diff --git a/src/gda/backend_gda.cpp b/src/gda/backend_gda.cpp index a456c24f3a..841c7c6c14 100644 --- a/src/gda/backend_gda.cpp +++ b/src/gda/backend_gda.cpp @@ -156,7 +156,7 @@ void GDABackend::setup_host_ctx() { void GDABackend::setup_default_ctx() { TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world; - default_context_proxy_ = GDADefaultContextProxyT(this, tinfo); + default_context_proxy_ = GDADefaultContextProxyT(this, tinfo, gda_provider); } void GDABackend::setup_ctxs() { @@ -166,7 +166,7 @@ void GDABackend::setup_ctxs() { CHECK_HIP(hipMalloc(&ctx_array, sizeof(GDAContext) * envvar::max_num_contexts)); // 0th context is default context for (size_t i = 0; i < envvar::max_num_contexts; i++) { - new (&ctx_array[i]) GDAContext(this, i + 1); + new (&ctx_array[i]) GDAContext(this, i + 1, gda_provider); ctx_free_list.get()->push_back(ctx_array + i); } } diff --git a/src/gda/bnxt/queue_pair_bnxt.cpp b/src/gda/bnxt/queue_pair_bnxt.cpp index 148d5be462..807a4d199b 100644 --- a/src/gda/bnxt/queue_pair_bnxt.cpp +++ b/src/gda/bnxt/queue_pair_bnxt.cpp @@ -213,6 +213,10 @@ __device__ void QueuePair::bnxt_quiet() { } } +__device__ void QueuePair::bnxt_quiet_single() { + poll_cq_until(sq.depth); +} + __device__ void QueuePair::bnxt_post_wqe_rma(int pe, int32_t length, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { uint64_t active_lane_mask; uint8_t active_lane_count; @@ -301,6 +305,90 @@ __device__ void QueuePair::bnxt_post_wqe_rma(int pe, int32_t length, uintptr_t * } } +__device__ void QueuePair::bnxt_post_wqe_rma_single(int pe, int32_t length, uintptr_t *laddr, + uintptr_t *raddr, uint8_t opcode) { + uint64_t active_lane_mask; + uint8_t active_lane_count; + uint8_t active_lane_id; + struct bnxt_re_bsqe hdr; + struct bnxt_re_rdma rdma; + struct bnxt_re_sge sge; + struct bnxt_re_bsqe *hdr_ptr; + struct bnxt_re_rdma *rdma_ptr; + struct bnxt_re_sge *sge_ptr; + uint32_t wqe_size; + uint32_t wqe_type; + uint32_t hdr_flags; + uint32_t inline_msg; + + aquire_lock(&sq.lock); + + inline_msg = length <= inline_threshold && + opcode == gda_op_rdma_write; + + poll_cq_until(GDA_BNXT_WQE_SLOT_COUNT); + + hdr_ptr = (struct bnxt_re_bsqe*) bnxt_re_get_hwqe(&sq, 0); + rdma_ptr = (struct bnxt_re_rdma*) bnxt_re_get_hwqe(&sq, 1); + sge_ptr = (struct bnxt_re_sge*) bnxt_re_get_hwqe(&sq, 2); + + /* Populate Header Segment */ + wqe_type = BNXT_RE_HDR_WT_MASK & opcode; + wqe_size = BNXT_RE_HDR_WS_MASK & GDA_BNXT_WQE_SLOT_COUNT; + hdr_flags = ((uint32_t) BNXT_RE_HDR_FLAGS_MASK) + & ((uint32_t) BNXT_RE_WR_FLAGS_SIGNALED); + + if (inline_msg) { + hdr_flags |= ((uint32_t) BNXT_RE_WR_FLAGS_INLINE); + } + + hdr.rsv_ws_fl_wt = (wqe_size << BNXT_RE_HDR_WS_SHIFT) + | (hdr_flags << BNXT_RE_HDR_FLAGS_SHIFT) + | wqe_type; + hdr.key_immd = 0; + hdr.lhdr.qkey_len = length; + + /* Populate RDMA Segment */ + rdma.rva = (uint64_t) raddr; + rdma.rkey = rkey; + + if (!inline_msg) { + /* Populate SG Segment */ + sge.pa = (uint64_t) laddr; + sge.lkey = lkey; + sge.length = length; + } + + /* Write WQE to SQ */ + memcpy(hdr_ptr, &hdr, sizeof(struct bnxt_re_bsqe)); + memcpy(rdma_ptr, &rdma, sizeof(struct bnxt_re_rdma)); + + if (inline_msg) { + memcpy(sge_ptr, laddr, length); + } else { + memcpy(sge_ptr, &sge, sizeof(struct bnxt_re_sge)); + } + + /* Populate MSN Table */ + bnxt_re_fill_psns_for_msntbl(&sq, length); + + /* Update SQ Pointer */ + bnxt_re_incr_tail(&sq, GDA_BNXT_WQE_SLOT_COUNT); + + /* Ring Doorbell + * Doorbell ring must be serialized as we cannot have all threads write to the same address */ + active_lane_mask = get_active_lane_mask(); + active_lane_count = get_active_lane_count(active_lane_mask); + active_lane_id = get_active_lane_num(active_lane_mask); + + for (int i = 0; i < active_lane_count; i++) { + if (i == active_lane_id) { + bnxt_ring_doorbell(sq.tail); + release_lock(&sq.lock); + } + } +} + __device__ uint64_t QueuePair::bnxt_post_wqe_amo(int pe, int32_t length, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching) { uint64_t active_lane_mask; diff --git a/src/gda/context_gda_device.cpp b/src/gda/context_gda_device.cpp index e0488b2c95..e179ba29cd 100644 --- a/src/gda/context_gda_device.cpp +++ b/src/gda/context_gda_device.cpp @@ -34,7 +34,7 @@ namespace rocshmem { -__host__ GDAContext::GDAContext(Backend *b, unsigned int ctx_id) +__host__ GDAContext::GDAContext(Backend *b, unsigned int ctx_id, int gda_provider) : Context(b, false) { GDABackend *backend{static_cast(b)}; base_heap = backend->heap.get_heap_bases().data(); @@ -56,6 +56,7 @@ __host__ GDAContext::GDAContext(Backend *b, unsigned int ctx_id) ipcImpl_.pes_with_ipc_avail = backend->ipcImpl.pes_with_ipc_avail; ctx_id_ = ctx_id; + gda_provider_ = gda_provider; } __host__ GDAContext::~GDAContext() { @@ -147,6 +148,10 @@ __device__ void GDAContext::pe_quiet(size_t pe) { qps[pe].quiet(); } +__device__ void GDAContext::pe_quiet_single(size_t pe) { + qps[pe].quiet_single(); +} + __device__ void *GDAContext::shmem_ptr(const void *dest, int pe) { void *ret = nullptr; int local_pe{-1}; diff --git a/src/gda/context_gda_device.hpp b/src/gda/context_gda_device.hpp index 08cf232f28..1fe9172644 100644 --- a/src/gda/context_gda_device.hpp +++ b/src/gda/context_gda_device.hpp @@ -34,7 +34,7 @@ class QueuePair; class GDAContext : public Context { public: - __host__ GDAContext(Backend *b, unsigned int ctx_id); + __host__ GDAContext(Backend *b, unsigned int ctx_id, int gda_provider); __host__ ~GDAContext(); @@ -63,6 +63,7 @@ class GDAContext : public Context { __device__ void quiet_wave(); __device__ void pe_quiet(size_t pe); + __device__ void pe_quiet_single(size_t pe); __device__ void *shmem_ptr(const void *dest, int pe); @@ -257,6 +258,10 @@ class GDAContext : public Context { __device__ void alltoall_linear(rocshmem_team_t team, T *dest, const T *source, int nelems); + template + __device__ void alltoall_linear_thread_puts(rocshmem_team_t team, T *dest, + const T *source, int nelems); + __device__ void internal_sync(int pe, int PE_start, int stride, int PE_size, int64_t *pSync); @@ -272,6 +277,10 @@ class GDAContext : public Context { __device__ void internal_direct_barrier_wg(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); + __device__ void internal_direct_barrier_wg_thread_puts(int pe, int PE_start, + int stride, int n_pes, + int64_t *pSync); + __device__ void internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); @@ -298,6 +307,8 @@ class GDAContext : public Context { */ unsigned int ctx_id_{}; + int gda_provider_{0}; + public: QueuePair *qps{nullptr}; diff --git a/src/gda/context_gda_device_coll.cpp b/src/gda/context_gda_device_coll.cpp index 7610e4b2c9..9ebadd318b 100644 --- a/src/gda/context_gda_device_coll.cpp +++ b/src/gda/context_gda_device_coll.cpp @@ -121,6 +121,55 @@ __device__ void GDAContext::internal_direct_barrier_wg(int pe, int PE_start, } } +__device__ void GDAContext::internal_direct_barrier_wg_thread_puts(int pe, int PE_start, + int stride, int n_pes, + int64_t *pSync) { + int64_t flag_val{1}; + + if (pe == PE_start) { + int tid = get_flat_block_id(); + int step_size = min(get_flat_block_size(), WF_SIZE); + + // Go through all PE offsets (except current offset = 0) + // and wait until they all reach + for (int j = tid + 1; j < n_pes; j+= step_size) { + wait_until(&pSync[j], ROCSHMEM_CMP_EQ, flag_val); + pSync[j] = ROCSHMEM_SYNC_VALUE; + } + + __syncthreads(); + + // Announce to other PEs that all have reached + for (int i = tid + 1, j = PE_start + stride + tid; + i < n_pes; + i+= step_size, j += (step_size * stride)) { + uint64_t L_offset = reinterpret_cast(&pSync[0]) - base_heap[my_pe]; + qps[j].put_nbi_single(base_heap[j] + L_offset, &flag_val, sizeof(long), j); + } + + for (int i = tid + 1, j = PE_start + stride + tid; + i < n_pes; + i+= step_size, j += (step_size * stride)) { + pe_quiet_single(j); + } + + __syncthreads(); + + if (is_thread_zero_in_block()) { + pSync[0] = ROCSHMEM_SYNC_VALUE; + } + } else { + if (is_thread_zero_in_block()) { + // Mark current PE offset as reached + size_t pe_offset = (pe - PE_start) / stride; + putmem(&pSync[pe_offset], &flag_val, sizeof(long), PE_start); + wait_until(&pSync[0], ROCSHMEM_CMP_EQ, flag_val); + pSync[0] = ROCSHMEM_SYNC_VALUE; + __threadfence_system(); + } + } +} + __device__ void GDAContext::internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync) { diff --git a/src/gda/context_gda_tmpl_device.hpp b/src/gda/context_gda_tmpl_device.hpp index c21bef8942..8fbb33af01 100644 --- a/src/gda/context_gda_tmpl_device.hpp +++ b/src/gda/context_gda_tmpl_device.hpp @@ -32,6 +32,7 @@ #include "gda_team.hpp" #include "queue_pair.hpp" #include "rocshmem_calc.hpp" +#include "backend_gda.hpp" #include @@ -604,7 +605,11 @@ __device__ void GDAContext::internal_broadcast(T *dst, const T *src, int nelems, template __device__ void GDAContext::alltoall(rocshmem_team_t team, T *dst, const T *src, int nelems) { - alltoall_linear(team, dst, src, nelems); + if (gda_provider_ == GDAProvider::BNXT) { + alltoall_linear_thread_puts(team, dst, src, nelems); + } else { + alltoall_linear(team, dst, src, nelems); + } } template @@ -620,7 +625,6 @@ __device__ void GDAContext::alltoall_linear(rocshmem_team_t team, T *dst, int wf_id = get_flat_block_id() / WF_SIZE; int wf_count = (int) ceil((double)get_flat_block_size() / (double)WF_SIZE); - bool wf_leader = 0 == get_active_lane_num(); // Have each PE put their designated data to the other PEs for (int j = wf_id; j < pe_size; j+= wf_count) { @@ -637,6 +641,36 @@ __device__ void GDAContext::alltoall_linear(rocshmem_team_t team, T *dst, internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync); } +template +__device__ void GDAContext::alltoall_linear_thread_puts(rocshmem_team_t team, T *dst, + const T *src, int nelems) { + GDATeam *team_obj = reinterpret_cast(team); + + int pe_start = team_obj->tinfo_wrt_world->pe_start; + int pe_size = team_obj->num_pes; + int stride = team_obj->tinfo_wrt_world->stride; + long *pSync = team_obj->alltoall_pSync; + int my_pe_in_team = team_obj->my_pe; + + int tid = get_flat_block_id(); + int step_size = min(get_flat_block_size(), WF_SIZE); + + // Have each PE put their designated data to the other PEs + for (int j = tid; j < pe_size; j+= step_size) { + int dest_pe = team_obj->get_pe_in_world(j); + uint64_t L_offset = reinterpret_cast(&dst[my_pe_in_team * nelems]) - base_heap[my_pe]; + qps[dest_pe].put_nbi_single(base_heap[dest_pe] + L_offset, &src[j * nelems], nelems * sizeof(T), dest_pe); + } + + for (int j = tid; j < pe_size; j+= step_size) { + int dest_pe = team_obj->get_pe_in_world(j); + pe_quiet_single(dest_pe); + } + + // wait until everyone has obtained their designated data + internal_direct_barrier_wg_thread_puts(my_pe, pe_start, stride, pe_size, pSync); +} + template __device__ void GDAContext::fcollect(rocshmem_team_t team, T *dst, const T *src, int nelems) { diff --git a/src/gda/gda_context_proxy.hpp b/src/gda/gda_context_proxy.hpp index 14cac518f5..0bd726023c 100644 --- a/src/gda/gda_context_proxy.hpp +++ b/src/gda/gda_context_proxy.hpp @@ -44,10 +44,11 @@ class GDADefaultContextProxy { * Placement new the memory which is allocated by proxy_ */ explicit GDADefaultContextProxy(GDABackend* backend, TeamInfo *tinfo, + int gda_provider, size_t num_elems = 1) : constructed_{true}, proxy_{num_elems} { auto ctx{proxy_.get()}; - new (ctx) GDAContext(reinterpret_cast(backend), 0); + new (ctx) GDAContext(reinterpret_cast(backend), 0, gda_provider); ctx->tinfo = tinfo; rocshmem_ctx_t local{ctx, tinfo}; set_internal_ctx(&local); diff --git a/src/gda/queue_pair.cpp b/src/gda/queue_pair.cpp index 3ac890f81a..07e3ee3e7b 100644 --- a/src/gda/queue_pair.cpp +++ b/src/gda/queue_pair.cpp @@ -138,19 +138,19 @@ __device__ void QueuePair::post_wqe_rma_turn(int pe, int32_t size, uintptr_t *la uint8_t lane = __ffsll((unsigned long long)turns) - 1; int pe_turn = __shfl(pe, lane); if (pe_turn == pe) { - post_wqe_rma_single(pe, size, laddr, raddr, opcode); + post_wqe_rma_mt(pe, size, laddr, raddr, opcode); need_turn = false; } turns = __ballot(need_turn); } } else { if (is_thread_zero_in_wave()) { - post_wqe_rma_single(pe, size, laddr, raddr, opcode); + post_wqe_rma_mt(pe, size, laddr, raddr, opcode); } } } -__device__ void QueuePair::post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { +__device__ void QueuePair::post_wqe_rma_mt(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: @@ -167,6 +167,19 @@ __device__ void QueuePair::post_wqe_rma_single(int pe, int32_t size, uintptr_t * } } +__device__ void QueuePair::post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { + switch (gda_provider_) { +#if defined(GDA_BNXT) + case GDAProvider::BNXT: + return bnxt_post_wqe_rma_single(pe, size, laddr, raddr, opcode); +#endif + case GDAProvider::IONIC: + case GDAProvider::MLX5: + default: + assert(false /* invalid nic provider */); + } +} + __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching) { switch (gda_provider_) { @@ -214,6 +227,20 @@ __device__ void QueuePair::quiet(Collectivity cy) { } } +__device__ void QueuePair::quiet_single() { + switch (gda_provider_) { +#if defined(GDA_BNXT) + case GDAProvider::BNXT: + bnxt_quiet_single(); + return; +#endif + case GDAProvider::MLX5: + case GDAProvider::IONIC: + default: + assert(false /* invalid nic provider */); + } +} + /****************************************************************************** ****************************** SHMEM INTERFACE ******************************* *****************************************************************************/ @@ -223,6 +250,12 @@ __device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems post_wqe_rma(pe, nelems, src, dst, gda_op_rdma_write, cy); } +__device__ void QueuePair::put_nbi_single(void *dest, const void *source, size_t nelems, int pe) { + uintptr_t *src = reinterpret_cast(const_cast(source)); + uintptr_t *dst = reinterpret_cast(dest); + post_wqe_rma_single(pe, nelems, src, dst, gda_op_rdma_write); +} + __device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy) { uintptr_t *src = reinterpret_cast(const_cast(source)); uintptr_t *dst = reinterpret_cast(dest); diff --git a/src/gda/queue_pair.hpp b/src/gda/queue_pair.hpp index e415666389..84a84a5d0d 100644 --- a/src/gda/queue_pair.hpp +++ b/src/gda/queue_pair.hpp @@ -77,6 +77,7 @@ class QueuePair { * @param[in] pe Destination processing element of data transmission. */ __device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy = THREAD); + __device__ void put_nbi_single(void *dest, const void *source, size_t nelems, int pe); /** * @brief Create and enqueue a non-blocking get work queue entry (wqe). @@ -92,6 +93,7 @@ class QueuePair { * @brief Empty all completions from the completion queue. */ __device__ void quiet(Collectivity cy = THREAD); + __device__ void quiet_single(); /** * @brief Create and enqueue an atomic fetch work queue entry (wqe). @@ -165,6 +167,7 @@ class QueuePair { __device__ __attribute__((noinline)) void post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); __device__ __attribute__((noinline)) void post_wqe_rma_turn(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); __device__ __attribute__((noinline)) void post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + __device__ __attribute__((noinline)) void post_wqe_rma_mt(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); #if defined(GDA_MLX5) __device__ uint64_t mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); @@ -174,7 +177,9 @@ class QueuePair { #if defined(GDA_BNXT) __device__ uint64_t bnxt_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); __device__ void bnxt_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + __device__ void bnxt_post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); __device__ void bnxt_quiet(); + __device__ void bnxt_quiet_single(); #endif #if defined(GDA_IONIC) __device__ uint64_t ionic_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); diff --git a/tests/functional_tests/tester.cpp b/tests/functional_tests/tester.cpp index 727cb4a0d5..cd550db80a 100644 --- a/tests/functional_tests/tester.cpp +++ b/tests/functional_tests/tester.cpp @@ -225,13 +225,7 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) { std::cout << "Alltoall Test ###" << std::endl; } - testers.push_back(new TeamAlltoallTester(args)); - testers.push_back(new TeamAlltoallTester(args)); - testers.push_back(new TeamAlltoallTester(args)); testers.push_back(new TeamAlltoallTester(args)); - testers.push_back(new TeamAlltoallTester(args)); - testers.push_back(new TeamAlltoallTester(args)); - testers.push_back(new TeamAlltoallTester(args)); return testers; case TeamFCollectTestType: if (rank == 0) {