Comhaid
rocm-systems/src/device/reduce_scatter.h
T
Sylvain Jeaugey b6475625fb 2.20.3-1
Add support for alternating rings, allow for cross-nic rings without
cross-rail communication.
Add support for user buffer registration for network send/recv.
Optimize aggregated operations to better utilize all channels.
Add flattening for BCM PCI gen5 switches.
Add support for inter-node NVLink communication
Add support for port fusion in NET/IB.
Add support for ReduceScatter and AllGather using Collnet.
Update net API to v8.
Fix hang during A2A connection.
2024-02-13 04:22:38 -08:00

299 línte
12 KiB
C++

/*************************************************************************
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "device.h"
#include "collectives.h"
#include "primitives.h"
namespace {
template<typename T, typename RedOp, typename Proto>
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
const int tid = threadIdx.x;
const uint32_t nthreads = (uint32_t)args->nWarps * WARP_SIZE;
ncclRing *ring = &ncclShmem.channel.ring;
int const *ringRanks = ring->userRanks;
const size_t chunkCount = args->chunkCount;
const int nranks = ncclShmem.comm.nRanks;
size_t channelCount = args->workCount;
size_t gridOffset = args->workOffset;
size_t offset;
size_t dataOffset;
size_t count = args->count;
uint32_t nelem;
int rankDest;
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->redOpArg);
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
nelem = min(chunkCount, channelCount - elemOffset);
dataOffset = gridOffset + elemOffset;
/////////////// begin ReduceScatter steps ///////////////
// step 0: push data to next GPU
rankDest = ringRanks[nranks-1];
offset = dataOffset + rankDest * count;
prims.send(offset, nelem);
// k-2 steps: reduce and copy to next GPU
for (int j=2; j<nranks; ++j) {
rankDest = ringRanks[nranks-j];
offset = dataOffset + rankDest * count;
prims.recvReduceSend(offset, nelem);
}
// step k-1: reduce this buffer and data, which will produce the final result
rankDest = ringRanks[0];
offset = dataOffset + rankDest * count;
prims.recvReduceCopy(offset, dataOffset, nelem, /*postOp=*/true);
}
}
}
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
runRing<T, RedOp, Proto>(args);
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
runRing<T, RedOp, ProtoLL>(args);
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
runRing<T, RedOp, ProtoLL128>(args);
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(ncclWorkElem *args) {
const int tid = threadIdx.x;
struct ncclNvls* nvls = &ncclShmem.channel.nvls;
const size_t chunkCount = args->chunkCount;
const size_t count = args->count;
const int rank = ncclShmem.comm.rank;
const int nranks = ncclShmem.comm.nRanks;
size_t gridOffset = args->workOffset;
size_t channelCount = args->workCount;
size_t offset;
int nelem;
/* if we are direct NVLS, we only need to allocate 1 warp to scatter for sync;
* if not, based on #ranks, we allocate 7 or 5 warps to reduce to saturate bandwidth
* and the rest are allocated to scatter. */
const int nThreadsReduce = args->regUsed ? (NCCL_MAX_NTHREADS - WARP_SIZE) : (nranks <= 6 ? 7 * WARP_SIZE : 5 * WARP_SIZE);
const int nThreadsScatter = args->regUsed ? WARP_SIZE : (NCCL_MAX_NTHREADS - nThreadsReduce);
const int tidEndScatter = nThreadsScatter;
const int tidEndReduce = tidEndScatter + nThreadsReduce;
if (!args->regUsed) {
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL,
args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
offset = gridOffset + elemOffset;
nelem = min(chunkCount, channelCount - elemOffset);
prims.scatter(offset, nvls->nHeads * count, nelem, count, -1, 0);
}
} else if (tid < tidEndReduce) {
// Reduce through NVLS
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, args->recvbuff,
args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0);
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
offset = gridOffset + elemOffset;
nelem = min(chunkCount, channelCount - elemOffset);
prims.recv(offset, nelem);
}
}
} else {
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, nvls->up, nvls->up, NULL, NULL,
args->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
prims.scatter(0, 0, 0, 0, -1, 0);
}
/* gather used as sync */
prims.gather(0, 0, 0, 0, -1, 0);
} else if (tid < tidEndReduce) {
// Reduce through NVLS
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/1, Proto, 0>
prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, &nvls->down, NULL, args->recvbuff,
args->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, args);
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
size_t outOffset = gridOffset + elemOffset;
size_t inpOffset = outOffset + rank * count;
nelem = min(chunkCount, channelCount - elemOffset);
prims.directRecvCopy(inpOffset, outOffset, nelem);
}
/* send for sync */
prims.send(0, 0);
}
}
}
};
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCCL_PROTO_SIMPLE> {
template<bool ReduceSendNotRecv>
struct Scatterer {
struct ncclWorkElem* args;
int chunkSize;
ssize_t railGridOffset;
template<int SlicePerChunk, int MinSrcs, int MaxSrcs, int MinDsts, int MaxDsts>
__device__ __forceinline__ void operator()(
int tid, int tn, int slice, int maxSliceSize,
int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs, int32_t* dstSizes
) {
static_assert(SlicePerChunk==1, "require: SlicePerChunk==1");
static_assert(MaxDsts<=1 || MaxSrcs<=1, "require: MaxDsts<=1 || MaxSrcs<=1");
struct ncclDirect* direct = &ncclShmem.channel.collnetDirect;
int nNodes = ncclShmem.comm.nNodes;
int nRails = direct->nHeads;
int bid = args->bid;
void* inbuf = (void*)args->sendbuff;
ssize_t sizePerRank = args->count;
ssize_t railAllBeg = min(railGridOffset + bid*chunkSize, nNodes*sizePerRank);
ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank);
int railAllSize = railAllEnd - railAllBeg;
if (tid < nDsts) dstSizes[tid] = railAllSize;
int dst = 0;
int rail;
if (!ReduceSendNotRecv) {
rail = direct->headRank;
} else {
rail = direct->headRank+1;
if (rail == nRails) rail = 0;
}
do {
int node = railAllBeg/sizePerRank;
int railAllOffset = 0;
while (railAllOffset < railAllSize) {
ssize_t railOneBeg = node*sizePerRank;
ssize_t railOneEnd = railOneBeg + sizePerRank;
ssize_t railOneOffset = (railAllBeg+railAllOffset) - railOneBeg;
int delta = min(railAllEnd, railOneEnd) - (railAllBeg+railAllOffset);
int rank = ncclShmem.comm.collNetDenseToUserRank[node*nRails + rail];
ssize_t userOneBeg = rank*sizePerRank + railOneOffset;
reduceCopy<ncclCollUnroll(), RedOp, T,
/*MultimemSrcs=*/0, 1+MinSrcs, 1+MaxSrcs,
/*MultimemDsts,MinDsts,MaxDsts=*/0,1,1,
/*PreOpSrcs=*/1>
(tid, tn, args->redOpArg, &args->redOpArg, false,
/*nSrcs=*/1+nSrcs, [=]__device__(int s) {
return s==0 ? (T*)inbuf + userOneBeg
: (T*)srcPtrs[s-1] + railAllOffset;
},
/*nDsts=*/1, [=]__device__(int d/*==0*/) {
return (T*)dstPtrs[dst] + railAllOffset;
},
delta);
railAllOffset += delta;
node += 1;
}
dst += 1;
rail += 1;
if (rail == nRails) rail = 0;
} while (ReduceSendNotRecv && dst < nRails-1);
}
};
__device__ __forceinline__ void run(ncclWorkElem *args) {
int tid = threadIdx.x;
const int nChannels = args->nChannels;
struct ncclDirect* direct = &ncclShmem.channel.collnetDirect;
int const &nNodes = ncclShmem.comm.nNodes;
ssize_t chunkSize = int(args->chunkCount);
ssize_t sizePerRank = args->count;
if (direct->out == -1) __trap();
bool isMultiRail = (direct->nHeads > 1);
int nWarps1 = (isMultiRail ? 2 : 0);
int nWarps2 = (isMultiRail ? 2 : 1);
int nWarps3 = 1;
float denom = float(args->nWarps)/float(nWarps1+nWarps2+nWarps3);
nWarps3 = int(denom*nWarps3);
nWarps2 = int(denom*nWarps2);
nWarps1 = args->nWarps - (nWarps2+nWarps3);
using Proto = ProtoSimple<1, 1>;
int tn = nWarps1*WARP_SIZE;
if (tid < tn) {
// Phase 1: Scatter inputs to peers
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, tn, nullptr, direct->heads+1, nullptr, nullptr,
args->redOpArg, 0*Proto::MaxGroupWidth, 1, 1);
for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) {
Scatterer</*ReduceSendNotRecv=*/true> scat;
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
prims.process</*Recv=*/0, /*Send=*/1>(scat);
}
return;
}
tid -= tn;
tn = nWarps2*WARP_SIZE;
if (tid < tn) {
// Phase 2: Reduce from peers + local input -> send to network
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/0, Proto, 0>
prims(tid, tn, direct->heads+1, &direct->out, nullptr, nullptr,
args->redOpArg, 1*Proto::MaxGroupWidth, 1, 1);
for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) {
Scatterer</*ReduceSendNotRecv=*/false> scat;
scat.args = args;
scat.chunkSize = chunkSize;
scat.railGridOffset = railGridOffset;
prims.process</*Recv=*/1, /*Send=*/1>(scat);
}
return;
}
tid -= tn;
tn = nWarps3*WARP_SIZE;
if (tid < tn) {
// Phase 3: recv from network
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
prims(tid, tn, &direct->out, nullptr, nullptr, args->recvbuff,
args->redOpArg, 2*Proto::MaxGroupWidth, 0, 0);
for (ssize_t railGridOffset=0; railGridOffset < nNodes*sizePerRank; railGridOffset += nChannels*chunkSize) {
ssize_t railAllBeg = railGridOffset + args->bid*chunkSize;
ssize_t railAllEnd = min(railAllBeg + chunkSize, nNodes*sizePerRank);
ssize_t railOneBeg = ncclShmem.comm.node*sizePerRank;
ssize_t railOneEnd = railOneBeg + sizePerRank;
ssize_t beg = max(railAllBeg, railOneBeg);
ssize_t end = min(railAllEnd, railOneEnd);
prims.recv(beg-railOneBeg, max(ssize_t(0), end-beg), /*postOp=*/true);
}
return;
}
}
};