Update Barrier and Sync APIs (#73)
* Add thread, wavefront, and workgroup-level `barrier` APIs in IPC and RO conduits; remove collectives on default context - Implemented `barrier` APIs for thread, wavefront, and workgroup scopes - Added support into both IPC and RO conduits - Added functional tests to cover all `barrier` APIs - Removed collective operations on default context * Add thread, wavefront, and workgroup-level `sync` APIs in IPC and RO conduits. - Implemented `sync` APIs for thread, wavefront, and workgroup scopes - Added support into both IPC and RO conduits - Added functional tests to cover all `sync` APIs * update naming convention for context-based `barrier` APIs
이 커밋은 다음에 포함됨:
+27
-21
@@ -588,14 +588,22 @@ __device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
|
||||
get_internal_ctx(ctx)->barrier_all_wg();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_barrier_all() {
|
||||
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_barrier(rocshmem_team_t team) {
|
||||
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_barrier\n");
|
||||
|
||||
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
|
||||
get_internal_ctx(ctx)->barrier(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_wave_barrier\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_wave(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_wg_barrier\n");
|
||||
|
||||
get_internal_ctx(ctx)->barrier_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
|
||||
@@ -604,39 +612,37 @@ __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
|
||||
get_internal_ctx(ctx)->sync_all();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_sync_all() {
|
||||
rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all_wave();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wave_sync_all() {
|
||||
rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_all_wg();
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_sync_all() {
|
||||
rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
|
||||
__device__ void rocshmem_ctx_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) {
|
||||
rocshmem_ctx_wg_team_sync(ROCSHMEM_CTX_DEFAULT, team);
|
||||
__device__ void rocshmem_ctx_wave_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
|
||||
rocshmem_team_t team) {
|
||||
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
|
||||
|
||||
get_internal_ctx(ctx)->sync_wg(team);
|
||||
}
|
||||
|
||||
__device__ int rocshmem_ctx_n_pes(rocshmem_ctx_t ctx) {
|
||||
|
||||
새 이슈에서 참조
사용자 차단