Refactor: modularize RMA and AMO WQE posting functions (#331)
* Refactor: modularize RMA and AMO WQE posting functions - Extract shared logic for SQ/CQ waiting, doorbell ringing, and WQE building * Remove unused variables * Update return buffer address calculation for atomics
This commit is contained in:
کامیت شده توسط
GitHub
والد
d5bcb3a201
کامیت
1acf454048
@@ -122,41 +122,47 @@ __device__ void QueuePair::mlx5_quiet() {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
|
||||
uint64_t activemask = get_active_lane_mask();
|
||||
uint8_t num_active_lanes = get_active_lane_count(activemask);
|
||||
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
|
||||
bool is_leader{my_logical_lane_id == 0};
|
||||
const uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
|
||||
uint8_t num_wqes{num_active_lanes};
|
||||
uint64_t wave_sq_counter{0};
|
||||
|
||||
if (is_leader) {
|
||||
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT);
|
||||
}
|
||||
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
|
||||
uint64_t my_sq_counter = wave_sq_counter + my_logical_lane_id;
|
||||
uint64_t my_sq_index = my_sq_counter % sq_wqe_cnt;
|
||||
|
||||
__device__ __forceinline__ void QueuePair::mlx5_wait_for_free_sq_slots(
|
||||
uint64_t wave_sq_counter, uint8_t num_active_lanes) {
|
||||
while (true) {
|
||||
uint64_t db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
uint64_t sunk = __hip_atomic_load(&sq_sunk, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
int64_t num_active_sq_entries = db_touched - sunk;
|
||||
uint64_t db_touched = __hip_atomic_load(&sq_db_touched,
|
||||
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
|
||||
uint64_t sunk = __hip_atomic_load(&sq_sunk,
|
||||
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
|
||||
int64_t num_active_sq_entries =
|
||||
static_cast<int64_t>(db_touched) -
|
||||
static_cast<int64_t>(sunk);
|
||||
|
||||
if (num_active_sq_entries < 0) {
|
||||
continue;
|
||||
}
|
||||
uint64_t num_free_entries = min(sq_wqe_cnt, cq_cnt) - num_active_sq_entries;
|
||||
uint64_t num_entries_until_wave_last_entry = wave_sq_counter + num_active_lanes - db_touched;
|
||||
|
||||
uint64_t num_free_entries =
|
||||
min(sq_wqe_cnt, cq_cnt) -
|
||||
static_cast<uint64_t>(num_active_sq_entries);
|
||||
|
||||
uint64_t num_entries_until_wave_last_entry =
|
||||
wave_sq_counter + num_active_lanes - db_touched;
|
||||
|
||||
if (num_free_entries > num_entries_until_wave_last_entry) {
|
||||
break;
|
||||
}
|
||||
|
||||
mlx5_quiet();
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void QueuePair::mlx5_build_rma_wqe(
|
||||
uint64_t my_sq_counter, uint64_t my_sq_index, uintptr_t *laddr,
|
||||
uintptr_t *raddr, int32_t size, uint8_t opcode) {
|
||||
outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter;
|
||||
|
||||
SegmentBuilder seg_build(my_sq_index, sq_buf);
|
||||
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, MLX5_WQE_CTRL_CQ_UPDATE, 3, 0, 0);
|
||||
|
||||
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num,
|
||||
MLX5_WQE_CTRL_CQ_UPDATE, 3, 0, 0);
|
||||
seg_build.update_raddr_seg(raddr, rkey);
|
||||
|
||||
if (size <= inline_threshold && opcode == gda_op_rdma_write) {
|
||||
@@ -164,105 +170,163 @@ __device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *la
|
||||
} else {
|
||||
seg_build.update_data_seg(laddr, size, lkey);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void QueuePair::mlx5_wait_for_db_touched_eq(
|
||||
uint64_t target_sq_counter) {
|
||||
uint64_t db_touched {0};
|
||||
do {
|
||||
db_touched = __hip_atomic_load(&sq_db_touched,
|
||||
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
} while (db_touched != target_sq_counter);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void QueuePair::mlx5_ring_doorbell(
|
||||
uint64_t wave_sq_counter, uint8_t num_wqes) {
|
||||
mlx5_wait_for_db_touched_eq(wave_sq_counter);
|
||||
|
||||
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
|
||||
uint64_t* ctrl_wqe_8B_for_db =
|
||||
reinterpret_cast<uint64_t*>(&base_ptr[64 *
|
||||
((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
|
||||
|
||||
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
|
||||
|
||||
__hip_atomic_fetch_add(&quiet_posted, num_wqes,
|
||||
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
|
||||
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes,
|
||||
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
}
|
||||
|
||||
__device__ void QueuePair::mlx5_post_wqe_rma(int32_t size, uintptr_t *laddr,
|
||||
uintptr_t *raddr, uint8_t opcode) {
|
||||
uint64_t activemask = get_active_lane_mask();
|
||||
uint8_t num_active_lanes = get_active_lane_count(activemask);
|
||||
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
|
||||
bool is_leader = {my_logical_lane_id == 0};
|
||||
uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
|
||||
|
||||
uint8_t num_wqes = num_active_lanes;
|
||||
uint64_t wave_sq_counter = 0;
|
||||
uint64_t my_sq_counter = 0;
|
||||
uint64_t my_sq_index = 0;
|
||||
|
||||
// 1. Leader allocates SQ entries for the whole wave
|
||||
if (is_leader) {
|
||||
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes,
|
||||
__ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT);
|
||||
}
|
||||
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
|
||||
my_sq_counter = wave_sq_counter + my_logical_lane_id;
|
||||
my_sq_index = my_sq_counter % sq_wqe_cnt;
|
||||
|
||||
// 2. Wait for SQ space for the whole wave
|
||||
mlx5_wait_for_free_sq_slots(wave_sq_counter, num_active_lanes);
|
||||
|
||||
// 3. Build the WQE for this lane
|
||||
mlx5_build_rma_wqe(my_sq_counter, my_sq_index, laddr, raddr, size, opcode);
|
||||
|
||||
__atomic_signal_fence(__ATOMIC_SEQ_CST);
|
||||
|
||||
// 4. Leader rings doorbell for the wave
|
||||
if (is_leader) {
|
||||
uint64_t db_touched {0};
|
||||
do {
|
||||
db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
} while (db_touched != wave_sq_counter);
|
||||
|
||||
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
|
||||
uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast<uint64_t*>(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
|
||||
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
|
||||
|
||||
__hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
mlx5_ring_doorbell(wave_sq_counter, num_wqes);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
|
||||
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
|
||||
uint64_t activemask = get_active_lane_mask();
|
||||
uint8_t num_active_lanes = get_active_lane_count(activemask);
|
||||
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
|
||||
bool is_leader{my_logical_lane_id == 0};
|
||||
const uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
|
||||
uint8_t num_wqes{num_active_lanes};
|
||||
uint64_t wave_sq_counter{0};
|
||||
|
||||
if (is_leader) {
|
||||
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
}
|
||||
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
|
||||
uint64_t my_sq_counter = wave_sq_counter + my_logical_lane_id;
|
||||
uint64_t my_sq_index = my_sq_counter % sq_wqe_cnt;
|
||||
|
||||
while (true) {
|
||||
uint64_t db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
uint64_t sunk = __hip_atomic_load(&sq_sunk, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
int64_t num_active_sq_entries = db_touched - sunk;
|
||||
if (num_active_sq_entries < 0) {
|
||||
continue;
|
||||
}
|
||||
uint64_t num_free_entries = min(sq_wqe_cnt, cq_cnt) - num_active_sq_entries;
|
||||
uint64_t num_entries_until_wave_last_entry = wave_sq_counter + num_active_lanes - db_touched;
|
||||
if (num_free_entries > num_entries_until_wave_last_entry) {
|
||||
break;
|
||||
}
|
||||
mlx5_quiet();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t*
|
||||
QueuePair::mlx5_allocate_wave_fetching_atomic_buffer(
|
||||
uint64_t wave_sq_counter, bool is_leader,
|
||||
uint64_t leader_phys_lane_id) {
|
||||
uint64_t* wave_fetch_atomic{nullptr};
|
||||
if (fetching) {
|
||||
if (is_leader) {
|
||||
uint64_t db_touched {0};
|
||||
do {
|
||||
db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
} while (db_touched != wave_sq_counter);
|
||||
if (is_leader) {
|
||||
mlx5_wait_for_db_touched_eq(wave_sq_counter);
|
||||
|
||||
auto res = fetching_atomic_freelist->pop_front();
|
||||
while (!res.success) {
|
||||
res = fetching_atomic_freelist->pop_front();
|
||||
}
|
||||
wave_fetch_atomic = res.value;
|
||||
auto res = fetching_atomic_freelist->pop_front();
|
||||
while (!res.success) {
|
||||
res = fetching_atomic_freelist->pop_front();
|
||||
}
|
||||
wave_fetch_atomic = (uint64_t*)__shfl((uint64_t)wave_fetch_atomic, leader_phys_lane_id);
|
||||
wave_fetch_atomic = res.value;
|
||||
}
|
||||
wave_fetch_atomic = (uint64_t*)__shfl((uint64_t)wave_fetch_atomic,
|
||||
leader_phys_lane_id);
|
||||
return wave_fetch_atomic;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void QueuePair::mlx5_build_amo_wqe(
|
||||
uint64_t my_sq_counter, uint64_t my_sq_index, uintptr_t *raddr,
|
||||
uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetching,
|
||||
uint64_t *wave_fetch_atomic) {
|
||||
outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter;
|
||||
|
||||
SegmentBuilder seg_build(my_sq_index, sq_buf);
|
||||
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num, MLX5_WQE_CTRL_CQ_UPDATE, 4, 0, 0);
|
||||
seg_build.update_ctrl_seg(my_sq_counter, opcode, 0, qp_num,
|
||||
MLX5_WQE_CTRL_CQ_UPDATE, 4, 0, 0);
|
||||
seg_build.update_raddr_seg(raddr, rkey);
|
||||
seg_build.update_atomic_seg(atomic_data, atomic_cmp);
|
||||
|
||||
if (fetching) {
|
||||
seg_build.update_data_seg(wave_fetch_atomic + my_logical_lane_id, 8, fetching_atomic_lkey);
|
||||
seg_build.update_data_seg(wave_fetch_atomic, 8, fetching_atomic_lkey);
|
||||
} else {
|
||||
seg_build.update_data_seg(nonfetching_atomic, 8, nonfetching_atomic_lkey);
|
||||
}
|
||||
__atomic_signal_fence(__ATOMIC_SEQ_CST);
|
||||
}
|
||||
|
||||
__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int32_t size,
|
||||
uintptr_t *raddr, uint8_t opcode, int64_t atomic_data,
|
||||
int64_t atomic_cmp, bool fetching) {
|
||||
uint64_t activemask = get_active_lane_mask();
|
||||
uint8_t num_active_lanes = get_active_lane_count(activemask);
|
||||
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
|
||||
bool is_leader = {my_logical_lane_id == 0};
|
||||
uint64_t leader_phys_lane_id = get_first_active_lane_id(activemask);
|
||||
|
||||
uint8_t num_wqes = num_active_lanes;
|
||||
uint64_t wave_sq_counter = 0;
|
||||
uint64_t my_sq_counter = 0;
|
||||
uint64_t my_sq_index = 0;
|
||||
|
||||
// 1. Leader allocates SQ entries for the whole wave
|
||||
if (is_leader) {
|
||||
uint64_t db_touched {0};
|
||||
do {
|
||||
db_touched = __hip_atomic_load(&sq_db_touched, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
} while (db_touched != wave_sq_counter);
|
||||
wave_sq_counter = __hip_atomic_fetch_add(&sq_posted, num_wqes,
|
||||
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
}
|
||||
wave_sq_counter = __shfl(wave_sq_counter, leader_phys_lane_id);
|
||||
my_sq_counter = wave_sq_counter + my_logical_lane_id;
|
||||
my_sq_index = my_sq_counter % sq_wqe_cnt;
|
||||
|
||||
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
|
||||
uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast<uint64_t*>(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
|
||||
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
|
||||
// 2. Wait for SQ space for the whole wave
|
||||
mlx5_wait_for_free_sq_slots(wave_sq_counter, num_active_lanes);
|
||||
|
||||
__hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
uint64_t* wave_fetch_atomic{nullptr};
|
||||
if (fetching) {
|
||||
wave_fetch_atomic = mlx5_allocate_wave_fetching_atomic_buffer(
|
||||
wave_sq_counter,
|
||||
is_leader,
|
||||
leader_phys_lane_id);
|
||||
}
|
||||
|
||||
// 3. Build the WQE for this lane
|
||||
mlx5_build_amo_wqe(my_sq_counter, my_sq_index, raddr, opcode,
|
||||
atomic_data, atomic_cmp, fetching,
|
||||
wave_fetch_atomic + my_logical_lane_id);
|
||||
|
||||
__atomic_signal_fence(__ATOMIC_SEQ_CST);
|
||||
|
||||
// 4. Leader rings doorbell for the wave
|
||||
if (is_leader) {
|
||||
mlx5_ring_doorbell(wave_sq_counter, num_wqes);
|
||||
}
|
||||
|
||||
// 5. Fetch result if requested
|
||||
uint64_t ret{0};
|
||||
if (fetching) {
|
||||
mlx5_quiet();
|
||||
ret = wave_fetch_atomic[my_logical_lane_id];
|
||||
|
||||
__atomic_signal_fence(__ATOMIC_SEQ_CST);
|
||||
|
||||
if (is_leader) {
|
||||
fetching_atomic_freelist->push_back(wave_fetch_atomic);
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ __device__ void QueuePair::post_wqe_rma_mt(int pe, int32_t size, uintptr_t *ladd
|
||||
switch (gda_provider_) {
|
||||
#if defined(GDA_MLX5)
|
||||
case GDAProvider::MLX5:
|
||||
mlx5_post_wqe_rma(pe, size, laddr, raddr, opcode);
|
||||
mlx5_post_wqe_rma(size, laddr, raddr, opcode);
|
||||
return;
|
||||
#endif
|
||||
#if defined(GDA_BNXT)
|
||||
@@ -188,7 +188,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
|
||||
switch (gda_provider_) {
|
||||
#if defined(GDA_MLX5)
|
||||
case GDAProvider::MLX5:
|
||||
return mlx5_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching);
|
||||
return mlx5_post_wqe_amo(size, raddr, opcode, atomic_data, atomic_cmp, fetching);
|
||||
#endif
|
||||
#if defined(GDA_BNXT)
|
||||
case GDAProvider::BNXT:
|
||||
|
||||
@@ -180,9 +180,40 @@ class QueuePair {
|
||||
__device__ __attribute__((noinline)) void post_wqe_rma_mt(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode);
|
||||
|
||||
#if defined(GDA_MLX5)
|
||||
__device__ uint64_t mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode, int64_t atomic_data, int64_t atomic_cmp, bool fetch);
|
||||
__device__ void mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode);
|
||||
__device__ void mlx5_quiet();
|
||||
__device__ __forceinline__ void
|
||||
mlx5_wait_for_free_sq_slots(uint64_t wave_sq_counter,
|
||||
uint8_t num_active_lanes);
|
||||
|
||||
__device__ __forceinline__ void
|
||||
mlx5_wait_for_db_touched_eq(uint64_t target_sq_counter);
|
||||
|
||||
__device__ __forceinline__ void
|
||||
mlx5_build_rma_wqe(uint64_t my_sq_counter, uint64_t my_sq_index,
|
||||
uintptr_t *laddr, uintptr_t *raddr, int32_t size, uint8_t opcode);
|
||||
|
||||
__device__ __forceinline__ void
|
||||
mlx5_build_amo_wqe(uint64_t my_sq_counter, uint64_t my_sq_index,
|
||||
uintptr_t *raddr, uint8_t opcode, int64_t atomic_data,
|
||||
int64_t atomic_cmp, bool fetching, uint64_t *wave_fetch_atomic);
|
||||
|
||||
__device__ __forceinline__ uint64_t*
|
||||
mlx5_allocate_wave_fetching_atomic_buffer(uint64_t wave_sq_counter,
|
||||
bool is_leader, uint64_t leader_phys_lane_id);
|
||||
|
||||
__device__ __forceinline__ void
|
||||
mlx5_ring_doorbell(uint64_t wave_sq_counter, uint8_t num_wqes);
|
||||
|
||||
__device__ uint64_t
|
||||
mlx5_post_wqe_amo(int32_t size, uintptr_t *raddr, uint8_t opcode,
|
||||
int64_t atomic_data, int64_t atomic_cmp, bool fetch);
|
||||
|
||||
__device__ void
|
||||
mlx5_post_wqe_rma(int32_t size, uintptr_t *laddr,
|
||||
uintptr_t *raddr, uint8_t opcode);
|
||||
|
||||
__device__ void
|
||||
mlx5_quiet();
|
||||
|
||||
#endif
|
||||
#if defined(GDA_BNXT)
|
||||
|
||||
|
||||
مرجع در شماره جدید
Block a user