add ascii art for ring allredude

This commit is contained in:
Edgar Gabriel
2024-10-23 19:23:53 +00:00
والد a4b4281f50
کامیت 11df5427a6
@@ -219,6 +219,62 @@ __device__ void IPCContext::internal_direct_allreduce(
__syncthreads();
}
/*
* Visual representation of the ring_allreduce algorithm below
* assuming 4 PEs and a single segment.
*
* Initial state
* PE# 0 1 2 3
* [00] [10] [20] [30]
* [01] [11] [21] [31]
* [02] [12] [22] [32]
* [03] [13] [23] [33]
*
* Loop 1:
* iter 0
* PE# 0 1 2 3
* [00+30] [10] [20] [30]
* [01] [01+11] [21] [31]
* [02] [12] [12+22] [32]
* [03] [13] [23] [23+33]
*
* iter 1
* PE# 0 1 2 3
* [00+30] [00+10+30] [20] [30]
* [01] [01+11] [01+11+21] [31]
* [02] [12] [12+22] [12+22+32]
* [03+23+33] [13] [23] [23+33]
*
* iter 2
* PE# 0 1 2 3
* [00+30] [00+10+30] [00+10+20+30] [30]
* [01] [01+11] [01+11+21] [01+11+21+31]
* [02+12+22+32] [12] [12+22] [12+22+32]
* [03+23+33] [03+13+23+33] [23] [23+33]
*
* Loop 2:
*
* iter 3
* PE# 0 1 2 3
* [00+30] [00+10+30] [00+10+20+30] [00+10+20+30]
* [01+11+21+31] [01+11] [01+11+21] [01+11+21+31]
* [02+12+22+32] [02+12+22+32] [12+22] [12+22+32]
* [03+23+33] [03+13+23+33] [03+13+23+33] [23+33]
*
* iter 4
* PE# 0 1 2 3
* [00+10+20+30] [00+10+30] [00+10+20+30] [00+10+20+30]
* [01+11+21+31] [01+11+21+31] [01+11+21] [01+11+21+31]
* [02+12+22+32] [02+12+22+32] [02+12+22+32] [12+22+32]
* [03+23+33] [03+13+23+33] [03+13+23+33] [03+13+23+33]
*
* iter 5
* PE# 0 1 2 3
* [00+10+20+30] [00+10+20+30] [00+10+20+30] [00+10+20+30]
* [01+11+21+31] [01+11+21+31] [01+11+21+31] [01+11+21+31]
* [02+12+22+32] [02+12+22+32] [02+12+22+32] [02+12+22+32]
* [03+13+23+33] [03+13+23+33] [03+13+23+33] [03+13+23+33]
*/
template <typename T, ROC_SHMEM_OP Op>
__device__ void IPCContext::internal_ring_allreduce(
T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start,
@@ -239,6 +295,7 @@ __device__ void IPCContext::internal_ring_allreduce(
for (size_t seg = 0; seg < n_seg; seg++) {
off_seg = seg * seg_size;
// Loop 2 in the algorithm above
for (int iter = 0; iter < num_pes - 1; iter++) {
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size);
@@ -262,6 +319,7 @@ __device__ void IPCContext::internal_ring_allreduce(
chunk_size, wg_id, wg_size);
}
// Loop 2 in the example above
for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) {
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
putmem_nbi_wg(reinterpret_cast<void *>(&dst[off_send + off_seg]),