Select device NIC vendor code at runtime (#263)

* Runtime selection of device implementation for post_wqe, quiet,
ring_doorbell

* Normalize function naming
Esse commit está contido em:
Aurelien Bouteiller
2025-09-26 00:27:41 -04:00
commit de GitHub
commit 16a4f10203
3 arquivos alterados com 128 adições e 36 exclusões
+92 -26
Ver Arquivo
@@ -67,7 +67,7 @@ QueuePair::QueuePair(struct ibv_pd* pd, int gda_vendor) {
}
/* Set Correct opcodes for each NIC */
#ifdef GDA_IONIC
#if defined(GDA_IONIC)
gda_op_rdma_write = IONIC_V2_OP_RDMA_WRITE;
gda_op_atomic_fa = IONIC_V2_OP_ATOMIC_FA;
gda_op_atomic_cs = IONIC_V2_OP_ATOMIC_CS;
@@ -83,6 +83,7 @@ QueuePair::QueuePair(struct ibv_pd* pd, int gda_vendor) {
gda_op_atomic_fa = MLX5_OPCODE_ATOMIC_FA;
gda_op_atomic_cs = MLX5_OPCODE_ATOMIC_CS;
}
gda_vendor_ = gda_vendor;
}
QueuePair::~QueuePair() {
@@ -104,7 +105,7 @@ QueuePair::~QueuePair() {
/******************************************************************************
************************ PROVIDER-SPECIFIC HELPERS ***************************
*****************************************************************************/
#ifdef GDA_IONIC
#if defined(GDA_IONIC)
__device__ uint64_t QueuePair::get_same_qp_lane_mask() {
uint64_t lane_mask = get_active_lane_mask();
uintptr_t this_val = reinterpret_cast<uintptr_t>(this);
@@ -150,7 +151,7 @@ __device__ uint32_t QueuePair::reserve_sq(uint64_t activemask, uint32_t num_wqes
my_sq_prod = __shfl(my_sq_prod, get_first_active_lane_id(activemask));
// wait for that space to be available
quiet_internal(activemask, my_sq_prod + num_wqes - sq_mask);
ionic_quiet_internal(activemask, my_sq_prod + num_wqes - sq_mask);
return my_sq_prod;
}
@@ -166,7 +167,7 @@ __device__ uint32_t QueuePair::commit_sq(bool last, uint32_t my_sq_prod, uint32_
// spin
}
ring_doorbell(dbprod);
ionic_ring_doorbell(dbprod);
__hip_atomic_exchange(&sq_dbprod, dbprod, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
}
@@ -207,7 +208,7 @@ __device__ void QueuePair::poll_wave_cqes(uint64_t activemask) {
/* Report if the completion indicates an error. */
if (!!(qtf_be & swap_endian_val<uint32_t>(IONIC_V1_CQE_ERROR))) {
#ifdef DEBUG
#if defined(DEBUG)
uint32_t qtf = swap_endian_val<uint32_t>(qtf_be);
uint32_t qid = qtf >> IONIC_V1_CQE_QID_SHIFT;
uint32_t type = (qtf >> IONIC_V1_CQE_TYPE_SHIFT) & IONIC_V1_CQE_TYPE_MASK;
@@ -245,7 +246,7 @@ __device__ void QueuePair::poll_wave_cqes(uint64_t activemask) {
sq_msn = msn;
}
__device__ void QueuePair::quiet_internal(uint64_t activemask, uint32_t cons) {
__device__ void QueuePair::ionic_quiet_internal(uint64_t activemask, uint32_t cons) {
/* wait for sq_msn to catch up or pass cons. */
/* 0x800000 - sign bit for 24-bit fields */
while ((sq_msn - cons) & 0x800000) {
@@ -264,8 +265,8 @@ __device__ void QueuePair::quiet_internal(uint64_t activemask, uint32_t cons) {
}
#endif // GDA_IONIC
#ifdef GDA_IONIC
__device__ void QueuePair::ring_doorbell(uint32_t pos) {
#if defined(GDA_IONIC)
__device__ void QueuePair::ionic_ring_doorbell(uint32_t pos) {
// TODO When threads write at once to the same address, not all writes reach the bus.
for (int i = 0; i < 64; ++i) {
if (__lane_id() == i) {
@@ -278,7 +279,7 @@ __device__ void QueuePair::ring_doorbell(uint32_t pos) {
#endif
#if defined(GDA_MLX5)
__device__ void QueuePair::ring_doorbell(uint64_t db_val, uint64_t my_sq_counter) {
__device__ void QueuePair::mlx5_ring_doorbell(uint64_t db_val, uint64_t my_sq_counter) {
swap_endian_store(const_cast<uint32_t*>(dbrec), (uint32_t)my_sq_counter);
__atomic_signal_fence(__ATOMIC_SEQ_CST);
@@ -289,14 +290,14 @@ __device__ void QueuePair::ring_doorbell(uint64_t db_val, uint64_t my_sq_counter
}
#endif // GDA_MLX5
#ifdef GDA_IONIC
__device__ void QueuePair::quiet() {
quiet_internal(get_same_qp_lane_mask(), sq_prod);
#if defined(GDA_IONIC)
__device__ void QueuePair::ionic_quiet() {
ionic_quiet_internal(get_same_qp_lane_mask(), sq_prod);
}
#endif
#if defined(GDA_MLX5)
__device__ void QueuePair::quiet() {
__device__ void QueuePair::mlx5_quiet() {
constexpr size_t BROADCAST_SIZE = 1024 / WF_SIZE;
__shared__ uint64_t wqe_broadcast[BROADCAST_SIZE];
uint8_t wavefront_id = get_flat_block_id() / WF_SIZE;
@@ -379,8 +380,73 @@ __device__ void QueuePair::quiet() {
}
#endif // GDA_MLX5
#ifdef GDA_IONIC
__device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
switch (gda_vendor_) {
#if defined(GDA_MLX5)
case GDAVendor::MLX5:
mlx5_post_wqe_rma(pe, size, laddr, raddr, opcode);
return;
#endif
#if defined(GDA_BNXT)
case GDAVendor::BNXT:
bnxt_post_wqe_rma(pe, size, laddr, raddr, opcode);
return;
#endif
#if defined(GDA_IONIC)
case GDAVendor::IONIC:
ionic_post_wqe_rma(pe, size, laddr, raddr, opcode);
return;
#endif
default:
assert(false /* invalid nic provider */);
}
}
__device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
switch (gda_vendor_) {
#if defined(GDA_MLX5)
case GDAVendor::MLX5:
return mlx5_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching);
#endif
#if defined(GDA_BNXT)
case GDAVendor::BNXT:
return bnxt_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching);
#endif
#if defined(GDA_IONIC)
case GDAVendor::IONIC:
return ionic_post_wqe_amo(pe, size, raddr, opcode, atomic_data, atomic_cmp, fetching);
#endif
default:
assert(false /* invalid nic provider */);
return 0;
}
}
__device__ void QueuePair::quiet() {
switch (gda_vendor_) {
#if defined(GDA_MLX5)
case GDAVendor::MLX5:
mlx5_quiet();
return;
#endif
#if defined(GDA_BNXT)
case GDAVendor::BNXT:
bnxt_quiet();
return;
#endif
#if defined(GDA_IONIC)
case GDAVendor::IONIC:
ionic_quiet();
return;
#endif
default:
assert(false /* invalid nic provider */);
}
}
#if defined(GDA_IONIC)
__device__ void QueuePair::ionic_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
uint64_t activemask = get_same_qp_lane_mask();
uint32_t num_wqes = get_active_lane_count(activemask);
uint32_t my_logical_lane_id = get_active_lane_num(activemask);
@@ -426,7 +492,7 @@ __device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr,
#endif
#if defined (GDA_MLX5)
__device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
__device__ void QueuePair::mlx5_post_wqe_rma(int pe, int32_t size, uintptr_t *laddr, uintptr_t *raddr, uint8_t opcode) {
uint64_t activemask = get_active_lane_mask();
uint8_t num_active_lanes = get_active_lane_count(activemask);
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
@@ -454,7 +520,7 @@ __device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr,
if (num_free_entries > num_entries_until_wave_last_entry) {
break;
}
quiet();
mlx5_quiet();
}
outstanding_wqes[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter;
@@ -479,7 +545,7 @@ __device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr,
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast<uint64_t*>(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
__hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
@@ -487,9 +553,9 @@ __device__ void QueuePair::post_wqe_rma(int pe, int32_t size, uintptr_t *laddr,
}
#endif // GDA_MLX5
#ifdef GDA_IONIC
__device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
#if defined(GDA_IONIC)
__device__ uint64_t QueuePair::ionic_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
uint64_t activemask = get_same_qp_lane_mask();
uint32_t num_wqes = get_active_lane_count(activemask);
uint32_t my_logical_lane_id = get_active_lane_num(activemask);
@@ -538,7 +604,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
uint64_t ret{0};
if (fetching) {
quiet_internal(activemask, cons);
ionic_quiet_internal(activemask, cons);
ret = wave_fetch_atomic[my_logical_lane_id];
__atomic_signal_fence(__ATOMIC_SEQ_CST);
if (is_leader) {
@@ -550,8 +616,8 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
#endif
#if defined(GDA_MLX5)
__device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
__device__ uint64_t QueuePair::mlx5_post_wqe_amo(int pe, int32_t size, uintptr_t *raddr, uint8_t opcode,
int64_t atomic_data, int64_t atomic_cmp, bool fetching) {
uint64_t activemask = get_active_lane_mask();
uint8_t num_active_lanes = get_active_lane_count(activemask);
uint8_t my_logical_lane_id = get_active_lane_num(activemask);
@@ -579,7 +645,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
if (num_free_entries > num_entries_until_wave_last_entry) {
break;
}
quiet();
mlx5_quiet();
}
uint64_t* wave_fetch_atomic{nullptr};
@@ -620,7 +686,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
uint8_t *base_ptr = reinterpret_cast<uint8_t*>(sq_buf);
uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast<uint64_t*>(&base_ptr[64 * ((wave_sq_counter + num_wqes - 1) % sq_wqe_cnt)]);
ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
mlx5_ring_doorbell(*ctrl_wqe_8B_for_db, wave_sq_counter + num_wqes);
__hip_atomic_fetch_add(&quiet_posted, num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store(&sq_db_touched, wave_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
@@ -628,7 +694,7 @@ __device__ uint64_t QueuePair::post_wqe_amo(int pe, int32_t size, uintptr_t *rad
uint64_t ret{0};
if (fetching) {
quiet();
mlx5_quiet();
ret = wave_fetch_atomic[my_logical_lane_id];
__atomic_signal_fence(__ATOMIC_SEQ_CST);
if (is_leader) {