Add team information to the context

* Update roc_shmem_ctx_fence API to use team-relative PE numbering
* Update backend to populate team_opaque member of ROC_SHMEM_CTX_DEFAULT (used to store information about the team wrt TEAM_WORLD)


[ROCm/rocshmem commit: 92fb1abaf2]
This commit is contained in:
avinashkethineedi
2024-10-04 17:56:15 +00:00
vanhempi 69784a7423
commit 37b1de86cd
8 muutettua tiedostoa jossa 33 lisäystä ja 14 poistoa
@@ -296,7 +296,8 @@ void GPUIBBackend::setup_default_ctx() {
CHECK_HIP(hipGetSymbolAddress(reinterpret_cast<void **>(&symbol_address),
HIP_SYMBOL(ROC_SHMEM_CTX_DEFAULT)));
roc_shmem_ctx_t ctx_default_host{default_ctx_, nullptr};
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
roc_shmem_ctx_t ctx_default_host{default_ctx_, tinfo};
hipStream_t stream;
CHECK_HIP(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
@@ -82,10 +82,11 @@ IPCBackend::IPCBackend(MPI_Comm comm)
allocate_atomic_region(&bp->atomic_ret, MAX_NUM_BLOCKS);
default_context_proxy_ = IPCDefaultContextProxyT(this);
setup_team_world();
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
default_context_proxy_ = IPCDefaultContextProxyT(this, tinfo);
roc_shmem_collective_init();
setup_fence_buffer();
@@ -143,6 +144,8 @@ __device__ bool IPCBackend::create_ctx(int64_t options, roc_shmem_ctx_t *ctx) {
ctx_ = pop_result.value;
ctx->ctx_opaque = ctx_;
ctx_->tinfo = reinterpret_cast<TeamInfo *>(ctx->team_opaque);
return true;
}
@@ -88,8 +88,8 @@ __device__ void IPCContext::getmem_nbi(void *dest, const void *source,
}
__device__ void IPCContext::fence() {
for (int i{0}; i < num_pes; i++) {
detail::atomic::store<int, detail::atomic::memory_scope_system>(&fence_pool[i], 1, orders_);
for (int i{0}, j{tinfo->pe_start}; i < tinfo->size; i++, j += tinfo->stride) {
detail::atomic::store<int, detail::atomic::memory_scope_system>(&fence_pool[j], 1, orders_);
}
}
@@ -25,6 +25,7 @@
#include "../context.hpp"
#include "../atomic.hpp"
#include "../team.hpp"
namespace rocshmem {
@@ -239,6 +240,14 @@ class IPCContext : public Context {
//Buffer to perform Atomic store to enforce memory ordering
int *fence_pool{nullptr};
public:
//TODO(Avinash):
//Make tinfo private variable, it requires changes to the context
//creation API in backend
//Team information for the team associated with the context
TeamInfo *tinfo{nullptr};
};
} // namespace rocshmem
@@ -41,10 +41,11 @@ class IPCDefaultContextProxy {
/*
* Placement new the memory which is allocated by proxy_
*/
explicit IPCDefaultContextProxy(IPCBackend* backend) : constructed_{true} {
explicit IPCDefaultContextProxy(IPCBackend* backend, TeamInfo *tinfo)
: constructed_{true} {
auto ctx{proxy_.get()};
new (ctx) IPCContext(reinterpret_cast<Backend*>(backend));
roc_shmem_ctx_t local{ctx, nullptr};
roc_shmem_ctx_t local{ctx, tinfo};
set_internal_ctx(&local);
}
@@ -94,7 +94,9 @@ ROBackend::ROBackend(MPI_Comm comm)
default_block_handle_proxy_ = DefaultBlockHandleProxyT(
bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get());
default_context_proxy_ = DefaultContextProxyT(this);
TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world;
default_context_proxy_ = DefaultContextProxyT(this, tinfo);
block_handle_proxy_ = BlockHandleProxyT(bp->g_ret, bp->atomic_ret, &queue_,
&ipcImpl, hdp_proxy_.get());
@@ -42,10 +42,11 @@ class DefaultContextProxy {
/*
* Placement new the memory which is allocated by proxy_
*/
explicit DefaultContextProxy(ROBackend* backend) : constructed_{true} {
explicit DefaultContextProxy(ROBackend* backend, TeamInfo *tinfo)
: constructed_{true} {
auto ctx{proxy_.get()};
new (ctx) ROContext(reinterpret_cast<Backend*>(backend), -1);
roc_shmem_ctx_t local{ctx, nullptr};
roc_shmem_ctx_t local{ctx, tinfo};
set_internal_ctx(&local);
}
@@ -262,9 +262,9 @@ __device__ int roc_shmem_wg_ctx_create(long option, roc_shmem_ctx_t *ctx) {
GPU_DPRINTF("Function: roc_shmem_ctx_create\n");
bool result{true};
if (get_flat_block_id() == 0) {
ctx->team_opaque = reinterpret_cast<TeamInfo *>(ROC_SHMEM_CTX_DEFAULT.team_opaque);
device_backend_proxy->create_ctx(option, ctx);
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(option);
ctx->team_opaque = nullptr;
}
__syncthreads();
return result == true ? 0 : -1;
@@ -279,11 +279,11 @@ __device__ int roc_shmem_wg_team_create_ctx(roc_shmem_team_t team, long options,
bool result{true};
if (get_flat_block_id() == 0) {
result = device_backend_proxy->create_ctx(options, ctx);
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(options);
Team *team_obj{get_internal_team(team)};
TeamInfo *info_wrt_world = team_obj->tinfo_wrt_world;
ctx->team_opaque = info_wrt_world;
result = device_backend_proxy->create_ctx(options, ctx);
reinterpret_cast<Context *>(ctx->ctx_opaque)->setFence(options);
}
__syncthreads();
@@ -412,7 +412,9 @@ __device__ void roc_shmem_ctx_fence(roc_shmem_ctx_t ctx) {
__device__ void roc_shmem_ctx_fence(roc_shmem_ctx_t ctx, int pe) {
GPU_DPRINTF("Function: roc_shmem_ctx_fence\n");
get_internal_ctx(ctx)->fence(pe);
int pe_in_world = translate_pe(ctx, pe);
get_internal_ctx(ctx)->fence(pe_in_world);
}
__device__ void roc_shmem_ctx_quiet(roc_shmem_ctx_t ctx) {