diff --git a/ext-src/read-allred.patch b/ext-src/read-allred.patch index e386c64b74..96135b55db 100644 --- a/ext-src/read-allred.patch +++ b/ext-src/read-allred.patch @@ -162,7 +162,7 @@ index 1b85136..a45345a 100644 template cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, - mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, -+ mscclpp::DeviceHandle* smScrChannels, ++ mscclpp::DeviceHandle* smScrChannels, + mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { @@ -173,9 +173,9 @@ index 1b85136..a45345a 100644 + envValue = std::getenv("MSCCLPP_READ_ALLRED"); + + if (envValue != nullptr) { -+ if (atoi(envValue) == 1) { -+ readAllred = 1; -+ } ++ if (atoi(envValue) == 1) { ++ readAllred = 1; ++ } + } if (sizeof(T) * nelems < worldSize * sizeof(int)) { @@ -184,7 +184,7 @@ index 1b85136..a45345a 100644 - allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, - channelScratchOffset, rank, nRanksPerNode, worldSize, - nelems, flag++); -+ allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, ++ allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, + channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++); } else if (sizeof(T) * nelems <= (1 << 20)) { int nBlocks = 28; @@ -195,13 +195,13 @@ index 1b85136..a45345a 100644 int nThreadsPerBlock = 512; - allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, + if (!readAllred) { -+ allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, smOutChannels, ++ allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems); + } else { + allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, + channelOutOffset, rank, nRanksPerNode, -+ worldSize, nelems); ++ worldSize, nelems); + } } @@ -235,7 +235,7 @@ index ec130b0..571508d 100644 int rank = comm->comm->bootstrap()->getRank(); channelKey sendKey{(void*)sendBasePtr, sendBytes}; channelKey recvKey{(void*)recvBasePtr, recvBytes}; -+ ++ mscclpp::DeviceHandle* smChannels = nullptr; mscclpp::DeviceHandle* smOutChannels = nullptr; + mscclpp::DeviceHandle* smScrChannels = nullptr; @@ -290,30 +290,30 @@ index ec130b0..571508d 100644 - CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels, - offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); -+ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, ++ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, + smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclFloat32: CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, - smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), -+ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, -+ comm->comm->bootstrap()->getRank(), ++ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, ++ comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclBfloat16: CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff, - smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); -+ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, -+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); ++ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, ++ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclInt32: case ncclUint32: - CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels, - offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); -+ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, -+ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), ++ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, ++ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; default: