Add host API for enqueuing barrier on given stream (#274)

* add host API for enqueuing barrier on given stream

[ROCm/rocshmem commit: a44b581997]
이 커밋은 다음에 포함됨:
Dimple Prajapati
2025-10-15 14:29:07 -07:00
커밋한 사람 GitHub
부모 fc73e4f858
커밋 6c4325d131
10개의 변경된 파일53개의 추가작업 그리고 0개의 파일을 삭제
+7
파일 보기
@@ -341,6 +341,13 @@ __host__ void rocshmem_quiet();
*/
__host__ void rocshmem_barrier_all();
/**
* @brief enqueues a collective barrier on given stream.
*
* @return void
*/
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream);
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
+8
파일 보기
@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source,
int nreduce);
/**
* @brief kernel for performing a barrier synchronization.
* Caller enqueues the kernel on given stream
*
* @return void
*/
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel();
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
+2
파일 보기
@@ -393,6 +393,8 @@ class Context {
__host__ void barrier_all();
__host__ void barrier_all_on_stream(hipStream_t stream);
__host__ void sync_all();
template <typename T>
+6
파일 보기
@@ -116,4 +116,10 @@ __host__ void Context::barrier_all() {
HOST_DISPATCH(barrier_all());
}
__host__ void Context::barrier_all_on_stream(hipStream_t stream) {
ctxHostStats.incStat(NUM_HOST_BARRIER_ALL);
HOST_DISPATCH(barrier_all_on_stream(stream));
}
} // namespace rocshmem
+10
파일 보기
@@ -324,6 +324,16 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) {
return;
}
__host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) {
// launch kernel to do barrier with given stream, if non, use default stream
if (stream == nullptr) {
stream = hipStreamDefault;
}
rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>();
}
__host__ void HostInterface::barrier_for_sync() {
if (host_comm_world_ != MPI_COMM_NULL) {
mpilib_ftable_.Barrier(host_comm_world_);
+2
파일 보기
@@ -193,6 +193,8 @@ class HostInterface {
__host__ void quiet(WindowInfo* window_info);
__host__ void barrier_all(WindowInfo* window_info);
__host__ void barrier_all_on_stream(hipStream_t stream);
__host__ void barrier_for_sync();
+4
파일 보기
@@ -101,4 +101,8 @@ __host__ void IPCHostContext::barrier_all() {
host_interface->barrier_all(context_window_info);
}
__host__ void IPCHostContext::barrier_all_on_stream(hipStream_t stream) {
host_interface->barrier_all_on_stream(stream);
}
} // namespace rocshmem
+2
파일 보기
@@ -82,6 +82,8 @@ class IPCHostContext : public Context {
__host__ void barrier_all();
__host__ void barrier_all_on_stream(hipStream_t stream);
__host__ void sync_all();
template <typename T>
+7
파일 보기
@@ -990,6 +990,13 @@ __host__ void rocshmem_barrier_all() {
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all();
}
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) {
DPRINTF("Host function: rocshmem_barrier_all_on_stream\n");
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all_on_stream(stream);
}
__host__ void rocshmem_sync_all() {
DPRINTF("Host function: rocshmem_sync_all\n");
+5
파일 보기
@@ -622,6 +622,11 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) {
return ctx_internal->test(ivars, cmp, val);
}
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(){
rocshmem_barrier_all();
}
__device__ void rocshmem_barrier_all() {
GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));