Refactor: modularize RMA and AMO WQE posting functions (#331)

* Refactor: modularize RMA and AMO WQE posting functions
  - Extract shared logic for SQ/CQ waiting, doorbell ringing, and WQE building
* Remove unused variables
* Update return buffer address calculation for atomics
This commit is contained in:
Avinash Kethineedi
2025-12-08 14:54:41 -06:00
کامیت شده توسط GitHub
والد d5bcb3a201
کامیت 1acf454048
3فایلهای تغییر یافته به همراه189 افزوده شده و 94 حذف شده
+153 -89
مشاهده پرونده
@@ -122,41 +122,47 @@ __device__ void QueuePair::mlx5_quiet() {
}
}
__device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
uint64_t activemask = get_active_lane_mask();
uint8_t num_active_lanes = get_active_lane_count(activemask);
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
bool is_leader{my_logical_lane_id == 0};
const uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
uint8_t num_wqes{num_active_lanes};
uint64_t wave_sq_counter{0};
if (is_leader) {
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT);
}
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
uint64_t my_sq_counter = wave_sq_counter + my_logical_lane_id;
uint64_t my_sq_index = my_sq_counter % sq_wqe_cnt;
__device__ __forceinline__ void QueuePair::mlx5_wait_for_free_sq_slots(
uint64_t wave_sq_counter, uint8_t num_active_lanes) {
while (true) {
uint64_t db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
uint64_t sunk = __hip_atomic_load(&sq_sunk, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
int64_t num_active_sq_entries = db_touched - sunk;
uint64_t db_touched = __hip_atomic_load(&sq_db_touched,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
uint64_t sunk = __hip_atomic_load(&sq_sunk,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
int64_t num_active_sq_entries =
static_cast<int64_t>(db_touched) -
static_cast<int64_t>(sunk);
if (num_active_sq_entries < 0) {
continue;
}
uint64_t num_free_entries = min(sq_wqe_cnt, cq_cnt) - num_active_sq_entries;
uint64_t num_entries_until_wave_last_entry = wave_sq_counter + num_active_lanes - db_touched;
uint64_t num_free_entries =
min(sq_wqe_cnt, cq_cnt) -
static_cast<uint64_t>(num_active_sq_entries);
uint64_t num_entries_until_wave_last_entry =
wave_sq_counter + num_active_lanes - db_touched;
if (num_free_entries > num_entries_until_wave_last_entry) {
break;
}
mlx5_quiet();
}
}
__device__ __forceinline__ void QueuePair::mlx5_build_rma_wqe(
uint64_t my_sq_counter, uint64_t my_sq_index, uintptr_t *laddr,
uintptr_t *raddr, int32_t size, uint8_t opcode) {
outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter;
SegmentBuilder seg_build(my_sq_index, sq_buf);
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, MLX5_WQE_CTRL_CQ_UPDATE, 3, 0, 0);
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num,
MLX5_WQE_CTRL_CQ_UPDATE, 3, 0, 0);
seg_build.update_raddr_seg(raddr, rkey);
if (size <= inline_threshold && opcode == gda_op_rdma_write) {
@@ -164,105 +170,163 @@ __device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *la
} else {
seg_build.update_data_seg(laddr, size, lkey);
}
}
__device__ __forceinline__ void QueuePair::mlx5_wait_for_db_touched_eq(
uint64_t target_sq_counter) {
uint64_t db_touched {0};
do {
db_touched = __hip_atomic_load(&sq_db_touched,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
} while (db_touched != target_sq_counter);
}
__device__ __forceinline__ void QueuePair::mlx5_ring_doorbell(
uint64_t wave_sq_counter, uint8_t num_wqes) {
mlx5_wait_for_db_touched_eq(wave_sq_counter);
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
uint64_t* ctrl_wqe_8B_for_db =
reinterpret_cast<uint64_t*>(&base_ptr[64 *
((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
__hip_atomic_fetch_add(&quiet_posted, num_wqes,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
__device__ void QueuePair::mlx5_post_wqe_rma(int32_t size, uintptr_t *laddr,
uintptr_t *raddr, uint8_t opcode) {
uint64_t activemask = get_active_lane_mask();
uint8_t num_active_lanes = get_active_lane_count(activemask);
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
bool is_leader = {my_logical_lane_id == 0};
uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
uint8_t num_wqes = num_active_lanes;
uint64_t wave_sq_counter = 0;
uint64_t my_sq_counter = 0;
uint64_t my_sq_index = 0;
// 1. Leader allocates SQ entries for the whole wave
if (is_leader) {
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes,
__ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT);
}
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
my_sq_counter = wave_sq_counter + my_logical_lane_id;
my_sq_index = my_sq_counter % sq_wqe_cnt;
// 2. Wait for SQ space for the whole wave
mlx5_wait_for_free_sq_slots(wave_sq_counter, num_active_lanes);
// 3. Build the WQE for this lane
mlx5_build_rma_wqe(my_sq_counter, my_sq_index, laddr, raddr, size, opcode);
__atomic_signal_fence(__ATOMIC_SEQ_CST);
// 4. Leader rings doorbell for the wave
if (is_leader) {
uint64_t db_touched {0};
do {
db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
} while (db_touched != wave_sq_counter);
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast<uint64_t*>(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
__hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
mlx5_ring_doorbell(wave_sq_counter, num_wqes);
}
}
__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
uint64_t activemask = get_active_lane_mask();
uint8_t num_active_lanes = get_active_lane_count(activemask);
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
bool is_leader{my_logical_lane_id == 0};
const uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
uint8_t num_wqes{num_active_lanes};
uint64_t wave_sq_counter{0};
if (is_leader) {
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
uint64_t my_sq_counter = wave_sq_counter + my_logical_lane_id;
uint64_t my_sq_index = my_sq_counter % sq_wqe_cnt;
while (true) {
uint64_t db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
uint64_t sunk = __hip_atomic_load(&sq_sunk, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
int64_t num_active_sq_entries = db_touched - sunk;
if (num_active_sq_entries < 0) {
continue;
}
uint64_t num_free_entries = min(sq_wqe_cnt, cq_cnt) - num_active_sq_entries;
uint64_t num_entries_until_wave_last_entry = wave_sq_counter + num_active_lanes - db_touched;
if (num_free_entries > num_entries_until_wave_last_entry) {
break;
}
mlx5_quiet();
}
__device__ __forceinline__ uint64_t*
QueuePair::mlx5_allocate_wave_fetching_atomic_buffer(
uint64_t wave_sq_counter, bool is_leader,
uint64_t leader_phys_lane_id) {
uint64_t* wave_fetch_atomic{nullptr};
if (fetching) {
if (is_leader) {
uint64_t db_touched {0};
do {
db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
} while (db_touched != wave_sq_counter);
if (is_leader) {
mlx5_wait_for_db_touched_eq(wave_sq_counter);
auto res = fetching_atomic_freelist->pop_front();
while (!res.success) {
res = fetching_atomic_freelist->pop_front();
}
wave_fetch_atomic = res.value;
auto res = fetching_atomic_freelist->pop_front();
while (!res.success) {
res = fetching_atomic_freelist->pop_front();
}
wave_fetch_atomic = (uint64_t*)__shfl((uint64_t)wave_fetch_atomic, leader_phys_lane_id);
wave_fetch_atomic = res.value;
}
wave_fetch_atomic = (uint64_t*)__shfl((uint64_t)wave_fetch_atomic,
leader_phys_lane_id);
return wave_fetch_atomic;
}
__device__ __forceinline__ void QueuePair::mlx5_build_amo_wqe(
uint64_t my_sq_counter, uint64_t my_sq_index, uintptr_t *raddr,
uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching,
uint64_t *wave_fetch_atomic) {
outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter;
SegmentBuilder seg_build(my_sq_index, sq_buf);
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, MLX5_WQE_CTRL_CQ_UPDATE, 4, 0, 0);
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num,
MLX5_WQE_CTRL_CQ_UPDATE, 4, 0, 0);
seg_build.update_raddr_seg(raddr, rkey);
seg_build.update_atomic_seg(atomic_data, atomic_cmp);
if (fetching) {
seg_build.update_data_seg(wave_fetch_atomic + my_logical_lane_id, 8, fetching_atomic_lkey);
seg_build.update_data_seg(wave_fetch_atomic, 8, fetching_atomic_lkey);
} else {
seg_build.update_data_seg(nonfetching_atomic, 8, nonfetching_atomic_lkey);
}
__atomic_signal_fence(__ATOMIC_SEQ_CST);
}
__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int32_t size,
uintptr_t *raddr, uint8_t opcode, int64_t atomic_data,
int64_t atomic_cmp, bool fetching) {
uint64_t activemask = get_active_lane_mask();
uint8_t num_active_lanes = get_active_lane_count(activemask);
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
bool is_leader = {my_logical_lane_id == 0};
uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
uint8_t num_wqes = num_active_lanes;
uint64_t wave_sq_counter = 0;
uint64_t my_sq_counter = 0;
uint64_t my_sq_index = 0;
// 1. Leader allocates SQ entries for the whole wave
if (is_leader) {
uint64_t db_touched {0};
do {
db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
} while (db_touched != wave_sq_counter);
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
my_sq_counter = wave_sq_counter + my_logical_lane_id;
my_sq_index = my_sq_counter % sq_wqe_cnt;
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast<uint64_t*>(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
// 2. Wait for SQ space for the whole wave
mlx5_wait_for_free_sq_slots(wave_sq_counter, num_active_lanes);
__hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
uint64_t* wave_fetch_atomic{nullptr};
if (fetching) {
wave_fetch_atomic = mlx5_allocate_wave_fetching_atomic_buffer(
wave_sq_counter,
is_leader,
leader_phys_lane_id);
}
// 3. Build the WQE for this lane
mlx5_build_amo_wqe(my_sq_counter, my_sq_index, raddr, opcode,
atomic_data, atomic_cmp, fetching,
wave_fetch_atomic + my_logical_lane_id);
__atomic_signal_fence(__ATOMIC_SEQ_CST);
// 4. Leader rings doorbell for the wave
if (is_leader) {
mlx5_ring_doorbell(wave_sq_counter, num_wqes);
}
// 5. Fetch result if requested
uint64_t ret{0};
if (fetching) {
mlx5_quiet();
ret = wave_fetch_atomic[my_logical_lane_id];
__atomic_signal_fence(__ATOMIC_SEQ_CST);
if (is_leader) {
fetching_atomic_freelist->push_back(wave_fetch_atomic);
}
+2 -2
مشاهده پرونده
@@ -157,7 +157,7 @@ __device__ void QueuePair::post_wqe_rma_mt(int pe, int32_t size, uintptr_t *ladd
switch (gda_provider_) {
#if defined(GDA_MLX5)
case GDAProvider::MLX5:
mlx5_post_wqe_rma(pe, size, laddr, raddr, opcode);
mlx5_post_wqe_rma(size, laddr, raddr, opcode);
return;
#endif
#if defined(GDA_BNXT)
@@ -188,7 +188,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
switch (gda_provider_) {
#if defined(GDA_MLX5)
case GDAProvider::MLX5:
return mlx5_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching);
return mlx5_post_wqe_amo(size, raddr, opcode, atomic_data, atomic_cmp, fetching);
#endif
#if defined(GDA_BNXT)
case GDAProvider::BNXT:
+34 -3
مشاهده پرونده
@@ -180,9 +180,40 @@ class QueuePair {
__device__ __attribute__((noinline)) void post_wqe_rma_mt(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode);
#if defined(GDA_MLX5)
__device__ uint64_t mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch);
__device__ void mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode);
__device__ void mlx5_quiet();
__device__ __forceinline__ void
mlx5_wait_for_free_sq_slots(uint64_t wave_sq_counter,
uint8_t num_active_lanes);
__device__ __forceinline__ void
mlx5_wait_for_db_touched_eq(uint64_t target_sq_counter);
__device__ __forceinline__ void
mlx5_build_rma_wqe(uint64_t my_sq_counter, uint64_t my_sq_index,
uintptr_t *laddr, uintptr_t *raddr, int32_t size, uint8_t opcode);
__device__ __forceinline__ void
mlx5_build_amo_wqe(uint64_t my_sq_counter, uint64_t my_sq_index,
uintptr_t *raddr, uint8_t opcode, int64_t atomic_data,
int64_t atomic_cmp, bool fetching, uint64_t *wave_fetch_atomic);
__device__ __forceinline__ uint64_t*
mlx5_allocate_wave_fetching_atomic_buffer(uint64_t wave_sq_counter,
bool is_leader, uint64_t leader_phys_lane_id);
__device__ __forceinline__ void
mlx5_ring_doorbell(uint64_t wave_sq_counter, uint8_t num_wqes);
__device__ uint64_t
mlx5_post_wqe_amo(int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetch);
__device__ void
mlx5_post_wqe_rma(int32_t size, uintptr_t *laddr,
uintptr_t *raddr, uint8_t opcode);
__device__ void
mlx5_quiet();
#endif
#if defined(GDA_BNXT)