From a44b581997a326124b141a6cd2491c4eccdebc80 Mon Sep 17 00:00:00 2001 From: Dimple Prajapati Date: Wed, 15 Oct 2025 14:29:07 -0700 Subject: [PATCH] Add host API for enqueuing barrier on given stream (#274) * add host API for enqueuing barrier on given stream --- include/rocshmem/rocshmem.hpp | 7 +++++++ include/rocshmem/rocshmem_COLL.hpp | 8 ++++++++ src/context.hpp | 2 ++ src/context_host.cpp | 6 ++++++ src/host/host.cpp | 10 ++++++++++ src/host/host.hpp | 2 ++ src/ipc/context_ipc_host.cpp | 4 ++++ src/ipc/context_ipc_host.hpp | 2 ++ src/rocshmem.cpp | 7 +++++++ src/rocshmem_gpu.cpp | 5 +++++ 10 files changed, 53 insertions(+) diff --git a/include/rocshmem/rocshmem.hpp b/include/rocshmem/rocshmem.hpp index c42c7fbe60..259b23db97 100644 --- a/include/rocshmem/rocshmem.hpp +++ b/include/rocshmem/rocshmem.hpp @@ -341,6 +341,13 @@ __host__ void rocshmem_quiet(); */ __host__ void rocshmem_barrier_all(); +/** + * @brief enqueues a collective barrier on given stream. + * + * @return void + */ +__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream); + /** * @brief registers the arrival of a PE at a barrier. * The caller is blocked until the synchronization is resolved. diff --git a/include/rocshmem/rocshmem_COLL.hpp b/include/rocshmem/rocshmem_COLL.hpp index 748e540d2f..587da8f5c5 100644 --- a/include/rocshmem/rocshmem_COLL.hpp +++ b/include/rocshmem/rocshmem_COLL.hpp @@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce( rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source, int nreduce); +/** + * @brief kernel for performing a barrier synchronization. + * Caller enqueues the kernel on given stream + * + * @return void + */ +__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(); + /** * @brief perform a collective barrier between all PEs in the system. * The caller is blocked until the barrier is resolved. diff --git a/src/context.hpp b/src/context.hpp index 72163bdbdd..5a1615aacd 100644 --- a/src/context.hpp +++ b/src/context.hpp @@ -393,6 +393,8 @@ class Context { __host__ void barrier_all(); + __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void sync_all(); template diff --git a/src/context_host.cpp b/src/context_host.cpp index 2f4f49455d..0a78f14c1f 100644 --- a/src/context_host.cpp +++ b/src/context_host.cpp @@ -116,4 +116,10 @@ __host__ void Context::barrier_all() { HOST_DISPATCH(barrier_all()); } +__host__ void Context::barrier_all_on_stream(hipStream_t stream) { + ctxHostStats.incStat(NUM_HOST_BARRIER_ALL); + + HOST_DISPATCH(barrier_all_on_stream(stream)); +} + } // namespace rocshmem diff --git a/src/host/host.cpp b/src/host/host.cpp index c7c9f6ff61..fe68cbe499 100644 --- a/src/host/host.cpp +++ b/src/host/host.cpp @@ -324,6 +324,16 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) { return; } +__host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) { + // launch kernel to do barrier with given stream, if non, use default stream + if (stream == nullptr) { + stream = hipStreamDefault; + } + + rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>(); +} + + __host__ void HostInterface::barrier_for_sync() { if (host_comm_world_ != MPI_COMM_NULL) { mpilib_ftable_.Barrier(host_comm_world_); diff --git a/src/host/host.hpp b/src/host/host.hpp index 43f11cc61f..324e513529 100644 --- a/src/host/host.hpp +++ b/src/host/host.hpp @@ -193,6 +193,8 @@ class HostInterface { __host__ void quiet(WindowInfo* window_info); __host__ void barrier_all(WindowInfo* window_info); + + __host__ void barrier_all_on_stream(hipStream_t stream); __host__ void barrier_for_sync(); diff --git a/src/ipc/context_ipc_host.cpp b/src/ipc/context_ipc_host.cpp index 4efd477b3e..dbddb50613 100644 --- a/src/ipc/context_ipc_host.cpp +++ b/src/ipc/context_ipc_host.cpp @@ -101,4 +101,8 @@ __host__ void IPCHostContext::barrier_all() { host_interface->barrier_all(context_window_info); } +__host__ void IPCHostContext::barrier_all_on_stream(hipStream_t stream) { + host_interface->barrier_all_on_stream(stream); +} + } // namespace rocshmem diff --git a/src/ipc/context_ipc_host.hpp b/src/ipc/context_ipc_host.hpp index e14f905035..10f1e6460d 100644 --- a/src/ipc/context_ipc_host.hpp +++ b/src/ipc/context_ipc_host.hpp @@ -82,6 +82,8 @@ class IPCHostContext : public Context { __host__ void barrier_all(); + __host__ void barrier_all_on_stream(hipStream_t stream); + __host__ void sync_all(); template diff --git a/src/rocshmem.cpp b/src/rocshmem.cpp index 7912dc9758..69ad2b6a0c 100644 --- a/src/rocshmem.cpp +++ b/src/rocshmem.cpp @@ -990,6 +990,13 @@ __host__ void rocshmem_barrier_all() { get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all(); } + +__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) { + DPRINTF("Host function: rocshmem_barrier_all_on_stream\n"); + + get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all_on_stream(stream); +} + __host__ void rocshmem_sync_all() { DPRINTF("Host function: rocshmem_sync_all\n"); diff --git a/src/rocshmem_gpu.cpp b/src/rocshmem_gpu.cpp index 7cc9f2c466..bce5c260a3 100644 --- a/src/rocshmem_gpu.cpp +++ b/src/rocshmem_gpu.cpp @@ -622,6 +622,11 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) { return ctx_internal->test(ivars, cmp, val); } +__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(){ + rocshmem_barrier_all(); +} + + __device__ void rocshmem_barrier_all() { GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n", get_internal_ctx(ROCSHMEM_CTX_DEFAULT));