Add host APIs for querying device ctx and remote heap pointer (#200)
* Add host APIs for querying device ctx and remote heap pointer * Host API to query device pointer for ROCSHMEM_DEFAULT_CONTEXT, this is needed to support dynamic module initialization via device kernel library bitcode. * Host API to query remote symmetric heap pointer that can be used in custom device kernel for RMA operations. * Added rocshmem_ptr implementation within the Host Context class * Enables pointer retrieval functionality for symmetric data objects * Copy IPC pointers to host memory in RO host context --------- Co-authored-by: avinashkethineedi <avinash.kethineedi@amd.com>
This commit is contained in:
committato da
GitHub
parent
42e28835ad
commit
87f99e7ec6
@@ -65,6 +65,29 @@ constexpr char VERSION[] = "3.0.0";
|
||||
*/
|
||||
__host__ void rocshmem_init(MPI_Comm comm = MPI_COMM_WORLD);
|
||||
|
||||
/**
|
||||
* @brief Query rocSHMEM context from host API
|
||||
*
|
||||
* @param[out] ctx Returns ROCSHMEM_CTX_DEFAULT device pointer that users
|
||||
* can query from one instance of rocshmem host library and
|
||||
* use use later for dynamic module initialization in
|
||||
* kernel bitcode device library in the same application
|
||||
*/
|
||||
__host__ void * rocshmem_get_device_ctx();
|
||||
|
||||
/**
|
||||
* @brief Query rocSHMEM remote symmetric heap pointer
|
||||
*
|
||||
* @param[in] dest local symmetric heap allocation pointer for current pe/device
|
||||
*
|
||||
* @param[in] pe remote PE
|
||||
*
|
||||
* @param[out] ptr Returns remote symmetric heap device pointer from host-side API.
|
||||
* This can be used to issue load/store from custom kernels
|
||||
* instead of using rocshmem device side get/put APIs for RMA operations.
|
||||
*/
|
||||
__host__ void *rocshmem_ptr(void *dest, int pe);
|
||||
|
||||
/**
|
||||
* @brief Initialize the rocSHMEM runtime and underlying transport layer
|
||||
* with an attempt to enable the requested thread support.
|
||||
|
||||
+23
-8
@@ -68,12 +68,12 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND };
|
||||
* @brief Device static dispatch method call with a return value.
|
||||
*/
|
||||
#ifdef USE_RO
|
||||
#define DISPATCH_RET(Func) \
|
||||
#define DISPATCH_RET(Func) \
|
||||
auto ret_val = static_cast<ROContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#else
|
||||
#define DISPATCH_RET(Func) \
|
||||
auto ret_val{0}; \
|
||||
#define DISPATCH_RET(Func) \
|
||||
auto ret_val{0}; \
|
||||
ret_val = static_cast<IPCContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#endif
|
||||
@@ -86,8 +86,8 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND };
|
||||
ret_val = static_cast<ROContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#else
|
||||
#define DISPATCH_RET_PTR(Func) \
|
||||
void *ret_val{nullptr}; \
|
||||
#define DISPATCH_RET_PTR(Func) \
|
||||
void *ret_val{nullptr}; \
|
||||
ret_val = static_cast<IPCContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#endif
|
||||
@@ -113,12 +113,27 @@ enum class BackendType { RO_BACKEND, IPC_BACKEND };
|
||||
*/
|
||||
|
||||
#ifdef USE_RO
|
||||
#define HOST_DISPATCH_RET(Func) \
|
||||
#define HOST_DISPATCH_RET(Func) \
|
||||
auto ret_val = static_cast<ROHostContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#else
|
||||
#define HOST_DISPATCH_RET(Func) \
|
||||
auto ret_val{0}; \
|
||||
#define HOST_DISPATCH_RET(Func) \
|
||||
auto ret_val{0}; \
|
||||
ret_val = static_cast<IPCHostContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Host static dispatch method call with a return type of pointer.
|
||||
*/
|
||||
#ifdef USE_RO
|
||||
#define HOST_DISPATCH_RET_PTR(Func) \
|
||||
void *ret_val{nullptr}; \
|
||||
ret_val = static_cast<ROHostContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#else
|
||||
#define HOST_DISPATCH_RET_PTR(Func) \
|
||||
void *ret_val{nullptr}; \
|
||||
ret_val = static_cast<IPCHostContext *>(this)->Func; \
|
||||
return ret_val;
|
||||
#endif
|
||||
|
||||
@@ -59,6 +59,8 @@ class Context {
|
||||
|
||||
__device__ Context(Backend* handle, bool shareable);
|
||||
|
||||
__host__ virtual ~Context();
|
||||
|
||||
/*
|
||||
* Dispatch functions to get runtime polymorphism without 'virtual' or
|
||||
* function pointers. Each one of these guys will use 'type' to
|
||||
@@ -387,6 +389,8 @@ class Context {
|
||||
|
||||
__host__ void quiet();
|
||||
|
||||
__host__ void* shmem_ptr(const void* dest, int pe);
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
+11
-1
@@ -31,7 +31,11 @@ namespace rocshmem {
|
||||
__host__ Context::Context(Backend* handle, bool shareable)
|
||||
: num_pes(handle->getNumPEs()),
|
||||
my_pe(handle->getMyPE()),
|
||||
fence_(shareable) {}
|
||||
fence_(shareable) {
|
||||
}
|
||||
|
||||
__host__ Context::~Context() {
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
********************** CONTEXT DISPATCH IMPLEMENTATIONS **********************
|
||||
@@ -93,6 +97,12 @@ __host__ void Context::quiet() {
|
||||
HOST_DISPATCH(quiet());
|
||||
}
|
||||
|
||||
__host__ void* Context::shmem_ptr(const void* dest, int pe) {
|
||||
ctxHostStats.incStat(NUM_HOST_SHMEM_PTR);
|
||||
|
||||
HOST_DISPATCH_RET_PTR(shmem_ptr(dest, pe));
|
||||
}
|
||||
|
||||
__host__ void Context::sync_all() {
|
||||
ctxHostStats.incStat(NUM_HOST_SYNC_ALL);
|
||||
|
||||
|
||||
@@ -42,9 +42,20 @@ __host__ IPCHostContext::IPCHostContext(Backend *backend,
|
||||
host_interface = b->host_interface;
|
||||
|
||||
context_window_info = host_interface->acquire_window_context();
|
||||
|
||||
char** ipc_bases = new char*[b->ipcImpl.shm_size];
|
||||
|
||||
CHECK_HIP(hipMemcpy(ipc_bases,
|
||||
b->ipcImpl.ipc_bases,
|
||||
b->ipcImpl.shm_size * sizeof(char *),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ipcImpl_.ipc_bases = ipc_bases;
|
||||
}
|
||||
|
||||
__host__ IPCHostContext::~IPCHostContext() {
|
||||
delete[] ipcImpl_.ipc_bases;
|
||||
|
||||
host_interface->release_window_context(context_window_info);
|
||||
}
|
||||
|
||||
@@ -76,6 +87,15 @@ __host__ void IPCHostContext::quiet() {
|
||||
host_interface->quiet(context_window_info);
|
||||
}
|
||||
|
||||
__host__ void *IPCHostContext::shmem_ptr(const void *dest, int pe) {
|
||||
void *ret = nullptr;
|
||||
void *dst = const_cast<void *>(dest);
|
||||
uint64_t L_offset =
|
||||
reinterpret_cast<char *>(dst) - ipcImpl_.ipc_bases[my_pe];
|
||||
ret = ipcImpl_.ipc_bases[pe] + L_offset;
|
||||
return ret;
|
||||
}
|
||||
|
||||
__host__ void IPCHostContext::sync_all() {
|
||||
host_interface->sync_all(context_window_info);
|
||||
}
|
||||
|
||||
@@ -78,6 +78,8 @@ class IPCHostContext : public Context {
|
||||
|
||||
__host__ void quiet();
|
||||
|
||||
__host__ void *shmem_ptr(const void *dest, int pe);
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
+1
-1
@@ -110,7 +110,7 @@ __host__ void IpcOnImpl::ipcHostInit(int my_pe, const HEAP_BASES_T &heap_bases,
|
||||
free(vec_ipc_handle);
|
||||
|
||||
if (0 == rocshmem_env_.get_ro_disable_ipc()) {
|
||||
int thread_comm_rank;
|
||||
int thread_comm_rank {-1};
|
||||
|
||||
CHECK_HIP(hipMalloc(reinterpret_cast<void**>(&pes_with_ipc_avail), shm_size * sizeof(int)));
|
||||
|
||||
|
||||
+2
-2
@@ -61,7 +61,7 @@ class IpcOnImpl {
|
||||
|
||||
__host__ void ipcHostStop();
|
||||
|
||||
__device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) {
|
||||
__host__ __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) {
|
||||
if (nullptr == pes_with_ipc_avail) { return false; }
|
||||
|
||||
for (int i=0; i<shm_size; i++) {
|
||||
@@ -146,7 +146,7 @@ class IpcOffImpl {
|
||||
|
||||
__host__ void ipcHostStop() {}
|
||||
|
||||
__device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { return false; }
|
||||
__host__ __device__ bool isIpcAvailable(int my_pe, int target_pe, int *local_target_pe) { return false; }
|
||||
|
||||
__device__ void ipcGpuInit(Backend *rocshmem_handle, Context *ctx,
|
||||
int thread_id) {}
|
||||
|
||||
@@ -41,10 +41,30 @@ __host__ ROHostContext::ROHostContext(Backend *backend, long options)
|
||||
host_interface = b->host_interface;
|
||||
|
||||
context_window_info = dynamic_cast<WindowInfoMPI*>(host_interface->acquire_window_context());
|
||||
|
||||
int *pes_with_ipc_avail = new int[backend->ipcImpl.shm_size];
|
||||
char** ipc_bases = new char*[b->ipcImpl.shm_size];
|
||||
|
||||
CHECK_HIP(hipMemcpy(pes_with_ipc_avail,
|
||||
backend->ipcImpl.pes_with_ipc_avail,
|
||||
backend->ipcImpl.shm_size * sizeof(int),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
CHECK_HIP(hipMemcpy(ipc_bases,
|
||||
backend->ipcImpl.ipc_bases,
|
||||
backend->ipcImpl.shm_size * sizeof(char *),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ipcImpl_.pes_with_ipc_avail = pes_with_ipc_avail;
|
||||
ipcImpl_.ipc_bases = ipc_bases;
|
||||
ipcImpl_.shm_size = backend->ipcImpl.shm_size;
|
||||
ipcImpl_.shm_rank = backend->ipcImpl.shm_rank;
|
||||
}
|
||||
|
||||
__host__ ROHostContext::~ROHostContext() {
|
||||
// host_interface->release_window_context(context_window_info);
|
||||
delete[] ipcImpl_.pes_with_ipc_avail;
|
||||
delete[] ipcImpl_.ipc_bases;
|
||||
}
|
||||
|
||||
__host__ void ROHostContext::putmem_nbi(void *dest, const void *source,
|
||||
@@ -87,6 +107,20 @@ __host__ void ROHostContext::quiet() {
|
||||
host_interface->quiet(context_window_info);
|
||||
}
|
||||
|
||||
__host__ void *ROHostContext::shmem_ptr(const void *dest, int pe) {
|
||||
DPRINTF("Function: ro_net_host_shmem_ptr\n");
|
||||
|
||||
void *ret = nullptr;
|
||||
int local_pe{-1};
|
||||
if (ipcImpl_.isIpcAvailable(my_pe, pe, &local_pe)) {
|
||||
void *dst = const_cast<void *>(dest);
|
||||
uint64_t L_offset =
|
||||
reinterpret_cast<char *>(dst) - ipcImpl_.ipc_bases[ipcImpl_.shm_rank];
|
||||
ret = ipcImpl_.ipc_bases[local_pe] + L_offset;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
__host__ void ROHostContext::sync_all() {
|
||||
DPRINTF("Function: ro_net_host_sync_all\n");
|
||||
|
||||
|
||||
@@ -127,6 +127,8 @@ class ROHostContext : public Context {
|
||||
|
||||
__host__ void quiet();
|
||||
|
||||
__host__ void *shmem_ptr(const void *dest, int pe);
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
@@ -300,6 +300,13 @@ rocshmem_ctx_t ROCSHMEM_HOST_CTX_DEFAULT;
|
||||
backend->heap.free(ptr);
|
||||
}
|
||||
|
||||
__host__ void * rocshmem_ptr(void * dest, int pe){
|
||||
|
||||
Context *ctx = reinterpret_cast<Context *>(ROCSHMEM_HOST_CTX_DEFAULT.ctx_opaque);
|
||||
|
||||
return ctx->shmem_ptr(dest, pe);
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ void rocshmem_reset_stats() {
|
||||
VERIFY_BACKEND();
|
||||
backend->reset_stats();
|
||||
|
||||
@@ -92,6 +92,21 @@ __device__ void rocshmem_query_thread(int *provided) {
|
||||
|
||||
__device__ void rocshmem_wg_finalize() {}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* These host APIs use Device side symbol - ROCSHMEM_CTX_DEFAULT so it needs
|
||||
* to stay here to avoid getting pulled into other places in compilation
|
||||
******************************************************************************/
|
||||
|
||||
__host__ void * rocshmem_get_device_ctx() {
|
||||
void *ctx = nullptr;
|
||||
|
||||
CHECK_HIP(hipMemcpyFromSymbol(&ctx, HIP_SYMBOL(ROCSHMEM_CTX_DEFAULT),
|
||||
sizeof(rocshmem_ctx_t)));
|
||||
return ctx;
|
||||
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
************************** Default Context Wrappers **************************
|
||||
*****************************************************************************/
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user