ext-src: Improved allreduce performance in cpx mode for MI308 (#1393)

To get the improved performance for TP=4, the user needs to use
RCCL_MSCCL_FORCE_ENABLE=1 and MSCCLPP_READ_ALLRED=1. For TP=8, the
user should use MSCCLPP_HIERARCHICAL_ALLRED=1.
This commit is contained in:
Nusrat Islam
2024-10-30 08:30:15 -05:00
zatwierdzone przez GitHub
rodzic ea20af698e
commit 0fb3b5eba9
+261 -70
Wyświetl plik
@@ -1,58 +1,8 @@
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index 1b85136..a45345a 100644
index 1b85136..a08f822 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -319,7 +319,7 @@ __global__ void __launch_bounds__(512, 1)
__syncthreads();
// Starts allgather
for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
- for (int i = 0; i < nPeer; i++) {
+ for (int i = 0; i < NPEER; i++) {
const int peerIdx = (i + blockIdx.x) % nPeer;
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock];
@@ -336,13 +336,13 @@ __global__ void __launch_bounds__(512, 1)
for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
data = add_vectors<T>(val, data);
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
data);
}
@@ -356,7 +356,7 @@ __global__ void __launch_bounds__(512, 1)
}
__syncthreads();
for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
- for (int i = 0; i < nPeer; i++) {
+ for (int i = 0; i < NPEER; i++) {
const int peerIdx = (i + blockIdx.x) % nPeer;
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock];
@@ -372,13 +372,13 @@ __global__ void __launch_bounds__(512, 1)
for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
data = add_vectors<T>(val, data);
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
data);
}
@@ -386,19 +386,132 @@ __global__ void __launch_bounds__(512, 1)
@@ -386,24 +386,361 @@ __global__ void __launch_bounds__(512, 1)
}
}
@@ -108,13 +58,13 @@ index 1b85136..a45345a 100644
+
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
+ data = add_vectors<T>(val, data);
+ }
+ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
+
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
+ data);
+ }
@@ -138,12 +88,12 @@ index 1b85136..a45345a 100644
+
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
+ data = add_vectors<T>(val, data);
+ }
+ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
+ data);
+ }
@@ -158,6 +108,220 @@ index 1b85136..a45345a 100644
+
+}
+
+template <typename T>
+__global__ void __launch_bounds__(1024, 1)
+ allreduce10(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
+ size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
+ const int nPeer = nRanksPerNode - 1;
+ const size_t chanOffset = nPeer * blockIdx.x;
+ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
+ const size_t nInt4 = nelems * sizeof(T) / sizeof(int4);
+ const size_t nInt4PerRank = nInt4 / NRANKS1_PER_NODE;
+
+ auto smChans = smChannels + chanOffset;
+ auto smOutChans = smOutChannels + chanOffset;
+ auto smScrChans = smScrChannels + chanOffset;
+
+ int4* buff4 = reinterpret_cast<int4*>(buff);
+ int4* scratch4 = reinterpret_cast<int4*>((char*)scratch + channelScratchOffset);
+ int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);
+
+ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
+ constexpr size_t unitNInt4 = 512;
+ const size_t maxNInt4PerBlock =
+ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
+ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
+ size_t nInt4OfThisBlock = maxNInt4PerBlock;
+ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
+
+ constexpr size_t nInt4PerChunk = 1024 * 1024 / sizeof(int4); // 256KB
+ int num_nodes = worldSize/NRANKS1_PER_NODE;
+
+ if (blockIdx.x >= nNeededBlocks) {
+ nInt4OfThisBlock = 0;
+ } else if (blockIdx.x == nNeededBlocks - 1) {
+ nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1);
+ }
+
+ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk;
+ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk;
+ const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk;
+
+ const size_t blockOffset = nInt4PerChunk * blockIdx.x;
+ const size_t scratchChunkRankOffset = chunkSizePerRank * rank;
+ const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4);
+
+ int localRank = rank % NRANKS1_PER_NODE;
+
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
+ __shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> scrChannels[NRANKS_PER_NODE - 1];
+
+ const int lid = threadIdx.x % WARP_SIZE;
+ if (lid < nPeer) {
+ channels[lid] = smChans[lid];
+ outChannels[lid] = smOutChans[lid];
+ scrChannels[lid] = smScrChans[lid];
+ }
+ __syncwarp();
+
+ // we can use double buffering to hide synchronization overhead
+ for (size_t itr = 0; itr < nItrs; itr++) {
+ if (threadIdx.x < (NRANKS1_PER_NODE-1)) {
+ int myNode = rank/NRANKS1_PER_NODE;
+ int remote = (threadIdx.x + 1 + rank);
+ int remoteNode = remote/NRANKS1_PER_NODE;
+
+ if (remoteNode > myNode) {
+ remote = remote - NRANKS1_PER_NODE;
+ }
+ int peerIdx = remote < rank ? remote : remote - 1;
+ outChannels[peerIdx].signal();
+ outChannels[peerIdx].wait();
+ }
+ __syncthreads();
+
+ int myNode = rank/NRANKS1_PER_NODE;
+
+ //Reduce within an OAM
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
+ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode +
+ NRANKS1_PER_NODE - 1); peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * localRank + offsetOfThisBlock + idx);
+ data = add_vectors<T>(val, data);
+ }
+ scratch4[idx + blockOffset] = data;
+ }
+
+ if (threadIdx.x < static_cast<uint32_t>(num_nodes-1)) {
+ int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize;
+ int peerIdx = remote < rank ? remote : remote - 1;
+ scrChannels[peerIdx].signal();
+ scrChannels[peerIdx].wait();
+ }
+ __syncthreads();
+
+ int remoteRank, peerIdx;
+ //Reduce across OAMs
+
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
+ int4 data = scratch4[idx + blockOffset];
+
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ int myLocal = rank % NRANKS1_PER_NODE;
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
+
+ if (myLocal == remoteLocal) {
+ int4 val = scrChannels[peerIdx].read<int4>(blockOffset + idx +
+ channelScratchOffset/sizeof(int4));
+ data = add_vectors<T>(val, data);
+ }
+ }
+
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
+
+ for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + NRANKS1_PER_NODE - 1); peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock +
+ channelOutDataOffset / sizeof(int4), data);
+ }
+ }
+
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+
+ offsetOfThisBlock += nInt4PerChunk;
+ }
+
+ if (restNInt4 > 0) {
+ if (threadIdx.x < (NRANKS1_PER_NODE-1)) {
+ int myNode = rank/NRANKS1_PER_NODE;
+ int remote = (threadIdx.x + 1 + rank);
+ int remoteNode = remote/NRANKS1_PER_NODE;
+
+ if (remoteNode > myNode) {
+ remote = remote - NRANKS1_PER_NODE;
+ }
+ int peerIdx = remote < rank ? remote : remote - 1;
+
+ outChannels[peerIdx].signal();
+ outChannels[peerIdx].wait();
+ }
+ __syncthreads();
+
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock];
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+
+ int myNode = rank/NRANKS1_PER_NODE;
+ int remoteNode = remoteRank/NRANKS1_PER_NODE;
+
+ if (myNode == remoteNode) {
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * localRank + offsetOfThisBlock + idx);
+ data = add_vectors<T>(val, data);
+ }
+ }
+ scratch4[idx + blockOffset] = data;
+ }
+
+ if (threadIdx.x < static_cast<uint32_t>(num_nodes-1)) {
+ int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize;
+ int peerIdx = remote < rank ? remote : remote - 1;
+ scrChannels[peerIdx].signal();
+ scrChannels[peerIdx].wait();
+ }
+ __syncthreads();
+
+ int remoteRank, peerIdx;
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = scratch4[idx + blockOffset];
+
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ int myLocal = rank % NRANKS1_PER_NODE;
+ int remoteLocal = remoteRank % NRANKS1_PER_NODE;
+
+ if (myLocal == remoteLocal) {
+ int4 val = scrChannels[peerIdx].read<int4>(blockOffset + idx +
+ channelScratchOffset/sizeof(int4));
+ data = add_vectors<T>(val, data);
+ }
+ }
+
+ resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data;
+ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) {
+ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
+ int myNode = rank/NRANKS1_PER_NODE;
+ int remoteNode = remoteRank/NRANKS1_PER_NODE;
+
+ if (myNode == remoteNode) {
+ outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock +
+ channelOutDataOffset / sizeof(int4), data);
+ }
+ }
+ }
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+
+ }
+
+}
+
+
+
+
template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
@@ -167,16 +331,26 @@ index 1b85136..a45345a 100644
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;
+ int readAllred = 0;
+ int readAllred = 0, hieAllred = 0;
+ char* envValue = nullptr;
+ char* envValue1 = nullptr;
+
+ nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode;
+
+ envValue = std::getenv("MSCCLPP_READ_ALLRED");
+ envValue1 = std::getenv("MSCCLPP_HIERARCHICAL_ALLRED");
+
+ if (envValue != nullptr) {
+ if (atoi(envValue) == 1) {
+ readAllred = 1;
+ }
+ }
+ if (envValue1 != nullptr) {
+ if (atoi(envValue1) == 1) {
+ hieAllred = 1;
+ }
+ }
+
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
@@ -187,38 +361,55 @@ index 1b85136..a45345a 100644
+ allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
+ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
- int nBlocks = 28;
+ int nBlocks = 4*(nRanksPerNode - 1);
int nThreadsPerBlock = 1024;
@@ -412,9 +525,15 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
if (nelems >= 8192) {
- nBlocks = 56;
+ nBlocks = 8*(nRanksPerNode - 1);
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
}
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
@@ -412,9 +749,20 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
} else {
int nBlocks = 35;
int nThreadsPerBlock = 512;
- allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
+ if (!readAllred) {
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels, smOutChannels,
channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);
- channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
- worldSize, nelems);
+ if (hieAllred && worldSize >= 8) {
+ allreduce10<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smScrChannels,
+ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
+ worldSize, nelems);
+ } else {
+ allreduce8Read<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels,
+ channelOutOffset, rank, nRanksPerNode,
+ worldSize, nelems);
+ if (!readAllred) {
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
+ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
+ worldSize, nelems);
+ } else {
+ allreduce8Read<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels,
+ channelOutOffset, rank, nRanksPerNode, worldSize, nelems);
+ }
+ }
}
return cudaGetLastError();
diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp
index 25c74e7..32672c6 100644
index 25c74e7..5e85468 100644
--- a/apps/nccl/src/common.hpp
+++ b/apps/nccl/src/common.hpp
@@ -13,5 +13,6 @@
@@ -11,7 +11,9 @@
#define WARP_SIZE 32
#endif
+constexpr int NRANKS1_PER_NODE = 4;
constexpr int NRANKS_PER_NODE = 8;
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB
+constexpr int NPEER = 7;
#endif // NCCL_COMMON_HPP_
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index ec130b0..571508d 100644
index cb0e7d5..a697be2 100644
--- a/apps/nccl/src/nccl.cu
+++ b/apps/nccl/src/nccl.cu
@@ -49,7 +49,9 @@ struct hash<channelKey> {
@@ -318,7 +509,7 @@ index ec130b0..571508d 100644
break;
default:
return ncclInvalidArgument;
@@ -551,7 +562,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
@@ -550,7 +561,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });