* feat(GDA): add `get*` and `get*_nbi` APIs for mlx and bnxt NICs
   - implemented thread, wave and wg variants of `get*` and `get_nbi`.

* test(GDA): enable functional tests for `get*` and `get*_nbi` APIs

[ROCm/rocshmem commit: 671f8187f4]
Этот коммит содержится в:
Avinash Kethineedi
2025-09-10 12:24:53 -04:00
коммит произвёл GitHub
родитель 526784233b
Коммит 6860bc1275
5 изменённых файлов: 91 добавлений и 44 удалений
+32 -32
Просмотреть файл
@@ -463,25 +463,25 @@ TestGDA() {
ExecTest "teamctxput" 2 4 128 1024
ExecTest "teamctxput" 2 16 256 1024
# ExecTest "get" 2 1 1 1048576
# ExecTest "get" 2 1 1024 512
# ExecTest "get" 2 8 1 1048576
# ExecTest "get" 2 16 128 8
# ExecTest "get" 2 32 256 512
# ExecTest "get" 2 64 1024 8
ExecTest "get" 2 1 1 1048576
ExecTest "get" 2 1 1024 512
ExecTest "get" 2 8 1 1048576
ExecTest "get" 2 16 128 8
ExecTest "get" 2 32 256 512
ExecTest "get" 2 64 1024 8
# ExecTest "wgget" 2 1 64 1048576
# ExecTest "wgget" 2 2 64 1048576
# ExecTest "wgget" 2 16 64 8
ExecTest "wgget" 2 1 64 1048576
ExecTest "wgget" 2 2 64 1048576
ExecTest "wgget" 2 16 64 8
# ExecTest "waveget" 2 1 64 1048576
# ExecTest "waveget" 2 2 64 1048576
# ExecTest "waveget" 2 2 128 1048576
# ExecTest "waveget" 2 16 128 8
ExecTest "waveget" 2 1 64 1048576
ExecTest "waveget" 2 2 64 1048576
ExecTest "waveget" 2 2 128 1048576
ExecTest "waveget" 2 16 128 8
# ExecTest "defaultctxget" 2 4 128 1024
# ExecTest "teamctxget" 2 4 128 1024
# ExecTest "teamctxget" 2 16 256 1024
ExecTest "defaultctxget" 2 4 128 1024
ExecTest "teamctxget" 2 4 128 1024
ExecTest "teamctxget" 2 16 256 1024
# ExecTest "g" 2 1 1 128
# ExecTest "g" 2 1 1024 2
@@ -516,25 +516,25 @@ TestGDA() {
ExecTest "teamctxputnbi" 2 4 128 1024
ExecTest "teamctxputnbi" 2 16 256 1024
# ExecTest "getnbi" 2 1 1 1048576
# ExecTest "getnbi" 2 1 1024 512
# ExecTest "getnbi" 2 8 1 1048576
# ExecTest "getnbi" 2 16 128 8
# ExecTest "getnbi" 2 32 256 512
# ExecTest "getnbi" 2 64 1024 8
ExecTest "getnbi" 2 1 1 1048576
ExecTest "getnbi" 2 1 1024 512
ExecTest "getnbi" 2 8 1 1048576
ExecTest "getnbi" 2 16 128 8
ExecTest "getnbi" 2 32 256 512
ExecTest "getnbi" 2 64 1024 8
# ExecTest "wggetnbi" 2 1 64 1048576
# ExecTest "wggetnbi" 2 2 64 1048576
# ExecTest "wggetnbi" 2 16 64 8
ExecTest "wggetnbi" 2 1 64 1048576
ExecTest "wggetnbi" 2 2 64 1048576
ExecTest "wggetnbi" 2 16 64 8
# ExecTest "wavegetnbi" 2 1 64 1048576
# ExecTest "wavegetnbi" 2 2 64 1048576
# ExecTest "wavegetnbi" 2 2 128 1048576
# ExecTest "wavegetnbi" 2 16 128 8
ExecTest "wavegetnbi" 2 1 64 1048576
ExecTest "wavegetnbi" 2 2 64 1048576
ExecTest "wavegetnbi" 2 2 128 1048576
ExecTest "wavegetnbi" 2 16 128 8
# ExecTest "defaultctxgetnbi" 2 4 128 1024
# ExecTest "teamctxgetnbi" 2 4 128 1024
# ExecTest "teamctxgetnbi" 2 16 256 1024
ExecTest "defaultctxgetnbi" 2 4 128 1024
ExecTest "teamctxgetnbi" 2 4 128 1024
ExecTest "teamctxgetnbi" 2 16 256 1024
#TestAMO() {
##############################################################################
+1
Просмотреть файл
@@ -33,6 +33,7 @@ extern "C" {
#define GDA_DEFAULT_GID 3
#define GDA_MAX_ATOMIC 1
#define GDA_OP_RDMA_WRITE BNXT_RE_WR_OPCD_RDMA_WRITE
#define GDA_OP_RDMA_READ BNXT_RE_WR_OPCD_RDMA_READ
#define GDA_OP_ATOMIC_FA BNXT_RE_WR_OPCD_ATOMIC_FA
#define GDA_OP_ATOMIC_CS BNXT_RE_WR_OPCD_ATOMIC_CS
+41 -12
Просмотреть файл
@@ -81,8 +81,20 @@ __device__ void GDAContext::putmem(void *dest, const void *source, size_t nelems
__device__ void GDAContext::getmem(void *dest, const void *source, size_t nelems,
int pe) {
printf("rocshmem::gda:getmem not implemented\n");
abort();
const char *src_typed = reinterpret_cast<const char *>(source);
uint64_t L_offset = const_cast<char *>(src_typed) - base_heap[my_pe];
bool need_turn {true};
uint64_t turns = __ballot(need_turn);
while (turns) {
uint8_t lane = __ffsll((unsigned long long)turns) - 1;
int pe_turn = __shfl(pe, lane);
if (pe_turn == pe) {
qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe);
qps[pe].quiet();
need_turn = false;
}
turns = __ballot(need_turn);
}
}
__device__ void GDAContext::putmem_nbi(void *dest, const void *source,
@@ -103,8 +115,19 @@ __device__ void GDAContext::putmem_nbi(void *dest, const void *source,
__device__ void GDAContext::getmem_nbi(void *dest, const void *source,
size_t nelems, int pe) {
printf("rocshmem::gda:getmem_nbi not implemented\n");
abort();
const char *src_typed = reinterpret_cast<const char *>(source);
uint64_t L_offset = const_cast<char *>(src_typed) - base_heap[my_pe];
bool need_turn {true};
uint64_t turns = __ballot(need_turn);
while (turns) {
uint8_t lane = __ffsll((unsigned long long)turns) - 1;
int pe_turn = __shfl(pe, lane);
if (pe_turn == pe) {
qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe);
need_turn = false;
}
turns = __ballot(need_turn);
}
}
__device__ void GDAContext::fence() { //TODO: optimize
@@ -139,9 +162,11 @@ __device__ void GDAContext::putmem_wg(void *dest, const void *source,
__device__ void GDAContext::getmem_wg(void *dest, const void *source,
size_t nelems, int pe) {
const char *src_typed = reinterpret_cast<const char *>(source);
uint64_t L_offset = const_cast<char *>(src_typed) - base_heap[my_pe];
if (is_thread_zero_in_block()) {
printf("rocshmem::gda:getmem_wg not implemented\n");
abort();
qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe);
qps[pe].quiet();
}
}
@@ -155,9 +180,10 @@ __device__ void GDAContext::putmem_nbi_wg(void *dest, const void *source,
__device__ void GDAContext::getmem_nbi_wg(void *dest, const void *source,
size_t nelems, int pe) {
const char *src_typed = reinterpret_cast<const char *>(source);
uint64_t L_offset = const_cast<char *>(src_typed) - base_heap[my_pe];
if (is_thread_zero_in_block()) {
printf("rocshmem::gda:getmem_nbi_wg not implemented\n");
abort();
qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe);
}
}
@@ -172,9 +198,11 @@ __device__ void GDAContext::putmem_wave(void *dest, const void *source,
__device__ void GDAContext::getmem_wave(void *dest, const void *source,
size_t nelems, int pe) {
const char *src_typed = reinterpret_cast<const char *>(source);
uint64_t L_offset = const_cast<char *>(src_typed) - base_heap[my_pe];
if (is_thread_zero_in_wave()) {
printf("rocshmem::gda:getmem_wave not implemented\n");
abort();
qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe);
qps[pe].quiet();
}
}
@@ -188,9 +216,10 @@ __device__ void GDAContext::putmem_nbi_wave(void *dest, const void *source,
__device__ void GDAContext::getmem_nbi_wave(void *dest, const void *source,
size_t nelems, int pe) {
const char *src_typed = reinterpret_cast<const char *>(source);
uint64_t L_offset = const_cast<char *>(src_typed) - base_heap[my_pe];
if (is_thread_zero_in_wave()) {
printf("rocshmem::gda:getmem_nbi_wave not implemented\n");
abort();
qps[pe].get_nbi(dest, base_heap[pe] + L_offset, nelems, pe);
}
}
+6
Просмотреть файл
@@ -626,6 +626,12 @@ __device__ void QueuePair::put_nbi(void *dest, const void *source, size_t nelems
post_wqe_rma(pe, nelems, src, dst, GDA_OP_RDMA_WRITE);
}
__device__ void QueuePair::get_nbi(void *dest, const void *source, size_t nelems, int pe) {
uintptr_t *src = reinterpret_cast<uintptr_t*>(const_cast<void*>(source));
uintptr_t *dst = reinterpret_cast<uintptr_t*>(dest);
post_wqe_rma(pe, nelems, dst, src, GDA_OP_RDMA_READ);
}
__device__ int64_t QueuePair::atomic_fetch(void *dest, int64_t atomic_data, int64_t atomic_cmp, int pe, uint8_t atomic_op) {
uintptr_t *dst = reinterpret_cast<uintptr_t*>(dest);
return post_wqe_amo(pe, sizeof(int64_t), dst, atomic_op, atomic_data, atomic_cmp, true);
+11
Просмотреть файл
@@ -61,6 +61,7 @@ extern "C" {
#elif defined(GDA_MLX5)
#define GDA_MAX_ATOMIC 1
#define GDA_OP_RDMA_WRITE MLX5_OPCODE_RDMA_WRITE
#define GDA_OP_RDMA_READ MLX5_OPCODE_RDMA_READ
#define GDA_OP_ATOMIC_FA MLX5_OPCODE_ATOMIC_FA
#define GDA_OP_ATOMIC_CS MLX5_OPCODE_ATOMIC_CS
#endif
@@ -102,6 +103,16 @@ class QueuePair {
*/
__device__ void put_nbi(void *dest, const void *source, size_t nelems, int pe);
/**
* @brief Create and enqueue a non-blocking get work queue entry (wqe).
*
* @param[in] dest Destination address for data transmission.
* @param[in] source Source address for data transmission.
* @param[in] nelems Size in bytes of data transmission.
* @param[in] pe Destination processing element of data transmission.
*/
__device__ void get_nbi(void *dest, const void *source, size_t nelems, int pe);
/**
* @brief Empty all completions from the completion queue.
*/