diff --git a/ext-src/read-allred.patch b/ext-src/read-allred.patch index 96135b55db..de632f6284 100644 --- a/ext-src/read-allred.patch +++ b/ext-src/read-allred.patch @@ -1,58 +1,8 @@ diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp -index 1b85136..a45345a 100644 +index 1b85136..a08f822 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp -@@ -319,7 +319,7 @@ __global__ void __launch_bounds__(512, 1) - __syncthreads(); - // Starts allgather - for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { -- for (int i = 0; i < nPeer; i++) { -+ for (int i = 0; i < NPEER; i++) { - const int peerIdx = (i + blockIdx.x) % nPeer; - const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; - int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; -@@ -336,13 +336,13 @@ __global__ void __launch_bounds__(512, 1) - - for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { - int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; -- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { - const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; - int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; - data = add_vectors(val, data); - } - resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; -- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { - outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), - data); - } -@@ -356,7 +356,7 @@ __global__ void __launch_bounds__(512, 1) - } - __syncthreads(); - for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { -- for (int i = 0; i < nPeer; i++) { -+ for (int i = 0; i < NPEER; i++) { - const int peerIdx = (i + blockIdx.x) % nPeer; - const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; - int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; -@@ -372,13 +372,13 @@ __global__ void __launch_bounds__(512, 1) - - for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { - int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; -- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { - const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; - int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; - data = add_vectors(val, data); - } - resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; -- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { - outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), - data); - } -@@ -386,19 +386,132 @@ __global__ void __launch_bounds__(512, 1) +@@ -386,24 +386,361 @@ __global__ void __launch_bounds__(512, 1) } } @@ -108,13 +58,13 @@ index 1b85136..a45345a 100644 + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); + } @@ -138,12 +88,12 @@ index 1b85136..a45345a 100644 + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); + } @@ -158,6 +108,220 @@ index 1b85136..a45345a 100644 + +} + ++template ++__global__ void __launch_bounds__(1024, 1) ++ allreduce10(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, ++ mscclpp::DeviceHandle* smScrChannels, ++ mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, ++ size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) { ++ const int nPeer = nRanksPerNode - 1; ++ const size_t chanOffset = nPeer * blockIdx.x; ++ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) ++ const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); ++ const size_t nInt4PerRank = nInt4 / NRANKS1_PER_NODE; ++ ++ auto smChans = smChannels + chanOffset; ++ auto smOutChans = smOutChannels + chanOffset; ++ auto smScrChans = smScrChannels + chanOffset; ++ ++ int4* buff4 = reinterpret_cast(buff); ++ int4* scratch4 = reinterpret_cast((char*)scratch + channelScratchOffset); ++ int4* resultBuff4 = reinterpret_cast(resultBuff); ++ ++ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` ++ constexpr size_t unitNInt4 = 512; ++ const size_t maxNInt4PerBlock = ++ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; ++ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; ++ size_t nInt4OfThisBlock = maxNInt4PerBlock; ++ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; ++ ++ constexpr size_t nInt4PerChunk = 1024 * 1024 / sizeof(int4); // 256KB ++ int num_nodes = worldSize/NRANKS1_PER_NODE; ++ ++ if (blockIdx.x >= nNeededBlocks) { ++ nInt4OfThisBlock = 0; ++ } else if (blockIdx.x == nNeededBlocks - 1) { ++ nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); ++ } ++ ++ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; ++ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; ++ const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk; ++ ++ const size_t blockOffset = nInt4PerChunk * blockIdx.x; ++ const size_t scratchChunkRankOffset = chunkSizePerRank * rank; ++ const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4); ++ ++ int localRank = rank % NRANKS1_PER_NODE; ++ ++ __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; ++ __shared__ mscclpp::DeviceHandle outChannels[NRANKS_PER_NODE - 1]; ++ __shared__ mscclpp::DeviceHandle scrChannels[NRANKS_PER_NODE - 1]; ++ ++ const int lid = threadIdx.x % WARP_SIZE; ++ if (lid < nPeer) { ++ channels[lid] = smChans[lid]; ++ outChannels[lid] = smOutChans[lid]; ++ scrChannels[lid] = smScrChans[lid]; ++ } ++ __syncwarp(); ++ ++ // we can use double buffering to hide synchronization overhead ++ for (size_t itr = 0; itr < nItrs; itr++) { ++ if (threadIdx.x < (NRANKS1_PER_NODE-1)) { ++ int myNode = rank/NRANKS1_PER_NODE; ++ int remote = (threadIdx.x + 1 + rank); ++ int remoteNode = remote/NRANKS1_PER_NODE; ++ ++ if (remoteNode > myNode) { ++ remote = remote - NRANKS1_PER_NODE; ++ } ++ int peerIdx = remote < rank ? remote : remote - 1; ++ outChannels[peerIdx].signal(); ++ outChannels[peerIdx].wait(); ++ } ++ __syncthreads(); ++ ++ int myNode = rank/NRANKS1_PER_NODE; ++ ++ //Reduce within an OAM ++ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { ++ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock]; ++ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + ++ NRANKS1_PER_NODE - 1); peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ ++ int4 val = channels[peerIdx].read(nInt4PerRank * localRank + offsetOfThisBlock + idx); ++ data = add_vectors(val, data); ++ } ++ scratch4[idx + blockOffset] = data; ++ } ++ ++ if (threadIdx.x < static_cast(num_nodes-1)) { ++ int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize; ++ int peerIdx = remote < rank ? remote : remote - 1; ++ scrChannels[peerIdx].signal(); ++ scrChannels[peerIdx].wait(); ++ } ++ __syncthreads(); ++ ++ int remoteRank, peerIdx; ++ //Reduce across OAMs ++ ++ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { ++ int4 data = scratch4[idx + blockOffset]; ++ ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int myLocal = rank % NRANKS1_PER_NODE; ++ int remoteLocal = remoteRank % NRANKS1_PER_NODE; ++ ++ if (myLocal == remoteLocal) { ++ int4 val = scrChannels[peerIdx].read(blockOffset + idx + ++ channelScratchOffset/sizeof(int4)); ++ data = add_vectors(val, data); ++ } ++ } ++ ++ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data; ++ ++ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + NRANKS1_PER_NODE - 1); peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock + ++ channelOutDataOffset / sizeof(int4), data); ++ } ++ } ++ ++ if (threadIdx.x < static_cast(nPeer)) { ++ outChannels[threadIdx.x].signal(); ++ outChannels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ offsetOfThisBlock += nInt4PerChunk; ++ } ++ ++ if (restNInt4 > 0) { ++ if (threadIdx.x < (NRANKS1_PER_NODE-1)) { ++ int myNode = rank/NRANKS1_PER_NODE; ++ int remote = (threadIdx.x + 1 + rank); ++ int remoteNode = remote/NRANKS1_PER_NODE; ++ ++ if (remoteNode > myNode) { ++ remote = remote - NRANKS1_PER_NODE; ++ } ++ int peerIdx = remote < rank ? remote : remote - 1; ++ ++ outChannels[peerIdx].signal(); ++ outChannels[peerIdx].wait(); ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { ++ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock]; ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ ++ int myNode = rank/NRANKS1_PER_NODE; ++ int remoteNode = remoteRank/NRANKS1_PER_NODE; ++ ++ if (myNode == remoteNode) { ++ int4 val = channels[peerIdx].read(nInt4PerRank * localRank + offsetOfThisBlock + idx); ++ data = add_vectors(val, data); ++ } ++ } ++ scratch4[idx + blockOffset] = data; ++ } ++ ++ if (threadIdx.x < static_cast(num_nodes-1)) { ++ int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize; ++ int peerIdx = remote < rank ? remote : remote - 1; ++ scrChannels[peerIdx].signal(); ++ scrChannels[peerIdx].wait(); ++ } ++ __syncthreads(); ++ ++ int remoteRank, peerIdx; ++ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { ++ int4 data = scratch4[idx + blockOffset]; ++ ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int myLocal = rank % NRANKS1_PER_NODE; ++ int remoteLocal = remoteRank % NRANKS1_PER_NODE; ++ ++ if (myLocal == remoteLocal) { ++ int4 val = scrChannels[peerIdx].read(blockOffset + idx + ++ channelScratchOffset/sizeof(int4)); ++ data = add_vectors(val, data); ++ } ++ } ++ ++ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data; ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int myNode = rank/NRANKS1_PER_NODE; ++ int remoteNode = remoteRank/NRANKS1_PER_NODE; ++ ++ if (myNode == remoteNode) { ++ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock + ++ channelOutDataOffset / sizeof(int4), data); ++ } ++ } ++ } ++ if (threadIdx.x < static_cast(nPeer)) { ++ outChannels[threadIdx.x].signal(); ++ outChannels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ } ++ ++} ++ ++ ++ + template cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, @@ -167,16 +331,26 @@ index 1b85136..a45345a 100644 size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { static uint32_t flag = 1; -+ int readAllred = 0; ++ int readAllred = 0, hieAllred = 0; + char* envValue = nullptr; ++ char* envValue1 = nullptr; ++ ++ nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode; + + envValue = std::getenv("MSCCLPP_READ_ALLRED"); ++ envValue1 = std::getenv("MSCCLPP_HIERARCHICAL_ALLRED"); + + if (envValue != nullptr) { + if (atoi(envValue) == 1) { + readAllred = 1; + } + } ++ if (envValue1 != nullptr) { ++ if (atoi(envValue1) == 1) { ++ hieAllred = 1; ++ } ++ } ++ if (sizeof(T) * nelems < worldSize * sizeof(int)) { int nBlocks = 7; @@ -187,38 +361,55 @@ index 1b85136..a45345a 100644 + allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, + channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++); } else if (sizeof(T) * nelems <= (1 << 20)) { - int nBlocks = 28; +- int nBlocks = 28; ++ int nBlocks = 4*(nRanksPerNode - 1); int nThreadsPerBlock = 1024; -@@ -412,9 +525,15 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + if (nelems >= 8192) { +- nBlocks = 56; ++ nBlocks = 8*(nRanksPerNode - 1); + nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; + } + allreduce7<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, +@@ -412,9 +749,20 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< } else { int nBlocks = 35; int nThreadsPerBlock = 512; - allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, -+ if (!readAllred) { -+ allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, smOutChannels, - channelOutOffset, channelScratchOffset, rank, nRanksPerNode, - worldSize, nelems); +- channelOutOffset, channelScratchOffset, rank, nRanksPerNode, +- worldSize, nelems); ++ if (hieAllred && worldSize >= 8) { ++ allreduce10<<>>(buff, scratch, resultBuff, smChannels, smScrChannels, ++ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, ++ worldSize, nelems); + } else { -+ allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, -+ channelOutOffset, rank, nRanksPerNode, -+ worldSize, nelems); ++ if (!readAllred) { ++ allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, ++ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, ++ worldSize, nelems); ++ } else { ++ allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, ++ channelOutOffset, rank, nRanksPerNode, worldSize, nelems); ++ } + } } return cudaGetLastError(); diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp -index 25c74e7..32672c6 100644 +index 25c74e7..5e85468 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp -@@ -13,5 +13,6 @@ +@@ -11,7 +11,9 @@ + #define WARP_SIZE 32 + #endif ++constexpr int NRANKS1_PER_NODE = 4; constexpr int NRANKS_PER_NODE = 8; constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB +constexpr int NPEER = 7; #endif // NCCL_COMMON_HPP_ diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu -index ec130b0..571508d 100644 +index cb0e7d5..a697be2 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -49,7 +49,9 @@ struct hash { @@ -318,7 +509,7 @@ index ec130b0..571508d 100644 break; default: return ncclInvalidArgument; -@@ -551,7 +562,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t +@@ -550,7 +561,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t std::vector> smChannelDeviceHandles; std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });