[GDA] Implement internal_direct_barrier_wg (#299)

[ROCm/rocshmem commit: 5f87bb061b]
This commit is contained in:
Yiltan
2025-10-31 10:26:24 -04:00
committed by GitHub
parent fa7841f0d4
commit 2f8a1c02a4
2 changed files with 58 additions and 4 deletions
@@ -267,6 +267,9 @@ class GDAContext : public Context {
__device__ void internal_direct_barrier(int pe, int PE_start, int stride,
int n_pes, int64_t *pSync);
__device__ void internal_direct_barrier_wg(int pe, int PE_start, int stride,
int n_pes, int64_t *pSync);
__device__ void internal_atomic_barrier(int pe, int PE_start, int stride,
int n_pes, int64_t *pSync);
@@ -70,6 +70,57 @@ __device__ void GDAContext::internal_direct_barrier(int pe, int PE_start,
}
}
__device__ void GDAContext::internal_direct_barrier_wg(int pe, int PE_start,
int stride, int n_pes,
int64_t *pSync) {
int64_t flag_val{1};
if (pe == PE_start) {
int wf_id = get_flat_block_id() / WF_SIZE;
int wf_count = (int) ceil((double)get_flat_block_size() / (double)WF_SIZE);
bool wf_leader = 0 == get_active_lane_num();
// Go through all PE offsets (except current offset = 0)
// and wait until they all reach
if (wf_leader) {
for (int j = wf_id + 1; j < n_pes; j+= wf_count) {
wait_until(&pSync[j], ROCSHMEM_CMP_EQ, flag_val);
pSync[j] = ROCSHMEM_SYNC_VALUE;
}
}
__syncthreads();
// Announce to other PEs that all have reached
for (int i = wf_id + 1, j = PE_start + stride + wf_id;
i < n_pes;
i+= wf_count, j += (wf_count * stride)) {
put_nbi_wave(&pSync[0], &flag_val, 1, j);
}
for (int i = wf_id + 1, j = PE_start + stride + wf_id;
i < n_pes;
i+= wf_count, j += (wf_count * stride)) {
pe_quiet(j);
}
__syncthreads();
if (is_thread_zero_in_block()) {
pSync[0] = ROCSHMEM_SYNC_VALUE;
}
} else {
if (is_thread_zero_in_block()) {
// Mark current PE offset as reached
size_t pe_offset = (pe - PE_start) / stride;
put(&pSync[pe_offset], &flag_val, 1, PE_start);
wait_until(&pSync[0], ROCSHMEM_CMP_EQ, flag_val);
pSync[0] = ROCSHMEM_SYNC_VALUE;
__threadfence_system();
}
}
}
__device__ void GDAContext::internal_atomic_barrier(int pe, int PE_start,
int stride, int n_pes,
int64_t *pSync) {
@@ -116,10 +167,10 @@ __device__ void GDAContext::internal_sync_wave(int pe, int PE_start, int stride,
__device__ void GDAContext::internal_sync_wg(int pe, int PE_start, int stride,
int PE_size, int64_t *pSync) {
__syncthreads();
if (is_thread_zero_in_block()) {
if (PE_size < 64) {
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
} else {
if (PE_size < 64) {
internal_direct_barrier_wg(pe, PE_start, stride, PE_size, pSync);
} else {
if (is_thread_zero_in_block()) {
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
}
}