diff --git a/src/gda/mlx5/queue_pair_mlx5.cpp b/src/gda/mlx5/queue_pair_mlx5.cpp index dcbf3974f2..54a34424b2 100644 --- a/src/gda/mlx5/queue_pair_mlx5.cpp +++ b/src/gda/mlx5/queue_pair_mlx5.cpp @@ -122,41 +122,47 @@ __device__ void QueuePair::mlx5_quiet() { } } -__device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) { - uint64_t activemask = get_active_lane_mask(); - uint8_t num_active_lanes = get_active_lane_count(activemask); - uint8_t my_logical_lane_id = get_active_lane_num(activemask); - bool is_leader{my_logical_lane_id == 0}; - const uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask); - uint8_t num_wqes{num_active_lanes}; - uint64_t wave_sq_counter{0}; - - if (is_leader) { - wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT); - } - wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id); - uint64_t my_sq_counter = wave_sq_counter + my_logical_lane_id; - uint64_t my_sq_index = my_sq_counter % sq_wqe_cnt; - +__device__ __forceinline__ void QueuePair::mlx5_wait_for_free_sq_slots( + uint64_t wave_sq_counter, uint8_t num_active_lanes) { while (true) { - uint64_t db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint64_t sunk = __hip_atomic_load(&sq_sunk, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - int64_t num_active_sq_entries = db_touched - sunk; + uint64_t db_touched = __hip_atomic_load(&sq_db_touched, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + + uint64_t sunk = __hip_atomic_load(&sq_sunk, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + + int64_t num_active_sq_entries = + static_cast(db_touched) - + static_cast(sunk); + if (num_active_sq_entries < 0) { continue; } - uint64_t num_free_entries = min(sq_wqe_cnt, cq_cnt) - num_active_sq_entries; - uint64_t num_entries_until_wave_last_entry = wave_sq_counter + num_active_lanes - db_touched; + + uint64_t num_free_entries = + min(sq_wqe_cnt, cq_cnt) - + static_cast(num_active_sq_entries); + + uint64_t num_entries_until_wave_last_entry = + wave_sq_counter + num_active_lanes - db_touched; + if (num_free_entries > num_entries_until_wave_last_entry) { break; } + mlx5_quiet(); } +} +__device__ __forceinline__ void QueuePair::mlx5_build_rma_wqe( + uint64_t my_sq_counter, uint64_t my_sq_index, uintptr_t *laddr, + uintptr_t *raddr, int32_t size, uint8_t opcode) { outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; SegmentBuilder seg_build(my_sq_index, sq_buf); - seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, MLX5_WQE_CTRL_CQ_UPDATE, 3, 0, 0); + + seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, + MLX5_WQE_CTRL_CQ_UPDATE, 3, 0, 0); seg_build.update_raddr_seg(raddr, rkey); if (size <= inline_threshold && opcode == gda_op_rdma_write) { @@ -164,105 +170,163 @@ __device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *la } else { seg_build.update_data_seg(laddr, size, lkey); } +} + +__device__ __forceinline__ void QueuePair::mlx5_wait_for_db_touched_eq( + uint64_t target_sq_counter) { + uint64_t db_touched {0}; + do { + db_touched = __hip_atomic_load(&sq_db_touched, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } while (db_touched != target_sq_counter); +} + +__device__ __forceinline__ void QueuePair::mlx5_ring_doorbell( + uint64_t wave_sq_counter, uint8_t num_wqes) { + mlx5_wait_for_db_touched_eq(wave_sq_counter); + + uint8_t *base_ptr = reinterpret_cast(sq_buf); + uint64_t* ctrl_wqe_8B_for_db = + reinterpret_cast(&base_ptr[64 * + ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]); + + mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes); + + __hip_atomic_fetch_add(&quiet_posted, num_wqes, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + + __hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ void QueuePair::mlx5_post_wqe_rma(int32_t size, uintptr_t *laddr, + uintptr_t *raddr, uint8_t opcode) { + uint64_t activemask = get_active_lane_mask(); + uint8_t num_active_lanes = get_active_lane_count(activemask); + uint8_t my_logical_lane_id = get_active_lane_num(activemask); + bool is_leader = {my_logical_lane_id == 0}; + uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask); + + uint8_t num_wqes = num_active_lanes; + uint64_t wave_sq_counter = 0; + uint64_t my_sq_counter = 0; + uint64_t my_sq_index = 0; + + // 1. Leader allocates SQ entries for the whole wave + if (is_leader) { + wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, + __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT); + } + wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id); + my_sq_counter = wave_sq_counter + my_logical_lane_id; + my_sq_index = my_sq_counter % sq_wqe_cnt; + + // 2. Wait for SQ space for the whole wave + mlx5_wait_for_free_sq_slots(wave_sq_counter, num_active_lanes); + + // 3. Build the WQE for this lane + mlx5_build_rma_wqe(my_sq_counter, my_sq_index, laddr, raddr, size, opcode); __atomic_signal_fence(__ATOMIC_SEQ_CST); + // 4. Leader rings doorbell for the wave if (is_leader) { - uint64_t db_touched {0}; - do { - db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - } while (db_touched != wave_sq_counter); - - uint8_t *base_ptr = reinterpret_cast(sq_buf); - uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]); - mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes); - - __hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + mlx5_ring_doorbell(wave_sq_counter, num_wqes); } } -__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, - int64_t atomic_data, int64_t atomic_cmp, bool fetching) { - uint64_t activemask = get_active_lane_mask(); - uint8_t num_active_lanes = get_active_lane_count(activemask); - uint8_t my_logical_lane_id = get_active_lane_num(activemask); - bool is_leader{my_logical_lane_id == 0}; - const uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask); - uint8_t num_wqes{num_active_lanes}; - uint64_t wave_sq_counter{0}; - - if (is_leader) { - wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - } - wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id); - uint64_t my_sq_counter = wave_sq_counter + my_logical_lane_id; - uint64_t my_sq_index = my_sq_counter % sq_wqe_cnt; - - while (true) { - uint64_t db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint64_t sunk = __hip_atomic_load(&sq_sunk, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - int64_t num_active_sq_entries = db_touched - sunk; - if (num_active_sq_entries < 0) { - continue; - } - uint64_t num_free_entries = min(sq_wqe_cnt, cq_cnt) - num_active_sq_entries; - uint64_t num_entries_until_wave_last_entry = wave_sq_counter + num_active_lanes - db_touched; - if (num_free_entries > num_entries_until_wave_last_entry) { - break; - } - mlx5_quiet(); - } - +__device__ __forceinline__ uint64_t* +QueuePair::mlx5_allocate_wave_fetching_atomic_buffer( + uint64_t wave_sq_counter, bool is_leader, + uint64_t leader_phys_lane_id) { uint64_t* wave_fetch_atomic{nullptr}; - if (fetching) { - if (is_leader) { - uint64_t db_touched {0}; - do { - db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - } while (db_touched != wave_sq_counter); + if (is_leader) { + mlx5_wait_for_db_touched_eq(wave_sq_counter); - auto res = fetching_atomic_freelist->pop_front(); - while (!res.success) { - res = fetching_atomic_freelist->pop_front(); - } - wave_fetch_atomic = res.value; + auto res = fetching_atomic_freelist->pop_front(); + while (!res.success) { + res = fetching_atomic_freelist->pop_front(); } - wave_fetch_atomic = (uint64_t*)__shfl((uint64_t)wave_fetch_atomic, leader_phys_lane_id); + wave_fetch_atomic = res.value; } + wave_fetch_atomic = (uint64_t*)__shfl((uint64_t)wave_fetch_atomic, + leader_phys_lane_id); + return wave_fetch_atomic; +} +__device__ __forceinline__ void QueuePair::mlx5_build_amo_wqe( + uint64_t my_sq_counter, uint64_t my_sq_index, uintptr_t *raddr, + uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching, + uint64_t *wave_fetch_atomic) { outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; SegmentBuilder seg_build(my_sq_index, sq_buf); - seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, MLX5_WQE_CTRL_CQ_UPDATE, 4, 0, 0); + seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, + MLX5_WQE_CTRL_CQ_UPDATE, 4, 0, 0); seg_build.update_raddr_seg(raddr, rkey); seg_build.update_atomic_seg(atomic_data, atomic_cmp); + if (fetching) { - seg_build.update_data_seg(wave_fetch_atomic + my_logical_lane_id, 8, fetching_atomic_lkey); + seg_build.update_data_seg(wave_fetch_atomic, 8, fetching_atomic_lkey); } else { seg_build.update_data_seg(nonfetching_atomic, 8, nonfetching_atomic_lkey); } - __atomic_signal_fence(__ATOMIC_SEQ_CST); +} +__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int32_t size, + uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, + int64_t atomic_cmp, bool fetching) { + uint64_t activemask = get_active_lane_mask(); + uint8_t num_active_lanes = get_active_lane_count(activemask); + uint8_t my_logical_lane_id = get_active_lane_num(activemask); + bool is_leader = {my_logical_lane_id == 0}; + uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask); + + uint8_t num_wqes = num_active_lanes; + uint64_t wave_sq_counter = 0; + uint64_t my_sq_counter = 0; + uint64_t my_sq_index = 0; + + // 1. Leader allocates SQ entries for the whole wave if (is_leader) { - uint64_t db_touched {0}; - do { - db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - } while (db_touched != wave_sq_counter); + wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } + wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id); + my_sq_counter = wave_sq_counter + my_logical_lane_id; + my_sq_index = my_sq_counter % sq_wqe_cnt; - uint8_t *base_ptr = reinterpret_cast(sq_buf); - uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]); - mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes); + // 2. Wait for SQ space for the whole wave + mlx5_wait_for_free_sq_slots(wave_sq_counter, num_active_lanes); - __hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + uint64_t* wave_fetch_atomic{nullptr}; + if (fetching) { + wave_fetch_atomic = mlx5_allocate_wave_fetching_atomic_buffer( + wave_sq_counter, + is_leader, + leader_phys_lane_id); } + // 3. Build the WQE for this lane + mlx5_build_amo_wqe(my_sq_counter, my_sq_index, raddr, opcode, + atomic_data, atomic_cmp, fetching, + wave_fetch_atomic + my_logical_lane_id); + + __atomic_signal_fence(__ATOMIC_SEQ_CST); + + // 4. Leader rings doorbell for the wave + if (is_leader) { + mlx5_ring_doorbell(wave_sq_counter, num_wqes); + } + + // 5. Fetch result if requested uint64_t ret{0}; if (fetching) { mlx5_quiet(); ret = wave_fetch_atomic[my_logical_lane_id]; + __atomic_signal_fence(__ATOMIC_SEQ_CST); + if (is_leader) { fetching_atomic_freelist->push_back(wave_fetch_atomic); } diff --git a/src/gda/queue_pair.cpp b/src/gda/queue_pair.cpp index 9dced0f938..2e6d2b9e0c 100644 --- a/src/gda/queue_pair.cpp +++ b/src/gda/queue_pair.cpp @@ -157,7 +157,7 @@ __device__ void QueuePair::post_wqe_rma_mt(int pe, int32_t size, uintptr_t *ladd switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: - mlx5_post_wqe_rma(pe, size, laddr, raddr, opcode); + mlx5_post_wqe_rma(size, laddr, raddr, opcode); return; #endif #if defined(GDA_BNXT) @@ -188,7 +188,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad switch (gda_provider_) { #if defined(GDA_MLX5) case GDAProvider::MLX5: - return mlx5_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching); + return mlx5_post_wqe_amo(size, raddr, opcode, atomic_data, atomic_cmp, fetching); #endif #if defined(GDA_BNXT) case GDAProvider::BNXT: diff --git a/src/gda/queue_pair.hpp b/src/gda/queue_pair.hpp index 58684977a7..94496ce261 100644 --- a/src/gda/queue_pair.hpp +++ b/src/gda/queue_pair.hpp @@ -180,9 +180,40 @@ class QueuePair { __device__ __attribute__((noinline)) void post_wqe_rma_mt(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); #if defined(GDA_MLX5) - __device__ uint64_t mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch); - __device__ void mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode); - __device__ void mlx5_quiet(); + __device__ __forceinline__ void + mlx5_wait_for_free_sq_slots(uint64_t wave_sq_counter, + uint8_t num_active_lanes); + + __device__ __forceinline__ void + mlx5_wait_for_db_touched_eq(uint64_t target_sq_counter); + + __device__ __forceinline__ void + mlx5_build_rma_wqe(uint64_t my_sq_counter, uint64_t my_sq_index, + uintptr_t *laddr, uintptr_t *raddr, int32_t size, uint8_t opcode); + + __device__ __forceinline__ void + mlx5_build_amo_wqe(uint64_t my_sq_counter, uint64_t my_sq_index, + uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, + int64_t atomic_cmp, bool fetching, uint64_t *wave_fetch_atomic); + + __device__ __forceinline__ uint64_t* + mlx5_allocate_wave_fetching_atomic_buffer(uint64_t wave_sq_counter, + bool is_leader, uint64_t leader_phys_lane_id); + + __device__ __forceinline__ void + mlx5_ring_doorbell(uint64_t wave_sq_counter, uint8_t num_wqes); + + __device__ uint64_t + mlx5_post_wqe_amo(int32_t size, uintptr_t *raddr, uint8_t opcode, + int64_t atomic_data, int64_t atomic_cmp, bool fetch); + + __device__ void + mlx5_post_wqe_rma(int32_t size, uintptr_t *laddr, + uintptr_t *raddr, uint8_t opcode); + + __device__ void + mlx5_quiet(); + #endif #if defined(GDA_BNXT)