From f61053dcba681041c3a3d83e895feaed443ee762 Mon Sep 17 00:00:00 2001 From: Nusrat Islam Date: Tue, 8 Oct 2024 14:42:12 -0500 Subject: [PATCH] Add a custom allreduce algorithm in MSCCLPP for cpx mode (#1362) * cmake: remove mscclpp patch after build is complete To enable mscclpp in cpx mode, a patch cpx.patch needs to be applied. This patch can be removed after building is done. This helps with the build process the following time. * Use read-based mscclpp allreduce from rccl MSCCLPP by default uses remote write in the allreduce kernel for large (> 1MB) messages. This PR adds an allreduce kernel that uses remote read. It needs the users to use an environment variable MSCCLPP_READ_ALLRED=1. [ROCm/rccl commit: 4d68751ce193675ff9ee6635e3caa392b161bb95] --- projects/rccl/cmake/MSCCLPP.cmake | 16 +- projects/rccl/ext-src/read-allred.patch | 329 ++++++++++++++++++++++++ 2 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 projects/rccl/ext-src/read-allred.patch diff --git a/projects/rccl/cmake/MSCCLPP.cmake b/projects/rccl/cmake/MSCCLPP.cmake index 54e3fb9654..32104440fa 100644 --- a/projects/rccl/cmake/MSCCLPP.cmake +++ b/projects/rccl/cmake/MSCCLPP.cmake @@ -66,7 +66,12 @@ if(ENABLE_MSCCLPP) execute_process( COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + ) + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + message(STATUS "Building mscclpp only for gfx942.") mscclpp_cmake_arg(CMAKE_PREFIX_PATH) @@ -89,6 +94,14 @@ if(ENABLE_MSCCLPP) ) find_package(mscclpp_nccl REQUIRED) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) endif() execute_process(COMMAND objcopy @@ -98,4 +111,5 @@ if(ENABLE_MSCCLPP) ) add_library(mscclpp_nccl STATIC IMPORTED) set_target_properties(mscclpp_nccl PROPERTIES IMPORTED_LOCATION ${PROJECT_BINARY_DIR}/libmscclpp_nccl.a) + endif() diff --git a/projects/rccl/ext-src/read-allred.patch b/projects/rccl/ext-src/read-allred.patch new file mode 100644 index 0000000000..e386c64b74 --- /dev/null +++ b/projects/rccl/ext-src/read-allred.patch @@ -0,0 +1,329 @@ +diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp +index 1b85136..a45345a 100644 +--- a/apps/nccl/src/allreduce.hpp ++++ b/apps/nccl/src/allreduce.hpp +@@ -319,7 +319,7 @@ __global__ void __launch_bounds__(512, 1) + __syncthreads(); + // Starts allgather + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { +- for (int i = 0; i < nPeer; i++) { ++ for (int i = 0; i < NPEER; i++) { + const int peerIdx = (i + blockIdx.x) % nPeer; + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; +@@ -336,13 +336,13 @@ __global__ void __launch_bounds__(512, 1) + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; +- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; +- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); + } +@@ -356,7 +356,7 @@ __global__ void __launch_bounds__(512, 1) + } + __syncthreads(); + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { +- for (int i = 0; i < nPeer; i++) { ++ for (int i = 0; i < NPEER; i++) { + const int peerIdx = (i + blockIdx.x) % nPeer; + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; +@@ -372,13 +372,13 @@ __global__ void __launch_bounds__(512, 1) + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; +- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; +- for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); + } +@@ -386,19 +386,132 @@ __global__ void __launch_bounds__(512, 1) + } + } + ++template ++__global__ void __launch_bounds__(512, 1) ++ allreduce8Read(T* buff, T* resultBuff, mscclpp::DeviceHandle* smChannels, ++ mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, ++ int rank, int nRanksPerNode, int worldSize, size_t nelems) { ++ const int nPeer = nRanksPerNode - 1; ++ const size_t chanOffset = nPeer * blockIdx.x; ++ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) ++ const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); ++ const size_t nInt4PerRank = nInt4 / worldSize; ++ auto smChans = smChannels + chanOffset; ++ auto smOutChans = smOutChannels + chanOffset; ++ ++ int4* buff4 = reinterpret_cast(buff); ++ int4* resultBuff4 = reinterpret_cast(resultBuff); ++ ++ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` ++ constexpr size_t unitNInt4 = 512; ++ const size_t maxNInt4PerBlock = ++ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; ++ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; ++ size_t nInt4OfThisBlock = maxNInt4PerBlock; ++ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; ++ constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB ++ if (blockIdx.x >= nNeededBlocks) { ++ nInt4OfThisBlock = 0; ++ } else if (blockIdx.x == nNeededBlocks - 1) { ++ nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); ++ } ++ ++ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; ++ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; ++ ++ __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; ++ __shared__ mscclpp::DeviceHandle outChannels[NRANKS_PER_NODE - 1]; ++ const int lid = threadIdx.x % WARP_SIZE; ++ if (lid < nPeer) { ++ channels[lid] = smChans[lid]; ++ outChannels[lid] = smOutChans[lid]; ++ } ++ __syncwarp(); ++ ++ // we can use double buffering to hide synchronization overhead ++ for (size_t itr = 0; itr < nItrs; itr++) { ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { ++ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; ++ data = add_vectors(val, data); ++ } ++ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; ++ ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), ++ data); ++ } ++ } ++ if (threadIdx.x < static_cast(nPeer)) { ++ outChannels[threadIdx.x].signal(); ++ outChannels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ offsetOfThisBlock += nInt4PerChunk; ++ } ++ ++ if (restNInt4 > 0) { ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { ++ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; ++ data = add_vectors(val, data); ++ } ++ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; ++ for (int peerIdx = 0; peerIdx < NPEER; peerIdx++) { ++ outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), ++ data); ++ } ++ } ++ ++ if (threadIdx.x < static_cast(nPeer)) { ++ outChannels[threadIdx.x].signal(); ++ outChannels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ } ++ ++} ++ ++ + template + cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, +- mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, ++ mscclpp::DeviceHandle* smScrChannels, ++ mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, + size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, + size_t nelems, cudaStream_t stream) { + static uint32_t flag = 1; ++ int readAllred = 0; ++ char* envValue = nullptr; ++ ++ envValue = std::getenv("MSCCLPP_READ_ALLRED"); ++ ++ if (envValue != nullptr) { ++ if (atoi(envValue) == 1) { ++ readAllred = 1; ++ } ++ } + + if (sizeof(T) * nelems < worldSize * sizeof(int)) { + int nBlocks = 7; + int nThreadsPerBlock = 32; +- allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, +- channelScratchOffset, rank, nRanksPerNode, worldSize, +- nelems, flag++); ++ allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, ++ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++); + } else if (sizeof(T) * nelems <= (1 << 20)) { + int nBlocks = 28; + int nThreadsPerBlock = 1024; +@@ -412,9 +525,15 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + } else { + int nBlocks = 35; + int nThreadsPerBlock = 512; +- allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, ++ if (!readAllred) { ++ allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, smOutChannels, + channelOutOffset, channelScratchOffset, rank, nRanksPerNode, + worldSize, nelems); ++ } else { ++ allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, ++ channelOutOffset, rank, nRanksPerNode, ++ worldSize, nelems); ++ } + } + + return cudaGetLastError(); +diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp +index 25c74e7..32672c6 100644 +--- a/apps/nccl/src/common.hpp ++++ b/apps/nccl/src/common.hpp +@@ -13,5 +13,6 @@ + + constexpr int NRANKS_PER_NODE = 8; + constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB ++constexpr int NPEER = 7; + + #endif // NCCL_COMMON_HPP_ +diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu +index ec130b0..571508d 100644 +--- a/apps/nccl/src/nccl.cu ++++ b/apps/nccl/src/nccl.cu +@@ -49,7 +49,9 @@ struct hash { + + struct ChannelInfo { + std::vector smChannels; ++ std::vector smChannels1; + std::shared_ptr> smChannelDeviceHandles; ++ std::shared_ptr> smChannelDeviceHandles1; + }; + + struct ncclComm { +@@ -212,8 +214,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + int rank = comm->comm->bootstrap()->getRank(); + channelKey sendKey{(void*)sendBasePtr, sendBytes}; + channelKey recvKey{(void*)recvBasePtr, recvBytes}; ++ + mscclpp::DeviceHandle* smChannels = nullptr; + mscclpp::DeviceHandle* smOutChannels = nullptr; ++ mscclpp::DeviceHandle* smScrChannels = nullptr; + + // Creating the channels + if (count * ncclTypeSize(datatype) <= comm->largeMessageSizeBoundary) { +@@ -221,19 +225,25 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + if (sendIt == comm->channelScratchInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); +- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; ++ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; + sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first; + } + + smChannels = sendIt->second.smChannelDeviceHandles.get(); + } else { + std::vector remoteMemories; ++ std::vector remoteMemories1; + + auto sendIt = comm->channelInInfos.find(sendKey); + if (sendIt == comm->channelInInfos.end()) { + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); +- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; ++ remoteMemories1 = ++ setupRemoteMemories(comm->comm, rank, (void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc); ++ std::vector channels1 = ++ setupSmChannels(comm, remoteMemories1, const_cast((void*)sendBasePtr)); ++ ++ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; + sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first; + } + +@@ -243,35 +253,36 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); + std::vector outChannels = + setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); +- ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)}; ++ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; + recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + } + +- smChannels = sendIt->second.smChannelDeviceHandles.get(); ++ smChannels = sendIt->second.smChannelDeviceHandles1.get(); + smOutChannels = recvIt->second.smChannelDeviceHandles.get(); ++ smScrChannels = sendIt->second.smChannelDeviceHandles.get(); + } + + switch (datatype) { + case ncclFloat16: +- CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels, +- offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, +- comm->comm->bootstrap()->getNranks(), count, stream)); ++ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, ++ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclFloat32: + CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, +- smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), ++ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, ++ comm->comm->bootstrap()->getRank(), + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclBfloat16: + CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff, +- smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, +- comm->comm->bootstrap()->getNranks(), count, stream)); ++ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, ++ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclInt32: + case ncclUint32: +- CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels, +- offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, +- comm->comm->bootstrap()->getNranks(), count, stream)); ++ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, ++ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), ++ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + default: + return ncclInvalidArgument; +@@ -551,7 +562,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t + std::vector> smChannelDeviceHandles; + std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); +- ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; ++ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; + it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + } +