531 rivejä
23 KiB
Diff
531 rivejä
23 KiB
Diff
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
|
|
index 4134241..76674ba 100644
|
|
--- a/apps/nccl/src/allreduce.hpp
|
|
+++ b/apps/nccl/src/allreduce.hpp
|
|
@@ -495,24 +495,345 @@ __global__ void __launch_bounds__(512, 1)
|
|
}
|
|
}
|
|
|
|
+template <typename T>
|
|
+__global__ void __launch_bounds__(512, 1)
|
|
+ allreduce8Read(T* buff, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
|
|
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
|
|
+ int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
|
+ const int nPeer = nRanksPerNode - 1;
|
|
+ const size_t chanOffset = nPeer * blockIdx.x;
|
|
+ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
|
|
+ const size_t nInt4 = nelems * sizeof(T) / sizeof(int4);
|
|
+ const size_t nInt4PerRank = nInt4 / worldSize;
|
|
+ auto smChans = smChannels + chanOffset;
|
|
+ auto smOutChans = smOutChannels + chanOffset;
|
|
+
|
|
+ int4* buff4 = reinterpret_cast<int4*>(buff);
|
|
+ int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);
|
|
+
|
|
+ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
|
|
+ constexpr size_t unitNInt4 = 512;
|
|
+ const size_t maxNInt4PerBlock =
|
|
+ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
|
|
+ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
|
|
+ size_t nInt4OfThisBlock = maxNInt4PerBlock;
|
|
+ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
|
|
+ constexpr size_t nInt4PerChunk = 1024 * 512 / sizeof(int4); // 512KB
|
|
+ if (blockIdx.x >= nNeededBlocks) {
|
|
+ nInt4OfThisBlock = 0;
|
|
+ } else if (blockIdx.x == nNeededBlocks - 1) {
|
|
+ nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1);
|
|
+ }
|
|
+
|
|
+ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk;
|
|
+ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk;
|
|
+
|
|
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
|
|
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
|
|
+ const int lid = threadIdx.x % WARP_SIZE;
|
|
+ if (lid < nPeer) {
|
|
+ channels[lid] = smChans[lid];
|
|
+ outChannels[lid] = smOutChans[lid];
|
|
+ }
|
|
+ __syncwarp();
|
|
+
|
|
+ // we can use double buffering to hide synchronization overhead
|
|
+ for (size_t itr = 0; itr < nItrs; itr++) {
|
|
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
|
+ channels[threadIdx.x].signal();
|
|
+ channels[threadIdx.x].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
|
|
+ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
|
|
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
|
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
|
|
+ data = add_vectors<T>(val, data);
|
|
+ }
|
|
+ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
|
|
+
|
|
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
|
+ outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
|
|
+ data);
|
|
+ }
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ offsetOfThisBlock += nInt4PerChunk;
|
|
+ }
|
|
+
|
|
+ if (restNInt4 > 0) {
|
|
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
|
+ channels[threadIdx.x].signal();
|
|
+ channels[threadIdx.x].wait();
|
|
+
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
|
|
+ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
|
|
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
|
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
|
|
+ data = add_vectors<T>(val, data);
|
|
+ }
|
|
+ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
|
|
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
|
+ outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
|
|
+ data);
|
|
+ }
|
|
+ }
|
|
+ __syncthreads();
|
|
+ }
|
|
+
|
|
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
|
+ outChannels[threadIdx.x].signal();
|
|
+ outChannels[threadIdx.x].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+}
|
|
+
|
|
+template <typename T>
|
|
+__global__ void __launch_bounds__(1024, 1)
|
|
+ allreduce10(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
|
|
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
|
|
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
|
|
+ size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
|
+ const int nPeer = nRanksPerNode - 1;
|
|
+ const size_t chanOffset = nPeer * blockIdx.x;
|
|
+ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
|
|
+ const size_t nInt4 = nelems * sizeof(T) / sizeof(int4);
|
|
+ const size_t nInt4PerRank = nInt4 / NRANKS1_PER_NODE;
|
|
+
|
|
+ auto smChans = smChannels + chanOffset;
|
|
+ auto smOutChans = smOutChannels + chanOffset;
|
|
+ auto smScrChans = smScrChannels + chanOffset;
|
|
+
|
|
+ int4* buff4 = reinterpret_cast<int4*>(buff);
|
|
+ int4* scratch4 = reinterpret_cast<int4*>((char*)scratch + channelScratchOffset);
|
|
+ int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);
|
|
+
|
|
+ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
|
|
+ constexpr size_t unitNInt4 = 512;
|
|
+ const size_t maxNInt4PerBlock =
|
|
+ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
|
|
+ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
|
|
+ size_t nInt4OfThisBlock = maxNInt4PerBlock;
|
|
+ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
|
|
+
|
|
+ constexpr size_t nInt4PerChunk = 1024 * 1024 / sizeof(int4); // 256KB
|
|
+ int num_nodes = worldSize/NRANKS1_PER_NODE;
|
|
+
|
|
+ if (blockIdx.x >= nNeededBlocks) {
|
|
+ nInt4OfThisBlock = 0;
|
|
+ } else if (blockIdx.x == nNeededBlocks - 1) {
|
|
+ nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1);
|
|
+ }
|
|
+
|
|
+ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk;
|
|
+ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk;
|
|
+
|
|
+ const size_t blockOffset = nInt4PerChunk * blockIdx.x;
|
|
+
|
|
+ int localRank = rank % NRANKS1_PER_NODE;
|
|
+
|
|
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
|
|
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
|
|
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> scrChannels[NRANKS_PER_NODE - 1];
|
|
+
|
|
+ const int lid = threadIdx.x % WARP_SIZE;
|
|
+ if (lid < nPeer) {
|
|
+ channels[lid] = smChans[lid];
|
|
+ outChannels[lid] = smOutChans[lid];
|
|
+ scrChannels[lid] = smScrChans[lid];
|
|
+ }
|
|
+ __syncwarp();
|
|
+
|
|
+ // we can use double buffering to hide synchronization overhead
|
|
+ for (size_t itr = 0; itr < nItrs; itr++) {
|
|
+ if (threadIdx.x < (NRANKS1_PER_NODE-1)) {
|
|
+ int myNode = rank/NRANKS1_PER_NODE;
|
|
+ int remote = (threadIdx.x + 1 + rank);
|
|
+ int remoteNode = remote/NRANKS1_PER_NODE;
|
|
+
|
|
+ if (remoteNode > myNode) {
|
|
+ remote = remote - NRANKS1_PER_NODE;
|
|
+ }
|
|
+ int peerIdx = remote < rank ? remote : remote - 1;
|
|
+ outChannels[peerIdx].signal();
|
|
+ outChannels[peerIdx].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ int myNode = rank/NRANKS1_PER_NODE;
|
|
+
|
|
+ //Reduce within an OAM
|
|
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
|
|
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
|
|
+ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode +
|
|
+ NRANKS1_PER_NODE - 1); peerIdx++) {
|
|
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * localRank + offsetOfThisBlock + idx);
|
|
+ data = add_vectors<T>(val, data);
|
|
+ }
|
|
+ scratch4[idx + blockOffset] = data;
|
|
+ }
|
|
+
|
|
+ if (threadIdx.x < static_cast<uint32_t>(num_nodes-1)) {
|
|
+ int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize;
|
|
+ int peerIdx = remote < rank ? remote : remote - 1;
|
|
+ scrChannels[peerIdx].signal();
|
|
+ scrChannels[peerIdx].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ //Reduce across OAMs
|
|
+
|
|
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
|
|
+ int4 data = scratch4[idx + blockOffset];
|
|
+
|
|
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
|
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
|
+ int myLocal = rank % NRANKS1_PER_NODE;
|
|
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
|
|
+
|
|
+ if (myLocal == remoteLocal) {
|
|
+ int4 val = scrChannels[peerIdx].read<int4>(blockOffset + idx +
|
|
+ channelScratchOffset/sizeof(int4));
|
|
+ data = add_vectors<T>(val, data);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
|
|
+
|
|
+ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + NRANKS1_PER_NODE - 1); peerIdx++) {
|
|
+ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock +
|
|
+ channelOutDataOffset / sizeof(int4), data);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
|
+ outChannels[threadIdx.x].signal();
|
|
+ outChannels[threadIdx.x].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ offsetOfThisBlock += nInt4PerChunk;
|
|
+ }
|
|
+
|
|
+ if (restNInt4 > 0) {
|
|
+ if (threadIdx.x < (NRANKS1_PER_NODE-1)) {
|
|
+ int myNode = rank/NRANKS1_PER_NODE;
|
|
+ int remote = (threadIdx.x + 1 + rank);
|
|
+ int remoteNode = remote/NRANKS1_PER_NODE;
|
|
+
|
|
+ if (remoteNode > myNode) {
|
|
+ remote = remote - NRANKS1_PER_NODE;
|
|
+ }
|
|
+ int peerIdx = remote < rank ? remote : remote - 1;
|
|
+
|
|
+ outChannels[peerIdx].signal();
|
|
+ outChannels[peerIdx].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
|
|
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
|
|
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
|
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
|
+
|
|
+ int myNode = rank/NRANKS1_PER_NODE;
|
|
+ int remoteNode = remoteRank/NRANKS1_PER_NODE;
|
|
+
|
|
+ if (myNode == remoteNode) {
|
|
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * localRank + offsetOfThisBlock + idx);
|
|
+ data = add_vectors<T>(val, data);
|
|
+ }
|
|
+ }
|
|
+ scratch4[idx + blockOffset] = data;
|
|
+ }
|
|
+
|
|
+ if (threadIdx.x < static_cast<uint32_t>(num_nodes-1)) {
|
|
+ int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize;
|
|
+ int peerIdx = remote < rank ? remote : remote - 1;
|
|
+ scrChannels[peerIdx].signal();
|
|
+ scrChannels[peerIdx].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
|
|
+ int4 data = scratch4[idx + blockOffset];
|
|
+
|
|
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
|
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
|
+ int myLocal = rank % NRANKS1_PER_NODE;
|
|
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
|
|
+
|
|
+ if (myLocal == remoteLocal) {
|
|
+ int4 val = scrChannels[peerIdx].read<int4>(blockOffset + idx +
|
|
+ channelScratchOffset/sizeof(int4));
|
|
+ data = add_vectors<T>(val, data);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
|
|
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
|
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
|
+ int myNode = rank/NRANKS1_PER_NODE;
|
|
+ int remoteNode = remoteRank/NRANKS1_PER_NODE;
|
|
+
|
|
+ if (myNode == remoteNode) {
|
|
+ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock +
|
|
+ channelOutDataOffset / sizeof(int4), data);
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
|
+ outChannels[threadIdx.x].signal();
|
|
+ outChannels[threadIdx.x].wait();
|
|
+ }
|
|
+ __syncthreads();
|
|
+
|
|
+ }
|
|
+
|
|
+}
|
|
+
|
|
template <typename T>
|
|
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
|
|
- mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
|
|
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
|
|
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
|
|
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
|
|
size_t nelems, cudaStream_t stream) {
|
|
static uint32_t flag = 1;
|
|
+ int readAllred = 0, hieAllred = 0;
|
|
+ char* envValue = nullptr;
|
|
+ char* envValue1 = nullptr;
|
|
|
|
+ nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode;
|
|
+
|
|
+ envValue = std::getenv("MSCCLPP_READ_ALLRED");
|
|
+ envValue1 = std::getenv("MSCCLPP_HIERARCHICAL_ALLRED");
|
|
+
|
|
+ if (envValue != nullptr) {
|
|
+ if (atoi(envValue) == 1) {
|
|
+ readAllred = 1;
|
|
+ }
|
|
+ }
|
|
+ if (envValue1 != nullptr) {
|
|
+ if (atoi(envValue1) == 1) {
|
|
+ hieAllred = 1;
|
|
+ }
|
|
+ }
|
|
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
|
|
- int nBlocks = 7;
|
|
+ int nBlocks = nRanksPerNode - 1;
|
|
int nThreadsPerBlock = 32;
|
|
- allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
|
|
- channelScratchOffset, rank, nRanksPerNode, worldSize,
|
|
- nelems, flag++);
|
|
+ allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
|
|
+ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++);
|
|
} else if (sizeof(T) * nelems <= (1 << 20)) {
|
|
- int nBlocks = 28;
|
|
+ int nBlocks = 4*(nRanksPerNode - 1);
|
|
int nThreadsPerBlock = 1024;
|
|
if (nelems >= 8192) {
|
|
- nBlocks = 56;
|
|
+ nBlocks = 8*(nRanksPerNode - 1);
|
|
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
|
|
}
|
|
#if defined(ENABLE_NPKIT)
|
|
@@ -526,11 +847,23 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
|
|
flag++);
|
|
#endif
|
|
} else {
|
|
- int nBlocks = 35;
|
|
+ int nBlocks = 5*(nRanksPerNode - 1);
|
|
int nThreadsPerBlock = 512;
|
|
- allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
|
|
- channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
|
- worldSize, nelems);
|
|
+ if (hieAllred && worldSize >= 8) {
|
|
+ nBlocks = 20;
|
|
+ allreduce10<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smScrChannels,
|
|
+ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
|
+ worldSize, nelems);
|
|
+ } else {
|
|
+ if (!readAllred) {
|
|
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
|
|
+ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
|
+ worldSize, nelems);
|
|
+ } else {
|
|
+ allreduce8Read<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels,
|
|
+ channelOutOffset, rank, nRanksPerNode, worldSize, nelems);
|
|
+ }
|
|
+ }
|
|
}
|
|
|
|
return cudaGetLastError();
|
|
diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp
|
|
index 015e0a2..ca2c272 100644
|
|
--- a/apps/nccl/src/common.hpp
|
|
+++ b/apps/nccl/src/common.hpp
|
|
@@ -13,6 +13,7 @@
|
|
#define WARP_SIZE 32
|
|
#endif
|
|
|
|
+constexpr int NRANKS1_PER_NODE = 4;
|
|
constexpr int NRANKS_PER_NODE = 8;
|
|
constexpr int NPEERS = 7;
|
|
|
|
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
|
|
index f91d15e..022d398 100644
|
|
--- a/apps/nccl/src/nccl.cu
|
|
+++ b/apps/nccl/src/nccl.cu
|
|
@@ -70,7 +70,9 @@ struct hash<channelKey> {
|
|
|
|
struct ChannelInfo {
|
|
std::vector<mscclpp::SmChannel> smChannels;
|
|
+ std::vector<mscclpp::SmChannel> smChannels1;
|
|
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
|
|
+ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles1;
|
|
};
|
|
|
|
struct ncclComm {
|
|
@@ -213,6 +215,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
|
channelKey recvKey{(void*)recvBasePtr, recvBytes};
|
|
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
|
|
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;
|
|
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels = nullptr;
|
|
|
|
// Creating the channels
|
|
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
|
|
@@ -220,19 +223,24 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
|
if (sendIt == comm->channelScratchInfos.end()) {
|
|
std::vector<mscclpp::SmChannel> channels =
|
|
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
|
|
- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
|
|
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
|
|
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
|
|
}
|
|
|
|
smChannels = sendIt->second.smChannelDeviceHandles.get();
|
|
} else {
|
|
std::vector<mscclpp::RegisteredMemory> remoteMemories;
|
|
-
|
|
+ std::vector<mscclpp::RegisteredMemory> remoteMemories1;
|
|
auto sendIt = comm->channelInInfos.find(sendKey);
|
|
if (sendIt == comm->channelInInfos.end()) {
|
|
std::vector<mscclpp::SmChannel> channels =
|
|
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
|
|
- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
|
|
+ remoteMemories1 =
|
|
+ setupRemoteMemories(comm->comm, rank, (void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
|
|
+ std::vector<mscclpp::SmChannel> channels1 =
|
|
+ setupSmChannels(comm, remoteMemories1, const_cast<void*>((void*)sendBasePtr));
|
|
+
|
|
+ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)};
|
|
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
|
|
}
|
|
|
|
@@ -242,35 +250,36 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
|
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
|
|
std::vector<mscclpp::SmChannel> outChannels =
|
|
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
|
|
- ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
|
|
+ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)};
|
|
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
|
|
}
|
|
|
|
- smChannels = sendIt->second.smChannelDeviceHandles.get();
|
|
+ smChannels = sendIt->second.smChannelDeviceHandles1.get();
|
|
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
|
|
+ smScrChannels = sendIt->second.smChannelDeviceHandles.get();
|
|
}
|
|
|
|
switch (datatype) {
|
|
case ncclFloat16:
|
|
- CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
|
|
- offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
|
|
- comm->comm->bootstrap()->getNranks(), count, stream));
|
|
+ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels,
|
|
+ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
|
break;
|
|
case ncclFloat32:
|
|
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
|
|
- smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
|
|
+ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch,
|
|
+ comm->comm->bootstrap()->getRank(),
|
|
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
|
break;
|
|
case ncclBfloat16:
|
|
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
|
|
- smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
|
|
- comm->comm->bootstrap()->getNranks(), count, stream));
|
|
+ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank,
|
|
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
|
break;
|
|
case ncclInt32:
|
|
case ncclUint32:
|
|
- CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
|
|
- offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
|
|
- comm->comm->bootstrap()->getNranks(), count, stream));
|
|
+ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels,
|
|
+ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
|
|
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
|
break;
|
|
default:
|
|
WARN("datatype is invalid");
|
|
@@ -315,7 +324,7 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
|
|
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
|
|
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
|
|
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
|
|
- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
|
|
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
|
|
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
|
|
}
|
|
|
|
@@ -597,7 +606,7 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff
|
|
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
|
|
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
|
|
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
|
|
- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
|
|
+ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
|
|
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
|
|
}
|
|
|
|
@@ -805,16 +814,6 @@ NCCL_API ncclResult_t ncclGroupEnd() {
|
|
return ncclSuccess;
|
|
}
|
|
|
|
-NCCL_API ncclResult_t ncclCommRegister(const ncclComm_t, void*, size_t, void**) {
|
|
- // TODO: Implementation
|
|
- return ncclSuccess;
|
|
-}
|
|
-
|
|
-NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) {
|
|
- // TODO: Implementation
|
|
- return ncclSuccess;
|
|
-}
|
|
-
|
|
ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
|
|
if (ptr == nullptr || size == 0) {
|
|
WARN("ptr is nullptr or size is 0");
|