Add team information to the context
* Update roc_shmem_ctx_fence API to use team-relative PE numbering
* Update backend to populate team_opaque member of ROC_SHMEM_CTX_DEFAULT (used to store information about the team wrt TEAM_WORLD)
[ROCm/rocshmem commit: 92fb1abaf2]
This commit is contained in:
@@ -296,7 +296,8 @@ void GPUIBBackend::setup_default_ctx() {
|
||||
CHECK_HIP(hipGetSymbolAddress(reinterpret_cast<void **>(&symbol_address),
|
||||
HIP_SYMBOL(ROC_SHMEM_CTX_DEFAULT)));
|
||||
|
||||
roc_shmem_ctx_t ctx_default_host{default_ctx_, nullptr};
|
||||
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
|
||||
roc_shmem_ctx_t ctx_default_host{default_ctx_, tinfo};
|
||||
|
||||
hipStream_t stream;
|
||||
CHECK_HIP(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
|
||||
|
||||
@@ -82,10 +82,11 @@ IPCBackend::IPCBackend(MPI_Comm comm)
|
||||
|
||||
allocate_atomic_region(&bp->atomic_ret, MAX_NUM_BLOCKS);
|
||||
|
||||
default_context_proxy_ = IPCDefaultContextProxyT(this);
|
||||
|
||||
setup_team_world();
|
||||
|
||||
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
|
||||
default_context_proxy_ = IPCDefaultContextProxyT(this, tinfo);
|
||||
|
||||
roc_shmem_collective_init();
|
||||
|
||||
setup_fence_buffer();
|
||||
@@ -143,6 +144,8 @@ __device__ bool IPCBackend::create_ctx(int64_t options, roc_shmem_ctx_t *ctx) {
|
||||
ctx_ = pop_result.value;
|
||||
|
||||
ctx->ctx_opaque = ctx_;
|
||||
|
||||
ctx_->tinfo = reinterpret_cast<TeamInfo *>(ctx->team_opaque);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -88,8 +88,8 @@ __device__ void IPCContext::getmem_nbi(void *dest, const void *source,
|
||||
}
|
||||
|
||||
__device__ void IPCContext::fence() {
|
||||
for (int i{0}; i < num_pes; i++) {
|
||||
detail::atomic::store<int, detail::atomic::memory_scope_system>(&fence_pool[i], 1, orders_);
|
||||
for (int i{0}, j{tinfo->pe_start}; i < tinfo->size; i++, j += tinfo->stride) {
|
||||
detail::atomic::store<int, detail::atomic::memory_scope_system>(&fence_pool[j], 1, orders_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
|
||||
#include "../context.hpp"
|
||||
#include "../atomic.hpp"
|
||||
#include "../team.hpp"
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
@@ -239,6 +240,14 @@ class IPCContext : public Context {
|
||||
|
||||
//Buffer to perform Atomic store to enforce memory ordering
|
||||
int *fence_pool{nullptr};
|
||||
|
||||
public:
|
||||
//TODO(Avinash):
|
||||
//Make tinfo private variable, it requires changes to the context
|
||||
//creation API in backend
|
||||
|
||||
//Team information for the team associated with the context
|
||||
TeamInfo *tinfo{nullptr};
|
||||
};
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -41,10 +41,11 @@ class IPCDefaultContextProxy {
|
||||
/*
|
||||
* Placement new the memory which is allocated by proxy_
|
||||
*/
|
||||
explicit IPCDefaultContextProxy(IPCBackend* backend) : constructed_{true} {
|
||||
explicit IPCDefaultContextProxy(IPCBackend* backend, TeamInfo *tinfo)
|
||||
: constructed_{true} {
|
||||
auto ctx{proxy_.get()};
|
||||
new (ctx) IPCContext(reinterpret_cast<Backend*>(backend));
|
||||
roc_shmem_ctx_t local{ctx, nullptr};
|
||||
roc_shmem_ctx_t local{ctx, tinfo};
|
||||
set_internal_ctx(&local);
|
||||
}
|
||||
|
||||
|
||||
@@ -94,7 +94,9 @@ ROBackend::ROBackend(MPI_Comm comm)
|
||||
|
||||
default_block_handle_proxy_ = DefaultBlockHandleProxyT(
|
||||
bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get());
|
||||
default_context_proxy_ = DefaultContextProxyT(this);
|
||||
|
||||
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
|
||||
default_context_proxy_ = DefaultContextProxyT(this, tinfo);
|
||||
|
||||
block_handle_proxy_ = BlockHandleProxyT(bp->g_ret, bp->atomic_ret, &queue_,
|
||||
&ipcImpl, hdp_proxy_.get());
|
||||
|
||||
@@ -42,10 +42,11 @@ class DefaultContextProxy {
|
||||
/*
|
||||
* Placement new the memory which is allocated by proxy_
|
||||
*/
|
||||
explicit DefaultContextProxy(ROBackend* backend) : constructed_{true} {
|
||||
explicit DefaultContextProxy(ROBackend* backend, TeamInfo *tinfo)
|
||||
: constructed_{true} {
|
||||
auto ctx{proxy_.get()};
|
||||
new (ctx) ROContext(reinterpret_cast<Backend*>(backend), -1);
|
||||
roc_shmem_ctx_t local{ctx, nullptr};
|
||||
roc_shmem_ctx_t local{ctx, tinfo};
|
||||
set_internal_ctx(&local);
|
||||
}
|
||||
|
||||
|
||||
@@ -262,9 +262,9 @@ __device__ int roc_shmem_wg_ctx_create(long option, roc_shmem_ctx_t *ctx) {
|
||||
GPU_DPRINTF("Function: roc_shmem_ctx_create\n");
|
||||
bool result{true};
|
||||
if (get_flat_block_id() == 0) {
|
||||
ctx->team_opaque = reinterpret_cast<TeamInfo *>(ROC_SHMEM_CTX_DEFAULT.team_opaque);
|
||||
device_backend_proxy->create_ctx(option, ctx);
|
||||
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(option);
|
||||
ctx->team_opaque = nullptr;
|
||||
}
|
||||
__syncthreads();
|
||||
return result == true ? 0 : -1;
|
||||
@@ -279,11 +279,11 @@ __device__ int roc_shmem_wg_team_create_ctx(roc_shmem_team_t team, long options,
|
||||
|
||||
bool result{true};
|
||||
if (get_flat_block_id() == 0) {
|
||||
result = device_backend_proxy->create_ctx(options, ctx);
|
||||
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(options);
|
||||
Team *team_obj{get_internal_team(team)};
|
||||
TeamInfo *info_wrt_world = team_obj->tinfo_wrt_world;
|
||||
ctx->team_opaque = info_wrt_world;
|
||||
result = device_backend_proxy->create_ctx(options, ctx);
|
||||
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(options);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -412,7 +412,9 @@ __device__ void roc_shmem_ctx_fence(roc_shmem_ctx_t ctx) {
|
||||
__device__ void roc_shmem_ctx_fence(roc_shmem_ctx_t ctx, int pe) {
|
||||
GPU_DPRINTF("Function: roc_shmem_ctx_fence\n");
|
||||
|
||||
get_internal_ctx(ctx)->fence(pe);
|
||||
int pe_in_world = translate_pe(ctx, pe);
|
||||
|
||||
get_internal_ctx(ctx)->fence(pe_in_world);
|
||||
}
|
||||
|
||||
__device__ void roc_shmem_ctx_quiet(roc_shmem_ctx_t ctx) {
|
||||
|
||||
Viittaa uudesa ongelmassa
Block a user