diff --git a/src/gda/context_gda_device.cpp b/src/gda/context_gda_device.cpp index 869a0a1d9b..5cba9c9588 100644 --- a/src/gda/context_gda_device.cpp +++ b/src/gda/context_gda_device.cpp @@ -77,18 +77,8 @@ __device__ void GDAContext::putmem(void *dest, const void *source, size_t nelems return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - qps[pe].quiet(); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); + qps[pe].quiet(); } __device__ void GDAContext::getmem(void *dest, const void *source, size_t nelems, @@ -101,18 +91,8 @@ __device__ void GDAContext::getmem(void *dest, const void *source, size_t nelems return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - qps[pe].quiet(); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); + qps[pe].quiet(); } __device__ void GDAContext::putmem_nbi(void *dest, const void *source, @@ -124,17 +104,7 @@ __device__ void GDAContext::putmem_nbi(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); } __device__ void GDAContext::getmem_nbi(void *dest, const void *source, @@ -147,17 +117,7 @@ __device__ void GDAContext::getmem_nbi(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - bool need_turn {true}; - uint64_t turns = __ballot(need_turn); - while (turns) { - uint8_t lane = __ffsll((unsigned long long)turns) - 1; - int pe_turn = __shfl(pe, lane); - if (pe_turn == pe) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - need_turn = false; - } - turns = __ballot(need_turn); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); } __device__ void GDAContext::fence() { //TODO: optimize @@ -177,6 +137,12 @@ __device__ void GDAContext::quiet() { } } +__device__ void GDAContext::quiet_wave() { + for (int i = 0; i < num_pes; i++) { + qps[i].quiet(QueuePair::WAVE); + } +} + __device__ void GDAContext::pe_quiet(size_t pe) { qps[pe].quiet(); } @@ -201,8 +167,8 @@ __device__ void GDAContext::putmem_wg(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); qps[pe].quiet(); } } @@ -217,8 +183,8 @@ __device__ void GDAContext::getmem_wg(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe, QueuePair::WAVE); qps[pe].quiet(); } } @@ -232,8 +198,8 @@ __device__ void GDAContext::putmem_nbi_wg(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); } } @@ -247,8 +213,8 @@ __device__ void GDAContext::getmem_nbi_wg(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_block()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); + if (is_wave_zero_in_block()) { + qps[pe].get_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); } } @@ -261,10 +227,8 @@ __device__ void GDAContext::putmem_wave(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - qps[pe].quiet(); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); + qps[pe].quiet(); } __device__ void GDAContext::getmem_wave(void *dest, const void *source, @@ -277,10 +241,8 @@ __device__ void GDAContext::getmem_wave(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - qps[pe].quiet(); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe, QueuePair::WAVE); + qps[pe].quiet(); } __device__ void GDAContext::putmem_nbi_wave(void *dest, const void *source, @@ -292,9 +254,7 @@ __device__ void GDAContext::putmem_nbi_wave(void *dest, const void *source, return; } uint64_t L_offset = reinterpret_cast(dest) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe); - } + qps[pe].put_nbi(base_heap[pe] + L_offset, source, nelems, pe, QueuePair::WAVE); } __device__ void GDAContext::getmem_nbi_wave(void *dest, const void *source, @@ -307,9 +267,7 @@ __device__ void GDAContext::getmem_nbi_wave(void *dest, const void *source, return; } uint64_t L_offset = const_cast(src_typed) - base_heap[my_pe]; - if (is_thread_zero_in_wave()) { - qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe); - } + qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe, QueuePair::WAVE); } diff --git a/src/gda/context_gda_device.hpp b/src/gda/context_gda_device.hpp index f00cb4b0a7..08cf232f28 100644 --- a/src/gda/context_gda_device.hpp +++ b/src/gda/context_gda_device.hpp @@ -60,6 +60,8 @@ class GDAContext : public Context { __device__ void quiet(); + __device__ void quiet_wave(); + __device__ void pe_quiet(size_t pe); __device__ void *shmem_ptr(const void *dest, int pe); diff --git a/src/gda/context_gda_device_coll.cpp b/src/gda/context_gda_device_coll.cpp index 7ae16341d7..7610e4b2c9 100644 --- a/src/gda/context_gda_device_coll.cpp +++ b/src/gda/context_gda_device_coll.cpp @@ -232,15 +232,13 @@ __device__ void GDAContext::barrier_all() { } __device__ void GDAContext::barrier_all_wave() { - if (is_thread_zero_in_wave()) { - quiet(); - } + quiet_wave(); sync_all_wave(); } __device__ void GDAContext::barrier_all_wg() { - if (is_thread_zero_in_block()) { - quiet(); + if (is_wave_zero_in_block()) { + quiet_wave(); } sync_all_wg(); __syncthreads(); @@ -268,9 +266,7 @@ __device__ void GDAContext::barrier_wave(rocshmem_team_t team) { int pe_size = team_obj->num_pes; long *p_sync = team_obj->barrier_pSync; - if (is_thread_zero_in_wave()) { - quiet(); - } + quiet_wave(); internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync); } @@ -283,8 +279,8 @@ __device__ void GDAContext::barrier_wg(rocshmem_team_t team) { int pe_size = team_obj->num_pes; long *p_sync = team_obj->barrier_pSync; - if (is_thread_zero_in_block()) { - quiet(); + if (is_wave_zero_in_block()) { + quiet_wave(); } internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync); __syncthreads(); diff --git a/src/gda/ionic/queue_pair_ionic.cpp b/src/gda/ionic/queue_pair_ionic.cpp index 59872a3d2e..9533ab21be 100644 --- a/src/gda/ionic/queue_pair_ionic.cpp +++ b/src/gda/ionic/queue_pair_ionic.cpp @@ -171,7 +171,8 @@ __device__ void QueuePair::ionic_quiet_internal(uint64_t activemask, uint32_t co } __device__ void QueuePair::ionic_ring_doorbell(uint32_t pos) { - // TODO When threads write at once to the same address, not all writes reach the bus. + // When threads write at once to the same address, not all writes reach the bus. + // Take turns and insert a thread fence between writes to the same address. for (int i = 0; i < 64; ++i) { if (__lane_id() == i) { __threadfence(); @@ -185,11 +186,22 @@ __device__ void QueuePair::ionic_quiet() { ionic_quiet_internal(get_same_qp_lane_mask(), sq_prod); } -__device__ void QueuePair::ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { +__device__ void QueuePair::ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy) { uint64_t activemask = get_same_qp_lane_mask(); - uint32_t num_wqes = get_active_lane_count(activemask); uint32_t my_logical_lane_id = get_active_lane_num(activemask); + uint32_t num_wqes = 1; + if (cy == THREAD) { + num_wqes = get_active_lane_count(activemask); + } + uint32_t my_sq_prod = reserve_sq(activemask, num_wqes); + if (cy == WAVE) { + if (!is_first_active_lane(activemask)) { + return; + } + activemask &= activemask ^ (activemask - 1); + } + uint32_t my_sq_pos = my_sq_prod + my_logical_lane_id; struct ionic_v1_wqe *wqe = &ionic_sq_buf[my_sq_pos & sq_mask]; uint16_t wqe_flags = 0; diff --git a/src/gda/queue_pair.cpp b/src/gda/queue_pair.cpp index c268ae99e6..729db6490b 100644 --- a/src/gda/queue_pair.cpp +++ b/src/gda/queue_pair.cpp @@ -118,7 +118,39 @@ QueuePair::~QueuePair() { /****************************************************************************** ************************ PROVIDER-SPECIFIC HELPERS *************************** *****************************************************************************/ -__device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { +__device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy) { + switch (gda_provider_) { +#if defined(GDA_IONIC) + case GDAProvider::IONIC: + ionic_post_wqe_rma(pe, size, laddr, raddr, opcode, cy); + return; +#endif + default: + post_wqe_rma_turn(pe, size, laddr, raddr, opcode, cy); + } +} + +__device__ void QueuePair::post_wqe_rma_turn(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy) { + if (cy == THREAD) { + bool need_turn {true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + post_wqe_rma_single(pe, size, laddr, raddr, opcode); + need_turn = false; + } + turns = __ballot(need_turn); + } + } else { + if (is_thread_zero_in_wave()) { + post_wqe_rma_single(pe, size, laddr, raddr, opcode); + } + } +} + +__device__ void QueuePair::post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: @@ -129,11 +161,6 @@ __device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, case GDAProvider::BNXT: bnxt_post_wqe_rma(pe, size, laddr, raddr, opcode); return; -#endif -#if defined(GDA_IONIC) - case GDAProvider::IONIC: - ionic_post_wqe_rma(pe, size, laddr, raddr, opcode); - return; #endif default: assert(false /* invalid nic provider */); @@ -161,16 +188,20 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad } } -__device__ void QueuePair::quiet() { +__device__ void QueuePair::quiet(Collectivity cy) { switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: - mlx5_quiet(); + if (cy == THREAD || is_thread_zero_in_wave()) { + mlx5_quiet(); + } return; #endif #if defined(GDA_BNXT) case GDAProvider::BNXT: - bnxt_quiet(); + if (cy == THREAD || is_thread_zero_in_wave()) { + bnxt_quiet(); + } return; #endif #if defined(GDA_IONIC) @@ -186,16 +217,16 @@ __device__ void QueuePair::quiet() { /****************************************************************************** ****************************** SHMEM INTERFACE ******************************* *****************************************************************************/ -__device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems, int pe) { +__device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy) { uintptr_t *src = reinterpret_cast(const_cast(source)); uintptr_t *dst = reinterpret_cast(dest); - post_wqe_rma(pe, nelems, src, dst, gda_op_rdma_write); + post_wqe_rma(pe, nelems, src, dst, gda_op_rdma_write, cy); } -__device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe) { +__device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy) { uintptr_t *src = reinterpret_cast(const_cast(source)); uintptr_t *dst = reinterpret_cast(dest); - post_wqe_rma(pe, nelems, dst, src, gda_op_rdma_read); + post_wqe_rma(pe, nelems, dst, src, gda_op_rdma_read, cy); } __device__ int64_t QueuePair::atomic_cas(void *dest, int64_t atomic_data, int64_t atomic_cmp, int pe) { diff --git a/src/gda/queue_pair.hpp b/src/gda/queue_pair.hpp index 8d348bbcb2..cfb7cfc6ba 100644 --- a/src/gda/queue_pair.hpp +++ b/src/gda/queue_pair.hpp @@ -64,6 +64,8 @@ class QueuePair { */ virtual ~QueuePair(); + enum Collectivity { THREAD, WAVE }; + /** * @brief Create and enqueue a non-blocking put work queue entry (wqe). * @@ -72,7 +74,7 @@ class QueuePair { * @param[in] nelems Size in bytes of data transmission. * @param[in] pe Destination processing element of data transmission. */ - __device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe); + __device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy = THREAD); /** * @brief Create and enqueue a non-blocking get work queue entry (wqe). @@ -82,12 +84,12 @@ class QueuePair { * @param[in] nelems Size in bytes of data transmission. * @param[in] pe Destination processing element of data transmission. */ - __device__ void get_nbi(void *dest, const void *source, size_t nelems, int pe); + __device__ void get_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy = THREAD); /** * @brief Empty all completions from the completion queue. */ - __device__ void quiet(); + __device__ void quiet(Collectivity cy = THREAD); /** * @brief Create and enqueue an atomic fetch work queue entry (wqe). @@ -158,7 +160,9 @@ class QueuePair { * @param[in] raddr Remote address. * @param[in] opcode Operation to be performed. */ - __device__ __attribute__((noinline)) void post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + __device__ __attribute__((noinline)) void post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); + __device__ __attribute__((noinline)) void post_wqe_rma_turn(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); + __device__ __attribute__((noinline)) void post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); #if defined(GDA_MLX5) __device__ uint64_t mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); @@ -172,7 +176,7 @@ class QueuePair { #endif #if defined(GDA_IONIC) __device__ uint64_t ionic_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); - __device__ void ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + __device__ void ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); __device__ void ionic_quiet(); #endif diff --git a/src/util.hpp b/src/util.hpp index d76ab0f3ef..10c51e7623 100644 --- a/src/util.hpp +++ b/src/util.hpp @@ -241,6 +241,13 @@ __device__ __forceinline__ bool is_thread_zero_in_wave() { return (get_flat_block_id() % WF_SIZE) == 0; } +/* + * Returns true if the caller's thread flat_id is in the zero'th wave. + */ +__device__ __forceinline__ bool is_wave_zero_in_block() { + return (get_flat_block_id() / WF_SIZE) == 0; +} + __device__ __forceinline__ uint64_t get_active_lane_mask() { return __ballot(true); } diff --git a/tests/functional_tests/wavefront_primitives.cpp b/tests/functional_tests/wavefront_primitives.cpp index 14bbe572b0..bc221659bf 100644 --- a/tests/functional_tests/wavefront_primitives.cpp +++ b/tests/functional_tests/wavefront_primitives.cpp @@ -55,11 +55,11 @@ __global__ void WaveFrontPrimitiveTest(int loop, int skip, for (int i = 0; i < loop + skip; i++) { if (i == skip) { // Ensures all RMA calls from the skip loops are completed - if(is_thread_zero_in_wave()) { - rocshmem_ctx_quiet(ctx); - } + rocshmem_ctx_quiet(ctx); __syncthreads(); - start_time[idx] = wall_clock64(); + if (is_thread_zero_in_wave()) { + start_time[idx] = wall_clock64(); + } } switch (type) { case WAVEGetTestType: @@ -79,8 +79,8 @@ __global__ void WaveFrontPrimitiveTest(int loop, int skip, } } + rocshmem_ctx_quiet(ctx); if (is_thread_zero_in_wave()) { - rocshmem_ctx_quiet(ctx); end_time[idx] = wall_clock64(); }