diff --git a/projects/rocshmem/src/gpu_ib/backend_ib.cpp b/projects/rocshmem/src/gpu_ib/backend_ib.cpp index 9ae3acd482..87b02473d3 100644 --- a/projects/rocshmem/src/gpu_ib/backend_ib.cpp +++ b/projects/rocshmem/src/gpu_ib/backend_ib.cpp @@ -296,7 +296,8 @@ void GPUIBBackend::setup_default_ctx() { CHECK_HIP(hipGetSymbolAddress(reinterpret_cast(&symbol_address), HIP_SYMBOL(ROC_SHMEM_CTX_DEFAULT))); - roc_shmem_ctx_t ctx_default_host{default_ctx_, nullptr}; + TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world; + roc_shmem_ctx_t ctx_default_host{default_ctx_, tinfo}; hipStream_t stream; CHECK_HIP(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); diff --git a/projects/rocshmem/src/ipc/backend_ipc.cpp b/projects/rocshmem/src/ipc/backend_ipc.cpp index b09adee2ff..70a9295486 100644 --- a/projects/rocshmem/src/ipc/backend_ipc.cpp +++ b/projects/rocshmem/src/ipc/backend_ipc.cpp @@ -82,10 +82,11 @@ IPCBackend::IPCBackend(MPI_Comm comm) allocate_atomic_region(&bp->atomic_ret, MAX_NUM_BLOCKS); - default_context_proxy_ = IPCDefaultContextProxyT(this); - setup_team_world(); + TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world; + default_context_proxy_ = IPCDefaultContextProxyT(this, tinfo); + roc_shmem_collective_init(); setup_fence_buffer(); @@ -143,6 +144,8 @@ __device__ bool IPCBackend::create_ctx(int64_t options, roc_shmem_ctx_t *ctx) { ctx_ = pop_result.value; ctx->ctx_opaque = ctx_; + + ctx_->tinfo = reinterpret_cast(ctx->team_opaque); return true; } diff --git a/projects/rocshmem/src/ipc/context_ipc_device.cpp b/projects/rocshmem/src/ipc/context_ipc_device.cpp index 036f5cf77b..4d891cbc53 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.cpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.cpp @@ -88,8 +88,8 @@ __device__ void IPCContext::getmem_nbi(void *dest, const void *source, } __device__ void IPCContext::fence() { - for (int i{0}; i < num_pes; i++) { - detail::atomic::store(&fence_pool[i], 1, orders_); + for (int i{0}, j{tinfo->pe_start}; i < tinfo->size; i++, j += tinfo->stride) { + detail::atomic::store(&fence_pool[j], 1, orders_); } } diff --git a/projects/rocshmem/src/ipc/context_ipc_device.hpp b/projects/rocshmem/src/ipc/context_ipc_device.hpp index 450bb2a3cc..48fe5acbdf 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.hpp @@ -25,6 +25,7 @@ #include "../context.hpp" #include "../atomic.hpp" +#include "../team.hpp" namespace rocshmem { @@ -239,6 +240,14 @@ class IPCContext : public Context { //Buffer to perform Atomic store to enforce memory ordering int *fence_pool{nullptr}; + + public: + //TODO(Avinash): + //Make tinfo private variable, it requires changes to the context + //creation API in backend + + //Team information for the team associated with the context + TeamInfo *tinfo{nullptr}; }; } // namespace rocshmem diff --git a/projects/rocshmem/src/ipc/ipc_context_proxy.hpp b/projects/rocshmem/src/ipc/ipc_context_proxy.hpp index 867f199094..87ca4371e0 100644 --- a/projects/rocshmem/src/ipc/ipc_context_proxy.hpp +++ b/projects/rocshmem/src/ipc/ipc_context_proxy.hpp @@ -41,10 +41,11 @@ class IPCDefaultContextProxy { /* * Placement new the memory which is allocated by proxy_ */ - explicit IPCDefaultContextProxy(IPCBackend* backend) : constructed_{true} { + explicit IPCDefaultContextProxy(IPCBackend* backend, TeamInfo *tinfo) + : constructed_{true} { auto ctx{proxy_.get()}; new (ctx) IPCContext(reinterpret_cast(backend)); - roc_shmem_ctx_t local{ctx, nullptr}; + roc_shmem_ctx_t local{ctx, tinfo}; set_internal_ctx(&local); } diff --git a/projects/rocshmem/src/reverse_offload/backend_ro.cpp b/projects/rocshmem/src/reverse_offload/backend_ro.cpp index 1c1bf4645c..96471181e3 100644 --- a/projects/rocshmem/src/reverse_offload/backend_ro.cpp +++ b/projects/rocshmem/src/reverse_offload/backend_ro.cpp @@ -94,7 +94,9 @@ ROBackend::ROBackend(MPI_Comm comm) default_block_handle_proxy_ = DefaultBlockHandleProxyT( bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get()); - default_context_proxy_ = DefaultContextProxyT(this); + + TeamInfo *tinfo = team_tracker.get_team_world()->tinfo_wrt_world; + default_context_proxy_ = DefaultContextProxyT(this, tinfo); block_handle_proxy_ = BlockHandleProxyT(bp->g_ret, bp->atomic_ret, &queue_, &ipcImpl, hdp_proxy_.get()); diff --git a/projects/rocshmem/src/reverse_offload/context_proxy.hpp b/projects/rocshmem/src/reverse_offload/context_proxy.hpp index d89281d921..4ae94d3851 100644 --- a/projects/rocshmem/src/reverse_offload/context_proxy.hpp +++ b/projects/rocshmem/src/reverse_offload/context_proxy.hpp @@ -42,10 +42,11 @@ class DefaultContextProxy { /* * Placement new the memory which is allocated by proxy_ */ - explicit DefaultContextProxy(ROBackend* backend) : constructed_{true} { + explicit DefaultContextProxy(ROBackend* backend, TeamInfo *tinfo) + : constructed_{true} { auto ctx{proxy_.get()}; new (ctx) ROContext(reinterpret_cast(backend), -1); - roc_shmem_ctx_t local{ctx, nullptr}; + roc_shmem_ctx_t local{ctx, tinfo}; set_internal_ctx(&local); } diff --git a/projects/rocshmem/src/roc_shmem_gpu.cpp b/projects/rocshmem/src/roc_shmem_gpu.cpp index 086a6deed1..1afd5b10ff 100644 --- a/projects/rocshmem/src/roc_shmem_gpu.cpp +++ b/projects/rocshmem/src/roc_shmem_gpu.cpp @@ -262,9 +262,9 @@ __device__ int roc_shmem_wg_ctx_create(long option, roc_shmem_ctx_t *ctx) { GPU_DPRINTF("Function: roc_shmem_ctx_create\n"); bool result{true}; if (get_flat_block_id() == 0) { + ctx->team_opaque = reinterpret_cast(ROC_SHMEM_CTX_DEFAULT.team_opaque); device_backend_proxy->create_ctx(option, ctx); reinterpret_cast(ctx->ctx_opaque)->setFence(option); - ctx->team_opaque = nullptr; } __syncthreads(); return result == true ? 0 : -1; @@ -279,11 +279,11 @@ __device__ int roc_shmem_wg_team_create_ctx(roc_shmem_team_t team, long options, bool result{true}; if (get_flat_block_id() == 0) { - result = device_backend_proxy->create_ctx(options, ctx); - reinterpret_cast(ctx->ctx_opaque)->setFence(options); Team *team_obj{get_internal_team(team)}; TeamInfo *info_wrt_world = team_obj->tinfo_wrt_world; ctx->team_opaque = info_wrt_world; + result = device_backend_proxy->create_ctx(options, ctx); + reinterpret_cast(ctx->ctx_opaque)->setFence(options); } __syncthreads(); @@ -412,7 +412,9 @@ __device__ void roc_shmem_ctx_fence(roc_shmem_ctx_t ctx) { __device__ void roc_shmem_ctx_fence(roc_shmem_ctx_t ctx, int pe) { GPU_DPRINTF("Function: roc_shmem_ctx_fence\n"); - get_internal_ctx(ctx)->fence(pe); + int pe_in_world = translate_pe(ctx, pe); + + get_internal_ctx(ctx)->fence(pe_in_world); } __device__ void roc_shmem_ctx_quiet(roc_shmem_ctx_t ctx) {