Update Barrier_All and Sync_All APIs (#72)

* Fix deadlock in `rocshmem_ctx_wg_barrier_all` API in IPC conduit by adding per-context pSync buffers and context IDs
  - Added separate pSync buffers for each device context
  - Resolved deadlock when invoking barrier API (`rocshmem_ctx_wg_barrier_all`) concurrently from multiple contexts

* Update barrier_all functional tests for multi-context support

* Add thread, wavefront, and workgroup-level barrier_all APIs in IPC and RO conduits
  - Implemented barrier_all APIs at thread, wavefront, and workgroup granularity
  - Added support in both IPC and RO conduits
  - Updated functional tests to cover all `barrier_all` APIs

* Add thread, wavefront, and workgroup-level sync_all APIs in IPC and RO conduits
  - Implemented sync_all APIs for thread, wavefront, and workgroup scopes
  - Added support into both IPC and RO conduits
  - Added functional tests to cover all `sync_all` APIs

[ROCm/rocshmem commit: c652f58cef]
This commit is contained in:
Avinash Kethineedi
2025-04-02 11:58:55 -05:00
gecommit door GitHub
bovenliggende 0cde5f53dc
commit 426bbf525b
22 gewijzigde bestanden met toevoegingen van 508 en 53 verwijderingen
@@ -490,6 +490,36 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
int src_pe,
rocshmem_team_t dest_team);
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
*
* This function must be invoked by a single thread within the PE.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
*
* This function must be called as a wave-front collective.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_barrier_all(
rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all();
/**
* @brief perform a collective barrier between all PEs in the system.
* The caller is blocked until the barrier is resolved.
@@ -515,6 +545,42 @@ __device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all();
*/
__device__ void rocshmem_barrier(rocshmem_team_t);
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
*
* In contrast with the shmem_barrier_all routine, shmem_sync_all only ensures
* completion and visibility of previously issued memory stores and does not
* ensure completion of remote memory updates issued via OpenSHMEM routines.
*
* This function must be invoked by a single thread within the PE.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_sync_all();
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
*
* In contrast with the shmem_barrier_all routine, shmem_sync_all only ensures
* completion and visibility of previously issued memory stores and does not
* ensure completion of remote memory updates issued via OpenSHMEM routines.
*
* This function must be called as a wave-front collective.
*
* @param[in] handle GPU side handle.
*
* @return void
*/
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx);
__device__ ATTR_NO_INLINE void rocshmem_wave_sync_all();
/**
* @brief registers the arrival of a PE at a barrier.
* The caller is blocked until the synchronization is resolved.
@@ -81,6 +81,16 @@ declare -A TEST_NUMBERS=(
["wgsignalfetch"]="56"
["wavesignalfetch"]="57"
["teambarrier"]="58"
["defaultctxget"]="59"
["defaultctxgetnbi"]="60"
["defaultctxput"]="61"
["defaultctxputnbi"]="62"
["defaultctxp"]="63"
["defaultctxg"]="64"
["wavebarrierall"]="65"
["wgbarrierall"]="66"
["wavesyncall"]="67"
["wgsyncall"]="68"
)
ExecTest() {
@@ -303,11 +313,38 @@ TestColl() {
# | Name | Ranks | Workgroups | Threads | Max Message Size #
##############################################################################
ExecTest "barrierall" 2 1 1
ExecTest "barrierall" 2 16 64
ExecTest "barrierall" 2 32 256
ExecTest "barrierall" 2 64 1024
ExecTest "wavebarrierall" 2 1 1
ExecTest "wavebarrierall" 2 16 64
ExecTest "wavebarrierall" 2 32 256
ExecTest "wavebarrierall" 2 64 1024
ExecTest "wgbarrierall" 2 1 1
ExecTest "wgbarrierall" 2 16 64
ExecTest "wgbarrierall" 2 32 256
ExecTest "wgbarrierall" 2 64 1024
ExecTest "teambarrier" 2 1 1
ExecTest "sync" 2 1 1
ExecTest "syncall" 2 1 1
ExecTest "syncall" 2 16 64
ExecTest "syncall" 2 32 256
ExecTest "syncall" 2 64 1024
ExecTest "wavesyncall" 2 1 1
ExecTest "wavesyncall" 2 16 64
ExecTest "wavesyncall" 2 32 256
ExecTest "wavesyncall" 2 64 1024
ExecTest "wgsyncall" 2 1 1
ExecTest "wgsyncall" 2 16 64
ExecTest "wgsyncall" 2 32 256
ExecTest "wgsyncall" 2 64 1024
ExecTest "alltoall" 2 1 1 512
@@ -130,6 +130,8 @@ void Backend::dump_stats() {
printf("Quiets %llu\n", device_stats.getStat(NUM_QUIET));
printf("ToAll %llu\n", device_stats.getStat(NUM_TO_ALL));
printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL));
printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE));
printf("WG_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WG));
printf("Wait Until %llu\n", device_stats.getStat(NUM_WAIT_UNTIL));
printf("Wait Until Any %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ANY));
printf("Wait Until All %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ALL));
@@ -153,6 +155,8 @@ void Backend::dump_stats() {
printf("Tests %llu\n", device_stats.getStat(NUM_TEST));
printf("SHMEM_PTR %llu\n", device_stats.getStat(NUM_SHMEM_PTR));
printf("SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL));
printf("WAVE_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WAVE));
printf("WG_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WG));
const auto& host_stats{globalHostStats};
printf("HOST STATS\n");
+9 -1
Bestand weergeven
@@ -137,11 +137,19 @@ class Context {
__device__ void barrier_all();
__device__ void barrier_all_wave();
__device__ void barrier_all_wg();
__device__ void barrier(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync(rocshmem_team_t team);
__device__ void sync_all_wave();
__device__ void sync_all_wg();
__device__ void sync_wg(rocshmem_team_t team);
template <typename T>
__device__ T amo_fetch(void* dst, T value, T cond, int pe, uint8_t atomic_op);
@@ -148,6 +148,18 @@ __device__ void Context::barrier_all() {
DISPATCH(barrier_all());
}
__device__ void Context::barrier_all_wave() {
ctxStats.incStat(NUM_BARRIER_ALL_WAVE);
DISPATCH(barrier_all_wave());
}
__device__ void Context::barrier_all_wg() {
ctxStats.incStat(NUM_BARRIER_ALL_WG);
DISPATCH(barrier_all_wg());
}
__device__ void Context::barrier(rocshmem_team_t team) {
ctxStats.incStat(NUM_BARRIER_ALL);
@@ -160,10 +172,22 @@ __device__ void Context::sync_all() {
DISPATCH(sync_all());
}
__device__ void Context::sync(rocshmem_team_t team) {
ctxStats.incStat(NUM_SYNC_ALL);
__device__ void Context::sync_all_wave() {
ctxStats.incStat(NUM_SYNC_ALL_WAVE);
DISPATCH(sync(team));
DISPATCH(sync_all_wave());
}
__device__ void Context::sync_all_wg() {
ctxStats.incStat(NUM_SYNC_ALL_WG);
DISPATCH(sync_all_wg());
}
__device__ void Context::sync_wg(rocshmem_team_t team) {
ctxStats.incStat(NUM_SYNC_ALL_WG);
DISPATCH(sync_wg(team));
}
__device__ void Context::putmem_wg(void* dest, const void* source,
@@ -114,8 +114,9 @@ IPCBackend::~IPCBackend() {
void IPCBackend::setup_ctxs() {
CHECK_HIP(hipMalloc(&ctx_array, sizeof(IPCContext) * maximum_num_contexts_));
// 0th context is default context
for (size_t i = 0; i < maximum_num_contexts_; i++) {
new (&ctx_array[i]) IPCContext(this);
new (&ctx_array[i]) IPCContext(this, i + 1);
ctx_free_list.get()->push_back(ctx_array + i);
}
}
@@ -278,9 +279,10 @@ void IPCBackend::init_wrk_sync_buffer() {
auto max_num_teams{team_tracker.get_max_num_teams()};
/**
* size of barrier sync
* size of barrier sync for all the contexts
*/
Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE;
Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE *
(maximum_num_contexts_ + 1);
/**
* Size of sync arrays for the teams
@@ -378,15 +380,18 @@ void IPCBackend::rocshmem_collective_init() {
/*
* Allocate heap space for barrier_sync
*/
size_t one_sync_size_bytes{sizeof(*barrier_sync)};
size_t sync_size_bytes{one_sync_size_bytes * ROCSHMEM_BARRIER_SYNC_SIZE};
size_t one_sync_size_bytes {sizeof(*barrier_sync)};
size_t total_sync_elems {
ROCSHMEM_BARRIER_SYNC_SIZE * (maximum_num_contexts_ + 1)};
size_t sync_size_bytes {one_sync_size_bytes * total_sync_elems};
barrier_sync = reinterpret_cast<int64_t*>(temp_Wrk_Sync_buff_ptr_);
temp_Wrk_Sync_buff_ptr_ += sync_size_bytes;
/*
* Initialize the barrier synchronization array with default values.
*/
for (int i = 0; i < num_pes; i++) {
for (int i = 0; i < total_sync_elems; i++) {
barrier_sync[i] = ROCSHMEM_SYNC_VALUE;
}
@@ -36,15 +36,18 @@
namespace rocshmem {
__host__ IPCContext::IPCContext(Backend *b)
__host__ IPCContext::IPCContext(Backend *b, unsigned int ctx_id)
: Context(b, false) {
IPCBackend *backend{static_cast<IPCBackend *>(b)};
ipcImpl_.ipc_bases = b->ipcImpl.ipc_bases;
ipcImpl_.shm_size = b->ipcImpl.shm_size;
barrier_sync = backend->barrier_sync;
size_t barrier_sync_offset = ctx_id * ROCSHMEM_BARRIER_SYNC_SIZE;
barrier_sync = backend->barrier_sync + barrier_sync_offset;
fence_pool = backend->fence_pool;
Wrk_Sync_buffer_bases_ = backend->get_wrk_sync_bases();
ctx_id_ = ctx_id;
orders_.store = detail::atomic::rocshmem_memory_order::memory_order_seq_cst;
}
@@ -31,9 +31,9 @@ namespace rocshmem {
class IPCContext : public Context {
public:
__host__ IPCContext(Backend *b);
__host__ IPCContext(Backend *b, unsigned int ctx_id);
__device__ IPCContext(Backend *b);
__device__ IPCContext(Backend *b, unsigned int ctx_id);
__device__ void threadfence_system();
@@ -61,11 +61,19 @@ class IPCContext : public Context {
__device__ void barrier_all();
__device__ void barrier_all_wave();
__device__ void barrier_all_wg();
__device__ void barrier(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync(rocshmem_team_t team);
__device__ void sync_all_wave();
__device__ void sync_all_wg();
__device__ void sync_wg(rocshmem_team_t team);
template <typename T>
__device__ void p(T *dest, T value, int pe);
@@ -240,6 +248,12 @@ class IPCContext : public Context {
__device__ void internal_sync(int pe, int PE_start, int stride, int PE_size,
int64_t *pSync);
__device__ void internal_sync_wave(int pe, int PE_start, int stride, int PE_size,
int64_t *pSync);
__device__ void internal_sync_wg(int pe, int PE_start, int stride, int PE_size,
int64_t *pSync);
__device__ void internal_direct_barrier(int pe, int PE_start, int stride,
int n_pes, int64_t *pSync);
@@ -289,6 +303,11 @@ class IPCContext : public Context {
*/
char **Wrk_Sync_buffer_bases_{nullptr};
/**
* @brief Decive context Id
*/
unsigned int ctx_id_{};
public:
//TODO(Avinash):
//Make tinfo private variable, it requires changes to the context
@@ -84,9 +84,29 @@ __device__ void IPCContext::internal_atomic_barrier(int pe, int PE_start,
}
}
// Uses PE values that are relative to world
__device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
int PE_size, int64_t *pSync) {
if (PE_size < 64) {
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
} else {
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
}
}
__device__ void IPCContext::internal_sync_wave(int pe, int PE_start, int stride,
int PE_size, int64_t *pSync) {
if (is_thread_zero_in_wave()) {
if (PE_size < 64) {
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
} else {
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
}
}
}
// Uses PE values that are relative to world
__device__ void IPCContext::internal_sync_wg(int pe, int PE_start, int stride,
int PE_size, int64_t *pSync) {
__syncthreads();
if (is_thread_zero_in_block()) {
if (PE_size < 64) {
@@ -98,7 +118,7 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
__syncthreads();
}
__device__ void IPCContext::sync(rocshmem_team_t team) {
__device__ void IPCContext::sync_wg(rocshmem_team_t team) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
int pe = team_obj->my_pe_in_world;
@@ -107,18 +127,38 @@ __device__ void IPCContext::sync(rocshmem_team_t team) {
int pe_size = team_obj->num_pes;
long *p_sync = team_obj->barrier_pSync;
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync);
}
__device__ void IPCContext::sync_all() {
internal_sync(my_pe, 0, 1, num_pes, barrier_sync);
}
__device__ void IPCContext::sync_all_wave() {
internal_sync_wave(my_pe, 0, 1, num_pes, barrier_sync);
}
__device__ void IPCContext::sync_all_wg() {
internal_sync_wg(my_pe, 0, 1, num_pes, barrier_sync);
}
__device__ void IPCContext::barrier_all() {
quiet();
sync_all();
}
__device__ void IPCContext::barrier_all_wave() {
if (is_thread_zero_in_wave()) {
quiet();
}
sync_all_wave();
}
__device__ void IPCContext::barrier_all_wg() {
if (is_thread_zero_in_block()) {
quiet();
}
sync_all();
sync_all_wg();
__syncthreads();
}
@@ -134,7 +174,7 @@ __device__ void IPCContext::barrier(rocshmem_team_t team) {
if (is_thread_zero_in_block()) {
quiet();
}
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync);
__syncthreads();
}
@@ -468,7 +468,7 @@ __device__ void IPCContext::internal_broadcast(T *dst, const T *src, int nelems,
}
// Synchronize on completion of broadcast
internal_sync(my_pe, pe_start, stride, pe_size, p_sync);
internal_sync_wg(my_pe, pe_start, stride, pe_size, p_sync);
}
template <typename T>
@@ -497,7 +497,7 @@ __device__ void IPCContext::alltoall_linear(rocshmem_team_t team, T *dst,
quiet();
}
// wait until everyone has obtained their designated data
internal_sync(my_pe, pe_start, stride, pe_size, pSync);
internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync);
}
template <typename T>
@@ -527,7 +527,7 @@ __device__ void IPCContext::fcollect_linear(rocshmem_team_t team, T *dst,
quiet();
}
// wait until everyone has obtained their designated data
internal_sync(my_pe, pe_start, stride, pe_size, pSync);
internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync);
}
// Block/wave functions
@@ -45,7 +45,7 @@ class IPCDefaultContextProxy {
size_t num_elems = 1)
: constructed_{true}, proxy_{num_elems} {
auto ctx{proxy_.get()};
new (ctx) IPCContext(reinterpret_cast<Backend*>(backend));
new (ctx) IPCContext(reinterpret_cast<Backend*>(backend), 0);
ctx->tinfo = tinfo;
rocshmem_ctx_t local{ctx, tinfo};
set_internal_ctx(&local);
@@ -170,6 +170,20 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
}
__device__ void ROContext::barrier_all() {
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true, get_status_flag(), is_default_ctx);
}
__device__ void ROContext::barrier_all_wave() {
if (is_thread_zero_in_wave()) {
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true, get_status_flag(), is_default_ctx);
}
}
__device__ void ROContext::barrier_all_wg() {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
@@ -189,6 +203,20 @@ __device__ void ROContext::barrier(rocshmem_team_t team) {
}
__device__ void ROContext::sync_all() {
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true, get_status_flag(), is_default_ctx);
}
__device__ void ROContext::sync_all_wave() {
if (is_thread_zero_in_wave()) {
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
block_handle, true, get_status_flag(), is_default_ctx);
}
}
__device__ void ROContext::sync_all_wg() {
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
@@ -197,7 +225,7 @@ __device__ void ROContext::sync_all() {
__syncthreads();
}
__device__ void ROContext::sync(rocshmem_team_t team) {
__device__ void ROContext::sync_wg(rocshmem_team_t team) {
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
if (is_thread_zero_in_block()) {
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
@@ -65,11 +65,19 @@ class ROContext : public Context {
__device__ void barrier_all();
__device__ void barrier_all_wave();
__device__ void barrier_all_wg();
__device__ void barrier(rocshmem_team_t team);
__device__ void sync_all();
__device__ void sync(rocshmem_team_t team);
__device__ void sync_all_wave();
__device__ void sync_all_wg();
__device__ void sync_wg(rocshmem_team_t team);
template <typename T>
__device__ void p(T *dest, T value, int pe);
+35 -3
Bestand weergeven
@@ -570,12 +570,24 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) {
return ctx_internal->test(ivars, cmp, val);
}
__device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
__device__ void rocshmem_ctx_barrier_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_barrier_all\n");
get_internal_ctx(ctx)->barrier_all();
}
__device__ void rocshmem_ctx_wave_barrier_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wave_barrier_all\n");
get_internal_ctx(ctx)->barrier_all_wave();
}
__device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wg_barrier_all\n");
get_internal_ctx(ctx)->barrier_all_wg();
}
__device__ void rocshmem_wg_barrier_all() {
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
}
@@ -586,12 +598,32 @@ __device__ void rocshmem_barrier(rocshmem_team_t team) {
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
}
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
get_internal_ctx(ctx)->sync_all();
}
__device__ void rocshmem_sync_all() {
rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
get_internal_ctx(ctx)->sync_all_wave();
}
__device__ void rocshmem_wave_sync_all() {
rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
get_internal_ctx(ctx)->sync_all_wg();
}
__device__ void rocshmem_wg_sync_all() {
rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT);
}
@@ -600,7 +632,7 @@ __device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
get_internal_ctx(ctx)->sync(team);
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) {
+4
Bestand weergeven
@@ -43,6 +43,8 @@ enum rocshmem_stats {
NUM_QUIET,
NUM_TO_ALL,
NUM_BARRIER_ALL,
NUM_BARRIER_ALL_WAVE,
NUM_BARRIER_ALL_WG,
NUM_WAIT_UNTIL,
NUM_WAIT_UNTIL_ANY,
NUM_WAIT_UNTIL_ALL,
@@ -70,6 +72,8 @@ enum rocshmem_stats {
NUM_TEST,
NUM_SHMEM_PTR,
NUM_SYNC_ALL,
NUM_SYNC_ALL_WAVE,
NUM_SYNC_ALL_WG,
NUM_BROADCAST,
NUM_PUT_WG,
NUM_PUT_NBI_WG,
@@ -30,9 +30,12 @@ using namespace rocshmem;
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
long long int *end_time) {
long long int *end_time, TestType type,
int wf_size) {
__shared__ rocshmem_ctx_t ctx;
int t_id = get_flat_block_id();
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx);
@@ -42,17 +45,25 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
start_time[wg_id] = wall_clock64();
}
__syncthreads();
/**
* The function `rocshmem_ctx_wg_barrier_all` should be called from only
* one group within the grid to avoid unintended behavior.
*/
if (is_block_zero_in_grid()) {
rocshmem_ctx_wg_barrier_all(ctx);
switch (type) {
case BarrierAllTestType:
if(t_id == 0) {
rocshmem_ctx_barrier_all(ctx);
}
break;
case WAVEBarrierAllTestType:
if(wf_id == 0) {
rocshmem_ctx_wave_barrier_all(ctx);
}
break;
case WGBarrierAllTestType:
rocshmem_ctx_wg_barrier_all(ctx);
break;
default:
break;
}
__syncthreads();
}
__syncthreads();
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
@@ -74,10 +85,10 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
size_t shared_bytes = 0;
hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time);
loop, args.skip, start_time, end_time, _type, wf_size);
num_msgs = loop + args.skip;
num_timed_msgs = loop;
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
void BarrierAllTester::resetBuffers(uint64_t size) {}
@@ -0,0 +1,96 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#include "sync_all_tester.hpp"
#include <rocshmem/rocshmem.hpp>
using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void SyncAllTest(int loop, int skip, long long int *start_time,
long long int *end_time, TestType type,
int wf_size) {
__shared__ rocshmem_ctx_t ctx;
int t_id = get_flat_block_id();
int wg_id = get_flat_grid_id();
int wf_id = t_id / wf_size;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx);
for (int i = 0; i < loop + skip; i++) {
if (hipThreadIdx_x == 0 && i == skip) {
start_time[wg_id] = wall_clock64();
}
switch (type) {
case SyncAllTestType:
if(t_id == 0) {
rocshmem_ctx_sync_all(ctx);
}
break;
case WAVESyncAllTestType:
if(wf_id == 0) {
rocshmem_ctx_wave_sync_all(ctx);
}
break;
case WGSyncAllTestType:
rocshmem_ctx_wg_sync_all(ctx);
break;
default:
break;
}
__syncthreads();
}
if (hipThreadIdx_x == 0) {
end_time[wg_id] = wall_clock64();
}
rocshmem_wg_ctx_destroy(&ctx);
rocshmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
SyncAllTester::SyncAllTester(TesterArguments args) : Tester(args) {}
SyncAllTester::~SyncAllTester() {}
void SyncAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(SyncAllTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, start_time, end_time, _type, wf_size);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
}
void SyncAllTester::resetBuffers(uint64_t size) {}
void SyncAllTester::verifyResults(uint64_t size) {}
@@ -0,0 +1,50 @@
/******************************************************************************
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*****************************************************************************/
#ifndef _BARRIER_ALL_TESTER_HPP_
#define _BARRIER_ALL_TESTER_HPP_
#include "tester.hpp"
/******************************************************************************
* DEVICE TEST KERNEL
*****************************************************************************/
__global__ void SyncAllTest(TestType type);
/******************************************************************************
* HOST TESTER CLASS
*****************************************************************************/
class SyncAllTester : public Tester {
public:
explicit SyncAllTester(TesterArguments args);
virtual ~SyncAllTester();
protected:
virtual void resetBuffers(uint64_t size) override;
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
uint64_t size) override;
virtual void verifyResults(uint64_t size) override;
};
#endif
@@ -41,15 +41,6 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
__syncthreads();
switch (type) {
case SyncAllTestType:
/**
* The function `rocshmem_ctx_wg_sync_all` should be called from only
* one group within the grid to avoid unintended behavior.
*/
if (is_block_zero_in_grid()) {
rocshmem_ctx_wg_sync_all(ctx);
}
break;
case SyncTestType:
rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]);
break;
@@ -43,10 +43,11 @@
#include "random_access_tester.hpp"
#include "shmem_ptr_tester.hpp"
#include "signaling_operations_tester.hpp"
#include "sync_all_tester.hpp"
#include "sync_tester.hpp"
#include "team_alltoall_tester.hpp"
#include "team_broadcast_tester.hpp"
#include "team_barrier_tester.hpp"
#include "team_broadcast_tester.hpp"
#include "team_ctx_infra_tester.hpp"
#include "team_ctx_primitive_tester.hpp"
#include "team_fcollect_tester.hpp"
@@ -317,6 +318,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) std::cout << "Barrier_All ###" << std::endl;
testers.push_back(new BarrierAllTester(args));
return testers;
case WAVEBarrierAllTestType:
if (rank == 0) std::cout << "WAVE Barrier_All ###" << std::endl;
testers.push_back(new BarrierAllTester(args));
return testers;
case WGBarrierAllTestType:
if (rank == 0) std::cout << "WG Barrier_All ###" << std::endl;
testers.push_back(new BarrierAllTester(args));
return testers;
case TeamBarrierTestType:
if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl;
testers.push_back(new TeamBarrierTester(args));
@@ -325,6 +334,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
if (rank == 0) std::cout << "SyncAll ###" << std::endl;
testers.push_back(new SyncTester(args));
return testers;
case WAVESyncAllTestType:
if (rank == 0) std::cout << "WAVE SyncAll ###" << std::endl;
testers.push_back(new SyncTester(args));
return testers;
case WGSyncAllTestType:
if (rank == 0) std::cout << "WG SyncAll ###" << std::endl;
testers.push_back(new SyncTester(args));
return testers;
case SyncTestType:
if (rank == 0) std::cout << "Sync ###" << std::endl;
testers.push_back(new SyncTester(args));
@@ -510,7 +527,9 @@ bool Tester::peLaunchesKernel() {
(_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) ||
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
(_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) ||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
(_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) ||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
(_type == TeamBarrierTestType);
@@ -99,6 +99,10 @@ enum TestType {
DefaultCTXPutNBITestType = 62,
DefaultCTXPTestType = 63,
DefaultCTXGTestType = 64,
WAVEBarrierAllTestType = 65,
WGBarrierAllTestType = 66,
WAVESyncAllTestType = 67,
WGSyncAllTestType = 68,
};
enum OpType { PutType = 0, GetType = 1 };
@@ -85,8 +85,12 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
case AMO_IncTestType:
case AMO_FetchTestType:
case BarrierAllTestType:
case WAVEBarrierAllTestType:
case WGBarrierAllTestType:
case TeamBarrierTestType:
case SyncAllTestType:
case WAVESyncAllTestType:
case WGSyncAllTestType:
case SyncTestType:
case ShmemPtrTestType:
min_msg_size = 8;
@@ -133,7 +137,9 @@ void TesterArguments::get_rocshmem_arguments() {
myid = rocshmem_my_pe();
TestType type = (TestType)algorithm;
if ((type != BarrierAllTestType) && (type != SyncAllTestType) &&
if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) &&
(type != WGBarrierAllTestType) && (type != SyncAllTestType) &&
(type != WAVESyncAllTestType) && (type != WGSyncAllTestType) &&
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&