* Add thread, wavefront, and workgroup-level `barrier` APIs in IPC and RO conduits; remove collectives on default context
 - Implemented `barrier` APIs for thread, wavefront, and workgroup scopes
 - Added support into both IPC and RO conduits
 - Added functional tests to cover all `barrier` APIs
 - Removed collective operations on default context

* Add thread, wavefront, and workgroup-level `sync` APIs in IPC and RO conduits.
  - Implemented `sync` APIs for thread, wavefront, and workgroup scopes
  - Added support into both IPC and RO conduits
  - Added functional tests to cover all `sync` APIs

* update naming convention for context-based `barrier` APIs
Этот коммит содержится в:
Avinash Kethineedi
2025-04-08 11:25:31 -05:00
коммит произвёл GitHub
родитель c652f58cef
Коммит dc61bca066
16 изменённых файлов: 347 добавлений и 67 удалений
+68 -14
Просмотреть файл
@@ -503,8 +503,6 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
@@ -518,8 +516,6 @@ __device__ ATTR_NO_INLINE void rocshmem_barrier_all();
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_barrier_all(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
@@ -533,17 +529,47 @@ __device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all();
__device__ ATTR_NO_INLINE void rocshmem_ctx_wg_barrier_all(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the team.
* The caller is blocked until the barrier is resolved.
*
* This function must be invoked by a single thread within the PE.
*
* @param[in] handle GPU side handle.
*
* @param[in] team The team on which to perform barrier synchronization
*
* @return void
*/
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief perform a collective barrier between all PEs in the team.
* The caller is blocked until the barrier is resolved.
*
* @param[in] team The team on which to perform barrier synchronization
* This function must be called as a wave-front collective.
*
* @param[in] handle GPU side handle.
*
* @param[in] team The team on which to perform barrier synchronization
*
* @return void
*/
__device__ void rocshmem_barrier(rocshmem_team_t);
__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief perform a collective barrier between all PEs in the team.
* The caller is blocked until the barrier is resolved.
*
* This function must be called as a work-group collective.
*
* @param[in] handle GPU side handle.
*
* @param[in] team The team on which to perform barrier synchronization
*
* @return void
*/
__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief registers the arrival of a PE at a barrier.
@@ -561,8 +587,6 @@ __device__ void rocshmem_barrier(rocshmem_team_t);
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_sync_all();
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
@@ -579,8 +603,6 @@ __device__ ATTR_NO_INLINE void rocshmem_sync_all();
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_wave_sync_all();
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
@@ -597,7 +619,41 @@ __device__ ATTR_NO_INLINE void rocshmem_wave_sync_all();
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_wg_sync_all();
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
*
* In contrast with the shmem_barrier_all routine, shmem_team_sync only ensures
* completion and visibility of previously issued memory stores and does not
* ensure completion of remote memory updates issued via OpenSHMEM routines.
*
* This function must be invoked by a single thread within the PE.
*
* @param[in] handle GPU side handle.
* @param[in] team Handle of the team being synchronized
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_team_sync(
rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
*
* In contrast with the shmem_barrier_all routine, shmem_team_sync only ensures
* completion and visibility of previously issued memory stores and does not
* ensure completion of remote memory updates issued via OpenSHMEM routines.
*
* This function must be called as a wave-front collective.
*
* @param[in] handle GPU side handle.
* @param[in] team Handle of the team being synchronized
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_team_sync(
rocshmem_ctx_t ctx, rocshmem_team_t team);
/**
* @brief registers the arrival of a PE at a barrier.
@@ -617,8 +673,6 @@ __device__ ATTR_NO_INLINE void rocshmem_wg_sync_all();
__device__ ATTR_NO_INLINE void rocshmem_ctx_wg_team_sync(
rocshmem_ctx_t ctx, rocshmem_team_t team);
__device__ ATTR_NO_INLINE void rocshmem_wg_team_sync(rocshmem_team_t team);
/**
* @brief Query a local pointer to a symmetric data object on the
* specified \pe . Returns an address that may be used to directly reference
+31 -1
Просмотреть файл
@@ -80,7 +80,7 @@ declare -A TEST_NUMBERS=(
["signalfetch"]="55"
["wgsignalfetch"]="56"
["wavesignalfetch"]="57"
["teambarrier"]="58"
["teamwgbarrier"]="58"
["defaultctxget"]="59"
["defaultctxgetnbi"]="60"
["defaultctxput"]="61"
@@ -91,6 +91,10 @@ declare -A TEST_NUMBERS=(
["wgbarrierall"]="66"
["wavesyncall"]="67"
["wgsyncall"]="68"
["teambarrier"]="69"
["teamwavebarrier"]="70"
["wavesync"]="71"
["wgsync"]="72"
)
ExecTest() {
@@ -328,8 +332,34 @@ TestColl() {
ExecTest "wgbarrierall" 2 64 1024
ExecTest "teambarrier" 2 1 1
ExecTest "teambarrier" 2 16 64
ExecTest "teambarrier" 2 32 256
ExecTest "teambarrier" 2 39 1024
ExecTest "teamwavebarrier" 2 1 1
ExecTest "teamwavebarrier" 2 16 64
ExecTest "teamwavebarrier" 2 32 256
ExecTest "teamwavebarrier" 2 39 1024
ExecTest "teamwgbarrier" 2 1 1
ExecTest "teamwgbarrier" 2 16 64
ExecTest "teamwgbarrier" 2 32 256
ExecTest "teamwgbarrier" 2 39 1024
ExecTest "sync" 2 1 1
ExecTest "sync" 2 16 64
ExecTest "sync" 2 32 256
ExecTest "sync" 2 39 1024
ExecTest "wavesync" 2 1 1
ExecTest "wavesync" 2 16 64
ExecTest "wavesync" 2 32 256
ExecTest "wavesync" 2 39 1024
ExecTest "wgsync" 2 1 1
ExecTest "wgsync" 2 16 64
ExecTest "wgsync" 2 32 256
ExecTest "wgsync" 2 39 1024
ExecTest "syncall" 2 1 1
ExecTest "syncall" 2 16 64
+6
Просмотреть файл
@@ -132,6 +132,9 @@ void Backend::dump_stats() {
printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL));
printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE));
printf("WG_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WG));
printf("Barrier %llu\n", device_stats.getStat(NUM_BARRIER));
printf("WAVE_Barrier %llu\n", device_stats.getStat(NUM_BARRIER_WAVE));
printf("WG_Barrier %llu\n", device_stats.getStat(NUM_BARRIER_WG));
printf("Wait Until %llu\n", device_stats.getStat(NUM_WAIT_UNTIL));
printf("Wait Until Any %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ANY));
printf("Wait Until All %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ALL));
@@ -157,6 +160,9 @@ void Backend::dump_stats() {
printf("SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL));
printf("WAVE_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WAVE));
printf("WG_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WG));
printf("Sync %llu\n", device_stats.getStat(NUM_SYNC));
printf("WAVE_Sync %llu\n", device_stats.getStat(NUM_SYNC_WAVE));
printf("WG_Sync %llu\n", device_stats.getStat(NUM_SYNC_WG));
const auto& host_stats{globalHostStats};
printf("HOST STATS\n");
+8
Просмотреть файл
@@ -143,12 +143,20 @@ class Context {
__device__ void barrier(rocshmem_team_t team);
__device__ void barrier_wave(rocshmem_team_t team);
__device__ void barrier_wg(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync_all_wave();
__device__ void sync_all_wg();
__device__ void sync(rocshmem_team_t team);
__device__ void sync_wave(rocshmem_team_t team);
__device__ void sync_wg(rocshmem_team_t team);
template <typename T>
+26 -2
Просмотреть файл
@@ -161,11 +161,23 @@ __device__ void Context::barrier_all_wg() {
}
__device__ void Context::barrier(rocshmem_team_t team) {
ctxStats.incStat(NUM_BARRIER_ALL);
ctxStats.incStat(NUM_BARRIER);
DISPATCH(barrier(team));
}
__device__ void Context::barrier_wave(rocshmem_team_t team) {
ctxStats.incStat(NUM_BARRIER_WAVE);
DISPATCH(barrier_wave(team));
}
__device__ void Context::barrier_wg(rocshmem_team_t team) {
ctxStats.incStat(NUM_BARRIER_WG);
DISPATCH(barrier_wg(team));
}
__device__ void Context::sync_all() {
ctxStats.incStat(NUM_SYNC_ALL);
@@ -184,8 +196,20 @@ __device__ void Context::sync_all_wg() {
DISPATCH(sync_all_wg());
}
__device__ void Context::sync(rocshmem_team_t team) {
ctxStats.incStat(NUM_SYNC);
DISPATCH(sync(team));
}
__device__ void Context::sync_wave(rocshmem_team_t team) {
ctxStats.incStat(NUM_SYNC_WAVE);
DISPATCH(sync_wave(team));
}
__device__ void Context::sync_wg(rocshmem_team_t team) {
ctxStats.incStat(NUM_SYNC_ALL_WG);
ctxStats.incStat(NUM_SYNC_WG);
DISPATCH(sync_wg(team));
}
+8
Просмотреть файл
@@ -67,12 +67,20 @@ class IPCContext : public Context {
__device__ void barrier(rocshmem_team_t team);
__device__ void barrier_wave(rocshmem_team_t team);
__device__ void barrier_wg(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync_all_wave();
__device__ void sync_all_wg();
__device__ void sync(rocshmem_team_t team);
__device__ void sync_wave(rocshmem_team_t team);
__device__ void sync_wg(rocshmem_team_t team);
template <typename T>
+52
Просмотреть файл
@@ -118,6 +118,30 @@ __device__ void IPCContext::internal_sync_wg(int pe, int PE_start, int stride,
__syncthreads();
}
__device__ void IPCContext::sync(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
int pe = team_obj->my_pe_in_world;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_stride = team_obj->tinfo_wrt_world->stride;
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
}
__device__ void IPCContext::sync_wave(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
int pe = team_obj->my_pe_in_world;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_stride = team_obj->tinfo_wrt_world->stride;
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync);
}
__device__ void IPCContext::sync_wg(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
@@ -171,6 +195,34 @@ __device__ void IPCContext::barrier(rocshmem_team_t team) {
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
quiet();
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
}
__device__ void IPCContext::barrier_wave(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
int pe = team_obj->my_pe_in_world;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_stride = team_obj->tinfo_wrt_world->stride;
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
if (is_thread_zero_in_wave()) {
quiet();
}
internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync);
}
__device__ void IPCContext::barrier_wg(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
int pe = team_obj->my_pe_in_world;
int pe_start = team_obj->tinfo_wrt_world->pe_start;
int pe_stride = team_obj->tinfo_wrt_world->stride;
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
if (is_thread_zero_in_block()) {
quiet();
}
+32
Просмотреть файл
@@ -193,6 +193,22 @@ __device__ void ROContext::barrier_all_wg() {
}
__device__ void ROContext::barrier(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
true, get_status_flag(), is_default_ctx);
}
__device__ void ROContext::barrier_wave(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
if (is_thread_zero_in_wave()) {
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
true, get_status_flag(), is_default_ctx);
}
}
__device__ void ROContext::barrier_wg(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
@@ -225,6 +241,22 @@ __device__ void ROContext::sync_all_wg() {
__syncthreads();
}
__device__ void ROContext::sync(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
true, get_status_flag(), is_default_ctx);
}
__device__ void ROContext::sync_wave(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
if (is_thread_zero_in_wave()) {
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
true, get_status_flag(), is_default_ctx);
}
}
__device__ void ROContext::sync_wg(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
if (is_thread_zero_in_block()) {
+8
Просмотреть файл
@@ -71,12 +71,20 @@ class ROContext : public Context {
__device__ void barrier(rocshmem_team_t team);
__device__ void barrier_wave(rocshmem_team_t team);
__device__ void barrier_wg(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync_all_wave();
__device__ void sync_all_wg();
__device__ void sync(rocshmem_team_t team);
__device__ void sync_wave(rocshmem_team_t team);
__device__ void sync_wg(rocshmem_team_t team);
template <typename T>
+27 -21
Просмотреть файл
@@ -588,14 +588,22 @@ __device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
get_internal_ctx(ctx)->barrier_all_wg();
}
__device__ void rocshmem_wg_barrier_all() {
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_barrier(rocshmem_team_t team) {
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_barrier\n");
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
get_internal_ctx(ctx)->barrier(team);
}
__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_wave_barrier\n");
get_internal_ctx(ctx)->barrier_wave(team);
}
__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_wg_barrier\n");
get_internal_ctx(ctx)->barrier_wg(team);
}
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
@@ -604,39 +612,37 @@ __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
get_internal_ctx(ctx)->sync_all();
}
__device__ void rocshmem_sync_all() {
rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
get_internal_ctx(ctx)->sync_all_wave();
}
__device__ void rocshmem_wave_sync_all() {
rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
get_internal_ctx(ctx)->sync_all_wg();
}
__device__ void rocshmem_wg_sync_all() {
rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
__device__ void rocshmem_ctx_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) {
rocshmem_ctx_wg_team_sync(ROCSHMEM_CTX_DEFAULT, team);
__device__ void rocshmem_ctx_wave_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ int rocshmem_ctx_n_pes(rocshmem_ctx_t ctx) {
+6
Просмотреть файл
@@ -45,6 +45,9 @@ enum rocshmem_stats {
NUM_BARRIER_ALL,
NUM_BARRIER_ALL_WAVE,
NUM_BARRIER_ALL_WG,
NUM_BARRIER,
NUM_BARRIER_WAVE,
NUM_BARRIER_WG,
NUM_WAIT_UNTIL,
NUM_WAIT_UNTIL_ANY,
NUM_WAIT_UNTIL_ALL,
@@ -74,6 +77,9 @@ enum rocshmem_stats {
NUM_SYNC_ALL,
NUM_SYNC_ALL_WAVE,
NUM_SYNC_ALL_WG,
NUM_SYNC,
NUM_SYNC_WAVE,
NUM_SYNC_WG,
NUM_BROADCAST,
NUM_PUT_WG,
NUM_PUT_NBI_WG,
+18 -11
Просмотреть файл
@@ -27,9 +27,12 @@
*****************************************************************************/
__global__ void SyncTest(int loop, int skip, long long int *start_time,
long long int *end_time, TestType type,
ShmemContextType ctx_type, rocshmem_team_t *teams) {
ShmemContextType ctx_type, int wf_size,
rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int t_id = get_flat_block_id();
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
@@ -39,16 +42,25 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
start_time[wg_id] = wall_clock64();
}
__syncthreads();
switch (type) {
case SyncTestType:
if(t_id == 0) {
rocshmem_ctx_team_sync(ctx, teams[wg_id]);
}
break;
case WAVESyncTestType:
if(wf_id == 0) {
rocshmem_ctx_wave_team_sync(ctx, teams[wg_id]);
}
break;
case WGSyncTestType:
rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]);
break;
default:
break;
}
__syncthreads();
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
@@ -100,15 +112,10 @@ void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop,
args.skip, start_time, end_time, _type, _shmem_context,
team_sync_world_dup);
wf_size, team_sync_world_dup);
num_msgs = loop + args.skip;
num_timed_msgs = loop;
if(_type == SyncTestType) {
num_msgs *= gridSize.x;
num_timed_msgs *= gridSize.x;
}
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
void SyncTester::postLaunchKernel() {
+26 -13
Просмотреть файл
@@ -28,28 +28,41 @@ rocshmem_team_t team_barrier_world_dup;
*****************************************************************************/
__global__ void TeamBarrierTest(int loop, int skip, long long int *start_time,
long long int *end_time,
ShmemContextType ctx_type,
rocshmem_team_t *teams) {
ShmemContextType ctx_type, TestType type,
int wf_size, rocshmem_team_t *teams) {
__shared__ rocshmem_ctx_t ctx;
int t_id = get_flat_block_id();
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);
__syncthreads();
for (int i = 0; i < loop + skip; i++) {
if (i == skip && hipThreadIdx_x == 0) {
start_time[wg_id] = wall_clock64();
}
rocshmem_barrier(teams[wg_id]);
switch (type) {
case TeamBarrierTestType:
if(t_id == 0) {
rocshmem_ctx_barrier(ctx, teams[wg_id]);
}
break;
case TeamWAVEBarrierTestType:
if(wf_id == 0) {
rocshmem_ctx_wave_barrier(ctx, teams[wg_id]);
}
break;
case TeamWGBarrierTestType:
rocshmem_ctx_wg_barrier(ctx, teams[wg_id]);
break;
default:
break;
}
__syncthreads();
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
@@ -95,10 +108,10 @@ void TeamBarrierTester::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize,
shared_bytes, stream, loop, args.skip,
start_time, end_time, _shmem_context,
team_barrier_world_dup);
hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, start_time, end_time,
_shmem_context, _type, wf_size,
team_barrier_world_dup);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
+20 -2
Просмотреть файл
@@ -330,6 +330,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl;
testers.push_back(new TeamBarrierTester(args));
return testers;
case TeamWAVEBarrierTestType:
if (rank == 0) std::cout << "Team WAVE Barrier Test ###" << std::endl;
testers.push_back(new TeamBarrierTester(args));
return testers;
case TeamWGBarrierTestType:
if (rank == 0) std::cout << "Team WG Barrier Test ###" << std::endl;
testers.push_back(new TeamBarrierTester(args));
return testers;
case SyncAllTestType:
if (rank == 0) std::cout << "SyncAll ###" << std::endl;
testers.push_back(new SyncTester(args));
@@ -346,6 +354,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) std::cout << "Sync ###" << std::endl;
testers.push_back(new SyncTester(args));
return testers;
case WAVESyncTestType:
if (rank == 0) std::cout << "WAVE Sync ###" << std::endl;
testers.push_back(new SyncTester(args));
return testers;
case WGSyncTestType:
if (rank == 0) std::cout << "WG Sync ###" << std::endl;
testers.push_back(new SyncTester(args));
return testers;
case RandomAccessTestType:
if (rank == 0) std::cout << "Random_Access ###" << std::endl;
testers.push_back(new RandomAccessTester(args));
@@ -528,10 +544,12 @@ bool Tester::peLaunchesKernel() {
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
(_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) ||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
(_type == SyncTestType) || (_type == WAVESyncTestType) ||
(_type == WGSyncTestType) || (_type == SyncAllTestType) ||
(_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) ||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
(_type == TeamBarrierTestType);
(_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) ||
(_type == TeamWGBarrierTestType);
return is_launcher;
}
+5 -1
Просмотреть файл
@@ -92,7 +92,7 @@ enum TestType {
SignalFetchTestType = 55,
WGSignalFetchTestType = 56,
WAVESignalFetchTestType = 57,
TeamBarrierTestType = 58,
TeamWGBarrierTestType = 58,
DefaultCTXGetTestType = 59,
DefaultCTXGetNBITestType = 60,
DefaultCTXPutTestType = 61,
@@ -103,6 +103,10 @@ enum TestType {
WGBarrierAllTestType = 66,
WAVESyncAllTestType = 67,
WGSyncAllTestType = 68,
TeamBarrierTestType = 69,
TeamWAVEBarrierTestType = 70,
WAVESyncTestType = 71,
WGSyncTestType = 72,
};
enum OpType { PutType = 0, GetType = 1 };
+6 -2
Просмотреть файл
@@ -88,6 +88,8 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case WAVEBarrierAllTestType:
case WGBarrierAllTestType:
case TeamBarrierTestType:
case TeamWAVEBarrierTestType:
case TeamWGBarrierTestType:
case SyncAllTestType:
case WAVESyncAllTestType:
case WGSyncAllTestType:
@@ -140,10 +142,12 @@ void TesterArguments::get_rocshmem_arguments() {
if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) &&
(type != WGBarrierAllTestType) && (type != SyncAllTestType) &&
(type != WAVESyncAllTestType) && (type != WGSyncAllTestType) &&
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
(type != SyncTestType) && (type != WAVESyncTestType) &&
(type != WGSyncTestType) && (type != TeamAllToAllTestType) &&
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
(type != TeamBarrierTestType)) {
(type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) &&
(type != TeamWGBarrierTestType)) {
if (numprocs != 2) {
if (myid == 0) {
std::cerr << "This test requires exactly two processes, we have "