diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index a14dfbc..66596f3 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -17,6 +17,7 @@ #endif #include "common.hpp" +#include "debug.h" template __forceinline__ __device__ To bit_cast(const From& src) { @@ -359,6 +360,176 @@ __global__ void __launch_bounds__(1024, 1) #endif } +template +__global__ void __launch_bounds__(512, 1) + allreduce8Mod(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) { + const int nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) + const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); + size_t nInt4PerRank = nInt4 / worldSize; + if (nInt4 % worldSize) + nInt4PerRank = nInt4PerRank + 1; + + auto smChans = smChannels + chanOffset; + + size_t channelScratchResultOffset = channelScratchOffset + SCRATCH_SIZE/2; + + int4* buff4 = reinterpret_cast(buff); + int4* scratch4 = reinterpret_cast((char*)scratch + channelScratchOffset); + int4* scratch4Result = reinterpret_cast((char*)scratch + channelScratchResultOffset); + int4* resultBuff4 = reinterpret_cast(resultBuff); + + // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` + constexpr size_t unitNInt4 = 512; + const size_t maxNInt4PerBlock = + (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; + size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; + size_t nInt4OfThisBlock = maxNInt4PerBlock; + size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; + constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB + if (blockIdx.x >= nNeededBlocks) { + nInt4OfThisBlock = 0; + } else if (blockIdx.x == nNeededBlocks - 1) { + nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); + } + const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; + const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; + const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk; + const size_t blockOffset = nInt4PerChunk * blockIdx.x; + + const size_t scratchChunkRankOffset = chunkSizePerRank * rank; + const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4); + const size_t scratchResultBaseOffsetInt4 = channelScratchResultOffset / sizeof(int4); + + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + + const int lid = threadIdx.x % WARP_SIZE; + if (lid < nPeer) { + channels[lid] = smChans[lid]; + } + __syncwarp(); + + // we can use double buffering to hide synchronization overhead + for (size_t itr = 0; itr < nItrs; itr++) { + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + // Starts allgather + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + for (int i = 0; i < NPEERS; i++) { + const int peerIdx = (i + blockIdx.x) % nPeer; + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; + channels[peerIdx].write(scratchBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, val); + } + } + + /// Starts reduce-scatter + // Ensure that all writes of this block have been issued before issuing the signal + __syncthreads(); + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + channels[peerIdx].write(scratchResultBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, data); + } + } + __syncthreads(); + + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4Result[chunkSizePerRank * remoteRank + blockOffset + idx]; + resultBuff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock] = val; + } + } + __syncthreads(); + + offsetOfThisBlock += nInt4PerChunk; + // Ensure all threads have consumed data from scratch buffer before signaling re-use in next iteration + } + if (restNInt4 > 0) { + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + for (int i = 0; i < NPEERS; i++) { + const int peerIdx = (i + blockIdx.x) % nPeer; + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; + channels[peerIdx].write(scratchBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, val); + } + } + + // Ensure that all writes of this block have been issued before issuing the signal + __syncthreads(); + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + channels[peerIdx].write(scratchResultBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, data); + } + } + __syncthreads(); + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4Result[chunkSizePerRank * remoteRank + blockOffset + idx]; + resultBuff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock] = val; + } + } + __syncthreads(); + // Ensure all threads have issued writes to outChannel + } + // Threads are already synchronized + // So all writes to outChannel have been issued before signal is being issued + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } +} + + template __global__ void __launch_bounds__(512, 1) allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, @@ -808,25 +979,9 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { static uint32_t flag = 1; - int readAllred = 0, hieAllred = 0; - char* envValue = nullptr; - char* envValue1 = nullptr; nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode; - envValue = std::getenv("MSCCLPP_READ_ALLRED"); - envValue1 = std::getenv("MSCCLPP_HIERARCHICAL_ALLRED"); - - if (envValue != nullptr) { - if (atoi(envValue) == 1) { - readAllred = 1; - } - } - if (envValue1 != nullptr) { - if (atoi(envValue1) == 1) { - hieAllred = 1; - } - } if (sizeof(T) * nelems < worldSize * sizeof(int)) { int nBlocks = nRanksPerNode - 1; int nThreadsPerBlock = 32; @@ -852,16 +1007,21 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< } else { int nBlocks = 8 * (nRanksPerNode - 1); int nThreadsPerBlock = 512; - if (hieAllred && worldSize >= 8) { + if (mscclppHierarchicalAllred && worldSize >= 8) { nBlocks = 20; allreduce10<<>>(buff, scratch, resultBuff, smChannels, smScrChannels, smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems); } else { - if (!readAllred) { - allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, - smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, - worldSize, nelems); + if (!mscclppReadAllred) { + if (mscclppDisableRemoteUbr) { + allreduce8Mod<<>>(buff, scratch, resultBuff, smScrChannels, + channelScratchOffset, rank, nRanksPerNode, worldSize, nelems); + } else { + allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, + smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, + worldSize, nelems); + } } else { allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, channelOutOffset, rank, nRanksPerNode, worldSize, nelems); diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp index a6056ea..6b7ec02 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp @@ -17,7 +17,7 @@ constexpr int NRANKS1_PER_NODE = 4; constexpr int NRANKS_PER_NODE = 8; constexpr int NPEERS = 7; -constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 112; // double buffer * 56 thread-blocks * 8 ranks * 256KB = 112MB +constexpr int SCRATCH_SIZE = 4 * 1024 * 1024 * 112; // double buffer * 56 thread-blocks * 8 ranks * 256KB = 112MB __device__ mscclpp::DeviceSyncer deviceSyncer; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 5fb99ef..1e8d739 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -84,6 +84,7 @@ struct ncclComm { std::unordered_map channelInInfos; std::unordered_map channelOutInfos; + std::vector channelInfos; std::unordered_map channelScratchInfos; std::unordered_map regHandles; std::unordered_map handleKeys; @@ -216,7 +217,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff; - size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx; + size_t offsetScratch = (SCRATCH_SIZE / (2 * comm->numScratchBuff)) * scratchBuffIdx; int rank = comm->comm->bootstrap()->getRank(); channelKey sendKey{(void*)sendBasePtr, sendBytes}; channelKey recvKey{(void*)recvBasePtr, recvBytes}; @@ -224,6 +225,8 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, mscclpp::DeviceHandle* smOutChannels = nullptr; mscclpp::DeviceHandle* smScrChannels = nullptr; + size_t bytes = count * ncclTypeSize(datatype); + // Creating the channels if (count * ncclTypeSize(datatype) <= (1 << 20)) { auto sendIt = comm->channelScratchInfos.find(sendKey); @@ -251,25 +254,39 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first; } - auto recvIt = comm->channelOutInfos.find(recvKey); - if (recvIt == comm->channelOutInfos.end()) { - remoteMemories = - setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); - std::vector outChannels = - setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); - ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; - recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + if(mscclppDisableRemoteUbr == false) { + auto recvIt = comm->channelOutInfos.find(recvKey); + if (recvIt == comm->channelOutInfos.end() || mscclppDisableChannelCache == true) { + if (mscclppDisableChannelCache == true) { + recvBytes = bytes; + recvBasePtr = (CUdeviceptr)recvbuff; + offsetOut = 0; + } + remoteMemories = + setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); + std::vector outChannels = + setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); + ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; + if (mscclppDisableChannelCache == true) { + comm->channelInfos.push_back(channelInfo); + smOutChannels = comm->channelInfos.back().smChannelDeviceHandles.get(); + } else { + recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + smOutChannels = recvIt->second.smChannelDeviceHandles.get(); + } + } else { + smOutChannels = recvIt->second.smChannelDeviceHandles.get(); + } } smChannels = sendIt->second.smChannelDeviceHandles1.get(); - smOutChannels = recvIt->second.smChannelDeviceHandles.get(); smScrChannels = sendIt->second.smChannelDeviceHandles.get(); } switch (datatype) { case ncclFloat16: CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, - smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclFloat32: CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, @@ -323,7 +340,12 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, mscclpp::DeviceHandle* smChannels = nullptr; auto it = comm->channelOutInfos.find(recvKey); - if (it == comm->channelOutInfos.end()) { + if (it == comm->channelOutInfos.end() || mscclppDisableChannelCache == true) { + if (mscclppDisableChannelCache == true) { + recvBytes = bytes; + recvBasePtr = (CUdeviceptr)recvbuff; + offsetOut = 0; + } std::vector remoteMemories = setupRemoteMemories( comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); std::vector channels = @@ -332,10 +354,17 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; - it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + if (mscclppDisableChannelCache == true) { + comm->channelInfos.push_back(channelInfo); + smChannels = comm->channelInfos.back().smChannelDeviceHandles.get(); + } else { + it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + smChannels = it->second.smChannelDeviceHandles.get(); + } + } else { + smChannels = it->second.smChannelDeviceHandles.get(); } - smChannels = it->second.smChannelDeviceHandles.get(); if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); diff --git a/src/debug.cc b/src/debug.cc index a8350fb..d49eb4a 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -15,6 +15,11 @@ #include int mscclppDebugLevel = -1; +bool mscclppDisableChannelCache = false; +bool mscclppDisableRemoteUbr = false; +bool mscclppReadAllred = false; +bool mscclppHierarchicalAllred = false; + static int pid = -1; static std::string hostname; thread_local int mscclppDebugNoWarn = 0; @@ -51,6 +56,42 @@ void mscclppDebugInit() { tempNcclDebugLevel = MSCCLPP_LOG_TRACE; } + const char* disable_channel_cache = getenv("MSCCLPP_DISABLE_CHANNEL_CACHE"); + if (disable_channel_cache == NULL) { + mscclppDisableChannelCache = false; + } else if (strcasecmp(disable_channel_cache, "TRUE") == 0) { + mscclppDisableChannelCache = true; + } else { + mscclppDisableChannelCache = false; + } + + const char* enable_read_allred = getenv("MSCCLPP_READ_ALLRED"); + if (enable_read_allred == NULL) { + mscclppReadAllred = false; + } else if (strcasecmp(enable_read_allred, "TRUE") == 0) { + mscclppReadAllred = true; + } else { + mscclppReadAllred = false; + } + + const char* enable_hierarchical_allred = getenv("MSCCLPP_HIERARCHICAL_ALLRED"); + if (enable_hierarchical_allred == NULL) { + mscclppHierarchicalAllred = false; + } else if (strcasecmp(enable_hierarchical_allred, "TRUE") == 0) { + mscclppHierarchicalAllred = true; + } else { + mscclppHierarchicalAllred = false; + } + + const char* disable_remote_ubr = getenv("MSCCLPP_DISABLE_REMOTE_UBR"); + if (disable_remote_ubr == NULL) { + mscclppDisableRemoteUbr = false; + } else if (strcasecmp(disable_remote_ubr, "TRUE") == 0) { + mscclppDisableRemoteUbr = true; + } else { + mscclppDisableRemoteUbr = false; + } + /* Parse the MSCCLPP_DEBUG_SUBSYS env var * This can be a comma separated list such as INIT,COLL * or ^INIT,COLL etc diff --git a/src/include/debug.h b/src/include/debug.h index 713371b..033e5eb 100644 --- a/src/include/debug.h +++ b/src/include/debug.h @@ -91,6 +91,12 @@ typedef enum { extern int mscclppDebugLevel; extern uint64_t mscclppDebugMask; + +extern bool mscclppDisableChannelCache; +extern bool mscclppReadAllred; +extern bool mscclppHierarchicalAllred; +extern bool mscclppDisableRemoteUbr; + extern pthread_mutex_t mscclppDebugLock; extern FILE* mscclppDebugFile;