ipc: add ring_allreduce algorithms

add the ring allreduce algorithm to the ipc conduit in order to be able
to execute slightly largers reductions.


[ROCm/rocshmem commit: 1fbb89bc73]
This commit is contained in:
Edgar Gabriel
2024-10-07 20:49:45 +00:00
rodzic 5f0f2f6e85
commit c9b5f03548
2 zmienionych plików z 115 dodań i 9 usunięć
@@ -241,9 +241,13 @@ class IPCContext : public Context {
template <typename T, ROC_SHMEM_OP Op>
__device__ void internal_direct_allreduce(T *dst, const T *src,
int nelems, int PE_start, int
logPE_stride, int PE_size,
T *pWrk, long *pSync);
int nelems, int PE_start, int logPE_stride,
int PE_size, T *pWrk, long *pSync);
template <typename T, ROC_SHMEM_OP Op>
__device__ void internal_ring_allreduce(T *dst, const T *src,
int nelems, int PE_start, int logPE_stride,
int PE_size, T *pWrk, long *pSync,
int n_seg, int seg_size, int chunk_size);
//internal functions used by collectives routines to write/read to
//work/sync buffers
@@ -219,9 +219,78 @@ __device__ void IPCContext::internal_direct_allreduce(
__syncthreads();
}
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::internal_ring_allreduce(
T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start,
[[maybe_unused]] int logPE_stride, [[maybe_unused]] int PE_size, T *pWrk,
long *pSync, // NOLINT(runtime/int)
int n_seg, int seg_size, int chunk_size) {
int off_seg, off_send, off_recv;
int send_pe = (my_pe + 1) % num_pes;
long wait_val; // NOLINT(runtime/int)
int wg_size = get_flat_block_size();
int wg_id = get_flat_block_id();
for (size_t i = wg_id; i < nelems; i += wg_size) {
dst[i] = src[i];
}
__syncthreads();
for (size_t seg = 0; seg < n_seg; seg++) {
off_seg = seg * seg_size;
for (int iter = 0; iter < num_pes - 1; iter++) {
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size);
putmem_nbi_wg(reinterpret_cast<void *>(&pWrk[off_send]),
reinterpret_cast<void *>(&dst[off_send + off_seg]),
chunk_size * sizeof(T), send_pe);
if (is_thread_zero_in_block()) {
fence();
wait_val = seg + 100;
put_nbi(&pSync[iter], &wait_val, 1, send_pe);
#if defined(__gfx90a__)
__threadfence_system();
#endif /* __gfx90a__ */
wait_until(&pSync[iter], ROC_SHMEM_CMP_EQ, wait_val);
}
__syncthreads();
compute_reduce<T, Op>(&pWrk[off_recv], &dst[off_seg + off_recv],
chunk_size, wg_id, wg_size);
}
for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) {
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
putmem_nbi_wg(reinterpret_cast<void *>(&dst[off_send + off_seg]),
reinterpret_cast<void *>(&dst[off_send + off_seg]),
chunk_size * sizeof(T), send_pe);
if (is_thread_zero_in_block()) {
fence();
wait_val = seg + 100;
put_nbi(&pSync[iter], &wait_val, 1, send_pe);
#if defined(__gfx90a__)
__threadfence_system();
#endif /* __gfx90a__ */
wait_until(&pSync[iter], ROC_SHMEM_CMP_EQ, wait_val);
}
__syncthreads();
}
}
__syncthreads();
for (size_t i = wg_id; i < 2 * num_pes - 2; i += wg_size) {
pSync[i] = ROC_SHMEM_SYNC_VALUE;
}
__syncthreads();
}
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
const T *source, int nreduce) {
const T *source, int nreduce) {
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
/**
@@ -239,11 +308,44 @@ __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest,
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::to_all(T *dest, const T *source, int nreduce,
int PE_start, int logPE_stride,
int PE_size, T *pWrk,
long *pSync) { // NOLINT(runtime/int)
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, logPE_stride,
PE_size, pWrk, pSync);
int PE_start, int logPE_stride,
int PE_size, T *pWrk,
long *pSync) { // NOLINT(runtime/int)
size_t direct_pWrk = num_pes * nreduce;
size_t direct_pSync = num_pes;
size_t ring_pSync = 2 * num_pes;
size_t provided_pWrk = max(nreduce / 2 + 1, ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE);
size_t provided_pSync = ROC_SHMEM_REDUCE_SYNC_SIZE;
if (provided_pWrk >= direct_pWrk && provided_pSync >= direct_pSync) {
internal_direct_allreduce<T, Op>(dest, source, nreduce, PE_start, logPE_stride,
PE_size, pWrk, pSync);
} else {
/* TODO (Edgar): some nreduce values cannot be evenly divided
** among num_pes and/or segments. THe algorithm
** currently cannot handle that one segment is of
** different size than other segments.
*/
if (ring_pSync <= ROC_SHMEM_REDUCE_SYNC_SIZE) {
size_t ring_pWrk = ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE;
// integer division truncating value
int chunk_size = ring_pWrk / num_pes;
int seg_size = chunk_size * num_pes;
// integer division rounding up
int n_seg = (nreduce + (seg_size -1)) / seg_size;
// recalculate chunk_size
chunk_size = seg_size / num_pes;
if (n_seg == 0) {
n_seg = 1;
}
internal_ring_allreduce<T, Op>(dest, source, nreduce, PE_start,
logPE_stride, PE_size, pWrk, pSync, n_seg,
seg_size, chunk_size);
} else {
GPU_DPRINTF("Unsupported reduction size for IPC conduit.\n");
}
}
}
template <typename T>