From 11df5427a6dfd5648d67226ad3757d0d709dcbf5 Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Wed, 23 Oct 2024 19:23:53 +0000 Subject: [PATCH] add ascii art for ring allredude --- src/ipc/context_ipc_tmpl_device.hpp | 58 +++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/ipc/context_ipc_tmpl_device.hpp b/src/ipc/context_ipc_tmpl_device.hpp index 57a9177c04..406367c72e 100644 --- a/src/ipc/context_ipc_tmpl_device.hpp +++ b/src/ipc/context_ipc_tmpl_device.hpp @@ -219,6 +219,62 @@ __device__ void IPCContext::internal_direct_allreduce( __syncthreads(); } +/* + * Visual representation of the ring_allreduce algorithm below + * assuming 4 PEs and a single segment. + * + * Initial state + * PE# 0 1 2 3 + * [00] [10] [20] [30] + * [01] [11] [21] [31] + * [02] [12] [22] [32] + * [03] [13] [23] [33] + * + * Loop 1: + * iter 0 + * PE# 0 1 2 3 + * [00+30] [10] [20] [30] + * [01] [01+11] [21] [31] + * [02] [12] [12+22] [32] + * [03] [13] [23] [23+33] + * + * iter 1 + * PE# 0 1 2 3 + * [00+30] [00+10+30] [20] [30] + * [01] [01+11] [01+11+21] [31] + * [02] [12] [12+22] [12+22+32] + * [03+23+33] [13] [23] [23+33] + * + * iter 2 + * PE# 0 1 2 3 + * [00+30] [00+10+30] [00+10+20+30] [30] + * [01] [01+11] [01+11+21] [01+11+21+31] + * [02+12+22+32] [12] [12+22] [12+22+32] + * [03+23+33] [03+13+23+33] [23] [23+33] + * + * Loop 2: + * + * iter 3 + * PE# 0 1 2 3 + * [00+30] [00+10+30] [00+10+20+30] [00+10+20+30] + * [01+11+21+31] [01+11] [01+11+21] [01+11+21+31] + * [02+12+22+32] [02+12+22+32] [12+22] [12+22+32] + * [03+23+33] [03+13+23+33] [03+13+23+33] [23+33] + * + * iter 4 + * PE# 0 1 2 3 + * [00+10+20+30] [00+10+30] [00+10+20+30] [00+10+20+30] + * [01+11+21+31] [01+11+21+31] [01+11+21] [01+11+21+31] + * [02+12+22+32] [02+12+22+32] [02+12+22+32] [12+22+32] + * [03+23+33] [03+13+23+33] [03+13+23+33] [03+13+23+33] + * + * iter 5 + * PE# 0 1 2 3 + * [00+10+20+30] [00+10+20+30] [00+10+20+30] [00+10+20+30] + * [01+11+21+31] [01+11+21+31] [01+11+21+31] [01+11+21+31] + * [02+12+22+32] [02+12+22+32] [02+12+22+32] [02+12+22+32] + * [03+13+23+33] [03+13+23+33] [03+13+23+33] [03+13+23+33] + */ template __device__ void IPCContext::internal_ring_allreduce( T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start, @@ -239,6 +295,7 @@ __device__ void IPCContext::internal_ring_allreduce( for (size_t seg = 0; seg < n_seg; seg++) { off_seg = seg * seg_size; + // Loop 2 in the algorithm above for (int iter = 0; iter < num_pes - 1; iter++) { off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size); off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size); @@ -262,6 +319,7 @@ __device__ void IPCContext::internal_ring_allreduce( chunk_size, wg_id, wg_size); } + // Loop 2 in the example above for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) { off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size); putmem_nbi_wg(reinterpret_cast(&dst[off_send + off_seg]),