Tune allreduce performance in CPX mode (single OAM) (#1508)
[ROCm/rccl commit: 7ac82248de]
이 커밋은 다음에 포함됨:
@@ -1,8 +1,8 @@
|
||||
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
|
||||
index 4134241..d65be4b 100644
|
||||
index 4134241..76674ba 100644
|
||||
--- a/apps/nccl/src/allreduce.hpp
|
||||
+++ b/apps/nccl/src/allreduce.hpp
|
||||
@@ -495,24 +495,348 @@ __global__ void __launch_bounds__(512, 1)
|
||||
@@ -495,24 +495,345 @@ __global__ void __launch_bounds__(512, 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ index 4134241..d65be4b 100644
|
||||
+ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
|
||||
+ size_t nInt4OfThisBlock = maxNInt4PerBlock;
|
||||
+ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
|
||||
+ constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB
|
||||
+ constexpr size_t nInt4PerChunk = 1024 * 512 / sizeof(int4); // 512KB
|
||||
+ if (blockIdx.x >= nNeededBlocks) {
|
||||
+ nInt4OfThisBlock = 0;
|
||||
+ } else if (blockIdx.x == nNeededBlocks - 1) {
|
||||
@@ -69,10 +69,6 @@ index 4134241..d65be4b 100644
|
||||
+ data);
|
||||
+ }
|
||||
+ }
|
||||
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
||||
+ outChannels[threadIdx.x].signal();
|
||||
+ outChannels[threadIdx.x].wait();
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+
|
||||
+ offsetOfThisBlock += nInt4PerChunk;
|
||||
@@ -98,14 +94,15 @@ index 4134241..d65be4b 100644
|
||||
+ data);
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
||||
+ outChannels[threadIdx.x].signal();
|
||||
+ outChannels[threadIdx.x].wait();
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+
|
||||
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
|
||||
+ outChannels[threadIdx.x].signal();
|
||||
+ outChannels[threadIdx.x].wait();
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+
|
||||
+}
|
||||
+
|
||||
+template <typename T>
|
||||
@@ -206,7 +203,7 @@ index 4134241..d65be4b 100644
|
||||
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
|
||||
+ int4 data = scratch4[idx + blockOffset];
|
||||
+
|
||||
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
|
||||
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
||||
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
||||
+ int myLocal = rank % NRANKS1_PER_NODE;
|
||||
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
|
||||
@@ -253,7 +250,7 @@ index 4134241..d65be4b 100644
|
||||
+
|
||||
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
|
||||
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
|
||||
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
|
||||
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
||||
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
||||
+
|
||||
+ int myNode = rank/NRANKS1_PER_NODE;
|
||||
@@ -278,7 +275,7 @@ index 4134241..d65be4b 100644
|
||||
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
|
||||
+ int4 data = scratch4[idx + blockOffset];
|
||||
+
|
||||
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
|
||||
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
||||
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
||||
+ int myLocal = rank % NRANKS1_PER_NODE;
|
||||
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
|
||||
@@ -291,7 +288,7 @@ index 4134241..d65be4b 100644
|
||||
+ }
|
||||
+
|
||||
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
|
||||
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
|
||||
+ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) {
|
||||
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
||||
+ int myNode = rank/NRANKS1_PER_NODE;
|
||||
+ int remoteNode = remoteRank/NRANKS1_PER_NODE;
|
||||
@@ -340,7 +337,8 @@ index 4134241..d65be4b 100644
|
||||
+ }
|
||||
+ }
|
||||
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
|
||||
int nBlocks = 7;
|
||||
- int nBlocks = 7;
|
||||
+ int nBlocks = nRanksPerNode - 1;
|
||||
int nThreadsPerBlock = 32;
|
||||
- allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
|
||||
- channelScratchOffset, rank, nRanksPerNode, worldSize,
|
||||
@@ -357,9 +355,12 @@ index 4134241..d65be4b 100644
|
||||
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
|
||||
}
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -528,9 +852,21 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
|
||||
@@ -526,11 +847,23 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
|
||||
flag++);
|
||||
#endif
|
||||
} else {
|
||||
int nBlocks = 35;
|
||||
- int nBlocks = 35;
|
||||
+ int nBlocks = 5*(nRanksPerNode - 1);
|
||||
int nThreadsPerBlock = 512;
|
||||
- allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
|
||||
- channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
||||
@@ -383,19 +384,16 @@ index 4134241..d65be4b 100644
|
||||
|
||||
return cudaGetLastError();
|
||||
diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp
|
||||
index 015e0a2..f8ba6d6 100644
|
||||
index 015e0a2..ca2c272 100644
|
||||
--- a/apps/nccl/src/common.hpp
|
||||
+++ b/apps/nccl/src/common.hpp
|
||||
@@ -13,8 +13,10 @@
|
||||
@@ -13,6 +13,7 @@
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
+constexpr int NRANKS1_PER_NODE = 4;
|
||||
constexpr int NRANKS_PER_NODE = 8;
|
||||
constexpr int NPEERS = 7;
|
||||
+constexpr int NPEER = 7;
|
||||
|
||||
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB
|
||||
|
||||
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
|
||||
index f91d15e..022d398 100644
|
||||
|
||||
새 이슈에서 참조
사용자 차단