From 2a7416d016b313be16a66038bbfa1c17cf02fcaa Mon Sep 17 00:00:00 2001 From: Avinash Kethineedi Date: Mon, 28 Jul 2025 12:01:02 -0500 Subject: [PATCH] Implement `rocshmem_ptr` in IPC conduit (#197) * Implement `rocshmem_ptr` in IPC conduit * tests: add functional test for `rocshmem_ptr` - Add safety check for pointer access and condition check before printing results for `rocshmem_ptr` test - Use `rocshmem_put` to store `rocshmem_ptr` availability for data validation [ROCm/rocshmem commit: 526105d3157ca0fdea5febccc4a643efa1ec30bd] --- .../scripts/functional_tests/driver.sh | 5 + .../rocshmem/src/ipc/context_ipc_device.cpp | 4 + .../functional_tests/shmem_ptr_tester.cpp | 125 +++++++++++++++--- .../functional_tests/shmem_ptr_tester.hpp | 2 +- .../tests/functional_tests/tester.cpp | 2 +- .../tests/functional_tests/tester.hpp | 5 +- .../functional_tests/tester_arguments.cpp | 2 +- 7 files changed, 122 insertions(+), 23 deletions(-) diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index f9bd015005..f164ab904f 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -200,6 +200,11 @@ TestRMAPut() { ExecTest "p" 2 8 1 32 ExecTest "p" 2 16 128 4 + ExecTest "shmemptr" 2 1 1 8 + ExecTest "shmemptr" 2 1 1024 8 + ExecTest "shmemptr" 2 8 1 8 + ExecTest "shmemptr" 2 16 128 8 + ################################ Non-Blocking ################################ ExecTest "putnbi" 2 1 1 1048576 diff --git a/projects/rocshmem/src/ipc/context_ipc_device.cpp b/projects/rocshmem/src/ipc/context_ipc_device.cpp index 02af694071..f68144ceae 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.cpp @@ -106,6 +106,10 @@ __device__ void IPCContext::quiet() { __device__ void *IPCContext::shmem_ptr(const void *dest, int pe) { void *ret = nullptr; + void *dst = const_cast(dest); + uint64_t L_offset = + reinterpret_cast(dst) - ipcImpl_.ipc_bases[my_pe]; + ret = ipcImpl_.ipc_bases[pe] + L_offset; return ret; } diff --git a/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.cpp b/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.cpp index 019493829b..58dbaeab3c 100644 --- a/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.cpp @@ -31,18 +31,86 @@ using namespace rocshmem; /****************************************************************************** * DEVICE TEST KERNEL *****************************************************************************/ -__global__ void ShmemPtrTest(char *r_buf, int *available) { - rocshmem_wg_init(); +__global__ void ShmemPtrTest(int loop, int skip, long long int *start_time, + long long int *end_time, char *dest, int wf_size, + ShmemContextType ctx_type, int *available) { + __shared__ rocshmem_ctx_t ctx; + int wg_id = get_flat_grid_id(); + int t_id = get_flat_block_id(); + int wf_id = t_id / wf_size; + + rocshmem_wg_init(); + rocshmem_wg_ctx_create(ctx_type, &ctx); + + /** + * Shared array to capture the start time for each wavefront + * Max threads per block = 1024, wavefront size = 64 (in most GPUs) + * Maximum array size required = 1024/64 = 16 + */ + __shared__ long long int wf_start_time[16]; + + + /** + * Calculate start index for each thread within the grid + */ + dest += get_flat_id(); + + char *local_addr = dest; + void *remote_addr = rocshmem_ptr((void *)local_addr, 1); + if (remote_addr != NULL) { + *available = 1; + } + + if(*available) { + for (int i = 0; i < loop + skip; i++) { + if (i == skip) { + __syncthreads(); + // Ensures all RMA calls from the skip loops are completed + if(is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); + } + __syncthreads(); + // Capture the start time of each wavefront to identify the earliest one + wf_start_time[wf_id] = wall_clock64(); + } - if (hipThreadIdx_x == 0) { - char *local_addr = r_buf + 4; - void *remote_addr = rocshmem_ptr((void *)local_addr, 1); - if (remote_addr != NULL) { - *available = 1; ((char *)remote_addr)[0] = '1'; } } + __syncthreads(); + if(is_thread_zero_in_block()) { + rocshmem_ctx_quiet(ctx); + } + + /** + * End time of the last wavefront is recorded by overwriting + * the value previously set by earlier wavefronts. + */ + end_time[wg_id] = wall_clock64(); + + // Find the earliest start time + int num_wfs = (get_flat_block_size() - 1 ) / wf_size + 1; + for (int i = num_wfs / 2; i > 0; i >>= 1 ) { + if(t_id < i) { + wf_start_time[t_id] = min(wf_start_time[t_id], wf_start_time[t_id + i]); + } + } + + // For data validation in remote PE + if( get_flat_id() == 0 ) { + int *store_avail = (int*)(dest + get_flat_grid_size()); + *store_avail = *available; + rocshmem_ctx_int_put(ctx, store_avail, store_avail, 1, 1); + } + + __syncthreads(); + + if (t_id == 0) { + start_time[wg_id] = wf_start_time[0]; + } + + rocshmem_wg_ctx_destroy(&ctx); rocshmem_wg_finalize(); } @@ -50,17 +118,26 @@ __global__ void ShmemPtrTest(char *r_buf, int *available) { * HOST TESTER CLASS METHODS *****************************************************************************/ ShmemPtrTester::ShmemPtrTester(TesterArguments args) : Tester(args) { + size_t buff_size = args.wg_size * args.num_wgs + sizeof(int); CHECK_HIP(hipMalloc((void **)&_available, sizeof(int))); - r_buf = (char *)rocshmem_malloc(args.max_msg_size); + dest = (char *)rocshmem_malloc(buff_size); + + if (dest == nullptr) { + std::cerr << "Error allocating memory from symmetric heap" << std::endl; + std::cerr << "dest: " << dest << std::endl; + + rocshmem_global_exit(1); + } } ShmemPtrTester::~ShmemPtrTester() { CHECK_HIP(hipFree(_available)); - rocshmem_free(r_buf); + rocshmem_free(dest); } void ShmemPtrTester::resetBuffers(size_t size) { - memset(r_buf, '0', args.max_msg_size); + size_t buff_size = args.wg_size * args.num_wgs + sizeof(int); + memset(dest, '0', buff_size); memset(_available, 0, sizeof(int)); } @@ -68,22 +145,32 @@ void ShmemPtrTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, size_t size) { size_t shared_bytes = 0; - hipLaunchKernelGGL(ShmemPtrTest, gridSize, blockSize, shared_bytes, stream, - r_buf, _available); + hipLaunchKernelGGL(ShmemPtrTest, gridSize, blockSize, shared_bytes, + stream, loop, args.skip, start_time, end_time, + dest, wf_size, _type, _available); - num_msgs = 0; - num_timed_msgs = 0; + num_msgs = (loop + args.skip) * gridSize.x * blockSize.x; + num_timed_msgs = loop * gridSize.x * blockSize.x; } void ShmemPtrTester::verifyResults(size_t size) { if (args.myid == 0) { if (*_available == 0) { - fprintf(stderr, "SHMEM_PTR NOT AVAILBLE \n"); + _print_results = false; + std::cout << "rocshmem ptr not available\n" << std::endl; } - } else { - if (r_buf[4] != '1') { - fprintf(stderr, "Data validation error \n"); - fprintf(stderr, "Got %c, Expected %c\n", r_buf[4], '1'); + } + else { + size_t buff_size = args.wg_size * args.num_wgs; + int *available = (int*)(dest + buff_size); + if(*available == 1) { + for (size_t i = 0; i < buff_size; i++) { + if (dest[i] != '1') { + std::cerr << "Data validation error at idx " << i << std::endl; + std::cerr << " Got " << dest[i] << ", Expected 1 " << std::endl; + exit(-1); + } + } } } } diff --git a/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.hpp b/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.hpp index 89fa573814..aabb465672 100644 --- a/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/shmem_ptr_tester.hpp @@ -43,7 +43,7 @@ class ShmemPtrTester : public Tester { virtual void verifyResults(size_t size) override; - char *r_buf = nullptr; + char *dest = nullptr; int *_available = nullptr; }; diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 8a8ee222a9..0d891dcc62 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -560,7 +560,7 @@ bool Tester::peLaunchesKernel() { } void Tester::print(uint64_t size) { - if (args.myid != 0) { + if (args.myid != 0 || !_print_results) { return; } diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index 120616f298..877f8d22f7 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -165,8 +165,11 @@ class Tester { bool *verification_error; + protected: + bool _print_results = true; + private: - bool _print_header = 1; + bool _print_header = true; void print(uint64_t size); void barrier(); diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index 9f842b2f4a..933a1056bc 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -96,11 +96,11 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { case WAVESyncAllTestType: case WGSyncAllTestType: case SyncTestType: - case ShmemPtrTestType: min_msg_size = 8; max_msg_size = 8; break; case PingPongTestType: + case ShmemPtrTestType: min_msg_size = 4; max_msg_size = 4; break;