From 7ac82248deea241d5f40b80a9e3a5836d2c506bc Mon Sep 17 00:00:00 2001 From: Nusrat Islam Date: Wed, 29 Jan 2025 08:58:48 -0600 Subject: [PATCH] Tune allreduce performance in CPX mode (single OAM) (#1508) --- ext-src/read-allred.patch | 46 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/ext-src/read-allred.patch b/ext-src/read-allred.patch index 09f81f1a13..a7c9d7fbf6 100644 --- a/ext-src/read-allred.patch +++ b/ext-src/read-allred.patch @@ -1,8 +1,8 @@ diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp -index 4134241..d65be4b 100644 +index 4134241..76674ba 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp -@@ -495,24 +495,348 @@ __global__ void __launch_bounds__(512, 1) +@@ -495,24 +495,345 @@ __global__ void __launch_bounds__(512, 1) } } @@ -29,7 +29,7 @@ index 4134241..d65be4b 100644 + size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; + size_t nInt4OfThisBlock = maxNInt4PerBlock; + size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; -+ constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB ++ constexpr size_t nInt4PerChunk = 1024 * 512 / sizeof(int4); // 512KB + if (blockIdx.x >= nNeededBlocks) { + nInt4OfThisBlock = 0; + } else if (blockIdx.x == nNeededBlocks - 1) { @@ -69,10 +69,6 @@ index 4134241..d65be4b 100644 + data); + } + } -+ if (threadIdx.x < static_cast(nPeer)) { -+ outChannels[threadIdx.x].signal(); -+ outChannels[threadIdx.x].wait(); -+ } + __syncthreads(); + + offsetOfThisBlock += nInt4PerChunk; @@ -98,14 +94,15 @@ index 4134241..d65be4b 100644 + data); + } + } -+ -+ if (threadIdx.x < static_cast(nPeer)) { -+ outChannels[threadIdx.x].signal(); -+ outChannels[threadIdx.x].wait(); -+ } + __syncthreads(); + } + ++ if (threadIdx.x < static_cast(nPeer)) { ++ outChannels[threadIdx.x].signal(); ++ outChannels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ +} + +template @@ -206,7 +203,7 @@ index 4134241..d65be4b 100644 + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = scratch4[idx + blockOffset]; + -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int myLocal = rank % NRANKS1_PER_NODE; + int remoteLocal = remoteRank % NRANKS1_PER_NODE; @@ -253,7 +250,7 @@ index 4134241..d65be4b 100644 + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock]; -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + + int myNode = rank/NRANKS1_PER_NODE; @@ -278,7 +275,7 @@ index 4134241..d65be4b 100644 + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = scratch4[idx + blockOffset]; + -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int myLocal = rank % NRANKS1_PER_NODE; + int remoteLocal = remoteRank % NRANKS1_PER_NODE; @@ -291,7 +288,7 @@ index 4134241..d65be4b 100644 + } + + resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data; -+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int myNode = rank/NRANKS1_PER_NODE; + int remoteNode = remoteRank/NRANKS1_PER_NODE; @@ -340,7 +337,8 @@ index 4134241..d65be4b 100644 + } + } if (sizeof(T) * nelems < worldSize * sizeof(int)) { - int nBlocks = 7; +- int nBlocks = 7; ++ int nBlocks = nRanksPerNode - 1; int nThreadsPerBlock = 32; - allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, - channelScratchOffset, rank, nRanksPerNode, worldSize, @@ -357,9 +355,12 @@ index 4134241..d65be4b 100644 nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; } #if defined(ENABLE_NPKIT) -@@ -528,9 +852,21 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< +@@ -526,11 +847,23 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + flag++); + #endif } else { - int nBlocks = 35; +- int nBlocks = 35; ++ int nBlocks = 5*(nRanksPerNode - 1); int nThreadsPerBlock = 512; - allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, - channelOutOffset, channelScratchOffset, rank, nRanksPerNode, @@ -383,19 +384,16 @@ index 4134241..d65be4b 100644 return cudaGetLastError(); diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp -index 015e0a2..f8ba6d6 100644 +index 015e0a2..ca2c272 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp -@@ -13,8 +13,10 @@ +@@ -13,6 +13,7 @@ #define WARP_SIZE 32 #endif +constexpr int NRANKS1_PER_NODE = 4; constexpr int NRANKS_PER_NODE = 8; constexpr int NPEERS = 7; -+constexpr int NPEER = 7; - - constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index f91d15e..022d398 100644