Refactor Barrier_all and Sync_all APIs to use default context (#159)

* Refactor `Barrier_all` and `Sync_all` to use default context

- Removed context-specific implementations of barrier_all and sync_all
- Added barrier_all and sync_all to the default context implementation
- Updated functional tests to use the default context for barrier_all and sync_all

* Update `Barrier_all` and `Sync_all` API usage in documentation

* Update `CHANGELOG`

---------

Co-authored-by: Yiltan <ytemucin@amd.com>

[ROCm/rocshmem commit: bf48bcabf2]
Этот коммит содержится в:
Avinash Kethineedi
2025-06-17 11:16:18 -05:00
коммит произвёл GitHub
родитель 56a3181a6f
Коммит 14756a73b1
10 изменённых файлов: 115 добавлений и 130 удалений
+6 -6
Просмотреть файл
@@ -9,15 +9,15 @@
* `rocshmem_ctx_barrier`
* `rocshmem_ctx_barrier_wave`
* `rocshmem_ctx_barrier_wg`
* `rocshmem_ctx_barrier_all`
* `rocshmem_ctx_barrier_all_wave`
* `rocshmem_ctx_barrier_all_wg`
* `rocshmem_barrier_all`
* `rocshmem_barrier_all_wave`
* `rocshmem_barrier_all_wg`
* `rocshmem_ctx_sync`
* `rocshmem_ctx_sync_wave`
* `rocshmem_ctx_sync_wg`
* `rocshmem_ctx_sync_all`
* `rocshmem_ctx_sync_all_wave`
* `rocshmem_ctx_sync_all_wg`
* `rocshmem_sync_all`
* `rocshmem_sync_all_wave`
* `rocshmem_sync_all_wg`
* `rocshmem_init_attr`
* `rocshmem_get_uniqueid`
* `rocshmem_set_attr_uniqueid_args`
+9 -9
Просмотреть файл
@@ -11,16 +11,16 @@ Collective routines
ROCSHMEM_BARRIER_ALL
--------------------
.. cpp:function:: __device__ void rocshmem_ctx_barrier_all(rocshmem_ctx_t ctx)
.. cpp:function:: __device__ void rocshmem_ctx_barrier_all_wave(rocshmem_ctx_t ctx)
.. cpp:function:: __device__ void rocshmem_ctx_barrier_all_wg(rocshmem_ctx_t ctx)
.. cpp:function:: __device__ void rocshmem_barrier_all()
.. cpp:function:: __device__ void rocshmem_barrier_all_wave()
.. cpp:function:: __device__ void rocshmem_barrier_all_wg()
:param ctx: Context with which to perform this operation.
:returns: None.
**Description:**
This routine performs a collective barrier across all PEs in the system.
The caller is blocked until the barrier is resolved and all updates local and remote are completed.
These APIs should be called from only one thread/wavefront/workgroup within the grid to avoid undefined behavior.
ROCSHMEM_BARRIER
----------------
@@ -58,15 +58,15 @@ ensure the completion of remote memory updates issued via OpenSHMEM routines.
ROCSHMEM_SYNC_ALL
-----------------
.. cpp:function:: __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx)
.. cpp:function:: __device__ void rocshmem_ctx_sync_all_wave(rocshmem_ctx_t ctx)
.. cpp:function:: __device__ void rocshmem_ctx_sync_all_wg(rocshmem_ctx_t ctx)
.. cpp:function:: __device__ void rocshmem_sync_all()
.. cpp:function:: __device__ void rocshmem_sync_all_wave()
.. cpp:function:: __device__ void rocshmem_sync_all_wg()
:param ctx: Context with which to perform this operation.
:returns: None.
**Description:**
This routine behaves the same as ``rocshmem_team_sync_wg`` when called on the world team.
These routines behaves the same way as ``rocshmem_team_sync_*`` when called on the world team.
These APIs should be called from only one thread/wavefront/workgroup within the grid to avoid undefined behavior.
ROSHMEM_ALLTOALL
----------------
+6 -21
Просмотреть файл
@@ -605,12 +605,9 @@ __host__ int rocshmem_ctx_double_prod_reduce(
*
* This function must be invoked by a single thread within the PE.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the system.
@@ -618,12 +615,9 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all(
*
* This function must be called as a wave-front collective.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all_wave(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_barrier_all_wave();
/**
* @brief perform a collective barrier between all PEs in the system.
@@ -631,12 +625,9 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all_wave(
*
* This function must be called as a work-group collective.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all_wg(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_barrier_all_wg();
/**
* @brief perform a collective barrier between all PEs in the team.
@@ -690,11 +681,9 @@ __device__ void rocshmem_ctx_barrier_wg(rocshmem_ctx_t ctx, rocshmem_team_t team
*
* This function must be invoked by a single thread within the PE.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_sync_all();
/**
* @brief registers the arrival of a PE at a barrier.
@@ -706,11 +695,9 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx);
*
* This function must be called as a wave-front collective.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all_wave(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_sync_all_wave();
/**
* @brief registers the arrival of a PE at a barrier.
@@ -722,11 +709,9 @@ __device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all_wave(rocshmem_ctx_t ctx);
*
* This function must be called as a work-group collective.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all_wg(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_sync_all_wg();
/**
* @brief registers the arrival of a PE at a barrier.
-18
Просмотреть файл
@@ -350,19 +350,10 @@ TestColl() {
# | Name | Ranks | Workgroups | Threads | Max Message Size #
##############################################################################
ExecTest "barrierall" 2 1 1
ExecTest "barrierall" 2 16 64
ExecTest "barrierall" 2 32 256
ExecTest "barrierall" 2 64 1024
ExecTest "wavebarrierall" 2 1 1
ExecTest "wavebarrierall" 2 16 64
ExecTest "wavebarrierall" 2 32 256
ExecTest "wavebarrierall" 2 64 1024
ExecTest "wgbarrierall" 2 1 1
ExecTest "wgbarrierall" 2 16 64
ExecTest "wgbarrierall" 2 32 256
ExecTest "wgbarrierall" 2 64 1024
ExecTest "teambarrier" 2 1 1
ExecTest "teambarrier" 2 16 64
@@ -395,19 +386,10 @@ TestColl() {
ExecTest "wgsync" 2 39 1024
ExecTest "syncall" 2 1 1
ExecTest "syncall" 2 16 64
ExecTest "syncall" 2 32 256
ExecTest "syncall" 2 64 1024
ExecTest "wavesyncall" 2 1 1
ExecTest "wavesyncall" 2 16 64
ExecTest "wavesyncall" 2 32 256
ExecTest "wavesyncall" 2 64 1024
ExecTest "wgsyncall" 2 1 1
ExecTest "wgsyncall" 2 16 64
ExecTest "wgsyncall" 2 32 256
ExecTest "wgsyncall" 2 64 1024
ExecTest "alltoall" 2 1 1 512
+4 -7
Просмотреть файл
@@ -265,10 +265,9 @@ void IPCBackend::init_wrk_sync_buffer() {
auto max_num_teams{team_tracker.get_max_num_teams()};
/**
* size of barrier sync for all the contexts
* size of barrier sync
*/
Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE *
(maximum_num_contexts_ + 1);
Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE;
/**
* Size of sync arrays for the teams
@@ -367,9 +366,7 @@ void IPCBackend::rocshmem_collective_init() {
* Allocate heap space for barrier_sync
*/
size_t one_sync_size_bytes {sizeof(*barrier_sync)};
size_t total_sync_elems {
ROCSHMEM_BARRIER_SYNC_SIZE * (maximum_num_contexts_ + 1)};
size_t sync_size_bytes {one_sync_size_bytes * total_sync_elems};
size_t sync_size_bytes {one_sync_size_bytes * ROCSHMEM_BARRIER_SYNC_SIZE};
barrier_sync = reinterpret_cast<int64_t*>(temp_Wrk_Sync_buff_ptr_);
temp_Wrk_Sync_buff_ptr_ += sync_size_bytes;
@@ -377,7 +374,7 @@ void IPCBackend::rocshmem_collective_init() {
/*
* Initialize the barrier synchronization array with default values.
*/
for (int i = 0; i < total_sync_elems; i++) {
for (int i = 0; i < num_pes; i++) {
barrier_sync[i] = ROCSHMEM_SYNC_VALUE;
}
+1 -3
Просмотреть файл
@@ -44,9 +44,7 @@ __host__ IPCContext::IPCContext(Backend *b, unsigned int ctx_id)
ipcImpl_.ipc_bases = b->ipcImpl.ipc_bases;
ipcImpl_.shm_size = b->ipcImpl.shm_size;
size_t barrier_sync_offset = ctx_id * ROCSHMEM_BARRIER_SYNC_SIZE;
barrier_sync = backend->barrier_sync + barrier_sync_offset;
barrier_sync = backend->barrier_sync;
fence_pool = backend->fence_pool;
Wrk_Sync_buffer_bases_ = backend->get_wrk_sync_bases();
ctx_id_ = ctx_id;
+24 -24
Просмотреть файл
@@ -594,25 +594,25 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) {
return ctx_internal->test(ivars, cmp, val);
}
__device__ void rocshmem_ctx_barrier_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_barrier_all (ctx=%zd)\n",
ctx.ctx_opaque);
__device__ void rocshmem_barrier_all() {
GPU_DPRINTF("Function: rocshmem_barrier_all (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
get_internal_ctx(ctx)->barrier_all();
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier_all();
}
__device__ void rocshmem_ctx_barrier_all_wave(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_barrier_all_wave (ctx=%zd)\n",
ctx.ctx_opaque);
__device__ void rocshmem_barrier_all_wave() {
GPU_DPRINTF("Function: rocshmem_barrier_all_wave (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
get_internal_ctx(ctx)->barrier_all_wave();
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier_all_wave();
}
__device__ void rocshmem_ctx_barrier_all_wg(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_barrier_all_wg (ctx=%zd)\n",
ctx.ctx_opaque);
__device__ void rocshmem_barrier_all_wg() {
GPU_DPRINTF("Function: rocshmem_barrier_all_wg (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
get_internal_ctx(ctx)->barrier_all_wg();
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier_all_wg();
}
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
@@ -636,25 +636,25 @@ __device__ void rocshmem_ctx_barrier_wg(rocshmem_ctx_t ctx, rocshmem_team_t team
get_internal_ctx(ctx)->barrier_wg(team);
}
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all (ctx=%zd)\n",
ctx.ctx_opaque);
__device__ void rocshmem_sync_all() {
GPU_DPRINTF("Function: rocshmem_sync_all (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
get_internal_ctx(ctx)->sync_all();
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->sync_all();
}
__device__ void rocshmem_ctx_sync_all_wave(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all_wave (ctx=%zd)\n",
ctx.ctx_opaque);
__device__ void rocshmem_sync_all_wave() {
GPU_DPRINTF("Function: rocshmem_sync_all_wave (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
get_internal_ctx(ctx)->sync_all_wave();
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->sync_all_wave();
}
__device__ void rocshmem_ctx_sync_all_wg(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all_wg (ctx=%zd)\n",
ctx.ctx_opaque);
__device__ void rocshmem_sync_all_wg() {
GPU_DPRINTF("Function: rocshmem_sync_all_wg (ctx=%zd)\n",
get_internal_ctx(ROCSHMEM_CTX_DEFAULT));
get_internal_ctx(ctx)->sync_all_wg();
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->sync_all_wg();
}
__device__ void rocshmem_ctx_sync(rocshmem_ctx_t ctx,
+32 -20
Просмотреть файл
@@ -40,29 +40,42 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx);
for (int i = 0; i < loop + skip; i++) {
if (hipThreadIdx_x == 0 && i == skip) {
start_time[wg_id] = wall_clock64();
}
switch (type) {
case BarrierAllTestType:
if(t_id == 0) {
rocshmem_ctx_barrier_all(ctx);
}
break;
case WAVEBarrierAllTestType:
if(wf_id == 0) {
rocshmem_ctx_barrier_all_wave(ctx);
}
break;
case WGBarrierAllTestType:
rocshmem_ctx_barrier_all_wg(ctx);
break;
default:
break;
if (is_block_zero_in_grid()) {
switch (type) {
case BarrierAllTestType:
if(t_id == 0) {
/**
* The function `rocshmem_barrier_all` should be called from only
* one thread within the grid to avoid undefined behavior.
*/
rocshmem_barrier_all();
}
break;
case WAVEBarrierAllTestType:
if(wf_id == 0) {
/**
* The function `rocshmem_barrier_all_wave` should be called from only
* one wavefront within the grid to avoid undefined behavior.
*/
rocshmem_barrier_all_wave();
}
break;
case WGBarrierAllTestType:
/**
* The function `rocshmem_barrier_all_wg` should be called from only
* one workgroup within the grid to avoid undefined behavior.
*/
rocshmem_barrier_all_wg();
break;
default:
break;
}
}
__syncthreads();
}
@@ -71,7 +84,6 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
@@ -89,8 +101,8 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time, _type, wf_size);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
num_msgs = (loop + args.skip);
num_timed_msgs = loop;
}
void BarrierAllTester::resetBuffers(uint64_t size) {}
+33 -21
Просмотреть файл
@@ -40,38 +40,50 @@ __global__ void SyncAllTest(int loop, int skip, long long int *start_time,
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx);
for (int i = 0; i < loop + skip; i++) {
if (hipThreadIdx_x == 0 && i == skip) {
start_time[wg_id] = wall_clock64();
}
switch (type) {
case SyncAllTestType:
if(t_id == 0) {
rocshmem_ctx_sync_all(ctx);
}
break;
case WAVESyncAllTestType:
if(wf_id == 0) {
rocshmem_ctx_sync_all_wave(ctx);
}
break;
case WGSyncAllTestType:
rocshmem_ctx_sync_all_wg(ctx);
break;
default:
break;
if (is_block_zero_in_grid()) {
switch (type) {
case SyncAllTestType:
if(t_id == 0) {
/**
* The function `rocshmem_sync_all` should be called from only
* one thread within the grid to avoid undefined behavior.
*/
rocshmem_sync_all();
}
break;
case WAVESyncAllTestType:
if(wf_id == 0) {
/**
* The function `rocshmem_sync_all_wave` should be called from only
* one thread within the grid to avoid undefined behavior.
*/
rocshmem_sync_all_wave();
}
break;
case WGSyncAllTestType:
/**
* The function `rocshmem_sync_all_wg` should be called from only
* one thread within the grid to avoid undefined behavior.
*/
rocshmem_sync_all_wg();
break;
default:
break;
}
__syncthreads();
}
__syncthreads();
}
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
@@ -89,8 +101,8 @@ void SyncAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
hipLaunchKernelGGL(SyncAllTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time, _type, wf_size);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
num_msgs = (loop + args.skip);
num_timed_msgs = loop;
}
void SyncAllTester::resetBuffers(uint64_t size) {}
-1
Просмотреть файл
@@ -95,7 +95,6 @@ __global__ void TeamReductionTest(int loop, int skip, long long int *start_time,
start_time[wg_id] = wall_clock64();
}
wg_team_reduce<T1, T2>(ctx, team, r_buf, s_buf, size);
rocshmem_ctx_barrier_all_wg(ctx);
}
__syncthreads();