Update Barrier and Sync APIs (#73)
* Add thread, wavefront, and workgroup-level `barrier` APIs in IPC and RO conduits; remove collectives on default context - Implemented `barrier` APIs for thread, wavefront, and workgroup scopes - Added support into both IPC and RO conduits - Added functional tests to cover all `barrier` APIs - Removed collective operations on default context * Add thread, wavefront, and workgroup-level `sync` APIs in IPC and RO conduits. - Implemented `sync` APIs for thread, wavefront, and workgroup scopes - Added support into both IPC and RO conduits - Added functional tests to cover all `sync` APIs * update naming convention for context-based `barrier` APIs
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
c652f58cef
Коммит
dc61bca066
@@ -503,8 +503,6 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all(
|
||||
rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_barrier_all();
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
@@ -518,8 +516,6 @@ __device__ ATTR_NO_INLINE void rocshmem_barrier_all();
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_barrier_all(
|
||||
rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all();
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
@@ -533,17 +529,47 @@ __device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all();
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wg_barrier_all(
|
||||
rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all();
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the team.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
*
|
||||
* This function must be invoked by a single thread within the PE.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @param[in] team The team on which to perform barrier synchronization
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team);
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the team.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
*
|
||||
* @param[in] team The team on which to perform barrier synchronization
|
||||
* This function must be called as a wave-front collective.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @param[in] team The team on which to perform barrier synchronization
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ void rocshmem_barrier(rocshmem_team_t);
|
||||
__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team);
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the team.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
*
|
||||
* This function must be called as a work-group collective.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @param[in] team The team on which to perform barrier synchronization
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
@@ -561,8 +587,6 @@ __device__ void rocshmem_barrier(rocshmem_team_t);
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_sync_all();
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
@@ -579,8 +603,6 @@ __device__ ATTR_NO_INLINE void rocshmem_sync_all();
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wave_sync_all();
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
@@ -597,7 +619,41 @@ __device__ ATTR_NO_INLINE void rocshmem_wave_sync_all();
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wg_sync_all();
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
*
|
||||
* In contrast with the shmem_barrier_all routine, shmem_team_sync only ensures
|
||||
* completion and visibility of previously issued memory stores and does not
|
||||
* ensure completion of remote memory updates issued via OpenSHMEM routines.
|
||||
*
|
||||
* This function must be invoked by a single thread within the PE.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
* @param[in] team Handle of the team being synchronized
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_team_sync(
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
*
|
||||
* In contrast with the shmem_barrier_all routine, shmem_team_sync only ensures
|
||||
* completion and visibility of previously issued memory stores and does not
|
||||
* ensure completion of remote memory updates issued via OpenSHMEM routines.
|
||||
*
|
||||
* This function must be called as a wave-front collective.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
* @param[in] team Handle of the team being synchronized
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_team_sync(
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
@@ -617,8 +673,6 @@ __device__ ATTR_NO_INLINE void rocshmem_wg_sync_all();
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wg_team_sync(
|
||||
rocshmem_ctx_t ctx, rocshmem_team_t team);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wg_team_sync(rocshmem_team_t team);
|
||||
|
||||
/**
|
||||
* @brief Query a local pointer to a symmetric data object on the
|
||||
* specified \pe . Returns an address that may be used to directly reference
|
||||
|
||||
@@ -80,7 +80,7 @@ declare -A TEST_NUMBERS=(
|
||||
["signalfetch"]="55"
|
||||
["wgsignalfetch"]="56"
|
||||
["wavesignalfetch"]="57"
|
||||
["teambarrier"]="58"
|
||||
["teamwgbarrier"]="58"
|
||||
["defaultctxget"]="59"
|
||||
["defaultctxgetnbi"]="60"
|
||||
["defaultctxput"]="61"
|
||||
@@ -91,6 +91,10 @@ declare -A TEST_NUMBERS=(
|
||||
["wgbarrierall"]="66"
|
||||
["wavesyncall"]="67"
|
||||
["wgsyncall"]="68"
|
||||
["teambarrier"]="69"
|
||||
["teamwavebarrier"]="70"
|
||||
["wavesync"]="71"
|
||||
["wgsync"]="72"
|
||||
)
|
||||
|
||||
ExecTest() {
|
||||
@@ -328,8 +332,34 @@ TestColl() {
|
||||
ExecTest "wgbarrierall" 2 64 1024
|
||||
|
||||
ExecTest "teambarrier" 2 1 1
|
||||
ExecTest "teambarrier" 2 16 64
|
||||
ExecTest "teambarrier" 2 32 256
|
||||
ExecTest "teambarrier" 2 39 1024
|
||||
|
||||
ExecTest "teamwavebarrier" 2 1 1
|
||||
ExecTest "teamwavebarrier" 2 16 64
|
||||
ExecTest "teamwavebarrier" 2 32 256
|
||||
ExecTest "teamwavebarrier" 2 39 1024
|
||||
|
||||
ExecTest "teamwgbarrier" 2 1 1
|
||||
ExecTest "teamwgbarrier" 2 16 64
|
||||
ExecTest "teamwgbarrier" 2 32 256
|
||||
ExecTest "teamwgbarrier" 2 39 1024
|
||||
|
||||
ExecTest "sync" 2 1 1
|
||||
ExecTest "sync" 2 16 64
|
||||
ExecTest "sync" 2 32 256
|
||||
ExecTest "sync" 2 39 1024
|
||||
|
||||
ExecTest "wavesync" 2 1 1
|
||||
ExecTest "wavesync" 2 16 64
|
||||
ExecTest "wavesync" 2 32 256
|
||||
ExecTest "wavesync" 2 39 1024
|
||||
|
||||
ExecTest "wgsync" 2 1 1
|
||||
ExecTest "wgsync" 2 16 64
|
||||
ExecTest "wgsync" 2 32 256
|
||||
ExecTest "wgsync" 2 39 1024
|
||||
|
||||
ExecTest "syncall" 2 1 1
|
||||
ExecTest "syncall" 2 16 64
|
||||
|
||||
@@ -132,6 +132,9 @@ void Backend::dump_stats() {
|
||||
printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL));
|
||||
printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE));
|
||||
printf("WG_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WG));
|
||||
printf("Barrier %llu\n", device_stats.getStat(NUM_BARRIER));
|
||||
printf("WAVE_Barrier %llu\n", device_stats.getStat(NUM_BARRIER_WAVE));
|
||||
printf("WG_Barrier %llu\n", device_stats.getStat(NUM_BARRIER_WG));
|
||||
printf("Wait Until %llu\n", device_stats.getStat(NUM_WAIT_UNTIL));
|
||||
printf("Wait Until Any %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ANY));
|
||||
printf("Wait Until All %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ALL));
|
||||
@@ -157,6 +160,9 @@ void Backend::dump_stats() {
|
||||
printf("SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL));
|
||||
printf("WAVE_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WAVE));
|
||||
printf("WG_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WG));
|
||||
printf("Sync %llu\n", device_stats.getStat(NUM_SYNC));
|
||||
printf("WAVE_Sync %llu\n", device_stats.getStat(NUM_SYNC_WAVE));
|
||||
printf("WG_Sync %llu\n", device_stats.getStat(NUM_SYNC_WG));
|
||||
|
||||
const auto& host_stats{globalHostStats};
|
||||
printf("HOST STATS\n");
|
||||
|
||||
@@ -143,12 +143,20 @@ class Context {
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void barrier_wave(rocshmem_team_t team);
|
||||
|
||||
__device__ void barrier_wg(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync_all_wave();
|
||||
|
||||
__device__ void sync_all_wg();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_wave(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_wg(rocshmem_team_t team);
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -161,11 +161,23 @@ __device__ void Context::barrier_all_wg() {
|
||||
}
|
||||
|
||||
__device__ void Context::barrier(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_BARRIER_ALL);
|
||||
ctxStats.incStat(NUM_BARRIER);
|
||||
|
||||
DISPATCH(barrier(team));
|
||||
}
|
||||
|
||||
__device__ void Context::barrier_wave(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_BARRIER_WAVE);
|
||||
|
||||
DISPATCH(barrier_wave(team));
|
||||
}
|
||||
|
||||
__device__ void Context::barrier_wg(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_BARRIER_WG);
|
||||
|
||||
DISPATCH(barrier_wg(team));
|
||||
}
|
||||
|
||||
__device__ void Context::sync_all() {
|
||||
ctxStats.incStat(NUM_SYNC_ALL);
|
||||
|
||||
@@ -184,8 +196,20 @@ __device__ void Context::sync_all_wg() {
|
||||
DISPATCH(sync_all_wg());
|
||||
}
|
||||
|
||||
__device__ void Context::sync(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_SYNC);
|
||||
|
||||
DISPATCH(sync(team));
|
||||
}
|
||||
|
||||
__device__ void Context::sync_wave(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_SYNC_WAVE);
|
||||
|
||||
DISPATCH(sync_wave(team));
|
||||
}
|
||||
|
||||
__device__ void Context::sync_wg(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_SYNC_ALL_WG);
|
||||
ctxStats.incStat(NUM_SYNC_WG);
|
||||
|
||||
DISPATCH(sync_wg(team));
|
||||
}
|
||||
|
||||
@@ -67,12 +67,20 @@ class IPCContext : public Context {
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void barrier_wave(rocshmem_team_t team);
|
||||
|
||||
__device__ void barrier_wg(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync_all_wave();
|
||||
|
||||
__device__ void sync_all_wg();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_wave(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_wg(rocshmem_team_t team);
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -118,6 +118,30 @@ __device__ void IPCContext::internal_sync_wg(int pe, int PE_start, int stride,
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync_wave(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync_wg(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
@@ -171,6 +195,34 @@ __device__ void IPCContext::barrier(rocshmem_team_t team) {
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
quiet();
|
||||
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::barrier_wave(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
if (is_thread_zero_in_wave()) {
|
||||
quiet();
|
||||
}
|
||||
internal_sync_wave(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::barrier_wg(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
int pe_start = team_obj->tinfo_wrt_world->pe_start;
|
||||
int pe_stride = team_obj->tinfo_wrt_world->stride;
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
quiet();
|
||||
}
|
||||
|
||||
@@ -193,6 +193,22 @@ __device__ void ROContext::barrier_all_wg() {
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
|
||||
true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier_wave(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
if (is_thread_zero_in_wave()) {
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
|
||||
true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier_wg(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
@@ -225,6 +241,22 @@ __device__ void ROContext::sync_all_wg() {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
|
||||
true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync_wave(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
if (is_thread_zero_in_wave()) {
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
|
||||
true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync_wg(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
if (is_thread_zero_in_block()) {
|
||||
|
||||
@@ -71,12 +71,20 @@ class ROContext : public Context {
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void barrier_wave(rocshmem_team_t team);
|
||||
|
||||
__device__ void barrier_wg(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync_all_wave();
|
||||
|
||||
__device__ void sync_all_wg();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_wave(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_wg(rocshmem_team_t team);
|
||||
|
||||
template <typename T>
|
||||
|
||||
+27
-21
@@ -588,14 +588,22 @@ __device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
|
||||
get_internal_ctx(ctx)->barrier_all_wg();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_barrier_all() {
|
||||
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_barrier(rocshmem_team_t team) {
|
||||
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_barrier\n");
|
||||
|
||||
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
|
||||
get_internal_ctx(ctx)->barrier(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_wave_barrier\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_wave(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_wg_barrier\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
|
||||
@@ -604,39 +612,37 @@ __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
|
||||
get_internal_ctx(ctx)->sync_all();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_sync_all() {
|
||||
rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all_wave();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wave_sync_all() {
|
||||
rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all_wg();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_sync_all() {
|
||||
rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
|
||||
__device__ void rocshmem_ctx_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) {
|
||||
rocshmem_ctx_wg_team_sync(ROCSHMEM_CTX_DEFAULT, team);
|
||||
__device__ void rocshmem_ctx_wave_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ int rocshmem_ctx_n_pes(rocshmem_ctx_t ctx) {
|
||||
|
||||
@@ -45,6 +45,9 @@ enum rocshmem_stats {
|
||||
NUM_BARRIER_ALL,
|
||||
NUM_BARRIER_ALL_WAVE,
|
||||
NUM_BARRIER_ALL_WG,
|
||||
NUM_BARRIER,
|
||||
NUM_BARRIER_WAVE,
|
||||
NUM_BARRIER_WG,
|
||||
NUM_WAIT_UNTIL,
|
||||
NUM_WAIT_UNTIL_ANY,
|
||||
NUM_WAIT_UNTIL_ALL,
|
||||
@@ -74,6 +77,9 @@ enum rocshmem_stats {
|
||||
NUM_SYNC_ALL,
|
||||
NUM_SYNC_ALL_WAVE,
|
||||
NUM_SYNC_ALL_WG,
|
||||
NUM_SYNC,
|
||||
NUM_SYNC_WAVE,
|
||||
NUM_SYNC_WG,
|
||||
NUM_BROADCAST,
|
||||
NUM_PUT_WG,
|
||||
NUM_PUT_NBI_WG,
|
||||
|
||||
@@ -27,9 +27,12 @@
|
||||
*****************************************************************************/
|
||||
__global__ void SyncTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, TestType type,
|
||||
ShmemContextType ctx_type, rocshmem_team_t *teams) {
|
||||
ShmemContextType ctx_type, int wf_size,
|
||||
rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int t_id = get_flat_block_id();
|
||||
int wg_id = get_flat_grid_id();
|
||||
int wf_id = t_id / wf_size;
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
@@ -39,16 +42,25 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
switch (type) {
|
||||
case SyncTestType:
|
||||
if(t_id == 0) {
|
||||
rocshmem_ctx_team_sync(ctx, teams[wg_id]);
|
||||
}
|
||||
break;
|
||||
case WAVESyncTestType:
|
||||
if(wf_id == 0) {
|
||||
rocshmem_ctx_wave_team_sync(ctx, teams[wg_id]);
|
||||
}
|
||||
break;
|
||||
case WGSyncTestType:
|
||||
rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
@@ -100,15 +112,10 @@ void SyncTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
|
||||
hipLaunchKernelGGL(SyncTest, gridSize, blockSize, shared_bytes, stream, loop,
|
||||
args.skip, start_time, end_time, _type, _shmem_context,
|
||||
team_sync_world_dup);
|
||||
wf_size, team_sync_world_dup);
|
||||
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
|
||||
if(_type == SyncTestType) {
|
||||
num_msgs *= gridSize.x;
|
||||
num_timed_msgs *= gridSize.x;
|
||||
}
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void SyncTester::postLaunchKernel() {
|
||||
|
||||
@@ -28,28 +28,41 @@ rocshmem_team_t team_barrier_world_dup;
|
||||
*****************************************************************************/
|
||||
__global__ void TeamBarrierTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time,
|
||||
ShmemContextType ctx_type,
|
||||
rocshmem_team_t *teams) {
|
||||
ShmemContextType ctx_type, TestType type,
|
||||
int wf_size, rocshmem_team_t *teams) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int t_id = get_flat_block_id();
|
||||
int wg_id = get_flat_grid_id();
|
||||
int wf_id = t_id / wf_size;
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_team_create_ctx(teams[wg_id], ctx_type, &ctx);
|
||||
|
||||
int n_pes = rocshmem_ctx_n_pes(ctx);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip && hipThreadIdx_x == 0) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_barrier(teams[wg_id]);
|
||||
switch (type) {
|
||||
case TeamBarrierTestType:
|
||||
if(t_id == 0) {
|
||||
rocshmem_ctx_barrier(ctx, teams[wg_id]);
|
||||
}
|
||||
break;
|
||||
case TeamWAVEBarrierTestType:
|
||||
if(wf_id == 0) {
|
||||
rocshmem_ctx_wave_barrier(ctx, teams[wg_id]);
|
||||
}
|
||||
break;
|
||||
case TeamWGBarrierTestType:
|
||||
rocshmem_ctx_wg_barrier(ctx, teams[wg_id]);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
@@ -95,10 +108,10 @@ void TeamBarrierTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize,
|
||||
shared_bytes, stream, loop, args.skip,
|
||||
start_time, end_time, _shmem_context,
|
||||
team_barrier_world_dup);
|
||||
hipLaunchKernelGGL(TeamBarrierTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, start_time, end_time,
|
||||
_shmem_context, _type, wf_size,
|
||||
team_barrier_world_dup);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
|
||||
@@ -330,6 +330,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl;
|
||||
testers.push_back(new TeamBarrierTester(args));
|
||||
return testers;
|
||||
case TeamWAVEBarrierTestType:
|
||||
if (rank == 0) std::cout << "Team WAVE Barrier Test ###" << std::endl;
|
||||
testers.push_back(new TeamBarrierTester(args));
|
||||
return testers;
|
||||
case TeamWGBarrierTestType:
|
||||
if (rank == 0) std::cout << "Team WG Barrier Test ###" << std::endl;
|
||||
testers.push_back(new TeamBarrierTester(args));
|
||||
return testers;
|
||||
case SyncAllTestType:
|
||||
if (rank == 0) std::cout << "SyncAll ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
@@ -346,6 +354,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) std::cout << "Sync ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
return testers;
|
||||
case WAVESyncTestType:
|
||||
if (rank == 0) std::cout << "WAVE Sync ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
return testers;
|
||||
case WGSyncTestType:
|
||||
if (rank == 0) std::cout << "WG Sync ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
return testers;
|
||||
case RandomAccessTestType:
|
||||
if (rank == 0) std::cout << "Random_Access ###" << std::endl;
|
||||
testers.push_back(new RandomAccessTester(args));
|
||||
@@ -528,10 +544,12 @@ bool Tester::peLaunchesKernel() {
|
||||
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
|
||||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
|
||||
(_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) ||
|
||||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
|
||||
(_type == SyncTestType) || (_type == WAVESyncTestType) ||
|
||||
(_type == WGSyncTestType) || (_type == SyncAllTestType) ||
|
||||
(_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) ||
|
||||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
|
||||
(_type == TeamBarrierTestType);
|
||||
(_type == TeamBarrierTestType) || (_type == TeamWAVEBarrierTestType) ||
|
||||
(_type == TeamWGBarrierTestType);
|
||||
|
||||
return is_launcher;
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ enum TestType {
|
||||
SignalFetchTestType = 55,
|
||||
WGSignalFetchTestType = 56,
|
||||
WAVESignalFetchTestType = 57,
|
||||
TeamBarrierTestType = 58,
|
||||
TeamWGBarrierTestType = 58,
|
||||
DefaultCTXGetTestType = 59,
|
||||
DefaultCTXGetNBITestType = 60,
|
||||
DefaultCTXPutTestType = 61,
|
||||
@@ -103,6 +103,10 @@ enum TestType {
|
||||
WGBarrierAllTestType = 66,
|
||||
WAVESyncAllTestType = 67,
|
||||
WGSyncAllTestType = 68,
|
||||
TeamBarrierTestType = 69,
|
||||
TeamWAVEBarrierTestType = 70,
|
||||
WAVESyncTestType = 71,
|
||||
WGSyncTestType = 72,
|
||||
};
|
||||
|
||||
enum OpType { PutType = 0, GetType = 1 };
|
||||
|
||||
@@ -88,6 +88,8 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case WAVEBarrierAllTestType:
|
||||
case WGBarrierAllTestType:
|
||||
case TeamBarrierTestType:
|
||||
case TeamWAVEBarrierTestType:
|
||||
case TeamWGBarrierTestType:
|
||||
case SyncAllTestType:
|
||||
case WAVESyncAllTestType:
|
||||
case WGSyncAllTestType:
|
||||
@@ -140,10 +142,12 @@ void TesterArguments::get_rocshmem_arguments() {
|
||||
if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) &&
|
||||
(type != WGBarrierAllTestType) && (type != SyncAllTestType) &&
|
||||
(type != WAVESyncAllTestType) && (type != WGSyncAllTestType) &&
|
||||
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
|
||||
(type != SyncTestType) && (type != WAVESyncTestType) &&
|
||||
(type != WGSyncTestType) && (type != TeamAllToAllTestType) &&
|
||||
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
|
||||
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
|
||||
(type != TeamBarrierTestType)) {
|
||||
(type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) &&
|
||||
(type != TeamWGBarrierTestType)) {
|
||||
if (numprocs != 2) {
|
||||
if (myid == 0) {
|
||||
std::cerr << "This test requires exactly two processes, we have "
|
||||
|
||||
Ссылка в новой задаче
Block a user