Tune allreduce performance in CPX mode (single OAM) (#1508)

[ROCm/rccl commit: 7ac82248de]
This commit is contained in:
Nusrat Islam
2025-01-29 08:58:48 -06:00
کامیت شده توسط GitHub
والد f1260602f7
کامیت 53c927678b
@@ -1,8 +1,8 @@
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index 4134241..d65be4b 100644
index 4134241..76674ba 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -495,24 +495,348 @@ __global__ void __launch_bounds__(512, 1)
@@ -495,24 +495,345 @@ __global__ void __launch_bounds__(512, 1)
}
}
@@ -29,7 +29,7 @@ index 4134241..d65be4b 100644
+ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
+ size_t nInt4OfThisBlock = maxNInt4PerBlock;
+ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
+ constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB
+ constexpr size_t nInt4PerChunk = 1024 * 512 / sizeof(int4); // 512KB
+ if (blockIdx.x >= nNeededBlocks) {
+ nInt4OfThisBlock = 0;
+ } else if (blockIdx.x == nNeededBlocks - 1) {
@@ -69,10 +69,6 @@ index 4134241..d65be4b 100644
+ data);
+ }
+ }
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+
+ offsetOfThisBlock += nInt4PerChunk;
@@ -98,14 +94,15 @@ index 4134241..d65be4b 100644
+ data);
+ }
+ }
+
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+ }
+
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+
+}
+
+template <typename T>
@@ -206,7 +203,7 @@ index 4134241..d65be4b 100644
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
+ int4 data = scratch4[idx + blockOffset];
+
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ int myLocal = rank % NRANKS1_PER_NODE;
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
@@ -253,7 +250,7 @@ index 4134241..d65be4b 100644
+
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+
+ int myNode = rank/NRANKS1_PER_NODE;
@@ -278,7 +275,7 @@ index 4134241..d65be4b 100644
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = scratch4[idx + blockOffset];
+
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ int myLocal = rank % NRANKS1_PER_NODE;
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
@@ -291,7 +288,7 @@ index 4134241..d65be4b 100644
+ }
+
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ int myNode = rank/NRANKS1_PER_NODE;
+ int remoteNode = remoteRank/NRANKS1_PER_NODE;
@@ -340,7 +337,8 @@ index 4134241..d65be4b 100644
+ }
+ }
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
- int nBlocks = 7;
+ int nBlocks = nRanksPerNode - 1;
int nThreadsPerBlock = 32;
- allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
- channelScratchOffset, rank, nRanksPerNode, worldSize,
@@ -357,9 +355,12 @@ index 4134241..d65be4b 100644
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
}
#if defined(ENABLE_NPKIT)
@@ -528,9 +852,21 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
@@ -526,11 +847,23 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
flag++);
#endif
} else {
int nBlocks = 35;
- int nBlocks = 35;
+ int nBlocks = 5*(nRanksPerNode - 1);
int nThreadsPerBlock = 512;
- allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
- channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
@@ -383,19 +384,16 @@ index 4134241..d65be4b 100644
return cudaGetLastError();
diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp
index 015e0a2..f8ba6d6 100644
index 015e0a2..ca2c272 100644
--- a/apps/nccl/src/common.hpp
+++ b/apps/nccl/src/common.hpp
@@ -13,8 +13,10 @@
@@ -13,6 +13,7 @@
#define WARP_SIZE 32
#endif
+constexpr int NRANKS1_PER_NODE = 4;
constexpr int NRANKS_PER_NODE = 8;
constexpr int NPEERS = 7;
+constexpr int NPEER = 7;
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index f91d15e..022d398 100644