From c9b5f03548146a5bf397fd160fbd1f0286a487cd Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Mon, 7 Oct 2024 20:49:45 +0000 Subject: [PATCH] ipc: add ring_allreduce algorithms add the ring allreduce algorithm to the ipc conduit in order to be able to execute slightly largers reductions. [ROCm/rocshmem commit: 1fbb89bc73a06fe4bc268539c34f36de5cdc3994] --- .../rocshmem/src/ipc/context_ipc_device.hpp | 10 +- .../src/ipc/context_ipc_tmpl_device.hpp | 114 +++++++++++++++++- 2 files changed, 115 insertions(+), 9 deletions(-) diff --git a/projects/rocshmem/src/ipc/context_ipc_device.hpp b/projects/rocshmem/src/ipc/context_ipc_device.hpp index 6bf8885d6c..dcb9571833 100644 --- a/projects/rocshmem/src/ipc/context_ipc_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_device.hpp @@ -241,9 +241,13 @@ class IPCContext : public Context { template __device__ void internal_direct_allreduce(T *dst, const T *src, - int nelems, int PE_start, int - logPE_stride, int PE_size, - T *pWrk, long *pSync); + int nelems, int PE_start, int logPE_stride, + int PE_size, T *pWrk, long *pSync); + template + __device__ void internal_ring_allreduce(T *dst, const T *src, + int nelems, int PE_start, int logPE_stride, + int PE_size, T *pWrk, long *pSync, + int n_seg, int seg_size, int chunk_size); //internal functions used by collectives routines to write/read to //work/sync buffers diff --git a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp index 33bb91a1f1..81e669d942 100644 --- a/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp +++ b/projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp @@ -219,9 +219,78 @@ __device__ void IPCContext::internal_direct_allreduce( __syncthreads(); } +template +__device__ void IPCContext::internal_ring_allreduce( + T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start, + [[maybe_unused]] int logPE_stride, [[maybe_unused]] int PE_size, T *pWrk, + long *pSync, // NOLINT(runtime/int) + int n_seg, int seg_size, int chunk_size) { + int off_seg, off_send, off_recv; + int send_pe = (my_pe + 1) % num_pes; + long wait_val; // NOLINT(runtime/int) + + int wg_size = get_flat_block_size(); + int wg_id = get_flat_block_id(); + + for (size_t i = wg_id; i < nelems; i += wg_size) { + dst[i] = src[i]; + } + __syncthreads(); + + for (size_t seg = 0; seg < n_seg; seg++) { + off_seg = seg * seg_size; + for (int iter = 0; iter < num_pes - 1; iter++) { + off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size); + off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size); + + putmem_nbi_wg(reinterpret_cast(&pWrk[off_send]), + reinterpret_cast(&dst[off_send + off_seg]), + chunk_size * sizeof(T), send_pe); + + if (is_thread_zero_in_block()) { + fence(); + + wait_val = seg + 100; + put_nbi(&pSync[iter], &wait_val, 1, send_pe); +#if defined(__gfx90a__) + __threadfence_system(); +#endif /* __gfx90a__ */ + wait_until(&pSync[iter], ROC_SHMEM_CMP_EQ, wait_val); + } + __syncthreads(); + compute_reduce(&pWrk[off_recv], &dst[off_seg + off_recv], + chunk_size, wg_id, wg_size); + } + + for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) { + off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size); + putmem_nbi_wg(reinterpret_cast(&dst[off_send + off_seg]), + reinterpret_cast(&dst[off_send + off_seg]), + chunk_size * sizeof(T), send_pe); + + if (is_thread_zero_in_block()) { + fence(); + wait_val = seg + 100; + put_nbi(&pSync[iter], &wait_val, 1, send_pe); +#if defined(__gfx90a__) + __threadfence_system(); +#endif /* __gfx90a__ */ + wait_until(&pSync[iter], ROC_SHMEM_CMP_EQ, wait_val); + } + __syncthreads(); + } + } + __syncthreads(); + + for (size_t i = wg_id; i < 2 * num_pes - 2; i += wg_size) { + pSync[i] = ROC_SHMEM_SYNC_VALUE; + } + __syncthreads(); +} + template __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest, - const T *source, int nreduce) { + const T *source, int nreduce) { IPCTeam *team_obj = reinterpret_cast(team); /** @@ -239,11 +308,44 @@ __device__ void IPCContext::to_all(roc_shmem_team_t team, T *dest, template __device__ void IPCContext::to_all(T *dest, const T *source, int nreduce, - int PE_start, int logPE_stride, - int PE_size, T *pWrk, - long *pSync) { // NOLINT(runtime/int) - internal_direct_allreduce(dest, source, nreduce, PE_start, logPE_stride, - PE_size, pWrk, pSync); + int PE_start, int logPE_stride, + int PE_size, T *pWrk, + long *pSync) { // NOLINT(runtime/int) + size_t direct_pWrk = num_pes * nreduce; + size_t direct_pSync = num_pes; + size_t ring_pSync = 2 * num_pes; + size_t provided_pWrk = max(nreduce / 2 + 1, ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE); + size_t provided_pSync = ROC_SHMEM_REDUCE_SYNC_SIZE; + + if (provided_pWrk >= direct_pWrk && provided_pSync >= direct_pSync) { + internal_direct_allreduce(dest, source, nreduce, PE_start, logPE_stride, + PE_size, pWrk, pSync); + } else { + /* TODO (Edgar): some nreduce values cannot be evenly divided + ** among num_pes and/or segments. THe algorithm + ** currently cannot handle that one segment is of + ** different size than other segments. + */ + if (ring_pSync <= ROC_SHMEM_REDUCE_SYNC_SIZE) { + size_t ring_pWrk = ROC_SHMEM_REDUCE_MIN_WRKDATA_SIZE; + // integer division truncating value + int chunk_size = ring_pWrk / num_pes; + int seg_size = chunk_size * num_pes; + + // integer division rounding up + int n_seg = (nreduce + (seg_size -1)) / seg_size; + // recalculate chunk_size + chunk_size = seg_size / num_pes; + if (n_seg == 0) { + n_seg = 1; + } + internal_ring_allreduce(dest, source, nreduce, PE_start, + logPE_stride, PE_size, pWrk, pSync, n_seg, + seg_size, chunk_size); + } else { + GPU_DPRINTF("Unsupported reduction size for IPC conduit.\n"); + } + } } template