Update Barrier_All and Sync_All APIs (#72)
* Fix deadlock in `rocshmem_ctx_wg_barrier_all` API in IPC conduit by adding per-context pSync buffers and context IDs
- Added separate pSync buffers for each device context
- Resolved deadlock when invoking barrier API (`rocshmem_ctx_wg_barrier_all`) concurrently from multiple contexts
* Update barrier_all functional tests for multi-context support
* Add thread, wavefront, and workgroup-level barrier_all APIs in IPC and RO conduits
- Implemented barrier_all APIs at thread, wavefront, and workgroup granularity
- Added support in both IPC and RO conduits
- Updated functional tests to cover all `barrier_all` APIs
* Add thread, wavefront, and workgroup-level sync_all APIs in IPC and RO conduits
- Implemented sync_all APIs for thread, wavefront, and workgroup scopes
- Added support into both IPC and RO conduits
- Added functional tests to cover all `sync_all` APIs
[ROCm/rocshmem commit: c652f58cef]
This commit is contained in:
gecommit door
GitHub
bovenliggende
0cde5f53dc
commit
426bbf525b
@@ -490,6 +490,36 @@ __device__ int rocshmem_team_translate_pe(rocshmem_team_t src_team,
|
||||
int src_pe,
|
||||
rocshmem_team_t dest_team);
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
*
|
||||
* This function must be invoked by a single thread within the PE.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_barrier_all(
|
||||
rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_barrier_all();
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
*
|
||||
* This function must be called as a wave-front collective.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_barrier_all(
|
||||
rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wave_barrier_all();
|
||||
|
||||
/**
|
||||
* @brief perform a collective barrier between all PEs in the system.
|
||||
* The caller is blocked until the barrier is resolved.
|
||||
@@ -515,6 +545,42 @@ __device__ ATTR_NO_INLINE void rocshmem_wg_barrier_all();
|
||||
*/
|
||||
__device__ void rocshmem_barrier(rocshmem_team_t);
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
*
|
||||
* In contrast with the shmem_barrier_all routine, shmem_sync_all only ensures
|
||||
* completion and visibility of previously issued memory stores and does not
|
||||
* ensure completion of remote memory updates issued via OpenSHMEM routines.
|
||||
*
|
||||
* This function must be invoked by a single thread within the PE.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_sync_all();
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
*
|
||||
* In contrast with the shmem_barrier_all routine, shmem_sync_all only ensures
|
||||
* completion and visibility of previously issued memory stores and does not
|
||||
* ensure completion of remote memory updates issued via OpenSHMEM routines.
|
||||
*
|
||||
* This function must be called as a wave-front collective.
|
||||
*
|
||||
* @param[in] handle GPU side handle.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
__device__ ATTR_NO_INLINE void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx);
|
||||
|
||||
__device__ ATTR_NO_INLINE void rocshmem_wave_sync_all();
|
||||
|
||||
/**
|
||||
* @brief registers the arrival of a PE at a barrier.
|
||||
* The caller is blocked until the synchronization is resolved.
|
||||
|
||||
@@ -81,6 +81,16 @@ declare -A TEST_NUMBERS=(
|
||||
["wgsignalfetch"]="56"
|
||||
["wavesignalfetch"]="57"
|
||||
["teambarrier"]="58"
|
||||
["defaultctxget"]="59"
|
||||
["defaultctxgetnbi"]="60"
|
||||
["defaultctxput"]="61"
|
||||
["defaultctxputnbi"]="62"
|
||||
["defaultctxp"]="63"
|
||||
["defaultctxg"]="64"
|
||||
["wavebarrierall"]="65"
|
||||
["wgbarrierall"]="66"
|
||||
["wavesyncall"]="67"
|
||||
["wgsyncall"]="68"
|
||||
)
|
||||
|
||||
ExecTest() {
|
||||
@@ -303,11 +313,38 @@ TestColl() {
|
||||
# | Name | Ranks | Workgroups | Threads | Max Message Size #
|
||||
##############################################################################
|
||||
ExecTest "barrierall" 2 1 1
|
||||
ExecTest "barrierall" 2 16 64
|
||||
ExecTest "barrierall" 2 32 256
|
||||
ExecTest "barrierall" 2 64 1024
|
||||
|
||||
ExecTest "wavebarrierall" 2 1 1
|
||||
ExecTest "wavebarrierall" 2 16 64
|
||||
ExecTest "wavebarrierall" 2 32 256
|
||||
ExecTest "wavebarrierall" 2 64 1024
|
||||
|
||||
ExecTest "wgbarrierall" 2 1 1
|
||||
ExecTest "wgbarrierall" 2 16 64
|
||||
ExecTest "wgbarrierall" 2 32 256
|
||||
ExecTest "wgbarrierall" 2 64 1024
|
||||
|
||||
ExecTest "teambarrier" 2 1 1
|
||||
|
||||
ExecTest "sync" 2 1 1
|
||||
|
||||
ExecTest "syncall" 2 1 1
|
||||
ExecTest "syncall" 2 16 64
|
||||
ExecTest "syncall" 2 32 256
|
||||
ExecTest "syncall" 2 64 1024
|
||||
|
||||
ExecTest "wavesyncall" 2 1 1
|
||||
ExecTest "wavesyncall" 2 16 64
|
||||
ExecTest "wavesyncall" 2 32 256
|
||||
ExecTest "wavesyncall" 2 64 1024
|
||||
|
||||
ExecTest "wgsyncall" 2 1 1
|
||||
ExecTest "wgsyncall" 2 16 64
|
||||
ExecTest "wgsyncall" 2 32 256
|
||||
ExecTest "wgsyncall" 2 64 1024
|
||||
|
||||
ExecTest "alltoall" 2 1 1 512
|
||||
|
||||
|
||||
@@ -130,6 +130,8 @@ void Backend::dump_stats() {
|
||||
printf("Quiets %llu\n", device_stats.getStat(NUM_QUIET));
|
||||
printf("ToAll %llu\n", device_stats.getStat(NUM_TO_ALL));
|
||||
printf("BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL));
|
||||
printf("WAVE_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WAVE));
|
||||
printf("WG_BarrierAll %llu\n", device_stats.getStat(NUM_BARRIER_ALL_WG));
|
||||
printf("Wait Until %llu\n", device_stats.getStat(NUM_WAIT_UNTIL));
|
||||
printf("Wait Until Any %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ANY));
|
||||
printf("Wait Until All %llu\n", device_stats.getStat(NUM_WAIT_UNTIL_ALL));
|
||||
@@ -153,6 +155,8 @@ void Backend::dump_stats() {
|
||||
printf("Tests %llu\n", device_stats.getStat(NUM_TEST));
|
||||
printf("SHMEM_PTR %llu\n", device_stats.getStat(NUM_SHMEM_PTR));
|
||||
printf("SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL));
|
||||
printf("WAVE_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WAVE));
|
||||
printf("WG_SyncAll %llu\n", device_stats.getStat(NUM_SYNC_ALL_WG));
|
||||
|
||||
const auto& host_stats{globalHostStats};
|
||||
printf("HOST STATS\n");
|
||||
|
||||
@@ -137,11 +137,19 @@ class Context {
|
||||
|
||||
__device__ void barrier_all();
|
||||
|
||||
__device__ void barrier_all_wave();
|
||||
|
||||
__device__ void barrier_all_wg();
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
__device__ void sync_all_wave();
|
||||
|
||||
__device__ void sync_all_wg();
|
||||
|
||||
__device__ void sync_wg(rocshmem_team_t team);
|
||||
|
||||
template <typename T>
|
||||
__device__ T amo_fetch(void* dst, T value, T cond, int pe, uint8_t atomic_op);
|
||||
|
||||
@@ -148,6 +148,18 @@ __device__ void Context::barrier_all() {
|
||||
DISPATCH(barrier_all());
|
||||
}
|
||||
|
||||
__device__ void Context::barrier_all_wave() {
|
||||
ctxStats.incStat(NUM_BARRIER_ALL_WAVE);
|
||||
|
||||
DISPATCH(barrier_all_wave());
|
||||
}
|
||||
|
||||
__device__ void Context::barrier_all_wg() {
|
||||
ctxStats.incStat(NUM_BARRIER_ALL_WG);
|
||||
|
||||
DISPATCH(barrier_all_wg());
|
||||
}
|
||||
|
||||
__device__ void Context::barrier(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_BARRIER_ALL);
|
||||
|
||||
@@ -160,10 +172,22 @@ __device__ void Context::sync_all() {
|
||||
DISPATCH(sync_all());
|
||||
}
|
||||
|
||||
__device__ void Context::sync(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_SYNC_ALL);
|
||||
__device__ void Context::sync_all_wave() {
|
||||
ctxStats.incStat(NUM_SYNC_ALL_WAVE);
|
||||
|
||||
DISPATCH(sync(team));
|
||||
DISPATCH(sync_all_wave());
|
||||
}
|
||||
|
||||
__device__ void Context::sync_all_wg() {
|
||||
ctxStats.incStat(NUM_SYNC_ALL_WG);
|
||||
|
||||
DISPATCH(sync_all_wg());
|
||||
}
|
||||
|
||||
__device__ void Context::sync_wg(rocshmem_team_t team) {
|
||||
ctxStats.incStat(NUM_SYNC_ALL_WG);
|
||||
|
||||
DISPATCH(sync_wg(team));
|
||||
}
|
||||
|
||||
__device__ void Context::putmem_wg(void* dest, const void* source,
|
||||
|
||||
@@ -114,8 +114,9 @@ IPCBackend::~IPCBackend() {
|
||||
|
||||
void IPCBackend::setup_ctxs() {
|
||||
CHECK_HIP(hipMalloc(&ctx_array, sizeof(IPCContext) * maximum_num_contexts_));
|
||||
// 0th context is default context
|
||||
for (size_t i = 0; i < maximum_num_contexts_; i++) {
|
||||
new (&ctx_array[i]) IPCContext(this);
|
||||
new (&ctx_array[i]) IPCContext(this, i + 1);
|
||||
ctx_free_list.get()->push_back(ctx_array + i);
|
||||
}
|
||||
}
|
||||
@@ -278,9 +279,10 @@ void IPCBackend::init_wrk_sync_buffer() {
|
||||
auto max_num_teams{team_tracker.get_max_num_teams()};
|
||||
|
||||
/**
|
||||
* size of barrier sync
|
||||
* size of barrier sync for all the contexts
|
||||
*/
|
||||
Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE;
|
||||
Wrk_Sync_buffer_size_ += sizeof(*barrier_sync) * ROCSHMEM_BARRIER_SYNC_SIZE *
|
||||
(maximum_num_contexts_ + 1);
|
||||
|
||||
/**
|
||||
* Size of sync arrays for the teams
|
||||
@@ -378,15 +380,18 @@ void IPCBackend::rocshmem_collective_init() {
|
||||
/*
|
||||
* Allocate heap space for barrier_sync
|
||||
*/
|
||||
size_t one_sync_size_bytes{sizeof(*barrier_sync)};
|
||||
size_t sync_size_bytes{one_sync_size_bytes * ROCSHMEM_BARRIER_SYNC_SIZE};
|
||||
size_t one_sync_size_bytes {sizeof(*barrier_sync)};
|
||||
size_t total_sync_elems {
|
||||
ROCSHMEM_BARRIER_SYNC_SIZE * (maximum_num_contexts_ + 1)};
|
||||
size_t sync_size_bytes {one_sync_size_bytes * total_sync_elems};
|
||||
|
||||
barrier_sync = reinterpret_cast<int64_t*>(temp_Wrk_Sync_buff_ptr_);
|
||||
temp_Wrk_Sync_buff_ptr_ += sync_size_bytes;
|
||||
|
||||
/*
|
||||
* Initialize the barrier synchronization array with default values.
|
||||
*/
|
||||
for (int i = 0; i < num_pes; i++) {
|
||||
for (int i = 0; i < total_sync_elems; i++) {
|
||||
barrier_sync[i] = ROCSHMEM_SYNC_VALUE;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,15 +36,18 @@
|
||||
|
||||
namespace rocshmem {
|
||||
|
||||
__host__ IPCContext::IPCContext(Backend *b)
|
||||
__host__ IPCContext::IPCContext(Backend *b, unsigned int ctx_id)
|
||||
: Context(b, false) {
|
||||
IPCBackend *backend{static_cast<IPCBackend *>(b)};
|
||||
ipcImpl_.ipc_bases = b->ipcImpl.ipc_bases;
|
||||
ipcImpl_.shm_size = b->ipcImpl.shm_size;
|
||||
|
||||
barrier_sync = backend->barrier_sync;
|
||||
size_t barrier_sync_offset = ctx_id * ROCSHMEM_BARRIER_SYNC_SIZE;
|
||||
|
||||
barrier_sync = backend->barrier_sync + barrier_sync_offset;
|
||||
fence_pool = backend->fence_pool;
|
||||
Wrk_Sync_buffer_bases_ = backend->get_wrk_sync_bases();
|
||||
ctx_id_ = ctx_id;
|
||||
|
||||
orders_.store = detail::atomic::rocshmem_memory_order::memory_order_seq_cst;
|
||||
}
|
||||
|
||||
@@ -31,9 +31,9 @@ namespace rocshmem {
|
||||
|
||||
class IPCContext : public Context {
|
||||
public:
|
||||
__host__ IPCContext(Backend *b);
|
||||
__host__ IPCContext(Backend *b, unsigned int ctx_id);
|
||||
|
||||
__device__ IPCContext(Backend *b);
|
||||
__device__ IPCContext(Backend *b, unsigned int ctx_id);
|
||||
|
||||
__device__ void threadfence_system();
|
||||
|
||||
@@ -61,11 +61,19 @@ class IPCContext : public Context {
|
||||
|
||||
__device__ void barrier_all();
|
||||
|
||||
__device__ void barrier_all_wave();
|
||||
|
||||
__device__ void barrier_all_wg();
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
__device__ void sync_all_wave();
|
||||
|
||||
__device__ void sync_all_wg();
|
||||
|
||||
__device__ void sync_wg(rocshmem_team_t team);
|
||||
|
||||
template <typename T>
|
||||
__device__ void p(T *dest, T value, int pe);
|
||||
@@ -240,6 +248,12 @@ class IPCContext : public Context {
|
||||
__device__ void internal_sync(int pe, int PE_start, int stride, int PE_size,
|
||||
int64_t *pSync);
|
||||
|
||||
__device__ void internal_sync_wave(int pe, int PE_start, int stride, int PE_size,
|
||||
int64_t *pSync);
|
||||
|
||||
__device__ void internal_sync_wg(int pe, int PE_start, int stride, int PE_size,
|
||||
int64_t *pSync);
|
||||
|
||||
__device__ void internal_direct_barrier(int pe, int PE_start, int stride,
|
||||
int n_pes, int64_t *pSync);
|
||||
|
||||
@@ -289,6 +303,11 @@ class IPCContext : public Context {
|
||||
*/
|
||||
char **Wrk_Sync_buffer_bases_{nullptr};
|
||||
|
||||
/**
|
||||
* @brief Decive context Id
|
||||
*/
|
||||
unsigned int ctx_id_{};
|
||||
|
||||
public:
|
||||
//TODO(Avinash):
|
||||
//Make tinfo private variable, it requires changes to the context
|
||||
|
||||
@@ -84,9 +84,29 @@ __device__ void IPCContext::internal_atomic_barrier(int pe, int PE_start,
|
||||
}
|
||||
}
|
||||
|
||||
// Uses PE values that are relative to world
|
||||
__device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
|
||||
int PE_size, int64_t *pSync) {
|
||||
if (PE_size < 64) {
|
||||
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
|
||||
} else {
|
||||
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void IPCContext::internal_sync_wave(int pe, int PE_start, int stride,
|
||||
int PE_size, int64_t *pSync) {
|
||||
if (is_thread_zero_in_wave()) {
|
||||
if (PE_size < 64) {
|
||||
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
|
||||
} else {
|
||||
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Uses PE values that are relative to world
|
||||
__device__ void IPCContext::internal_sync_wg(int pe, int PE_start, int stride,
|
||||
int PE_size, int64_t *pSync) {
|
||||
__syncthreads();
|
||||
if (is_thread_zero_in_block()) {
|
||||
if (PE_size < 64) {
|
||||
@@ -98,7 +118,7 @@ __device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync(rocshmem_team_t team) {
|
||||
__device__ void IPCContext::sync_wg(rocshmem_team_t team) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
int pe = team_obj->my_pe_in_world;
|
||||
@@ -107,18 +127,38 @@ __device__ void IPCContext::sync(rocshmem_team_t team) {
|
||||
int pe_size = team_obj->num_pes;
|
||||
long *p_sync = team_obj->barrier_pSync;
|
||||
|
||||
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync_all() {
|
||||
internal_sync(my_pe, 0, 1, num_pes, barrier_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync_all_wave() {
|
||||
internal_sync_wave(my_pe, 0, 1, num_pes, barrier_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::sync_all_wg() {
|
||||
internal_sync_wg(my_pe, 0, 1, num_pes, barrier_sync);
|
||||
}
|
||||
|
||||
__device__ void IPCContext::barrier_all() {
|
||||
quiet();
|
||||
sync_all();
|
||||
}
|
||||
|
||||
__device__ void IPCContext::barrier_all_wave() {
|
||||
if (is_thread_zero_in_wave()) {
|
||||
quiet();
|
||||
}
|
||||
sync_all_wave();
|
||||
}
|
||||
|
||||
__device__ void IPCContext::barrier_all_wg() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
quiet();
|
||||
}
|
||||
sync_all();
|
||||
sync_all_wg();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
@@ -134,7 +174,7 @@ __device__ void IPCContext::barrier(rocshmem_team_t team) {
|
||||
if (is_thread_zero_in_block()) {
|
||||
quiet();
|
||||
}
|
||||
internal_sync(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
internal_sync_wg(pe, pe_start, pe_stride, pe_size, p_sync);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
|
||||
@@ -468,7 +468,7 @@ __device__ void IPCContext::internal_broadcast(T *dst, const T *src, int nelems,
|
||||
}
|
||||
|
||||
// Synchronize on completion of broadcast
|
||||
internal_sync(my_pe, pe_start, stride, pe_size, p_sync);
|
||||
internal_sync_wg(my_pe, pe_start, stride, pe_size, p_sync);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -497,7 +497,7 @@ __device__ void IPCContext::alltoall_linear(rocshmem_team_t team, T *dst,
|
||||
quiet();
|
||||
}
|
||||
// wait until everyone has obtained their designated data
|
||||
internal_sync(my_pe, pe_start, stride, pe_size, pSync);
|
||||
internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -527,7 +527,7 @@ __device__ void IPCContext::fcollect_linear(rocshmem_team_t team, T *dst,
|
||||
quiet();
|
||||
}
|
||||
// wait until everyone has obtained their designated data
|
||||
internal_sync(my_pe, pe_start, stride, pe_size, pSync);
|
||||
internal_sync_wg(my_pe, pe_start, stride, pe_size, pSync);
|
||||
}
|
||||
|
||||
// Block/wave functions
|
||||
|
||||
@@ -45,7 +45,7 @@ class IPCDefaultContextProxy {
|
||||
size_t num_elems = 1)
|
||||
: constructed_{true}, proxy_{num_elems} {
|
||||
auto ctx{proxy_.get()};
|
||||
new (ctx) IPCContext(reinterpret_cast<Backend*>(backend));
|
||||
new (ctx) IPCContext(reinterpret_cast<Backend*>(backend), 0);
|
||||
ctx->tinfo = tinfo;
|
||||
rocshmem_ctx_t local{ctx, tinfo};
|
||||
set_internal_ctx(&local);
|
||||
|
||||
@@ -170,6 +170,20 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier_all() {
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier_all_wave() {
|
||||
if (is_thread_zero_in_wave()) {
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void ROContext::barrier_all_wg() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
@@ -189,6 +203,20 @@ __device__ void ROContext::barrier(rocshmem_team_t team) {
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync_all() {
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync_all_wave() {
|
||||
if (is_thread_zero_in_wave()) {
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
block_handle, true, get_status_flag(), is_default_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync_all_wg() {
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
|
||||
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
|
||||
@@ -197,7 +225,7 @@ __device__ void ROContext::sync_all() {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ROContext::sync(rocshmem_team_t team) {
|
||||
__device__ void ROContext::sync_wg(rocshmem_team_t team) {
|
||||
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
|
||||
if (is_thread_zero_in_block()) {
|
||||
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
|
||||
|
||||
@@ -65,11 +65,19 @@ class ROContext : public Context {
|
||||
|
||||
__device__ void barrier_all();
|
||||
|
||||
__device__ void barrier_all_wave();
|
||||
|
||||
__device__ void barrier_all_wg();
|
||||
|
||||
__device__ void barrier(rocshmem_team_t team);
|
||||
|
||||
__device__ void sync_all();
|
||||
|
||||
__device__ void sync(rocshmem_team_t team);
|
||||
__device__ void sync_all_wave();
|
||||
|
||||
__device__ void sync_all_wg();
|
||||
|
||||
__device__ void sync_wg(rocshmem_team_t team);
|
||||
|
||||
template <typename T>
|
||||
__device__ void p(T *dest, T value, int pe);
|
||||
|
||||
@@ -570,12 +570,24 @@ __device__ int rocshmem_test(T *ivars, int cmp, T val) {
|
||||
return ctx_internal->test(ivars, cmp, val);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
|
||||
__device__ void rocshmem_ctx_barrier_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_barrier_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_all();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wave_barrier_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wave_barrier_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_all_wave();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wg_barrier_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_all_wg();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_barrier_all() {
|
||||
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
@@ -586,12 +598,32 @@ __device__ void rocshmem_barrier(rocshmem_team_t team) {
|
||||
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
|
||||
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_sync_all() {
|
||||
rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all_wave();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wave_sync_all() {
|
||||
rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all_wg();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_sync_all() {
|
||||
rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
@@ -600,7 +632,7 @@ __device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync(team);
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) {
|
||||
|
||||
@@ -43,6 +43,8 @@ enum rocshmem_stats {
|
||||
NUM_QUIET,
|
||||
NUM_TO_ALL,
|
||||
NUM_BARRIER_ALL,
|
||||
NUM_BARRIER_ALL_WAVE,
|
||||
NUM_BARRIER_ALL_WG,
|
||||
NUM_WAIT_UNTIL,
|
||||
NUM_WAIT_UNTIL_ANY,
|
||||
NUM_WAIT_UNTIL_ALL,
|
||||
@@ -70,6 +72,8 @@ enum rocshmem_stats {
|
||||
NUM_TEST,
|
||||
NUM_SHMEM_PTR,
|
||||
NUM_SYNC_ALL,
|
||||
NUM_SYNC_ALL_WAVE,
|
||||
NUM_SYNC_ALL_WG,
|
||||
NUM_BROADCAST,
|
||||
NUM_PUT_WG,
|
||||
NUM_PUT_NBI_WG,
|
||||
|
||||
@@ -30,9 +30,12 @@ using namespace rocshmem;
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time) {
|
||||
long long int *end_time, TestType type,
|
||||
int wf_size) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int t_id = get_flat_block_id();
|
||||
int wg_id = get_flat_grid_id();
|
||||
int wf_id = t_id / wf_size;
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx);
|
||||
@@ -42,17 +45,25 @@ __global__ void BarrierAllTest(int loop, int skip, long long int *start_time,
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* The function `rocshmem_ctx_wg_barrier_all` should be called from only
|
||||
* one group within the grid to avoid unintended behavior.
|
||||
*/
|
||||
if (is_block_zero_in_grid()) {
|
||||
rocshmem_ctx_wg_barrier_all(ctx);
|
||||
switch (type) {
|
||||
case BarrierAllTestType:
|
||||
if(t_id == 0) {
|
||||
rocshmem_ctx_barrier_all(ctx);
|
||||
}
|
||||
break;
|
||||
case WAVEBarrierAllTestType:
|
||||
if(wf_id == 0) {
|
||||
rocshmem_ctx_wave_barrier_all(ctx);
|
||||
}
|
||||
break;
|
||||
case WGBarrierAllTestType:
|
||||
rocshmem_ctx_wg_barrier_all(ctx);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
@@ -74,10 +85,10 @@ void BarrierAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(BarrierAllTest, gridSize, blockSize, shared_bytes, stream,
|
||||
loop, args.skip, start_time, end_time);
|
||||
loop, args.skip, start_time, end_time, _type, wf_size);
|
||||
|
||||
num_msgs = loop + args.skip;
|
||||
num_timed_msgs = loop;
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void BarrierAllTester::resetBuffers(uint64_t size) {}
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#include "sync_all_tester.hpp"
|
||||
|
||||
#include <rocshmem/rocshmem.hpp>
|
||||
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void SyncAllTest(int loop, int skip, long long int *start_time,
|
||||
long long int *end_time, TestType type,
|
||||
int wf_size) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
int t_id = get_flat_block_id();
|
||||
int wg_id = get_flat_grid_id();
|
||||
int wf_id = t_id / wf_size;
|
||||
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ROCSHMEM_CTX_WG_PRIVATE, &ctx);
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (hipThreadIdx_x == 0 && i == skip) {
|
||||
start_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case SyncAllTestType:
|
||||
if(t_id == 0) {
|
||||
rocshmem_ctx_sync_all(ctx);
|
||||
}
|
||||
break;
|
||||
case WAVESyncAllTestType:
|
||||
if(wf_id == 0) {
|
||||
rocshmem_ctx_wave_sync_all(ctx);
|
||||
}
|
||||
break;
|
||||
case WGSyncAllTestType:
|
||||
rocshmem_ctx_wg_sync_all(ctx);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
end_time[wg_id] = wall_clock64();
|
||||
}
|
||||
|
||||
rocshmem_wg_ctx_destroy(&ctx);
|
||||
rocshmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
SyncAllTester::SyncAllTester(TesterArguments args) : Tester(args) {}
|
||||
|
||||
SyncAllTester::~SyncAllTester() {}
|
||||
|
||||
void SyncAllTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(SyncAllTest, gridSize, blockSize, shared_bytes, stream,
|
||||
loop, args.skip, start_time, end_time, _type, wf_size);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void SyncAllTester::resetBuffers(uint64_t size) {}
|
||||
|
||||
void SyncAllTester::verifyResults(uint64_t size) {}
|
||||
@@ -0,0 +1,50 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to
|
||||
* deal in the Software without restriction, including without limitation the
|
||||
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
* sell copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
* IN THE SOFTWARE.
|
||||
*****************************************************************************/
|
||||
|
||||
#ifndef _BARRIER_ALL_TESTER_HPP_
|
||||
#define _BARRIER_ALL_TESTER_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
*****************************************************************************/
|
||||
__global__ void SyncAllTest(TestType type);
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
*****************************************************************************/
|
||||
class SyncAllTester : public Tester {
|
||||
public:
|
||||
explicit SyncAllTester(TesterArguments args);
|
||||
virtual ~SyncAllTester();
|
||||
|
||||
protected:
|
||||
virtual void resetBuffers(uint64_t size) override;
|
||||
|
||||
virtual void launchKernel(dim3 gridSize, dim3 blockSize, int loop,
|
||||
uint64_t size) override;
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -41,15 +41,6 @@ __global__ void SyncTest(int loop, int skip, long long int *start_time,
|
||||
|
||||
__syncthreads();
|
||||
switch (type) {
|
||||
case SyncAllTestType:
|
||||
/**
|
||||
* The function `rocshmem_ctx_wg_sync_all` should be called from only
|
||||
* one group within the grid to avoid unintended behavior.
|
||||
*/
|
||||
if (is_block_zero_in_grid()) {
|
||||
rocshmem_ctx_wg_sync_all(ctx);
|
||||
}
|
||||
break;
|
||||
case SyncTestType:
|
||||
rocshmem_ctx_wg_team_sync(ctx, teams[wg_id]);
|
||||
break;
|
||||
|
||||
@@ -43,10 +43,11 @@
|
||||
#include "random_access_tester.hpp"
|
||||
#include "shmem_ptr_tester.hpp"
|
||||
#include "signaling_operations_tester.hpp"
|
||||
#include "sync_all_tester.hpp"
|
||||
#include "sync_tester.hpp"
|
||||
#include "team_alltoall_tester.hpp"
|
||||
#include "team_broadcast_tester.hpp"
|
||||
#include "team_barrier_tester.hpp"
|
||||
#include "team_broadcast_tester.hpp"
|
||||
#include "team_ctx_infra_tester.hpp"
|
||||
#include "team_ctx_primitive_tester.hpp"
|
||||
#include "team_fcollect_tester.hpp"
|
||||
@@ -317,6 +318,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) std::cout << "Barrier_All ###" << std::endl;
|
||||
testers.push_back(new BarrierAllTester(args));
|
||||
return testers;
|
||||
case WAVEBarrierAllTestType:
|
||||
if (rank == 0) std::cout << "WAVE Barrier_All ###" << std::endl;
|
||||
testers.push_back(new BarrierAllTester(args));
|
||||
return testers;
|
||||
case WGBarrierAllTestType:
|
||||
if (rank == 0) std::cout << "WG Barrier_All ###" << std::endl;
|
||||
testers.push_back(new BarrierAllTester(args));
|
||||
return testers;
|
||||
case TeamBarrierTestType:
|
||||
if (rank == 0) std::cout << "Team Barrier Test ###" << std::endl;
|
||||
testers.push_back(new TeamBarrierTester(args));
|
||||
@@ -325,6 +334,14 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
if (rank == 0) std::cout << "SyncAll ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
return testers;
|
||||
case WAVESyncAllTestType:
|
||||
if (rank == 0) std::cout << "WAVE SyncAll ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
return testers;
|
||||
case WGSyncAllTestType:
|
||||
if (rank == 0) std::cout << "WG SyncAll ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
return testers;
|
||||
case SyncTestType:
|
||||
if (rank == 0) std::cout << "Sync ###" << std::endl;
|
||||
testers.push_back(new SyncTester(args));
|
||||
@@ -510,7 +527,9 @@ bool Tester::peLaunchesKernel() {
|
||||
(_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) ||
|
||||
(_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) ||
|
||||
(_type == PingPongTestType) || (_type == BarrierAllTestType) ||
|
||||
(_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) ||
|
||||
(_type == SyncTestType) || (_type == SyncAllTestType) ||
|
||||
(_type == WAVESyncAllTestType) || (_type == WGSyncAllTestType) ||
|
||||
(_type == RandomAccessTestType) || (_type == PingAllTestType) ||
|
||||
(_type == TeamBarrierTestType);
|
||||
|
||||
|
||||
@@ -99,6 +99,10 @@ enum TestType {
|
||||
DefaultCTXPutNBITestType = 62,
|
||||
DefaultCTXPTestType = 63,
|
||||
DefaultCTXGTestType = 64,
|
||||
WAVEBarrierAllTestType = 65,
|
||||
WGBarrierAllTestType = 66,
|
||||
WAVESyncAllTestType = 67,
|
||||
WGSyncAllTestType = 68,
|
||||
};
|
||||
|
||||
enum OpType { PutType = 0, GetType = 1 };
|
||||
|
||||
@@ -85,8 +85,12 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
case AMO_IncTestType:
|
||||
case AMO_FetchTestType:
|
||||
case BarrierAllTestType:
|
||||
case WAVEBarrierAllTestType:
|
||||
case WGBarrierAllTestType:
|
||||
case TeamBarrierTestType:
|
||||
case SyncAllTestType:
|
||||
case WAVESyncAllTestType:
|
||||
case WGSyncAllTestType:
|
||||
case SyncTestType:
|
||||
case ShmemPtrTestType:
|
||||
min_msg_size = 8;
|
||||
@@ -133,7 +137,9 @@ void TesterArguments::get_rocshmem_arguments() {
|
||||
myid = rocshmem_my_pe();
|
||||
|
||||
TestType type = (TestType)algorithm;
|
||||
if ((type != BarrierAllTestType) && (type != SyncAllTestType) &&
|
||||
if ((type != BarrierAllTestType) && (type != WAVEBarrierAllTestType) &&
|
||||
(type != WGBarrierAllTestType) && (type != SyncAllTestType) &&
|
||||
(type != WAVESyncAllTestType) && (type != WGSyncAllTestType) &&
|
||||
(type != SyncTestType) && (type != TeamAllToAllTestType) &&
|
||||
(type != TeamFCollectTestType) && (type != TeamReductionTestType) &&
|
||||
(type != TeamBroadcastTestType) && (type != PingAllTestType) &&
|
||||
|
||||
Verwijs in nieuw issue
Block a user