Add host APIs for querying device ctx and remote heap pointer (#200)

* Add host APIs for querying device ctx and remote heap pointer

* Host API to query device pointer for ROCSHMEM_DEFAULT_CONTEXT,
  this is needed to support dynamic module initialization via device kernel
  library bitcode.
* Host API to query remote symmetric heap pointer that can be used in
  custom device kernel for RMA operations.

* Added rocshmem_ptr implementation within the Host Context class
* Enables pointer retrieval functionality for symmetric data objects
* Copy IPC pointers to host memory in RO host context

---------

Co-authored-by: avinashkethineedi <avinash.kethineedi@amd.com>
This commit is contained in:
Dimple Prajapati
2025-07-24 11:03:03 -07:00
committato da GitHub
parent 42e28835ad
commit 87f99e7ec6
12 ha cambiato i file con 144 aggiunte e 12 eliminazioni
+23
Vedi File
@@ -65,6 +65,29 @@ constexpr char VERSION[] = "3.0.0";
*/
__host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD);
/**
* @brief Query rocSHMEM context from host API
*
* @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users
* can query from one instance of rocshmem host library and
* use use later for dynamic module initialization in
* kernel bitcode device library in the same application
*/
__host__ void * rocshmem_get_device_ctx();
/**
* @brief Query rocSHMEM remote symmetric heap pointer
*
* @param[in] dest local symmetric heap allocation pointer for current pe/device
*
* @param[in] pe remote PE
*
* @param[out] ptr Returns remote symmetric heap device pointer from host-side API.
* This can be used to issue load/store from custom kernels
* instead of using rocshmem device side get/put APIs for RMA operations.
*/
__host__ void *rocshmem_ptr(void *dest, int pe);
/**
* @brief Initialize the rocSHMEM runtime and underlying transport layer
* with an attempt to enable the requested thread support.
+23 -8
Vedi File
@@ -68,12 +68,12 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND };
* @brief Device static dispatch method call with a return value.
*/
#ifdef USE_RO
#define DISPATCH_RET(Func) \
#define DISPATCH_RET(Func) \
auto ret_val = static_cast<ROContext *>(this)->Func; \
return ret_val;
#else
#define DISPATCH_RET(Func) \
auto ret_val{0}; \
#define DISPATCH_RET(Func) \
auto ret_val{0}; \
ret_val = static_cast<IPCContext *>(this)->Func; \
return ret_val;
#endif
@@ -86,8 +86,8 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND };
ret_val = static_cast<ROContext *>(this)->Func; \
return ret_val;
#else
#define DISPATCH_RET_PTR(Func) \
void *ret_val{nullptr}; \
#define DISPATCH_RET_PTR(Func) \
void *ret_val{nullptr}; \
ret_val = static_cast<IPCContext *>(this)->Func; \
return ret_val;
#endif
@@ -113,12 +113,27 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND };
*/
#ifdef USE_RO
#define HOST_DISPATCH_RET(Func) \
#define HOST_DISPATCH_RET(Func) \
auto ret_val = static_cast<ROHostContext *>(this)->Func; \
return ret_val;
#else
#define HOST_DISPATCH_RET(Func) \
auto ret_val{0}; \
#define HOST_DISPATCH_RET(Func) \
auto ret_val{0}; \
ret_val = static_cast<IPCHostContext *>(this)->Func; \
return ret_val;
#endif
/**
* @brief Host static dispatch method call with a return type of pointer.
*/
#ifdef USE_RO
#define HOST_DISPATCH_RET_PTR(Func) \
void *ret_val{nullptr}; \
ret_val = static_cast<ROHostContext *>(this)->Func; \
return ret_val;
#else
#define HOST_DISPATCH_RET_PTR(Func) \
void *ret_val{nullptr}; \
ret_val = static_cast<IPCHostContext *>(this)->Func; \
return ret_val;
#endif
+4
Vedi File
@@ -59,6 +59,8 @@ class Context {
__device__ Context(Backend* handle, bool shareable);
__host__ virtual ~Context();
/*
* Dispatch functions to get runtime polymorphism without 'virtual' or
* function pointers. Each one of these guys will use 'type' to
@@ -387,6 +389,8 @@ class Context {
__host__ void quiet();
__host__ void* shmem_ptr(const void* dest, int pe);
__host__ void barrier_all();
__host__ void sync_all();
+11 -1
Vedi File
@@ -31,7 +31,11 @@ namespace rocshmem {
__host__ Context::Context(Backend* handle, bool shareable)
: num_pes(handle->getNumPEs()),
my_pe(handle->getMyPE()),
fence_(shareable) {}
fence_(shareable) {
}
__host__ Context::~Context() {
}
/******************************************************************************
********************** CONTEXT DISPATCH IMPLEMENTATIONS **********************
@@ -93,6 +97,12 @@ __host__ void Context::quiet() {
HOST_DISPATCH(quiet());
}
__host__ void* Context::shmem_ptr(const void* dest, int pe) {
ctxHostStats.incStat(NUM_HOST_SHMEM_PTR);
HOST_DISPATCH_RET_PTR(shmem_ptr(dest, pe));
}
__host__ void Context::sync_all() {
ctxHostStats.incStat(NUM_HOST_SYNC_ALL);
+20
Vedi File
@@ -42,9 +42,20 @@ __host__ IPCHostContext::IPCHostContext(Backend *backend,
host_interface = b->host_interface;
context_window_info = host_interface->acquire_window_context();
char** ipc_bases = new char*[b->ipcImpl.shm_size];
CHECK_HIP(hipMemcpy(ipc_bases,
b->ipcImpl.ipc_bases,
b->ipcImpl.shm_size * sizeof(char *),
hipMemcpyDeviceToHost));
ipcImpl_.ipc_bases = ipc_bases;
}
__host__ IPCHostContext::~IPCHostContext() {
delete[] ipcImpl_.ipc_bases;
host_interface->release_window_context(context_window_info);
}
@@ -76,6 +87,15 @@ __host__ void IPCHostContext::quiet() {
host_interface->quiet(context_window_info);
}
__host__ void *IPCHostContext::shmem_ptr(const void *dest, int pe) {
void *ret = nullptr;
void *dst = const_cast<void *>(dest);
uint64_t L_offset =
reinterpret_cast<char *>(dst) - ipcImpl_.ipc_bases[my_pe];
ret = ipcImpl_.ipc_bases[pe] + L_offset;
return ret;
}
__host__ void IPCHostContext::sync_all() {
host_interface->sync_all(context_window_info);
}
+2
Vedi File
@@ -78,6 +78,8 @@ class IPCHostContext : public Context {
__host__ void quiet();
__host__ void *shmem_ptr(const void *dest, int pe);
__host__ void barrier_all();
__host__ void sync_all();
+1 -1
Vedi File
@@ -110,7 +110,7 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases,
free(vec_ipc_handle);
if (0 == rocshmem_env_.get_ro_disable_ipc()) {
int thread_comm_rank;
int thread_comm_rank {-1};
CHECK_HIP(hipMalloc(reinterpret_cast<void**>(&pes_with_ipc_avail), shm_size * sizeof(int)));
+2 -2
Vedi File
@@ -61,7 +61,7 @@ class IpcOnImpl {
__host__ void ipcHostStop();
__device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) {
__host__ __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) {
if (nullptr == pes_with_ipc_avail) { return false; }
for (int i=0; i<shm_size; i++) {
@@ -146,7 +146,7 @@ class IpcOffImpl {
__host__ void ipcHostStop() {}
__device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { return false; }
__host__ __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { return false; }
__device__ void ipcGpuInit(Backend *rocshmem_handle, Context *ctx,
int thread_id) {}
@@ -41,10 +41,30 @@ __host__ ROHostContext::ROHostContext(Backend *backend, long options)
host_interface = b->host_interface;
context_window_info = dynamic_cast<WindowInfoMPI*>(host_interface->acquire_window_context());
int *pes_with_ipc_avail = new int[backend->ipcImpl.shm_size];
char** ipc_bases = new char*[b->ipcImpl.shm_size];
CHECK_HIP(hipMemcpy(pes_with_ipc_avail,
backend->ipcImpl.pes_with_ipc_avail,
backend->ipcImpl.shm_size * sizeof(int),
hipMemcpyDeviceToHost));
CHECK_HIP(hipMemcpy(ipc_bases,
backend->ipcImpl.ipc_bases,
backend->ipcImpl.shm_size * sizeof(char *),
hipMemcpyDeviceToHost));
ipcImpl_.pes_with_ipc_avail = pes_with_ipc_avail;
ipcImpl_.ipc_bases = ipc_bases;
ipcImpl_.shm_size = backend->ipcImpl.shm_size;
ipcImpl_.shm_rank = backend->ipcImpl.shm_rank;
}
__host__ ROHostContext::~ROHostContext() {
// host_interface->release_window_context(context_window_info);
delete[] ipcImpl_.pes_with_ipc_avail;
delete[] ipcImpl_.ipc_bases;
}
__host__ void ROHostContext::putmem_nbi(void *dest, const void *source,
@@ -87,6 +107,20 @@ __host__ void ROHostContext::quiet() {
host_interface->quiet(context_window_info);
}
__host__ void *ROHostContext::shmem_ptr(const void *dest, int pe) {
DPRINTF("Function: ro_net_host_shmem_ptr\n");
void *ret = nullptr;
int local_pe{-1};
if (ipcImpl_.isIpcAvailable(my_pe, pe, &local_pe)) {
void *dst = const_cast<void *>(dest);
uint64_t L_offset =
reinterpret_cast<char *>(dst) - ipcImpl_.ipc_bases[ipcImpl_.shm_rank];
ret = ipcImpl_.ipc_bases[local_pe] + L_offset;
}
return ret;
}
__host__ void ROHostContext::sync_all() {
DPRINTF("Function: ro_net_host_sync_all\n");
@@ -127,6 +127,8 @@ class ROHostContext : public Context {
__host__ void quiet();
__host__ void *shmem_ptr(const void *dest, int pe);
__host__ void barrier_all();
__host__ void sync_all();
+7
Vedi File
@@ -300,6 +300,13 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
backend->heap.free(ptr);
}
__host__ void * rocshmem_ptr(void * dest, int pe){
Context *ctx = reinterpret_cast<Context *>(ROCSHMEM_HOST_CTX_DEFAULT.ctx_opaque);
return ctx->shmem_ptr(dest, pe);
}
[[maybe_unused]] __host__ void rocshmem_reset_stats() {
VERIFY_BACKEND();
backend->reset_stats();
+15
Vedi File
@@ -92,6 +92,21 @@ __device__ void rocshmem_query_thread(int *provided) {
__device__ void rocshmem_wg_finalize() {}
/******************************************************************************
* These host APIs use Device side symbol - ROCSHMEM_CTX_DEFAULT so it needs
* to stay here to avoid getting pulled into other places in compilation
******************************************************************************/
__host__ void * rocshmem_get_device_ctx() {
void *ctx = nullptr;
CHECK_HIP(hipMemcpyFromSymbol(&ctx, HIP_SYMBOL(ROCSHMEM_CTX_DEFAULT),
sizeof(rocshmem_ctx_t)));
return ctx;
}
/******************************************************************************
************************** Default Context Wrappers **************************
*****************************************************************************/