[GDA] Implement internal_direct_barrier_wg (#299)
[ROCm/rocshmem commit: 5f87bb061b]
This commit is contained in:
@@ -267,6 +267,9 @@ class GDAContext : public Context {
|
||||
__device__ void internal_direct_barrier(int pe, int PE_start, int stride,
|
||||
int n_pes, int64_t *pSync);
|
||||
|
||||
__device__ void internal_direct_barrier_wg(int pe, int PE_start, int stride,
|
||||
int n_pes, int64_t *pSync);
|
||||
|
||||
__device__ void internal_atomic_barrier(int pe, int PE_start, int stride,
|
||||
int n_pes, int64_t *pSync);
|
||||
|
||||
|
||||
@@ -70,6 +70,57 @@ __device__ void GDAContext::internal_direct_barrier(int pe, int PE_start,
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void GDAContext::internal_direct_barrier_wg(int pe, int PE_start,
|
||||
int stride, int n_pes,
|
||||
int64_t *pSync) {
|
||||
int64_t flag_val{1};
|
||||
|
||||
if (pe == PE_start) {
|
||||
int wf_id = get_flat_block_id() / WF_SIZE;
|
||||
int wf_count = (int) ceil((double)get_flat_block_size() / (double)WF_SIZE);
|
||||
bool wf_leader = 0 == get_active_lane_num();
|
||||
|
||||
// Go through all PE offsets (except current offset = 0)
|
||||
// and wait until they all reach
|
||||
if (wf_leader) {
|
||||
for (int j = wf_id + 1; j < n_pes; j+= wf_count) {
|
||||
wait_until(&pSync[j], ROCSHMEM_CMP_EQ, flag_val);
|
||||
pSync[j] = ROCSHMEM_SYNC_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Announce to other PEs that all have reached
|
||||
for (int i = wf_id + 1, j = PE_start + stride + wf_id;
|
||||
i < n_pes;
|
||||
i+= wf_count, j += (wf_count * stride)) {
|
||||
put_nbi_wave(&pSync[0], &flag_val, 1, j);
|
||||
}
|
||||
|
||||
for (int i = wf_id + 1, j = PE_start + stride + wf_id;
|
||||
i < n_pes;
|
||||
i+= wf_count, j += (wf_count * stride)) {
|
||||
pe_quiet(j);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
pSync[0] = ROCSHMEM_SYNC_VALUE;
|
||||
}
|
||||
} else {
|
||||
if (is_thread_zero_in_block()) {
|
||||
// Mark current PE offset as reached
|
||||
size_t pe_offset = (pe - PE_start) / stride;
|
||||
put(&pSync[pe_offset], &flag_val, 1, PE_start);
|
||||
wait_until(&pSync[0], ROCSHMEM_CMP_EQ, flag_val);
|
||||
pSync[0] = ROCSHMEM_SYNC_VALUE;
|
||||
__threadfence_system();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void GDAContext::internal_atomic_barrier(int pe, int PE_start,
|
||||
int stride, int n_pes,
|
||||
int64_t *pSync) {
|
||||
@@ -116,10 +167,10 @@ __device__ void GDAContext::internal_sync_wave(int pe, int PE_start, int stride,
|
||||
__device__ void GDAContext::internal_sync_wg(int pe, int PE_start, int stride,
|
||||
int PE_size, int64_t *pSync) {
|
||||
__syncthreads();
|
||||
if (is_thread_zero_in_block()) {
|
||||
if (PE_size < 64) {
|
||||
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
|
||||
} else {
|
||||
if (PE_size < 64) {
|
||||
internal_direct_barrier_wg(pe, PE_start, stride, PE_size, pSync);
|
||||
} else {
|
||||
if (is_thread_zero_in_block()) {
|
||||
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user