diff --git a/projects/rccl/cmake/MSCCLPP.cmake b/projects/rccl/cmake/MSCCLPP.cmake index 0a9ec655fd..c23ce59c80 100644 --- a/projects/rccl/cmake/MSCCLPP.cmake +++ b/projects/rccl/cmake/MSCCLPP.cmake @@ -88,6 +88,12 @@ if(ENABLE_MSCCLPP) WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/no-cache.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + set(CMAKE_INHERITED_ARGS "") set(CMAKE_ARGS_LIST "CMAKE_PREFIX_PATH;CMAKE_INSTALL_RPATH_USE_LINK_PATH;HIP_COMPILER") foreach(arg IN LISTS CMAKE_ARGS_LIST) @@ -131,40 +137,45 @@ if(ENABLE_MSCCLPP) find_package(mscclpp_nccl REQUIRED) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/no-cache.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/non-multiple-128-fix.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/non-multiple-128-fix.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) #endif() diff --git a/projects/rccl/ext-src/no-cache.patch b/projects/rccl/ext-src/no-cache.patch new file mode 100644 index 0000000000..1ff63ab52d --- /dev/null +++ b/projects/rccl/ext-src/no-cache.patch @@ -0,0 +1,445 @@ +diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp +index a14dfbc..66596f3 100644 +--- a/apps/nccl/src/allreduce.hpp ++++ b/apps/nccl/src/allreduce.hpp +@@ -17,6 +17,7 @@ + #endif + + #include "common.hpp" ++#include "debug.h" + + template + __forceinline__ __device__ To bit_cast(const From& src) { +@@ -359,6 +360,176 @@ __global__ void __launch_bounds__(1024, 1) + #endif + } + ++template ++__global__ void __launch_bounds__(512, 1) ++ allreduce8Mod(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, ++ size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) { ++ const int nPeer = nRanksPerNode - 1; ++ const size_t chanOffset = nPeer * blockIdx.x; ++ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) ++ const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); ++ size_t nInt4PerRank = nInt4 / worldSize; ++ if (nInt4 % worldSize) ++ nInt4PerRank = nInt4PerRank + 1; ++ ++ auto smChans = smChannels + chanOffset; ++ ++ size_t channelScratchResultOffset = channelScratchOffset + SCRATCH_SIZE/2; ++ ++ int4* buff4 = reinterpret_cast(buff); ++ int4* scratch4 = reinterpret_cast((char*)scratch + channelScratchOffset); ++ int4* scratch4Result = reinterpret_cast((char*)scratch + channelScratchResultOffset); ++ int4* resultBuff4 = reinterpret_cast(resultBuff); ++ ++ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` ++ constexpr size_t unitNInt4 = 512; ++ const size_t maxNInt4PerBlock = ++ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; ++ size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; ++ size_t nInt4OfThisBlock = maxNInt4PerBlock; ++ size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; ++ constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB ++ if (blockIdx.x >= nNeededBlocks) { ++ nInt4OfThisBlock = 0; ++ } else if (blockIdx.x == nNeededBlocks - 1) { ++ nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); ++ } ++ const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; ++ const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; ++ const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk; ++ const size_t blockOffset = nInt4PerChunk * blockIdx.x; ++ ++ const size_t scratchChunkRankOffset = chunkSizePerRank * rank; ++ const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4); ++ const size_t scratchResultBaseOffsetInt4 = channelScratchResultOffset / sizeof(int4); ++ ++ __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; ++ ++ const int lid = threadIdx.x % WARP_SIZE; ++ if (lid < nPeer) { ++ channels[lid] = smChans[lid]; ++ } ++ __syncwarp(); ++ ++ // we can use double buffering to hide synchronization overhead ++ for (size_t itr = 0; itr < nItrs; itr++) { ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ // Starts allgather ++ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { ++ for (int i = 0; i < NPEERS; i++) { ++ const int peerIdx = (i + blockIdx.x) % nPeer; ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; ++ channels[peerIdx].write(scratchBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, val); ++ } ++ } ++ ++ /// Starts reduce-scatter ++ // Ensure that all writes of this block have been issued before issuing the signal ++ __syncthreads(); ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { ++ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; ++ data = add_vectors(val, data); ++ } ++ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { ++ channels[peerIdx].write(scratchResultBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, data); ++ } ++ } ++ __syncthreads(); ++ ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int4 val = scratch4Result[chunkSizePerRank * remoteRank + blockOffset + idx]; ++ resultBuff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock] = val; ++ } ++ } ++ __syncthreads(); ++ ++ offsetOfThisBlock += nInt4PerChunk; ++ // Ensure all threads have consumed data from scratch buffer before signaling re-use in next iteration ++ } ++ if (restNInt4 > 0) { ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { ++ for (int i = 0; i < NPEERS; i++) { ++ const int peerIdx = (i + blockIdx.x) % nPeer; ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; ++ channels[peerIdx].write(scratchBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, val); ++ } ++ } ++ ++ // Ensure that all writes of this block have been issued before issuing the signal ++ __syncthreads(); ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { ++ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; ++ data = add_vectors(val, data); ++ } ++ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { ++ channels[peerIdx].write(scratchResultBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, data); ++ } ++ } ++ __syncthreads(); ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); ++ ++ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { ++ for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { ++ const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; ++ int4 val = scratch4Result[chunkSizePerRank * remoteRank + blockOffset + idx]; ++ resultBuff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock] = val; ++ } ++ } ++ __syncthreads(); ++ // Ensure all threads have issued writes to outChannel ++ } ++ // Threads are already synchronized ++ // So all writes to outChannel have been issued before signal is being issued ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++} ++ ++ + template + __global__ void __launch_bounds__(512, 1) + allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, +@@ -808,25 +979,9 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, + size_t nelems, cudaStream_t stream) { + static uint32_t flag = 1; +- int readAllred = 0, hieAllred = 0; +- char* envValue = nullptr; +- char* envValue1 = nullptr; + + nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode; + +- envValue = std::getenv("MSCCLPP_READ_ALLRED"); +- envValue1 = std::getenv("MSCCLPP_HIERARCHICAL_ALLRED"); +- +- if (envValue != nullptr) { +- if (atoi(envValue) == 1) { +- readAllred = 1; +- } +- } +- if (envValue1 != nullptr) { +- if (atoi(envValue1) == 1) { +- hieAllred = 1; +- } +- } + if (sizeof(T) * nelems < worldSize * sizeof(int)) { + int nBlocks = nRanksPerNode - 1; + int nThreadsPerBlock = 32; +@@ -852,16 +1007,21 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + } else { + int nBlocks = 8 * (nRanksPerNode - 1); + int nThreadsPerBlock = 512; +- if (hieAllred && worldSize >= 8) { ++ if (mscclppHierarchicalAllred && worldSize >= 8) { + nBlocks = 20; + allreduce10<<>>(buff, scratch, resultBuff, smChannels, smScrChannels, + smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, + worldSize, nelems); + } else { +- if (!readAllred) { +- allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, +- smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, +- worldSize, nelems); ++ if (!mscclppReadAllred) { ++ if (mscclppDisableRemoteUbr) { ++ allreduce8Mod<<>>(buff, scratch, resultBuff, smScrChannels, ++ channelScratchOffset, rank, nRanksPerNode, worldSize, nelems); ++ } else { ++ allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, ++ smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, ++ worldSize, nelems); ++ } + } else { + allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, + channelOutOffset, rank, nRanksPerNode, worldSize, nelems); +diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp +index a6056ea..6b7ec02 100644 +--- a/apps/nccl/src/common.hpp ++++ b/apps/nccl/src/common.hpp +@@ -17,7 +17,7 @@ constexpr int NRANKS1_PER_NODE = 4; + constexpr int NRANKS_PER_NODE = 8; + constexpr int NPEERS = 7; + +-constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 112; // double buffer * 56 thread-blocks * 8 ranks * 256KB = 112MB ++constexpr int SCRATCH_SIZE = 4 * 1024 * 1024 * 112; // double buffer * 56 thread-blocks * 8 ranks * 256KB = 112MB + + __device__ mscclpp::DeviceSyncer deviceSyncer; + +diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu +index 5fb99ef..1e8d739 100644 +--- a/apps/nccl/src/nccl.cu ++++ b/apps/nccl/src/nccl.cu +@@ -84,6 +84,7 @@ struct ncclComm { + + std::unordered_map channelInInfos; + std::unordered_map channelOutInfos; ++ std::vector channelInfos; + std::unordered_map channelScratchInfos; + std::unordered_map regHandles; + std::unordered_map handleKeys; +@@ -216,7 +217,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; + size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; + uint32_t scratchBuffIdx = (++(comm->buffFlag)) % comm->numScratchBuff; +- size_t offsetScratch = (SCRATCH_SIZE / comm->numScratchBuff) * scratchBuffIdx; ++ size_t offsetScratch = (SCRATCH_SIZE / (2 * comm->numScratchBuff)) * scratchBuffIdx; + int rank = comm->comm->bootstrap()->getRank(); + channelKey sendKey{(void*)sendBasePtr, sendBytes}; + channelKey recvKey{(void*)recvBasePtr, recvBytes}; +@@ -224,6 +225,8 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + mscclpp::DeviceHandle* smOutChannels = nullptr; + mscclpp::DeviceHandle* smScrChannels = nullptr; + ++ size_t bytes = count * ncclTypeSize(datatype); ++ + // Creating the channels + if (count * ncclTypeSize(datatype) <= (1 << 20)) { + auto sendIt = comm->channelScratchInfos.find(sendKey); +@@ -251,25 +254,39 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first; + } + +- auto recvIt = comm->channelOutInfos.find(recvKey); +- if (recvIt == comm->channelOutInfos.end()) { +- remoteMemories = +- setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); +- std::vector outChannels = +- setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); +- ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; +- recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; ++ if(mscclppDisableRemoteUbr == false) { ++ auto recvIt = comm->channelOutInfos.find(recvKey); ++ if (recvIt == comm->channelOutInfos.end() || mscclppDisableChannelCache == true) { ++ if (mscclppDisableChannelCache == true) { ++ recvBytes = bytes; ++ recvBasePtr = (CUdeviceptr)recvbuff; ++ offsetOut = 0; ++ } ++ remoteMemories = ++ setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); ++ std::vector outChannels = ++ setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); ++ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; ++ if (mscclppDisableChannelCache == true) { ++ comm->channelInfos.push_back(channelInfo); ++ smOutChannels = comm->channelInfos.back().smChannelDeviceHandles.get(); ++ } else { ++ recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; ++ smOutChannels = recvIt->second.smChannelDeviceHandles.get(); ++ } ++ } else { ++ smOutChannels = recvIt->second.smChannelDeviceHandles.get(); ++ } + } + + smChannels = sendIt->second.smChannelDeviceHandles1.get(); +- smOutChannels = recvIt->second.smChannelDeviceHandles.get(); + smScrChannels = sendIt->second.smChannelDeviceHandles.get(); + } + + switch (datatype) { + case ncclFloat16: + CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, +- smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); ++ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclFloat32: + CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, +@@ -323,7 +340,12 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, + mscclpp::DeviceHandle* smChannels = nullptr; + + auto it = comm->channelOutInfos.find(recvKey); +- if (it == comm->channelOutInfos.end()) { ++ if (it == comm->channelOutInfos.end() || mscclppDisableChannelCache == true) { ++ if (mscclppDisableChannelCache == true) { ++ recvBytes = bytes; ++ recvBasePtr = (CUdeviceptr)recvbuff; ++ offsetOut = 0; ++ } + std::vector remoteMemories = setupRemoteMemories( + comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); + std::vector channels = +@@ -332,10 +354,17 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, + std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; +- it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; ++ if (mscclppDisableChannelCache == true) { ++ comm->channelInfos.push_back(channelInfo); ++ smChannels = comm->channelInfos.back().smChannelDeviceHandles.get(); ++ } else { ++ it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; ++ smChannels = it->second.smChannelDeviceHandles.get(); ++ } ++ } else { ++ smChannels = it->second.smChannelDeviceHandles.get(); + } + +- smChannels = it->second.smChannelDeviceHandles.get(); + if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { + CUDACHECK(allgather((int*)sendbuff, (int*)nullptr, (int*)recvbuff, smChannels, offsetOut, rank, + NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); +diff --git a/src/debug.cc b/src/debug.cc +index a8350fb..d49eb4a 100644 +--- a/src/debug.cc ++++ b/src/debug.cc +@@ -15,6 +15,11 @@ + #include + + int mscclppDebugLevel = -1; ++bool mscclppDisableChannelCache = false; ++bool mscclppDisableRemoteUbr = false; ++bool mscclppReadAllred = false; ++bool mscclppHierarchicalAllred = false; ++ + static int pid = -1; + static std::string hostname; + thread_local int mscclppDebugNoWarn = 0; +@@ -51,6 +56,42 @@ void mscclppDebugInit() { + tempNcclDebugLevel = MSCCLPP_LOG_TRACE; + } + ++ const char* disable_channel_cache = getenv("MSCCLPP_DISABLE_CHANNEL_CACHE"); ++ if (disable_channel_cache == NULL) { ++ mscclppDisableChannelCache = false; ++ } else if (strcasecmp(disable_channel_cache, "TRUE") == 0) { ++ mscclppDisableChannelCache = true; ++ } else { ++ mscclppDisableChannelCache = false; ++ } ++ ++ const char* enable_read_allred = getenv("MSCCLPP_READ_ALLRED"); ++ if (enable_read_allred == NULL) { ++ mscclppReadAllred = false; ++ } else if (strcasecmp(enable_read_allred, "TRUE") == 0) { ++ mscclppReadAllred = true; ++ } else { ++ mscclppReadAllred = false; ++ } ++ ++ const char* enable_hierarchical_allred = getenv("MSCCLPP_HIERARCHICAL_ALLRED"); ++ if (enable_hierarchical_allred == NULL) { ++ mscclppHierarchicalAllred = false; ++ } else if (strcasecmp(enable_hierarchical_allred, "TRUE") == 0) { ++ mscclppHierarchicalAllred = true; ++ } else { ++ mscclppHierarchicalAllred = false; ++ } ++ ++ const char* disable_remote_ubr = getenv("MSCCLPP_DISABLE_REMOTE_UBR"); ++ if (disable_remote_ubr == NULL) { ++ mscclppDisableRemoteUbr = false; ++ } else if (strcasecmp(disable_remote_ubr, "TRUE") == 0) { ++ mscclppDisableRemoteUbr = true; ++ } else { ++ mscclppDisableRemoteUbr = false; ++ } ++ + /* Parse the MSCCLPP_DEBUG_SUBSYS env var + * This can be a comma separated list such as INIT,COLL + * or ^INIT,COLL etc +diff --git a/src/include/debug.h b/src/include/debug.h +index 713371b..033e5eb 100644 +--- a/src/include/debug.h ++++ b/src/include/debug.h +@@ -91,6 +91,12 @@ typedef enum { + + extern int mscclppDebugLevel; + extern uint64_t mscclppDebugMask; ++ ++extern bool mscclppDisableChannelCache; ++extern bool mscclppReadAllred; ++extern bool mscclppHierarchicalAllred; ++extern bool mscclppDisableRemoteUbr; ++ + extern pthread_mutex_t mscclppDebugLock; + extern FILE* mscclppDebugFile; +