Implement rocshmem_ptr in IPC conduit (#197)

* Implement `rocshmem_ptr` in IPC conduit

* tests: add functional test for `rocshmem_ptr`
  - Add safety check for pointer access and condition check before printing results for `rocshmem_ptr` test
  - Use `rocshmem_put` to store `rocshmem_ptr` availability for data validation

[ROCm/rocshmem commit: 526105d315]
This commit is contained in:
Avinash Kethineedi
2025-07-28 12:01:02 -05:00
کامیت شده توسط GitHub
والد 72ed270a5c
کامیت 2a7416d016
7فایلهای تغییر یافته به همراه122 افزوده شده و 23 حذف شده
@@ -200,6 +200,11 @@ TestRMAPut() {
ExecTest "p" 2 8 1 32
ExecTest "p" 2 16 128 4
ExecTest "shmemptr" 2 1 1 8
ExecTest "shmemptr" 2 1 1024 8
ExecTest "shmemptr" 2 8 1 8
ExecTest "shmemptr" 2 16 128 8
################################ Non-Blocking ################################
ExecTest "putnbi" 2 1 1 1048576
@@ -106,6 +106,10 @@ __device__ void IPCContext::quiet() {
__device__ void *IPCContext::shmem_ptr(const void *dest, int pe) {
void *ret = nullptr;
void *dst = const_cast<void *>(dest);
uint64_t L_offset =
reinterpret_cast<char *>(dst) - ipcImpl_.ipc_bases[my_pe];
ret = ipcImpl_.ipc_bases[pe] + L_offset;
return ret;
}
@@ -31,18 +31,86 @@ using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void ShmemPtrTest(char *r_buf, int *available) {
rocshmem_wg_init();
__global__ void ShmemPtrTest(int loop, int skip, long long int *start_time,
long long int *end_time, char *dest, int wf_size,
ShmemContextType ctx_type, int *available) {
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
int t_id = get_flat_block_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
/**
* Shared array to capture the start time for each wavefront
* Max threads per block = 1024, wavefront size = 64 (in most GPUs)
* Maximum array size required = 1024/64 = 16
*/
__shared__ long long int wf_start_time[16];
/**
* Calculate start index for each thread within the grid
*/
dest += get_flat_id();
char *local_addr = dest;
void *remote_addr = rocshmem_ptr((void *)local_addr, 1);
if (remote_addr != NULL) {
*available = 1;
}
if(*available) {
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
__syncthreads();
// Ensures all RMA calls from the skip loops are completed
if(is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
__syncthreads();
// Capture the start time of each wavefront to identify the earliest one
wf_start_time[wf_id] = wall_clock64();
}
if (hipThreadIdx_x == 0) {
char *local_addr = r_buf + 4;
void *remote_addr = rocshmem_ptr((void *)local_addr, 1);
if (remote_addr != NULL) {
*available = 1;
((char *)remote_addr)[0] = '1';
}
}
__syncthreads();
if(is_thread_zero_in_block()) {
rocshmem_ctx_quiet(ctx);
}
/**
* End time of the last wavefront is recorded by overwriting
* the value previously set by earlier wavefronts.
*/
end_time[wg_id] = wall_clock64();
// Find the earliest start time
int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1;
for (int i = num_wfs / 2; i > 0; i >>= 1 ) {
if(t_id < i) {
wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]);
}
}
// For data validation in remote PE
if( get_flat_id() == 0 ) {
int *store_avail = (int*)(dest + get_flat_grid_size());
*store_avail = *available;
rocshmem_ctx_int_put(ctx, store_avail, store_avail, 1, 1);
}
__syncthreads();
if (t_id == 0) {
start_time[wg_id] = wf_start_time[0];
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
@@ -50,17 +118,26 @@ __global__ void ShmemPtrTest(char *r_buf, int *available) {
* HOST TESTER CLASS METHODS
*****************************************************************************/
ShmemPtrTester::ShmemPtrTester(TesterArguments args) : Tester(args) {
size_t buff_size = args.wg_size * args.num_wgs + sizeof(int);
CHECK_HIP(hipMalloc((void **)&_available, sizeof(int)));
r_buf = (char *)rocshmem_malloc(args.max_msg_size);
dest = (char *)rocshmem_malloc(buff_size);
if (dest == nullptr) {
std::cerr << "Error allocating memory from symmetric heap" << std::endl;
std::cerr << "dest: " << dest << std::endl;
rocshmem_global_exit(1);
}
}
ShmemPtrTester::~ShmemPtrTester() {
CHECK_HIP(hipFree(_available));
rocshmem_free(r_buf);
rocshmem_free(dest);
}
void ShmemPtrTester::resetBuffers(size_t size) {
memset(r_buf, '0', args.max_msg_size);
size_t buff_size = args.wg_size * args.num_wgs + sizeof(int);
memset(dest, '0', buff_size);
memset(_available, 0, sizeof(int));
}
@@ -68,22 +145,32 @@ void ShmemPtrTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
size_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(ShmemPtrTest, gridSize, blockSize, shared_bytes, stream,
r_buf, _available);
hipLaunchKernelGGL(ShmemPtrTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time,
dest, wf_size, _type, _available);
num_msgs = 0;
num_timed_msgs = 0;
num_msgs = (loop + args.skip) * gridSize.x * blockSize.x;
num_timed_msgs = loop * gridSize.x * blockSize.x;
}
void ShmemPtrTester::verifyResults(size_t size) {
if (args.myid == 0) {
if (*_available == 0) {
fprintf(stderr, "SHMEM_PTR NOT AVAILBLE \n");
_print_results = false;
std::cout << "rocshmem ptr not available\n" << std::endl;
}
} else {
if (r_buf[4] != '1') {
fprintf(stderr, "Data validation error \n");
fprintf(stderr, "Got %c, Expected %c\n", r_buf[4], '1');
}
else {
size_t buff_size = args.wg_size * args.num_wgs;
int *available = (int*)(dest + buff_size);
if(*available == 1) {
for (size_t i = 0; i < buff_size; i++) {
if (dest[i] != '1') {
std::cerr << "Data validation error at idx " << i << std::endl;
std::cerr << " Got " << dest[i] << ", Expected 1 " << std::endl;
exit(-1);
}
}
}
}
}
@@ -43,7 +43,7 @@ class ShmemPtrTester : public Tester {
virtual void verifyResults(size_t size) override;
char *r_buf = nullptr;
char *dest = nullptr;
int *_available = nullptr;
};
@@ -560,7 +560,7 @@ bool Tester::peLaunchesKernel() {
}
void Tester::print(uint64_t size) {
if (args.myid != 0) {
if (args.myid != 0 || !_print_results) {
return;
}
@@ -165,8 +165,11 @@ class Tester {
bool *verification_error;
protected:
bool _print_results = true;
private:
bool _print_header = 1;
bool _print_header = true;
void print(uint64_t size);
void barrier();
@@ -96,11 +96,11 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case WAVESyncAllTestType:
case WGSyncAllTestType:
case SyncTestType:
case ShmemPtrTestType:
min_msg_size = 8;
max_msg_size = 8;
break;
case PingPongTestType:
case ShmemPtrTestType:
min_msg_size = 4;
max_msg_size = 4;
break;