Update Barrier and Sync APIs (#73)

* Add thread, wavefront, and workgroup-level `barrier` APIs in IPC and RO conduits; remove collectives on default context
 - Implemented `barrier` APIs for thread, wavefront, and workgroup scopes
 - Added support into both IPC and RO conduits
 - Added functional tests to cover all `barrier` APIs
 - Removed collective operations on default context

* Add thread, wavefront, and workgroup-level `sync` APIs in IPC and RO conduits.
  - Implemented `sync` APIs for thread, wavefront, and workgroup scopes
  - Added support into both IPC and RO conduits
  - Added functional tests to cover all `sync` APIs

* update naming convention for context-based `barrier` APIs
이 커밋은 다음에 포함됨:
Avinash Kethineedi
2025-04-08 11:25:31 -05:00
커밋한 사람 GitHub
부모 c652f58cef
커밋 dc61bca066
16개의 변경된 파일347개의 추가작업 그리고 67개의 파일을 삭제
+27 -21
파일 보기
@@ -588,14 +588,22 @@ __device__ void rocshmem_ctx_wg_barrier_all(rocshmem_ctx_t ctx) {
get_internal_ctx(ctx)->barrier_all_wg();
}
__device__ void rocshmem_wg_barrier_all() {
rocshmem_ctx_wg_barrier_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_barrier(rocshmem_team_t team) {
__device__ void rocshmem_ctx_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_barrier\n");
get_internal_ctx(ROCSHMEM_CTX_DEFAULT)->barrier(team);
get_internal_ctx(ctx)->barrier(team);
}
__device__ void rocshmem_ctx_wave_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_wave_barrier\n");
get_internal_ctx(ctx)->barrier_wave(team);
}
__device__ void rocshmem_ctx_wg_barrier(rocshmem_ctx_t ctx, rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_wg_barrier\n");
get_internal_ctx(ctx)->barrier_wg(team);
}
__device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
@@ -604,39 +612,37 @@ __device__ void rocshmem_ctx_sync_all(rocshmem_ctx_t ctx) {
get_internal_ctx(ctx)->sync_all();
}
__device__ void rocshmem_sync_all() {
rocshmem_ctx_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wave_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
get_internal_ctx(ctx)->sync_all_wave();
}
__device__ void rocshmem_wave_sync_all() {
rocshmem_ctx_wave_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wg_sync_all(rocshmem_ctx_t ctx) {
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
get_internal_ctx(ctx)->sync_all_wg();
}
__device__ void rocshmem_wg_sync_all() {
rocshmem_ctx_wg_sync_all(ROCSHMEM_CTX_DEFAULT);
}
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
__device__ void rocshmem_ctx_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_sync_all\n");
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ void rocshmem_wg_team_sync(rocshmem_team_t team) {
rocshmem_ctx_wg_team_sync(ROCSHMEM_CTX_DEFAULT, team);
__device__ void rocshmem_ctx_wave_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_wave_sync_all\n");
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ void rocshmem_ctx_wg_team_sync(rocshmem_ctx_t ctx,
rocshmem_team_t team) {
GPU_DPRINTF("Function: rocshmem_ctx_wg_sync_all\n");
get_internal_ctx(ctx)->sync_wg(team);
}
__device__ int rocshmem_ctx_n_pes(rocshmem_ctx_t ctx) {