add ascii art for ring allredude
This commit is contained in:
@@ -219,6 +219,62 @@ __device__ void IPCContext::internal_direct_allreduce(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/*
|
||||
* Visual representation of the ring_allreduce algorithm below
|
||||
* assuming 4 PEs and a single segment.
|
||||
*
|
||||
* Initial state
|
||||
* PE# 0 1 2 3
|
||||
* [00] [10] [20] [30]
|
||||
* [01] [11] [21] [31]
|
||||
* [02] [12] [22] [32]
|
||||
* [03] [13] [23] [33]
|
||||
*
|
||||
* Loop 1:
|
||||
* iter 0
|
||||
* PE# 0 1 2 3
|
||||
* [00+30] [10] [20] [30]
|
||||
* [01] [01+11] [21] [31]
|
||||
* [02] [12] [12+22] [32]
|
||||
* [03] [13] [23] [23+33]
|
||||
*
|
||||
* iter 1
|
||||
* PE# 0 1 2 3
|
||||
* [00+30] [00+10+30] [20] [30]
|
||||
* [01] [01+11] [01+11+21] [31]
|
||||
* [02] [12] [12+22] [12+22+32]
|
||||
* [03+23+33] [13] [23] [23+33]
|
||||
*
|
||||
* iter 2
|
||||
* PE# 0 1 2 3
|
||||
* [00+30] [00+10+30] [00+10+20+30] [30]
|
||||
* [01] [01+11] [01+11+21] [01+11+21+31]
|
||||
* [02+12+22+32] [12] [12+22] [12+22+32]
|
||||
* [03+23+33] [03+13+23+33] [23] [23+33]
|
||||
*
|
||||
* Loop 2:
|
||||
*
|
||||
* iter 3
|
||||
* PE# 0 1 2 3
|
||||
* [00+30] [00+10+30] [00+10+20+30] [00+10+20+30]
|
||||
* [01+11+21+31] [01+11] [01+11+21] [01+11+21+31]
|
||||
* [02+12+22+32] [02+12+22+32] [12+22] [12+22+32]
|
||||
* [03+23+33] [03+13+23+33] [03+13+23+33] [23+33]
|
||||
*
|
||||
* iter 4
|
||||
* PE# 0 1 2 3
|
||||
* [00+10+20+30] [00+10+30] [00+10+20+30] [00+10+20+30]
|
||||
* [01+11+21+31] [01+11+21+31] [01+11+21] [01+11+21+31]
|
||||
* [02+12+22+32] [02+12+22+32] [02+12+22+32] [12+22+32]
|
||||
* [03+23+33] [03+13+23+33] [03+13+23+33] [03+13+23+33]
|
||||
*
|
||||
* iter 5
|
||||
* PE# 0 1 2 3
|
||||
* [00+10+20+30] [00+10+20+30] [00+10+20+30] [00+10+20+30]
|
||||
* [01+11+21+31] [01+11+21+31] [01+11+21+31] [01+11+21+31]
|
||||
* [02+12+22+32] [02+12+22+32] [02+12+22+32] [02+12+22+32]
|
||||
* [03+13+23+33] [03+13+23+33] [03+13+23+33] [03+13+23+33]
|
||||
*/
|
||||
template <typename T, ROC_SHMEM_OP Op>
|
||||
__device__ void IPCContext::internal_ring_allreduce(
|
||||
T *dst, const T *src, int nelems, [[maybe_unused]] int PE_start,
|
||||
@@ -239,6 +295,7 @@ __device__ void IPCContext::internal_ring_allreduce(
|
||||
|
||||
for (size_t seg = 0; seg < n_seg; seg++) {
|
||||
off_seg = seg * seg_size;
|
||||
// Loop 2 in the algorithm above
|
||||
for (int iter = 0; iter < num_pes - 1; iter++) {
|
||||
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
off_recv = (((my_pe - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
@@ -262,6 +319,7 @@ __device__ void IPCContext::internal_ring_allreduce(
|
||||
chunk_size, wg_id, wg_size);
|
||||
}
|
||||
|
||||
// Loop 2 in the example above
|
||||
for (size_t iter = num_pes - 1; iter < 2 * num_pes - 2; iter++) {
|
||||
off_send = (((my_pe + 1 - iter + 2 * num_pes) % num_pes) * chunk_size);
|
||||
putmem_nbi_wg(reinterpret_cast<void *>(&dst[off_send + off_seg]),
|
||||
|
||||
مرجع در شماره جدید
Block a user