ipc: add ring_allreduce algorithms
add the ring allreduce algorithm to the ipc conduit in order to be able
to execute slightly largers reductions.
[ROCm/rocshmem commit: 1fbb89bc73]
This commit is contained in:
@@ -241,9 +241,13 @@ class IPCContext : public Context {
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void internal_direct_allreduce(T *dst, const T *src,
|
||||
int nelems, int PE_start, int
|
||||
logPE_stride, int PE_size,
|
||||
T *pWrk, long *pSync);
|
||||
int nelems, int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk, long *pSync);
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void internal_ring_allreduce(T *dst, const T *src,
|
||||
int nelems, int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk, long *pSync,
|
||||
int n_seg, int seg_size, int chunk_size);
|
||||
|
||||
//internal functions used by collectives routines to write/read to
|
||||
//work/sync buffers
|
||||
|
||||
@@ -219,9 +219,78 @@ __device__ void IPCContext::internal_direct_allreduce(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::internal_ring_allreduce(
|
||||
T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start,
|
||||
[[maybe_unused]] int logPE_stride, [[maybe_unused]] int PE_size, T *pWrk,
|
||||
long *pSync, // NOLINT(runtime/int)
|
||||
int n_seg, int seg_size, int chunk_size) {
|
||||
int off_seg, off_send, off_recv;
|
||||
int send_pe = (my_pe + 1) % num_pes;
|
||||
long wait_val; // NOLINT(runtime/int)
|
||||
|
||||
int wg_size = get_flat_block_size();
|
||||
int wg_id = get_flat_block_id();
|
||||
|
||||
for (size_t i = wg_id; i < nelems; i += wg_size) {
|
||||
dst[i] = src[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (size_t seg = 0; seg < n_seg; seg++) {
|
||||
off_seg = seg * seg_size;
|
||||
for (int iter = 0; iter < num_pes - 1; iter++) {
|
||||
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
|
||||
putmem_nbi_wg(reinterpret_cast<void *>(&pWrk[off_send]),
|
||||
reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
chunk_size * sizeof(T), send_pe);
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
fence();
|
||||
|
||||
wait_val = seg + 100;
|
||||
put_nbi(&pSync[iter], &wait_val, 1, send_pe);
|
||||
#if defined(__gfx90a__)
|
||||
__threadfence_system();
|
||||
#endif /* __gfx90a__ */
|
||||
wait_until(&pSync[iter], ROC_SHMEM_CMP_EQ, wait_val);
|
||||
}
|
||||
__syncthreads();
|
||||
compute_reduce<T, Op>(&pWrk[off_recv], &dst[off_seg + off_recv],
|
||||
chunk_size, wg_id, wg_size);
|
||||
}
|
||||
|
||||
for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) {
|
||||
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
putmem_nbi_wg(reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
chunk_size * sizeof(T), send_pe);
|
||||
|
||||
if (is_thread_zero_in_block()) {
|
||||
fence();
|
||||
wait_val = seg + 100;
|
||||
put_nbi(&pSync[iter], &wait_val, 1, send_pe);
|
||||
#if defined(__gfx90a__)
|
||||
__threadfence_system();
|
||||
#endif /* __gfx90a__ */
|
||||
wait_until(&pSync[iter], ROC_SHMEM_CMP_EQ, wait_val);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = wg_id; i < 2 * num_pes - 2; i += wg_size) {
|
||||
pSync[i] = ROC_SHMEM_SYNC_VALUE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
|
||||
const T *source, int nreduce) {
|
||||
const T *source, int nreduce) {
|
||||
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
|
||||
|
||||
/**
|
||||
@@ -239,11 +308,44 @@ __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
|
||||
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::to_all(T *dest, const T *source, int nreduce,
|
||||
int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk,
|
||||
long *pSync) { // NOLINT(runtime/int)
|
||||
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, logPE_stride,
|
||||
PE_size, pWrk, pSync);
|
||||
int PE_start, int logPE_stride,
|
||||
int PE_size, T *pWrk,
|
||||
long *pSync) { // NOLINT(runtime/int)
|
||||
size_t direct_pWrk = num_pes * nreduce;
|
||||
size_t direct_pSync = num_pes;
|
||||
size_t ring_pSync = 2 * num_pes;
|
||||
size_t provided_pWrk = max(nreduce / 2 + 1, ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE);
|
||||
size_t provided_pSync = ROC_SHMEM_REDUCE_SYNC_SIZE;
|
||||
|
||||
if (provided_pWrk >= direct_pWrk && provided_pSync >= direct_pSync) {
|
||||
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, logPE_stride,
|
||||
PE_size, pWrk, pSync);
|
||||
} else {
|
||||
/* TODO (Edgar): some nreduce values cannot be evenly divided
|
||||
** among num_pes and/or segments. THe algorithm
|
||||
** currently cannot handle that one segment is of
|
||||
** different size than other segments.
|
||||
*/
|
||||
if (ring_pSync <= ROC_SHMEM_REDUCE_SYNC_SIZE) {
|
||||
size_t ring_pWrk = ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE;
|
||||
// integer division truncating value
|
||||
int chunk_size = ring_pWrk / num_pes;
|
||||
int seg_size = chunk_size * num_pes;
|
||||
|
||||
// integer division rounding up
|
||||
int n_seg = (nreduce + (seg_size -1)) / seg_size;
|
||||
// recalculate chunk_size
|
||||
chunk_size = seg_size / num_pes;
|
||||
if (n_seg == 0) {
|
||||
n_seg = 1;
|
||||
}
|
||||
internal_ring_allreduce<T, Op>(dest, source, nreduce, PE_start,
|
||||
logPE_stride, PE_size, pWrk, pSync, n_seg,
|
||||
seg_size, chunk_size);
|
||||
} else {
|
||||
GPU_DPRINTF("Unsupported reduction size for IPC conduit.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user