ext-src: tune TP=8 case on MI308 CPX mode (#1446)

Tune the number of blocks for hierarchical mscclpp allreduce.
Este commit está contenido en:
Nusrat Islam
2024-12-06 08:16:39 -06:00
cometido por GitHub
padre a05329bd0d
commit 42b6831a39
+4 -11
Ver fichero
@@ -1,8 +1,8 @@
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index 1b85136..a08f822 100644
index 1b85136..ee90c2f 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -386,24 +386,361 @@ __global__ void __launch_bounds__(512, 1)
@@ -386,24 +386,353 @@ __global__ void __launch_bounds__(512, 1)
}
}
@@ -147,11 +147,8 @@ index 1b85136..a08f822 100644
+
+ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk;
+ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk;
+ const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk;
+
+ const size_t blockOffset = nInt4PerChunk * blockIdx.x;
+ const size_t scratchChunkRankOffset = chunkSizePerRank * rank;
+ const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4);
+
+ int localRank = rank % NRANKS1_PER_NODE;
+
@@ -190,8 +187,6 @@ index 1b85136..a08f822 100644
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
+ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode +
+ NRANKS1_PER_NODE - 1); peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * localRank + offsetOfThisBlock + idx);
+ data = add_vectors<T>(val, data);
+ }
@@ -206,7 +201,6 @@ index 1b85136..a08f822 100644
+ }
+ __syncthreads();
+
+ int remoteRank, peerIdx;
+ //Reduce across OAMs
+
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
@@ -227,7 +221,6 @@ index 1b85136..a08f822 100644
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
+
+ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + NRANKS1_PER_NODE - 1); peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock +
+ channelOutDataOffset / sizeof(int4), data);
+ }
@@ -282,7 +275,6 @@ index 1b85136..a08f822 100644
+ }
+ __syncthreads();
+
+ int remoteRank, peerIdx;
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = scratch4[idx + blockOffset];
+
@@ -370,7 +362,7 @@ index 1b85136..a08f822 100644
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
}
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
@@ -412,9 +749,20 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
@@ -412,9 +741,21 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
} else {
int nBlocks = 35;
int nThreadsPerBlock = 512;
@@ -378,6 +370,7 @@ index 1b85136..a08f822 100644
- channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
- worldSize, nelems);
+ if (hieAllred && worldSize >= 8) {
+ nBlocks = 20;
+ allreduce10<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smScrChannels,
+ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
+ worldSize, nelems);