Add host API for enqueuing barrier on given stream (#274)
* add host API for enqueuing barrier on given stream
[ROCm/rocshmem commit: a44b581997]
이 커밋은 다음에 포함됨:
@@ -341,6 +341,13 @@ __host__ void rocshmem_quiet();
|
||||
*/
|
||||
__host__ void rocshmem_barrier_all();
|
||||
|
||||
/**
|
||||
* @brief enqueues a collective barrier on given stream.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
|
||||
@@ -599,6 +599,14 @@ __host__ int rocshmem_ctx_double_prod_reduce(
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team, double *dest, const double *source,
|
||||
int nreduce);
|
||||
|
||||
/**
|
||||
* @brief kernel for performing a barrier synchronization.
|
||||
* Caller enqueues the kernel on given stream
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel();
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
|
||||
@@ -393,6 +393,8 @@ class Context {
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -116,4 +116,10 @@ __host__ void Context::barrier_all() {
|
||||
HOST_DISPATCH(barrier_all());
|
||||
}
|
||||
|
||||
__host__ void Context::barrier_all_on_stream(hipStream_t stream) {
|
||||
ctxHostStats.incStat(NUM_HOST_BARRIER_ALL);
|
||||
|
||||
HOST_DISPATCH(barrier_all_on_stream(stream));
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -324,6 +324,16 @@ __host__ void HostInterface::barrier_all(WindowInfo* window_info) {
|
||||
return;
|
||||
}
|
||||
|
||||
__host__ void HostInterface::barrier_all_on_stream(hipStream_t stream) {
|
||||
// launch kernel to do barrier with given stream, if non, use default stream
|
||||
if (stream == nullptr) {
|
||||
stream = hipStreamDefault;
|
||||
}
|
||||
|
||||
rocshmem_barrier_all_kernel<<<1, 1, 0, stream>>>();
|
||||
}
|
||||
|
||||
|
||||
__host__ void HostInterface::barrier_for_sync() {
|
||||
if (host_comm_world_ != MPI_COMM_NULL) {
|
||||
mpilib_ftable_.Barrier(host_comm_world_);
|
||||
|
||||
@@ -193,6 +193,8 @@ class HostInterface {
|
||||
__host__ void quiet(WindowInfo* window_info);
|
||||
|
||||
__host__ void barrier_all(WindowInfo* window_info);
|
||||
|
||||
__host__ void barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
__host__ void barrier_for_sync();
|
||||
|
||||
|
||||
@@ -101,4 +101,8 @@ __host__ void IPCHostContext::barrier_all() {
|
||||
host_interface->barrier_all(context_window_info);
|
||||
}
|
||||
|
||||
__host__ void IPCHostContext::barrier_all_on_stream(hipStream_t stream) {
|
||||
host_interface->barrier_all_on_stream(stream);
|
||||
}
|
||||
|
||||
} // namespace rocshmem
|
||||
|
||||
@@ -82,6 +82,8 @@ class IPCHostContext : public Context {
|
||||
|
||||
__host__ void barrier_all();
|
||||
|
||||
__host__ void barrier_all_on_stream(hipStream_t stream);
|
||||
|
||||
__host__ void sync_all();
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -990,6 +990,13 @@ __host__ void rocshmem_barrier_all() {
|
||||
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all();
|
||||
}
|
||||
|
||||
|
||||
__host__ void rocshmem_barrier_all_on_stream(hipStream_t stream) {
|
||||
DPRINTF("Host function: rocshmem_barrier_all_on_stream\n");
|
||||
|
||||
get_internal_ctx(ROCSHMEM_HOST_CTX_DEFAULT)->barrier_all_on_stream(stream);
|
||||
}
|
||||
|
||||
__host__ void rocshmem_sync_all() {
|
||||
DPRINTF("Host function: rocshmem_sync_all\n");
|
||||
|
||||
|
||||
@@ -622,6 +622,11 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) {
|
||||
return ctx_internal->test(ivars, cmp, val);
|
||||
}
|
||||
|
||||
__global__ ATTR_NO_INLINE void rocshmem_barrier_all_kernel(){
|
||||
rocshmem_barrier_all();
|
||||
}
|
||||
|
||||
|
||||
__device__ void rocshmem_barrier_all() {
|
||||
GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n",
|
||||
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
|
||||
|
||||
새 이슈에서 참조
사용자 차단