From 87f99e7ec6d94558cc22a90c41f62c2fc2274878 Mon Sep 17 00:00:00 2001 From: Dimple Prajapati Date: Thu, 24 Jul 2025 11:03:03 -0700 Subject: [PATCH] Add host APIs for querying device ctx and remote heap pointer (#200) * Add host APIs for querying device ctx and remote heap pointer * Host API to query device pointer for ROCSHMEM_DEFAULT_CONTEXT, this is needed to support dynamic module initialization via device kernel library bitcode. * Host API to query remote symmetric heap pointer that can be used in custom device kernel for RMA operations. * Added rocshmem_ptr implementation within the Host Context class * Enables pointer retrieval functionality for symmetric data objects * Copy IPC pointers to host memory in RO host context --------- Co-authored-by: avinashkethineedi --- include/rocshmem/rocshmem.hpp | 23 +++++++++++++++++ src/backend_type.hpp | 31 ++++++++++++++++------ src/context.hpp | 4 +++ src/context_host.cpp | 12 ++++++++- src/ipc/context_ipc_host.cpp | 20 +++++++++++++++ src/ipc/context_ipc_host.hpp | 2 ++ src/ipc_policy.cpp | 2 +- src/ipc_policy.hpp | 4 +-- src/reverse_offload/context_ro_host.cpp | 34 +++++++++++++++++++++++++ src/reverse_offload/context_ro_host.hpp | 2 ++ src/rocshmem.cpp | 7 +++++ src/rocshmem_gpu.cpp | 15 +++++++++++ 12 files changed, 144 insertions(+), 12 deletions(-) diff --git a/include/rocshmem/rocshmem.hpp b/include/rocshmem/rocshmem.hpp index be4de8cf16..e2bf79abce 100644 --- a/include/rocshmem/rocshmem.hpp +++ b/include/rocshmem/rocshmem.hpp @@ -65,6 +65,29 @@ constexpr char VERSION[] = "3.0.0"; */ __host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD); +/** + * @brief Query rocSHMEM context from host API + * + * @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users + * can query from one instance of rocshmem host library and + * use use later for dynamic module initialization in + * kernel bitcode device library in the same application + */ +__host__ void * rocshmem_get_device_ctx(); + +/** + * @brief Query rocSHMEM remote symmetric heap pointer + * + * @param[in] dest local symmetric heap allocation pointer for current pe/device + * + * @param[in] pe remote PE + * + * @param[out] ptr Returns remote symmetric heap device pointer from host-side API. + * This can be used to issue load/store from custom kernels + * instead of using rocshmem device side get/put APIs for RMA operations. + */ +__host__ void *rocshmem_ptr(void *dest, int pe); + /** * @brief Initialize the rocSHMEM runtime and underlying transport layer * with an attempt to enable the requested thread support. diff --git a/src/backend_type.hpp b/src/backend_type.hpp index f7c2046939..98268c7422 100644 --- a/src/backend_type.hpp +++ b/src/backend_type.hpp @@ -68,12 +68,12 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND }; * @brief Device static dispatch method call with a return value. */ #ifdef USE_RO -#define DISPATCH_RET(Func) \ +#define DISPATCH_RET(Func) \ auto ret_val = static_cast(this)->Func; \ return ret_val; #else -#define DISPATCH_RET(Func) \ - auto ret_val{0}; \ +#define DISPATCH_RET(Func) \ + auto ret_val{0}; \ ret_val = static_cast(this)->Func; \ return ret_val; #endif @@ -86,8 +86,8 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND }; ret_val = static_cast(this)->Func; \ return ret_val; #else -#define DISPATCH_RET_PTR(Func) \ - void *ret_val{nullptr}; \ +#define DISPATCH_RET_PTR(Func) \ + void *ret_val{nullptr}; \ ret_val = static_cast(this)->Func; \ return ret_val; #endif @@ -113,12 +113,27 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND }; */ #ifdef USE_RO -#define HOST_DISPATCH_RET(Func) \ +#define HOST_DISPATCH_RET(Func) \ auto ret_val = static_cast(this)->Func; \ return ret_val; #else -#define HOST_DISPATCH_RET(Func) \ - auto ret_val{0}; \ +#define HOST_DISPATCH_RET(Func) \ + auto ret_val{0}; \ + ret_val = static_cast(this)->Func; \ + return ret_val; +#endif + +/** + * @brief Host static dispatch method call with a return type of pointer. + */ +#ifdef USE_RO +#define HOST_DISPATCH_RET_PTR(Func) \ + void *ret_val{nullptr}; \ + ret_val = static_cast(this)->Func; \ + return ret_val; +#else +#define HOST_DISPATCH_RET_PTR(Func) \ + void *ret_val{nullptr}; \ ret_val = static_cast(this)->Func; \ return ret_val; #endif diff --git a/src/context.hpp b/src/context.hpp index 8268dc6a59..143893b430 100644 --- a/src/context.hpp +++ b/src/context.hpp @@ -59,6 +59,8 @@ class Context { __device__ Context(Backend* handle, bool shareable); + __host__ virtual ~Context(); + /* * Dispatch functions to get runtime polymorphism without 'virtual' or * function pointers. Each one of these guys will use 'type' to @@ -387,6 +389,8 @@ class Context { __host__ void quiet(); + __host__ void* shmem_ptr(const void* dest, int pe); + __host__ void barrier_all(); __host__ void sync_all(); diff --git a/src/context_host.cpp b/src/context_host.cpp index 7e30d9ec73..e6b3304ebd 100644 --- a/src/context_host.cpp +++ b/src/context_host.cpp @@ -31,7 +31,11 @@ namespace rocshmem { __host__ Context::Context(Backend* handle, bool shareable) : num_pes(handle->getNumPEs()), my_pe(handle->getMyPE()), - fence_(shareable) {} + fence_(shareable) { +} + +__host__ Context::~Context() { +} /****************************************************************************** ********************** CONTEXT DISPATCH IMPLEMENTATIONS ********************** @@ -93,6 +97,12 @@ __host__ void Context::quiet() { HOST_DISPATCH(quiet()); } +__host__ void* Context::shmem_ptr(const void* dest, int pe) { + ctxHostStats.incStat(NUM_HOST_SHMEM_PTR); + + HOST_DISPATCH_RET_PTR(shmem_ptr(dest, pe)); +} + __host__ void Context::sync_all() { ctxHostStats.incStat(NUM_HOST_SYNC_ALL); diff --git a/src/ipc/context_ipc_host.cpp b/src/ipc/context_ipc_host.cpp index fe2e92b8a7..7c459df1ba 100644 --- a/src/ipc/context_ipc_host.cpp +++ b/src/ipc/context_ipc_host.cpp @@ -42,9 +42,20 @@ __host__ IPCHostContext::IPCHostContext(Backend *backend, host_interface = b->host_interface; context_window_info = host_interface->acquire_window_context(); + + char** ipc_bases = new char*[b->ipcImpl.shm_size]; + + CHECK_HIP(hipMemcpy(ipc_bases, + b->ipcImpl.ipc_bases, + b->ipcImpl.shm_size * sizeof(char *), + hipMemcpyDeviceToHost)); + + ipcImpl_.ipc_bases = ipc_bases; } __host__ IPCHostContext::~IPCHostContext() { + delete[] ipcImpl_.ipc_bases; + host_interface->release_window_context(context_window_info); } @@ -76,6 +87,15 @@ __host__ void IPCHostContext::quiet() { host_interface->quiet(context_window_info); } +__host__ void *IPCHostContext::shmem_ptr(const void *dest, int pe) { + void *ret = nullptr; + void *dst = const_cast(dest); + uint64_t L_offset = + reinterpret_cast(dst) - ipcImpl_.ipc_bases[my_pe]; + ret = ipcImpl_.ipc_bases[pe] + L_offset; + return ret; +} + __host__ void IPCHostContext::sync_all() { host_interface->sync_all(context_window_info); } diff --git a/src/ipc/context_ipc_host.hpp b/src/ipc/context_ipc_host.hpp index 4665e12f30..ddec120299 100644 --- a/src/ipc/context_ipc_host.hpp +++ b/src/ipc/context_ipc_host.hpp @@ -78,6 +78,8 @@ class IPCHostContext : public Context { __host__ void quiet(); + __host__ void *shmem_ptr(const void *dest, int pe); + __host__ void barrier_all(); __host__ void sync_all(); diff --git a/src/ipc_policy.cpp b/src/ipc_policy.cpp index 0bee8eafb7..7a6d0639e6 100644 --- a/src/ipc_policy.cpp +++ b/src/ipc_policy.cpp @@ -110,7 +110,7 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases, free(vec_ipc_handle); if (0 == rocshmem_env_.get_ro_disable_ipc()) { - int thread_comm_rank; + int thread_comm_rank {-1}; CHECK_HIP(hipMalloc(reinterpret_cast(&pes_with_ipc_avail), shm_size * sizeof(int))); diff --git a/src/ipc_policy.hpp b/src/ipc_policy.hpp index edf350a9f5..7f1e17c925 100644 --- a/src/ipc_policy.hpp +++ b/src/ipc_policy.hpp @@ -61,7 +61,7 @@ class IpcOnImpl { __host__ void ipcHostStop(); - __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { + __host__ __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { if (nullptr == pes_with_ipc_avail) { return false; } for (int i=0; ihost_interface; context_window_info = dynamic_cast(host_interface->acquire_window_context()); + + int *pes_with_ipc_avail = new int[backend->ipcImpl.shm_size]; + char** ipc_bases = new char*[b->ipcImpl.shm_size]; + + CHECK_HIP(hipMemcpy(pes_with_ipc_avail, + backend->ipcImpl.pes_with_ipc_avail, + backend->ipcImpl.shm_size * sizeof(int), + hipMemcpyDeviceToHost)); + + CHECK_HIP(hipMemcpy(ipc_bases, + backend->ipcImpl.ipc_bases, + backend->ipcImpl.shm_size * sizeof(char *), + hipMemcpyDeviceToHost)); + + ipcImpl_.pes_with_ipc_avail = pes_with_ipc_avail; + ipcImpl_.ipc_bases = ipc_bases; + ipcImpl_.shm_size = backend->ipcImpl.shm_size; + ipcImpl_.shm_rank = backend->ipcImpl.shm_rank; } __host__ ROHostContext::~ROHostContext() { // host_interface->release_window_context(context_window_info); + delete[] ipcImpl_.pes_with_ipc_avail; + delete[] ipcImpl_.ipc_bases; } __host__ void ROHostContext::putmem_nbi(void *dest, const void *source, @@ -87,6 +107,20 @@ __host__ void ROHostContext::quiet() { host_interface->quiet(context_window_info); } +__host__ void *ROHostContext::shmem_ptr(const void *dest, int pe) { + DPRINTF("Function: ro_net_host_shmem_ptr\n"); + + void *ret = nullptr; + int local_pe{-1}; + if (ipcImpl_.isIpcAvailable(my_pe, pe, &local_pe)) { + void *dst = const_cast(dest); + uint64_t L_offset = + reinterpret_cast(dst) - ipcImpl_.ipc_bases[ipcImpl_.shm_rank]; + ret = ipcImpl_.ipc_bases[local_pe] + L_offset; + } + return ret; +} + __host__ void ROHostContext::sync_all() { DPRINTF("Function: ro_net_host_sync_all\n"); diff --git a/src/reverse_offload/context_ro_host.hpp b/src/reverse_offload/context_ro_host.hpp index c306c0690e..13d06f94a5 100644 --- a/src/reverse_offload/context_ro_host.hpp +++ b/src/reverse_offload/context_ro_host.hpp @@ -127,6 +127,8 @@ class ROHostContext : public Context { __host__ void quiet(); + __host__ void *shmem_ptr(const void *dest, int pe); + __host__ void barrier_all(); __host__ void sync_all(); diff --git a/src/rocshmem.cpp b/src/rocshmem.cpp index f992a70bbb..a5d69cf179 100644 --- a/src/rocshmem.cpp +++ b/src/rocshmem.cpp @@ -300,6 +300,13 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT; backend->heap.free(ptr); } +__host__ void * rocshmem_ptr(void * dest, int pe){ + + Context *ctx = reinterpret_cast(ROCSHMEM_HOST_CTX_DEFAULT.ctx_opaque); + + return ctx->shmem_ptr(dest, pe); +} + [[maybe_unused]] __host__ void rocshmem_reset_stats() { VERIFY_BACKEND(); backend->reset_stats(); diff --git a/src/rocshmem_gpu.cpp b/src/rocshmem_gpu.cpp index e2571c62b6..2a52f2bb8b 100644 --- a/src/rocshmem_gpu.cpp +++ b/src/rocshmem_gpu.cpp @@ -92,6 +92,21 @@ __device__ void rocshmem_query_thread(int *provided) { __device__ void rocshmem_wg_finalize() {} + +/****************************************************************************** +* These host APIs use Device side symbol - ROCSHMEM_CTX_DEFAULT so it needs +* to stay here to avoid getting pulled into other places in compilation +******************************************************************************/ + +__host__ void * rocshmem_get_device_ctx() { + void *ctx = nullptr; + + CHECK_HIP(hipMemcpyFromSymbol(&ctx, HIP_SYMBOL(ROCSHMEM_CTX_DEFAULT), + sizeof(rocshmem_ctx_t))); + return ctx; + +} + /****************************************************************************** ************************** Default Context Wrappers ************************** *****************************************************************************/