ext-src: Fix compiler warnings for MSCCLPP integration (#1368)
This commit is contained in:
gecommit door
GitHub
bovenliggende
364a6c2130
commit
6160603d4c
@@ -162,7 +162,7 @@ index 1b85136..a45345a 100644
|
||||
template <typename T>
|
||||
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
|
||||
- mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
|
||||
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
|
||||
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
|
||||
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
|
||||
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems, cudaStream_t stream) {
|
||||
@@ -173,9 +173,9 @@ index 1b85136..a45345a 100644
|
||||
+ envValue = std::getenv("MSCCLPP_READ_ALLRED");
|
||||
+
|
||||
+ if (envValue != nullptr) {
|
||||
+ if (atoi(envValue) == 1) {
|
||||
+ readAllred = 1;
|
||||
+ }
|
||||
+ if (atoi(envValue) == 1) {
|
||||
+ readAllred = 1;
|
||||
+ }
|
||||
+ }
|
||||
|
||||
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
|
||||
@@ -184,7 +184,7 @@ index 1b85136..a45345a 100644
|
||||
- allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
|
||||
- channelScratchOffset, rank, nRanksPerNode, worldSize,
|
||||
- nelems, flag++);
|
||||
+ allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
|
||||
+ allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
|
||||
+ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++);
|
||||
} else if (sizeof(T) * nelems <= (1 << 20)) {
|
||||
int nBlocks = 28;
|
||||
@@ -195,13 +195,13 @@ index 1b85136..a45345a 100644
|
||||
int nThreadsPerBlock = 512;
|
||||
- allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
|
||||
+ if (!readAllred) {
|
||||
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels, smOutChannels,
|
||||
+ allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels, smOutChannels,
|
||||
channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
||||
worldSize, nelems);
|
||||
+ } else {
|
||||
+ allreduce8Read<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels,
|
||||
+ channelOutOffset, rank, nRanksPerNode,
|
||||
+ worldSize, nelems);
|
||||
+ worldSize, nelems);
|
||||
+ }
|
||||
}
|
||||
|
||||
@@ -235,7 +235,7 @@ index ec130b0..571508d 100644
|
||||
int rank = comm->comm->bootstrap()->getRank();
|
||||
channelKey sendKey{(void*)sendBasePtr, sendBytes};
|
||||
channelKey recvKey{(void*)recvBasePtr, recvBytes};
|
||||
+
|
||||
+
|
||||
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
|
||||
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;
|
||||
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels = nullptr;
|
||||
@@ -290,30 +290,30 @@ index ec130b0..571508d 100644
|
||||
- CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
|
||||
- offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
|
||||
- comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
+ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels,
|
||||
+ CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels,
|
||||
+ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
break;
|
||||
case ncclFloat32:
|
||||
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
|
||||
- smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
|
||||
+ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch,
|
||||
+ comm->comm->bootstrap()->getRank(),
|
||||
+ smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch,
|
||||
+ comm->comm->bootstrap()->getRank(),
|
||||
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
break;
|
||||
case ncclBfloat16:
|
||||
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
|
||||
- smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
|
||||
- comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
+ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank,
|
||||
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
+ smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank,
|
||||
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
break;
|
||||
case ncclInt32:
|
||||
case ncclUint32:
|
||||
- CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
|
||||
- offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
|
||||
- comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
+ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels,
|
||||
+ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
|
||||
+ CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels,
|
||||
+ smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
|
||||
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
break;
|
||||
default:
|
||||
|
||||
Verwijs in nieuw issue
Block a user