fix allreduce tester (#385)

- use the reduce_psync buffers for synchronization in allreduce, not the
  barrier_psync.
- execute a wwg barrier after the allreduce operation. After internal
  discussion it was determined that it is required for correctness.
This commit is contained in:
Edgar Gabriel
2026-01-16 08:10:25 -06:00
کامیت شده توسط GitHub
والد 885e41ec62
کامیت 6f512e92a5
3فایلهای تغییر یافته به همراه7 افزوده شده و 30 حذف شده
@@ -164,13 +164,7 @@ __device__ void IPCContext::internal_putmem(void *dest, const void *source,
size_t nelems, int pe) {
uint64_t L_offset = reinterpret_cast<char *>(dest) - wrk_sync_pool_bases_[my_pe];
memcpy_lane(wrk_sync_pool_bases_[pe] + L_offset, const_cast<void *>(source), nelems);
#if defined(__gfx90a__)
__threadfence_system();
#elif defined (__gfx1201__) || defined (__gfx1100__)
fence(pe);
#else
ipcImpl_.ipcFence();
#endif
}
__device__ void IPCContext::internal_getmem(void *dest, const void *source,
@@ -186,15 +180,7 @@ __device__ void IPCContext::internal_putmem_wg(void *dest, const void *source,
uint64_t L_offset = reinterpret_cast<char *>(dest) - wrk_sync_pool_bases_[my_pe];
memcpy_wg(wrk_sync_pool_bases_[pe] + L_offset, const_cast<void *>(source), nelems);
__syncthreads();
#if defined(__gfx90a__)
__threadfence_system();
#elif defined (__gfx1201__) || defined (__gfx1100__)
if (is_thread_zero_in_block() ) {
fence(pe);
}
#else
ipcImpl_.ipcFence();
#endif
}
__device__ void IPCContext::internal_getmem_wg(void *dest, const void *source,
@@ -210,15 +196,7 @@ __device__ void IPCContext::internal_putmem_wave(void *dest,
const void *source, size_t nelems, int pe) {
uint64_t L_offset = reinterpret_cast<char *>(dest) - wrk_sync_pool_bases_[my_pe];
memcpy_wave(wrk_sync_pool_bases_[pe] + L_offset, const_cast<void *>(source), nelems);
#if defined(__gfx90a__)
__threadfence_system();
#elif defined (__gfx1201__) || defined (__gfx1100__)
if (is_thread_zero_in_wave() ) {
fence(pe);
}
#else
ipcImpl_.ipcFence();
#endif
}
__device__ void IPCContext::internal_getmem_wave(void *dest,
@@ -174,7 +174,7 @@ __device__ void IPCContext::internal_direct_allreduce(
int stride = team_obj->tinfo_wrt_world->stride;
int PE_start = team_obj->tinfo_wrt_world->pe_start;
int PE_size = team_obj->tinfo_wrt_world->size;
long *pSync = team_obj->barrier_pSync;
long *pSync = team_obj->reduce_pSync;
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
int finish = PE_start + stride * PE_size;
@@ -195,7 +195,7 @@ __device__ void IPCContext::internal_direct_allreduce(
nelems * sizeof(T), i);
if (is_thread_zero_in_block()) {
fence();
fence(i);
internal_putmem(&pSync[pe], &flag_val, sizeof(*pSync), i);
}
}
@@ -222,7 +222,6 @@ __device__ void IPCContext::internal_direct_allreduce(
for (int i = wg_id; i < num_pes; i += wg_size) {
pSync[i] = ROCSHMEM_SYNC_VALUE;
}
threadfence_system();
__syncthreads();
}
@@ -290,7 +289,7 @@ __device__ void IPCContext::internal_ring_allreduce(
int stride = team_obj->tinfo_wrt_world->stride;
int PE_start = team_obj->tinfo_wrt_world->pe_start;
int PE_size = team_obj->tinfo_wrt_world->size;
long *pSync = team_obj->barrier_pSync;
long *pSync = team_obj->reduce_pSync;
T *pWrk = reinterpret_cast<T *>(team_obj->pWrk);
int my_pe_in_team = team_obj->my_pe;
@@ -320,8 +319,8 @@ __device__ void IPCContext::internal_ring_allreduce(
chunk_size * sizeof(T), send_pe);
if (is_thread_zero_in_block()) {
fence();
wait_val = seg + 100;
fence(send_pe);
internal_putmem(&pSync[iter], &wait_val, sizeof(*pSync), send_pe);
wait_until(&pSync[iter], ROCSHMEM_CMP_EQ, wait_val);
}
@@ -338,8 +337,8 @@ __device__ void IPCContext::internal_ring_allreduce(
chunk_size * sizeof(T), send_pe);
if (is_thread_zero_in_block()) {
fence();
wait_val = seg + 10;
fence(send_pe);
internal_putmem(&pSync[iter], &wait_val, sizeof(*pSync), send_pe);
wait_until(&pSync[iter], ROCSHMEM_CMP_EQ, wait_val);
}
@@ -350,7 +349,6 @@ __device__ void IPCContext::internal_ring_allreduce(
for (int i = wg_id; i < 2 * num_pes - 2; i += wg_size) {
pSync[i] = ROCSHMEM_SYNC_VALUE;
}
threadfence_system();
__syncthreads();
}
@@ -409,6 +407,7 @@ __device__ int IPCContext::reduce(rocshmem_team_t team, T *dest,
return ROCSHMEM_ERROR;
}
}
barrier_wg(team);
return ROCSHMEM_SUCCESS;
}
@@ -83,7 +83,7 @@ __global__ void TeamReductionTest(int loop, int skip, long long int *start_time,
__shared__ rocshmem_ctx_t ctx;
int wg_id = get_flat_grid_id();
rocshmem_wg_ctx_create(ctx_type, &ctx);
rocshmem_wg_team_create_ctx(team, ctx_type, &ctx);
int n_pes = rocshmem_ctx_n_pes(ctx);