ext-src: Fix compiler warnings for MSCCLPP integration (#1368)

This commit is contained in:
Nusrat Islam
2024-10-10 08:20:02 -05:00
gecommit door GitHub
bovenliggende 364a6c2130
commit 6160603d4c
+15 -15
Bestand weergeven
@@ -162,7 +162,7 @@ index 1b85136..a45345a 100644
template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
- mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
@@ -173,9 +173,9 @@ index 1b85136..a45345a 100644
+ envValue = std::getenv("MSCCLPP_READ_ALLRED");
+
+ if (envValue != nullptr) {
+ if (atoi(envValue) == 1) {
+ readAllred = 1;
+ }
+ if (atoi(envValue) == 1) {
+ readAllred = 1;
+ }
+ }
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
@@ -184,7 +184,7 @@ index 1b85136..a45345a 100644
- allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
- channelScratchOffset, rank, nRanksPerNode, worldSize,
- nelems, flag++);
+ allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
+ allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
+ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
@@ -195,13 +195,13 @@ index 1b85136..a45345a 100644
int nThreadsPerBlock = 512;
- allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
+ if (!readAllred) {
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels, smOutChannels,
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels, smOutChannels,
channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);
+ } else {
+ allreduce8Read<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels,
+ channelOutOffset, rank, nRanksPerNode,
+ worldSize, nelems);
+ worldSize, nelems);
+ }
}
@@ -235,7 +235,7 @@ index ec130b0..571508d 100644
int rank = comm->comm->bootstrap()->getRank();
channelKey sendKey{(void*)sendBasePtr, sendBytes};
channelKey recvKey{(void*)recvBasePtr, recvBytes};
+
+
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels = nullptr;
@@ -290,30 +290,30 @@ index ec130b0..571508d 100644
- CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
- offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
- comm->comm->bootstrap()->getNranks(), count, stream));
+ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels,
+ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels,
+ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
- smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
+ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch,
+ comm->comm->bootstrap()->getRank(),
+ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch,
+ comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclBfloat16:
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
- smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
- comm->comm->bootstrap()->getNranks(), count, stream));
+ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank,
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
+ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank,
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
- CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
- offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
- comm->comm->bootstrap()->getNranks(), count, stream));
+ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels,
+ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
+ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels,
+ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
default: