diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 4134241..76674ba 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -495,24 +495,345 @@ __global__ void __launch_bounds__(512, 1) } } +template +__global__ void __launch_bounds__(512, 1) + allreduce8Read(T* buff, T* resultBuff, mscclpp::DeviceHandle* smChannels, + mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, + int rank, int nRanksPerNode, int worldSize, size_t nelems) { + const int nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) + const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); + const size_t nInt4PerRank = nInt4 / worldSize; + auto smChans = smChannels + chanOffset; + auto smOutChans = smOutChannels + chanOffset; + + int4* buff4 = reinterpret_cast(buff); + int4* resultBuff4 = reinterpret_cast(resultBuff); + + // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` + constexpr size_t unitNInt4 = 512; + const size_t maxNInt4PerBlock = + (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; + size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; + size_t nInt4OfThisBlock = maxNInt4PerBlock; + size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; + constexpr size_t nInt4PerChunk = 1024 * 512 / sizeof(int4); // 512KB + if (blockIdx.x >= nNeededBlocks) { + nInt4OfThisBlock = 0; + } else if (blockIdx.x == nNeededBlocks - 1) { + nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); + } + + const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; + const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; + + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + __shared__ mscclpp::DeviceHandle outChannels[NRANKS_PER_NODE - 1]; + const int lid = threadIdx.x % WARP_SIZE; + if (lid < nPeer) { + channels[lid] = smChans[lid]; + outChannels[lid] = smOutChans[lid]; + } + __syncwarp(); + + // we can use double buffering to hide synchronization overhead + for (size_t itr = 0; itr < nItrs; itr++) { + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); + } + } + __syncthreads(); + + offsetOfThisBlock += nInt4PerChunk; + } + + if (restNInt4 > 0) { + if (threadIdx.x < static_cast(nPeer)) { + channels[threadIdx.x].signal(); + channels[threadIdx.x].wait(); + + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); + } + } + __syncthreads(); + } + + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + +} + +template +__global__ void __launch_bounds__(1024, 1) + allreduce10(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + mscclpp::DeviceHandle* smScrChannels, + mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, + size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) { + const int nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) + const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); + const size_t nInt4PerRank = nInt4 / NRANKS1_PER_NODE; + + auto smChans = smChannels + chanOffset; + auto smOutChans = smOutChannels + chanOffset; + auto smScrChans = smScrChannels + chanOffset; + + int4* buff4 = reinterpret_cast(buff); + int4* scratch4 = reinterpret_cast((char*)scratch + channelScratchOffset); + int4* resultBuff4 = reinterpret_cast(resultBuff); + + // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` + constexpr size_t unitNInt4 = 512; + const size_t maxNInt4PerBlock = + (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; + size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; + size_t nInt4OfThisBlock = maxNInt4PerBlock; + size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; + + constexpr size_t nInt4PerChunk = 1024 * 1024 / sizeof(int4); // 256KB + int num_nodes = worldSize/NRANKS1_PER_NODE; + + if (blockIdx.x >= nNeededBlocks) { + nInt4OfThisBlock = 0; + } else if (blockIdx.x == nNeededBlocks - 1) { + nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); + } + + const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; + const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; + + const size_t blockOffset = nInt4PerChunk * blockIdx.x; + + int localRank = rank % NRANKS1_PER_NODE; + + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + __shared__ mscclpp::DeviceHandle outChannels[NRANKS_PER_NODE - 1]; + __shared__ mscclpp::DeviceHandle scrChannels[NRANKS_PER_NODE - 1]; + + const int lid = threadIdx.x % WARP_SIZE; + if (lid < nPeer) { + channels[lid] = smChans[lid]; + outChannels[lid] = smOutChans[lid]; + scrChannels[lid] = smScrChans[lid]; + } + __syncwarp(); + + // we can use double buffering to hide synchronization overhead + for (size_t itr = 0; itr < nItrs; itr++) { + if (threadIdx.x < (NRANKS1_PER_NODE-1)) { + int myNode = rank/NRANKS1_PER_NODE; + int remote = (threadIdx.x + 1 + rank); + int remoteNode = remote/NRANKS1_PER_NODE; + + if (remoteNode > myNode) { + remote = remote - NRANKS1_PER_NODE; + } + int peerIdx = remote < rank ? remote : remote - 1; + outChannels[peerIdx].signal(); + outChannels[peerIdx].wait(); + } + __syncthreads(); + + int myNode = rank/NRANKS1_PER_NODE; + + //Reduce within an OAM + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock]; + for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + + NRANKS1_PER_NODE - 1); peerIdx++) { + int4 val = channels[peerIdx].read(nInt4PerRank * localRank + offsetOfThisBlock + idx); + data = add_vectors(val, data); + } + scratch4[idx + blockOffset] = data; + } + + if (threadIdx.x < static_cast(num_nodes-1)) { + int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize; + int peerIdx = remote < rank ? remote : remote - 1; + scrChannels[peerIdx].signal(); + scrChannels[peerIdx].wait(); + } + __syncthreads(); + + //Reduce across OAMs + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = scratch4[idx + blockOffset]; + + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int myLocal = rank % NRANKS1_PER_NODE; + int remoteLocal = remoteRank % NRANKS1_PER_NODE; + + if (myLocal == remoteLocal) { + int4 val = scrChannels[peerIdx].read(blockOffset + idx + + channelScratchOffset/sizeof(int4)); + data = add_vectors(val, data); + } + } + + resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data; + + for (int peerIdx = NRANKS1_PER_NODE*myNode; peerIdx < (NRANKS1_PER_NODE*myNode + NRANKS1_PER_NODE - 1); peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock + + channelOutDataOffset / sizeof(int4), data); + } + } + + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + + offsetOfThisBlock += nInt4PerChunk; + } + + if (restNInt4 > 0) { + if (threadIdx.x < (NRANKS1_PER_NODE-1)) { + int myNode = rank/NRANKS1_PER_NODE; + int remote = (threadIdx.x + 1 + rank); + int remoteNode = remote/NRANKS1_PER_NODE; + + if (remoteNode > myNode) { + remote = remote - NRANKS1_PER_NODE; + } + int peerIdx = remote < rank ? remote : remote - 1; + + outChannels[peerIdx].signal(); + outChannels[peerIdx].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * localRank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + + int myNode = rank/NRANKS1_PER_NODE; + int remoteNode = remoteRank/NRANKS1_PER_NODE; + + if (myNode == remoteNode) { + int4 val = channels[peerIdx].read(nInt4PerRank * localRank + offsetOfThisBlock + idx); + data = add_vectors(val, data); + } + } + scratch4[idx + blockOffset] = data; + } + + if (threadIdx.x < static_cast(num_nodes-1)) { + int remote = (NRANKS1_PER_NODE * (threadIdx.x + 1) + rank) % worldSize; + int peerIdx = remote < rank ? remote : remote - 1; + scrChannels[peerIdx].signal(); + scrChannels[peerIdx].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = scratch4[idx + blockOffset]; + + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int myLocal = rank % NRANKS1_PER_NODE; + int remoteLocal = remoteRank % NRANKS1_PER_NODE; + + if (myLocal == remoteLocal) { + int4 val = scrChannels[peerIdx].read(blockOffset + idx + + channelScratchOffset/sizeof(int4)); + data = add_vectors(val, data); + } + } + + resultBuff4[nInt4PerRank * localRank + idx + offsetOfThisBlock] = data; + for (int peerIdx = 0; peerIdx < NPEERS; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int myNode = rank/NRANKS1_PER_NODE; + int remoteNode = remoteRank/NRANKS1_PER_NODE; + + if (myNode == remoteNode) { + outChannels[peerIdx].write(nInt4PerRank * localRank + idx + offsetOfThisBlock + + channelOutDataOffset / sizeof(int4), data); + } + } + } + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + + } + +} + template cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, - mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, + mscclpp::DeviceHandle* smScrChannels, + mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) { static uint32_t flag = 1; + int readAllred = 0, hieAllred = 0; + char* envValue = nullptr; + char* envValue1 = nullptr; + nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode; + + envValue = std::getenv("MSCCLPP_READ_ALLRED"); + envValue1 = std::getenv("MSCCLPP_HIERARCHICAL_ALLRED"); + + if (envValue != nullptr) { + if (atoi(envValue) == 1) { + readAllred = 1; + } + } + if (envValue1 != nullptr) { + if (atoi(envValue1) == 1) { + hieAllred = 1; + } + } if (sizeof(T) * nelems < worldSize * sizeof(int)) { - int nBlocks = 7; + int nBlocks = nRanksPerNode - 1; int nThreadsPerBlock = 32; - allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, - channelScratchOffset, rank, nRanksPerNode, worldSize, - nelems, flag++); + allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, + channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++); } else if (sizeof(T) * nelems <= (1 << 20)) { - int nBlocks = 28; + int nBlocks = 4*(nRanksPerNode - 1); int nThreadsPerBlock = 1024; if (nelems >= 8192) { - nBlocks = 56; + nBlocks = 8*(nRanksPerNode - 1); nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; } #if defined(ENABLE_NPKIT) @@ -526,11 +847,23 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< flag++); #endif } else { - int nBlocks = 35; + int nBlocks = 5*(nRanksPerNode - 1); int nThreadsPerBlock = 512; - allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, - channelOutOffset, channelScratchOffset, rank, nRanksPerNode, - worldSize, nelems); + if (hieAllred && worldSize >= 8) { + nBlocks = 20; + allreduce10<<>>(buff, scratch, resultBuff, smChannels, smScrChannels, + smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, + worldSize, nelems); + } else { + if (!readAllred) { + allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, + smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, + worldSize, nelems); + } else { + allreduce8Read<<>>(buff, resultBuff, smChannels, smOutChannels, + channelOutOffset, rank, nRanksPerNode, worldSize, nelems); + } + } } return cudaGetLastError(); diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp index 015e0a2..ca2c272 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp @@ -13,6 +13,7 @@ #define WARP_SIZE 32 #endif +constexpr int NRANKS1_PER_NODE = 4; constexpr int NRANKS_PER_NODE = 8; constexpr int NPEERS = 7; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index f91d15e..022d398 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -70,7 +70,9 @@ struct hash { struct ChannelInfo { std::vector smChannels; + std::vector smChannels1; std::shared_ptr> smChannelDeviceHandles; + std::shared_ptr> smChannelDeviceHandles1; }; struct ncclComm { @@ -213,6 +215,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, channelKey recvKey{(void*)recvBasePtr, recvBytes}; mscclpp::DeviceHandle* smChannels = nullptr; mscclpp::DeviceHandle* smOutChannels = nullptr; + mscclpp::DeviceHandle* smScrChannels = nullptr; // Creating the channels if (count * ncclTypeSize(datatype) <= (1 << 20)) { @@ -220,19 +223,24 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, if (sendIt == comm->channelScratchInfos.end()) { std::vector channels = setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first; } smChannels = sendIt->second.smChannelDeviceHandles.get(); } else { std::vector remoteMemories; - + std::vector remoteMemories1; auto sendIt = comm->channelInInfos.find(sendKey); if (sendIt == comm->channelInInfos.end()) { std::vector channels = setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)sendBasePtr)); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + remoteMemories1 = + setupRemoteMemories(comm->comm, rank, (void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc); + std::vector channels1 = + setupSmChannels(comm, remoteMemories1, const_cast((void*)sendBasePtr)); + + ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first; } @@ -242,35 +250,36 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); std::vector outChannels = setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); - ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)}; + ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first; } - smChannels = sendIt->second.smChannelDeviceHandles.get(); + smChannels = sendIt->second.smChannelDeviceHandles1.get(); smOutChannels = recvIt->second.smChannelDeviceHandles.get(); + smScrChannels = sendIt->second.smChannelDeviceHandles.get(); } switch (datatype) { case ncclFloat16: - CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels, - offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, + smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclFloat32: CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, - smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), + smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, + comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclBfloat16: CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff, - smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; case ncclInt32: case ncclUint32: - CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels, - offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, + smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); break; default: WARN("datatype is invalid"); @@ -315,7 +324,7 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, std::vector> smChannelDeviceHandles; std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; } @@ -597,7 +606,7 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff std::vector> smChannelDeviceHandles; std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); - ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; } @@ -805,16 +814,6 @@ NCCL_API ncclResult_t ncclGroupEnd() { return ncclSuccess; } -NCCL_API ncclResult_t ncclCommRegister(const ncclComm_t, void*, size_t, void**) { - // TODO: Implementation - return ncclSuccess; -} - -NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) { - // TODO: Implementation - return ncclSuccess; -} - ncclResult_t ncclMemAlloc(void** ptr, size_t size) { if (ptr == nullptr || size == 0) { WARN("ptr is nullptr or size is 0");