diff --git a/projects/rocshmem/src/gda/bnxt/bnxt_re_hsi.h b/projects/rocshmem/src/gda/bnxt/bnxt_re_hsi.h index 710222ba9d..24f32a4aab 100644 --- a/projects/rocshmem/src/gda/bnxt/bnxt_re_hsi.h +++ b/projects/rocshmem/src/gda/bnxt/bnxt_re_hsi.h @@ -40,8 +40,6 @@ #ifdef __cplusplus extern "C" { #endif -#define true 1 -#define false 0 #define BNXT_RE_SLOT_SIZE_BB 16 #define BNXT_RE_STATIC_WQE_SIZE_SLOTS 4 diff --git a/projects/rocshmem/src/gda/bnxt/queue_pair_bnxt.cpp b/projects/rocshmem/src/gda/bnxt/queue_pair_bnxt.cpp index 807a4d199b..54f77b9e3c 100644 --- a/projects/rocshmem/src/gda/bnxt/queue_pair_bnxt.cpp +++ b/projects/rocshmem/src/gda/bnxt/queue_pair_bnxt.cpp @@ -217,99 +217,7 @@ __device__ void QueuePair::bnxt_quiet_single() { poll_cq_until(sq.depth); } -__device__ void QueuePair::bnxt_post_wqe_rma(int pe, int32_t length, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { - uint64_t active_lane_mask; - uint8_t active_lane_count; - uint8_t active_lane_id; - - active_lane_mask = get_active_lane_mask(); - active_lane_count = get_active_lane_count(active_lane_mask); - active_lane_id = get_active_lane_num(active_lane_mask); - - if (0 == active_lane_id) { - aquire_lock(&sq.lock); - } - - for (int i = 0; i < active_lane_count; i++) { - if (i == active_lane_id) { - struct bnxt_re_bsqe hdr; - struct bnxt_re_rdma rdma; - struct bnxt_re_sge sge; - struct bnxt_re_bsqe *hdr_ptr; - struct bnxt_re_rdma *rdma_ptr; - struct bnxt_re_sge *sge_ptr; - uint32_t wqe_size; - uint32_t wqe_type; - uint32_t hdr_flags; - uint32_t inline_msg; - - inline_msg = length <= inline_threshold && - opcode == gda_op_rdma_write; - - poll_cq_until(GDA_BNXT_WQE_SLOT_COUNT); - - hdr_ptr = (struct bnxt_re_bsqe*) bnxt_re_get_hwqe(&sq, 0); - rdma_ptr = (struct bnxt_re_rdma*) bnxt_re_get_hwqe(&sq, 1); - sge_ptr = (struct bnxt_re_sge*) bnxt_re_get_hwqe(&sq, 2); - - /* Populate Header Segment */ - wqe_type = BNXT_RE_HDR_WT_MASK & opcode; - wqe_size = BNXT_RE_HDR_WS_MASK & GDA_BNXT_WQE_SLOT_COUNT; - hdr_flags = ((uint32_t) BNXT_RE_HDR_FLAGS_MASK) - & ((uint32_t) BNXT_RE_WR_FLAGS_SIGNALED); - - if (inline_msg) { - hdr_flags |= ((uint32_t) BNXT_RE_WR_FLAGS_INLINE); - } - - hdr.rsv_ws_fl_wt = (wqe_size << BNXT_RE_HDR_WS_SHIFT) - | (hdr_flags << BNXT_RE_HDR_FLAGS_SHIFT) - | wqe_type; - hdr.key_immd = 0; - hdr.lhdr.qkey_len = length; - - /* Populate RDMA Segment */ - rdma.rva = (uint64_t) raddr; - rdma.rkey = rkey; - - if (!inline_msg) { - /* Populate SG Segment */ - sge.pa = (uint64_t) laddr; - sge.lkey = lkey; - sge.length = length; - } - - /* Write WQE to SQ */ - memcpy(hdr_ptr, &hdr, sizeof(struct bnxt_re_bsqe)); - memcpy(rdma_ptr, &rdma, sizeof(struct bnxt_re_rdma)); - - if (inline_msg) { - memcpy(sge_ptr, laddr, length); - } else { - memcpy(sge_ptr, &sge, sizeof(struct bnxt_re_sge)); - } - - /* Populate MSN Table */ - bnxt_re_fill_psns_for_msntbl(&sq, length); - - /* Update SQ Pointer */ - bnxt_re_incr_tail(&sq, GDA_BNXT_WQE_SLOT_COUNT); - - /* Ring Doorbell */ - bnxt_ring_doorbell(sq.tail); - } - } - - if (0 == active_lane_id) { - release_lock(&sq.lock); - } -} - -__device__ void QueuePair::bnxt_post_wqe_rma_single(int pe, int32_t length, uintptr_t *laddr, - uintptr_t *raddr, uint8_t opcode) { - uint64_t active_lane_mask; - uint8_t active_lane_count; - uint8_t active_lane_id; +__device__ void QueuePair::bnxt_write_rma_wqe(uintptr_t *raddr, uintptr_t *laddr, int32_t length, uint8_t opcode) { struct bnxt_re_bsqe hdr; struct bnxt_re_rdma rdma; struct bnxt_re_sge sge; @@ -321,8 +229,6 @@ __device__ void QueuePair::bnxt_post_wqe_rma_single(int pe, int32_t length, uint uint32_t hdr_flags; uint32_t inline_msg; - aquire_lock(&sq.lock); - inline_msg = length <= inline_threshold && opcode == gda_op_rdma_write; @@ -374,23 +280,132 @@ __device__ void QueuePair::bnxt_post_wqe_rma_single(int pe, int32_t length, uint /* Update SQ Pointer */ bnxt_re_incr_tail(&sq, GDA_BNXT_WQE_SLOT_COUNT); +} + +__device__ void QueuePair::bnxt_post_wqe_rma(int pe, int32_t length, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { + uint64_t active_lane_mask; + uint8_t active_lane_count; + uint8_t active_lane_id; - /* Ring Doorbell - * Doorbell ring must be serialized as we cannot have all threads write to the same address */ active_lane_mask = get_active_lane_mask(); active_lane_count = get_active_lane_count(active_lane_mask); active_lane_id = get_active_lane_num(active_lane_mask); + if (0 == active_lane_id) { + aquire_lock(&sq.lock); + } + for (int i = 0; i < active_lane_count; i++) { if (i == active_lane_id) { + /* Write WQE to SQ */ + bnxt_write_rma_wqe(raddr, laddr, length, opcode); + + /* Ring Doorbell */ bnxt_ring_doorbell(sq.tail); - release_lock(&sq.lock); } } + + if (0 == active_lane_id) { + release_lock(&sq.lock); + } } -__device__ uint64_t QueuePair::bnxt_post_wqe_amo(int pe, int32_t length, uintptr_t *raddr, uint8_t opcode, - int64_t atomic_data, int64_t atomic_cmp, bool fetching) { +__device__ void QueuePair::bnxt_post_wqe_rma_single(int32_t length, uintptr_t *laddr, + uintptr_t *raddr, uint8_t opcode, + bool ring_db) { + + aquire_lock(&sq.lock); + + /* Write WQE to SQ */ + bnxt_write_rma_wqe(raddr, laddr, length, opcode); + + if (ring_db) { + uint64_t active_lane_mask; + uint8_t active_lane_count; + uint8_t active_lane_id; + + active_lane_mask = get_active_lane_mask(); + active_lane_count = get_active_lane_count(active_lane_mask); + active_lane_id = get_active_lane_num(active_lane_mask); + + /* Ring Doorbell + * Doorbell ring must be serialized as we cannot have all threads write to the same address */ + for (int i = 0; i < active_lane_count; i++) { + if (i == active_lane_id) { + bnxt_ring_doorbell(sq.tail); + } + } + } + + release_lock(&sq.lock); +} + +__device__ uint32_t QueuePair::bnxt_write_amo_wqe(uintptr_t *raddr, uint8_t opcode, + int64_t atomic_data, int64_t atomic_cmp, + bool fetching) { + struct bnxt_re_bsqe hdr; + struct bnxt_re_atomic amo; + struct bnxt_re_sge sge; + struct bnxt_re_bsqe *hdr_ptr; + struct bnxt_re_atomic *amo_ptr; + struct bnxt_re_sge *sge_ptr; + uint32_t wqe_size; + uint32_t wqe_type; + uint32_t hdr_flags; + + uint32_t atomic_idx = 0; + uint32_t length = sizeof(uint64_t); + + poll_cq_until(GDA_BNXT_WQE_SLOT_COUNT); + + hdr_ptr = (struct bnxt_re_bsqe*) bnxt_re_get_hwqe(&sq, 0); + amo_ptr = (struct bnxt_re_atomic*) bnxt_re_get_hwqe(&sq, 1); + sge_ptr = (struct bnxt_re_sge*) bnxt_re_get_hwqe(&sq, 2); + + /* Populate Header Segment */ + wqe_size = BNXT_RE_HDR_WS_MASK & GDA_BNXT_WQE_SLOT_COUNT; + hdr_flags = ((uint32_t) BNXT_RE_HDR_FLAGS_MASK) + & ((uint32_t) BNXT_RE_WR_FLAGS_SIGNALED); + wqe_type = BNXT_RE_HDR_WT_MASK & opcode; + + hdr.rsv_ws_fl_wt = (wqe_size << BNXT_RE_HDR_WS_SHIFT) + | (hdr_flags << BNXT_RE_HDR_FLAGS_SHIFT) + | wqe_type; + hdr.key_immd = rkey; + hdr.lhdr.rva = (uint64_t) raddr; + + /* Populate AMO Segment */ + amo.swp_dt = atomic_data; + amo.cmp_dt = atomic_cmp; + + /* Populate SG Segment - (Return address of atomic) */ + if (fetching) { + atomic_idx = fetching_atomic_idx++ % FETCHING_ATOMIC_CNT; + sge.pa = (uint64_t) &fetching_atomic[atomic_idx]; + sge.lkey = fetching_atomic_lkey; + } else { + sge.pa = (uint64_t) nonfetching_atomic; + sge.lkey = nonfetching_atomic_lkey; + } + sge.length = length; + + /* Write WQE to SQ */ + memcpy(hdr_ptr, &hdr, sizeof(struct bnxt_re_bsqe)); + memcpy(amo_ptr, &amo, sizeof(struct bnxt_re_atomic)); + memcpy(sge_ptr, &sge, sizeof(struct bnxt_re_sge)); + + /* Populate MSN Table */ + bnxt_re_fill_psns_for_msntbl(&sq, length); + + /* Update SQ Pointer */ + bnxt_re_incr_tail(&sq, GDA_BNXT_WQE_SLOT_COUNT); + + return atomic_idx; +} + +__device__ uint64_t QueuePair::bnxt_post_wqe_amo(uintptr_t *raddr, uint8_t opcode, + int64_t atomic_data, int64_t atomic_cmp, + bool fetching) { uint64_t active_lane_mask; uint8_t active_lane_count; uint8_t active_lane_id; @@ -406,59 +421,7 @@ __device__ uint64_t QueuePair::bnxt_post_wqe_amo(int pe, int32_t length, uintptr for (int i = 0; i < active_lane_count; i++) { if (i == active_lane_id) { - struct bnxt_re_bsqe hdr; - struct bnxt_re_atomic amo; - struct bnxt_re_sge sge; - struct bnxt_re_bsqe *hdr_ptr; - struct bnxt_re_atomic *amo_ptr; - struct bnxt_re_sge *sge_ptr; - uint32_t wqe_size; - uint32_t wqe_type; - uint32_t hdr_flags; - - poll_cq_until(GDA_BNXT_WQE_SLOT_COUNT); - - hdr_ptr = (struct bnxt_re_bsqe*) bnxt_re_get_hwqe(&sq, 0); - amo_ptr = (struct bnxt_re_atomic*) bnxt_re_get_hwqe(&sq, 1); - sge_ptr = (struct bnxt_re_sge*) bnxt_re_get_hwqe(&sq, 2); - - /* Populate Header Segment */ - wqe_size = BNXT_RE_HDR_WS_MASK & GDA_BNXT_WQE_SLOT_COUNT; - hdr_flags = ((uint32_t) BNXT_RE_HDR_FLAGS_MASK) - & ((uint32_t) BNXT_RE_WR_FLAGS_SIGNALED); - wqe_type = BNXT_RE_HDR_WT_MASK & opcode; - - hdr.rsv_ws_fl_wt = (wqe_size << BNXT_RE_HDR_WS_SHIFT) - | (hdr_flags << BNXT_RE_HDR_FLAGS_SHIFT) - | wqe_type; - hdr.key_immd = rkey; - hdr.lhdr.rva = (uint64_t) raddr; - - /* Populate AMO Segment */ - amo.swp_dt = atomic_data; - amo.cmp_dt = atomic_cmp; - - /* Populate SG Segment - (Return address of atomic) */ - if (fetching) { - atomic_idx = fetching_atomic_idx++ % FETCHING_ATOMIC_CNT; - sge.pa = (uint64_t) &fetching_atomic[atomic_idx]; - sge.lkey = fetching_atomic_lkey; - } else { - sge.pa = (uint64_t) nonfetching_atomic; - sge.lkey = nonfetching_atomic_lkey; - } - sge.length = length; - - /* Write WQE to SQ */ - memcpy(hdr_ptr, &hdr, sizeof(struct bnxt_re_bsqe)); - memcpy(amo_ptr, &amo, sizeof(struct bnxt_re_atomic)); - memcpy(sge_ptr, &sge, sizeof(struct bnxt_re_sge)); - - /* Populate MSN Table */ - bnxt_re_fill_psns_for_msntbl(&sq, length); - - /* Update SQ Pointer */ - bnxt_re_incr_tail(&sq, GDA_BNXT_WQE_SLOT_COUNT); + atomic_idx = bnxt_write_amo_wqe(raddr, opcode, atomic_data, atomic_cmp, fetching); /* Ring Doorbell */ bnxt_ring_doorbell(sq.tail); @@ -477,4 +440,39 @@ __device__ uint64_t QueuePair::bnxt_post_wqe_amo(int pe, int32_t length, uintptr return 0; } +__device__ uint64_t QueuePair::bnxt_post_wqe_amo_single(uintptr_t *raddr, uint8_t opcode, + int64_t atomic_data, int64_t atomic_cmp, + bool fetching) { + uint64_t active_lane_mask; + uint8_t active_lane_count; + uint8_t active_lane_id; + uint32_t atomic_idx = 0; + + active_lane_mask = get_active_lane_mask(); + active_lane_count = get_active_lane_count(active_lane_mask); + active_lane_id = get_active_lane_num(active_lane_mask); + + aquire_lock(&sq.lock); + + /* Write WQE to SQ */ + atomic_idx = bnxt_write_amo_wqe(raddr, opcode, atomic_data, atomic_cmp, fetching); + + /* Ring Doorbell + * Doorbell ring must be serialized as we cannot have all threads write to the same address */ + for (int i = 0; i < active_lane_count; i++) { + if (i == active_lane_id) { + bnxt_ring_doorbell(sq.tail); + } + } + + if (fetching) { + quiet(); + return fetching_atomic[atomic_idx]; + } + + release_lock(&sq.lock); + + return 0; +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/context_gda_device.hpp b/projects/rocshmem/src/gda/context_gda_device.hpp index 1fe9172644..1c1605ed04 100644 --- a/projects/rocshmem/src/gda/context_gda_device.hpp +++ b/projects/rocshmem/src/gda/context_gda_device.hpp @@ -277,10 +277,6 @@ class GDAContext : public Context { __device__ void internal_direct_barrier_wg(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); - __device__ void internal_direct_barrier_wg_thread_puts(int pe, int PE_start, - int stride, int n_pes, - int64_t *pSync); - __device__ void internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync); diff --git a/projects/rocshmem/src/gda/context_gda_device_coll.cpp b/projects/rocshmem/src/gda/context_gda_device_coll.cpp index 9ebadd318b..7610e4b2c9 100644 --- a/projects/rocshmem/src/gda/context_gda_device_coll.cpp +++ b/projects/rocshmem/src/gda/context_gda_device_coll.cpp @@ -121,55 +121,6 @@ __device__ void GDAContext::internal_direct_barrier_wg(int pe, int PE_start, } } -__device__ void GDAContext::internal_direct_barrier_wg_thread_puts(int pe, int PE_start, - int stride, int n_pes, - int64_t *pSync) { - int64_t flag_val{1}; - - if (pe == PE_start) { - int tid = get_flat_block_id(); - int step_size = min(get_flat_block_size(), WF_SIZE); - - // Go through all PE offsets (except current offset = 0) - // and wait until they all reach - for (int j = tid + 1; j < n_pes; j+= step_size) { - wait_until(&pSync[j], ROCSHMEM_CMP_EQ, flag_val); - pSync[j] = ROCSHMEM_SYNC_VALUE; - } - - __syncthreads(); - - // Announce to other PEs that all have reached - for (int i = tid + 1, j = PE_start + stride + tid; - i < n_pes; - i+= step_size, j += (step_size * stride)) { - uint64_t L_offset = reinterpret_cast(&pSync[0]) - base_heap[my_pe]; - qps[j].put_nbi_single(base_heap[j] + L_offset, &flag_val, sizeof(long), j); - } - - for (int i = tid + 1, j = PE_start + stride + tid; - i < n_pes; - i+= step_size, j += (step_size * stride)) { - pe_quiet_single(j); - } - - __syncthreads(); - - if (is_thread_zero_in_block()) { - pSync[0] = ROCSHMEM_SYNC_VALUE; - } - } else { - if (is_thread_zero_in_block()) { - // Mark current PE offset as reached - size_t pe_offset = (pe - PE_start) / stride; - putmem(&pSync[pe_offset], &flag_val, sizeof(long), PE_start); - wait_until(&pSync[0], ROCSHMEM_CMP_EQ, flag_val); - pSync[0] = ROCSHMEM_SYNC_VALUE; - __threadfence_system(); - } - } -} - __device__ void GDAContext::internal_atomic_barrier(int pe, int PE_start, int stride, int n_pes, int64_t *pSync) { diff --git a/projects/rocshmem/src/gda/context_gda_tmpl_device.hpp b/projects/rocshmem/src/gda/context_gda_tmpl_device.hpp index 2c48cd5f8b..bf6436ad2d 100644 --- a/projects/rocshmem/src/gda/context_gda_tmpl_device.hpp +++ b/projects/rocshmem/src/gda/context_gda_tmpl_device.hpp @@ -651,6 +651,7 @@ __device__ void GDAContext::alltoall_linear_thread_puts(rocshmem_team_t team, T int stride = team_obj->tinfo_wrt_world->stride; long *pSync = team_obj->alltoall_pSync; int my_pe_in_team = team_obj->my_pe; + uint64_t alltoall_pSync_offset = (team_obj->alltoall_sequence_number % 2) * pe_size; int tid = get_flat_block_id(); int step_size = min(get_flat_block_size(), WF_SIZE); @@ -658,17 +659,29 @@ __device__ void GDAContext::alltoall_linear_thread_puts(rocshmem_team_t team, T // Have each PE put their designated data to the other PEs for (int j = tid; j < pe_size; j+= step_size) { int dest_pe = team_obj->get_pe_in_world(j); - uint64_t L_offset = reinterpret_cast(&dst[my_pe_in_team * nelems]) - base_heap[my_pe]; - qps[dest_pe].put_nbi_single(base_heap[dest_pe] + L_offset, &src[j * nelems], nelems * sizeof(T), dest_pe); - } - - for (int j = tid; j < pe_size; j+= step_size) { - int dest_pe = team_obj->get_pe_in_world(j); - pe_quiet_single(dest_pe); + uint64_t base_heap_offset = base_heap[dest_pe] - base_heap[my_pe]; + qps[dest_pe].put_nbi_single(reinterpret_cast(&dst[my_pe_in_team * nelems]) + base_heap_offset, + &src[j * nelems], nelems * sizeof(T), false); + qps[dest_pe].atomic_nofetch_single(reinterpret_cast(&pSync[alltoall_pSync_offset + my_pe_in_team]) + base_heap_offset, + 1); } // wait until everyone has obtained their designated data - internal_direct_barrier_wg_thread_puts(my_pe, pe_start, stride, pe_size, pSync); + for (int j = tid; j < pe_size; j+= step_size) { + int dest_pe = team_obj->get_pe_in_world(j); + + volatile long *vol_ivars = &pSync[alltoall_pSync_offset + dest_pe]; + while (uncached_load(vol_ivars) != 1) { } + + pe_quiet_single(dest_pe); + + pSync[alltoall_pSync_offset + dest_pe] = ROCSHMEM_SYNC_VALUE; + } + + if (is_thread_zero_in_block()) { + team_obj->alltoall_sequence_number++; + } + __syncthreads(); } diff --git a/projects/rocshmem/src/gda/gda_team.hpp b/projects/rocshmem/src/gda/gda_team.hpp index 4d4a4e54b0..fa60429406 100644 --- a/projects/rocshmem/src/gda/gda_team.hpp +++ b/projects/rocshmem/src/gda/gda_team.hpp @@ -45,6 +45,7 @@ class GDATeam : public Team { void* pAta{nullptr}; int pool_index_{-1}; + uint64_t alltoall_sequence_number = 0; }; } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/queue_pair.cpp b/projects/rocshmem/src/gda/queue_pair.cpp index 2499d6cc22..9dced0f938 100644 --- a/projects/rocshmem/src/gda/queue_pair.cpp +++ b/projects/rocshmem/src/gda/queue_pair.cpp @@ -170,11 +170,11 @@ __device__ void QueuePair::post_wqe_rma_mt(int pe, int32_t size, uintptr_t *ladd } } -__device__ void QueuePair::post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { +__device__ void QueuePair::post_wqe_rma_single(int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, bool ring_db) { switch (gda_provider_) { #if defined(GDA_BNXT) case GDAProvider::BNXT: - return bnxt_post_wqe_rma_single(pe, size, laddr, raddr, opcode); + return bnxt_post_wqe_rma_single(size, laddr, raddr, opcode, ring_db); #endif case GDAProvider::IONIC: case GDAProvider::MLX5: @@ -192,7 +192,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad #endif #if defined(GDA_BNXT) case GDAProvider::BNXT: - return bnxt_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching); + return bnxt_post_wqe_amo(raddr, opcode, atomic_data, atomic_cmp, fetching); #endif #if defined(GDA_IONIC) case GDAProvider::IONIC: @@ -204,6 +204,22 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad } } +__device__ uint64_t QueuePair::post_wqe_amo_single(uintptr_t *raddr, uint8_t opcode, + int64_t atomic_data, int64_t atomic_cmp, + bool fetching) { + switch (gda_provider_) { +#if defined(GDA_BNXT) + case GDAProvider::BNXT: + return bnxt_post_wqe_amo_single(raddr, opcode, atomic_data, atomic_cmp, fetching); +#endif + case GDAProvider::MLX5: + case GDAProvider::IONIC: + default: + assert(false /* invalid nic provider */); + return 0; + } +} + __device__ void QueuePair::quiet(Collectivity cy) { switch (gda_provider_) { #if defined(GDA_MLX5) @@ -253,10 +269,10 @@ __device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems post_wqe_rma(pe, nelems, src, dst, gda_op_rdma_write, cy); } -__device__ void QueuePair::put_nbi_single(void *dest, const void *source, size_t nelems, int pe) { +__device__ void QueuePair::put_nbi_single(void *dest, const void *source, size_t nelems, bool ring_db) { uintptr_t *src = reinterpret_cast(const_cast(source)); uintptr_t *dst = reinterpret_cast(dest); - post_wqe_rma_single(pe, nelems, src, dst, gda_op_rdma_write); + post_wqe_rma_single(nelems, src, dst, gda_op_rdma_write, ring_db); } __device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy) { @@ -285,4 +301,10 @@ __device__ void QueuePair::atomic_nofetch(void *dest, int64_t atomic_data, int64 post_wqe_amo(pe, sizeof(int64_t), dst, gda_op_atomic_fa, atomic_data, atomic_cmp, false); } +__device__ void QueuePair::atomic_nofetch_single(void *dest, int64_t value) { + const bool fetching = false; + uintptr_t *dst = static_cast(dest); + post_wqe_amo_single(dst, gda_op_atomic_fa, value, 0, false); +} + } // namespace rocshmem diff --git a/projects/rocshmem/src/gda/queue_pair.hpp b/projects/rocshmem/src/gda/queue_pair.hpp index 84a84a5d0d..58684977a7 100644 --- a/projects/rocshmem/src/gda/queue_pair.hpp +++ b/projects/rocshmem/src/gda/queue_pair.hpp @@ -77,7 +77,8 @@ class QueuePair { * @param[in] pe Destination processing element of data transmission. */ __device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe, Collectivity cy = THREAD); - __device__ void put_nbi_single(void *dest, const void *source, size_t nelems, int pe); + + __device__ void put_nbi_single(void *dest, const void *source, size_t nelems, bool ring_db); /** * @brief Create and enqueue a non-blocking get work queue entry (wqe). @@ -117,6 +118,8 @@ class QueuePair { */ __device__ void atomic_nofetch(void *dest, int64_t value, int64_t cond, int pe); + __device__ void atomic_nofetch_single(void *dest, int64_t value); + /** * @brief Create and enqueue an atomic cas work queue entry (wqe). * @@ -155,6 +158,12 @@ class QueuePair { */ __device__ __attribute__((noinline)) uint64_t post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); + __device__ __attribute__((noinline)) uint64_t post_wqe_amo_single(uintptr_t *raddr, + uint8_t opcode, + int64_t atomic_data, + int64_t atomic_cmp, + bool fetching); + /** * @brief Helper method to build work requests for the send queue. * @@ -166,7 +175,8 @@ class QueuePair { */ __device__ __attribute__((noinline)) void post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); __device__ __attribute__((noinline)) void post_wqe_rma_turn(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, Collectivity cy); - __device__ __attribute__((noinline)) void post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + + __device__ __attribute__((noinline)) void post_wqe_rma_single(int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, bool ring_db); __device__ __attribute__((noinline)) void post_wqe_rma_mt(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); #if defined(GDA_MLX5) @@ -175,9 +185,16 @@ class QueuePair { __device__ void mlx5_quiet(); #endif #if defined(GDA_BNXT) - __device__ uint64_t bnxt_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); + + __device__ void bnxt_write_rma_wqe(uintptr_t *raddr, uintptr_t *laddr, int32_t length, uint8_t opcode); + __device__ uint32_t bnxt_write_amo_wqe(uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching); + + __device__ uint64_t bnxt_post_wqe_amo_single(uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching); + __device__ uint64_t bnxt_post_wqe_amo(uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching); + __device__ void bnxt_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); - __device__ void bnxt_post_wqe_rma_single(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); + + __device__ void bnxt_post_wqe_rma_single(int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode, bool ring_db); __device__ void bnxt_quiet(); __device__ void bnxt_quiet_single(); #endif