diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index 3e49aeac2e..c00db63d6d 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -419,12 +419,30 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count mscclWork work; work.syncFlags = status.syncFlags; size_t sizeNeeded = computeSizeNeeded(status.nBytes, hostAlgo->nScratchChunks, hostAlgo->nChunksPerLoop); - if (status.scratchBuffers.find(sizeNeeded) == status.scratchBuffers.end()) { - void *scratchBuffer = nullptr; - NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeNeeded, true)); - status.scratchBuffers[sizeNeeded] = scratchBuffer; + if (sizeNeeded > 0) { + auto itr = status.scratchBuffers.lower_bound(sizeNeeded); + if (itr == status.scratchBuffers.end()) { + void *scratchBuffer = nullptr; + size_t sizeRounded = 1; + if (status.scratchBuffers.size() > 0) { + sizeRounded = status.scratchBuffers.rbegin()->first; + } + while (sizeRounded < sizeNeeded) { + if (sizeRounded >= sizeRounded * 2) { + WARN("MSCCL: Size of allocation for scratch buffer (%lu * 2) will wrap around", sizeRounded); + return ncclInvalidUsage; + } + sizeRounded *= 2; + } + NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeRounded, true)); + work.scratchBuffer = status.scratchBuffers[sizeRounded] = scratchBuffer; + INFO(NCCL_INIT, "MSCCL: Allocated scratch buffer of size %lu on request (%lu)", sizeRounded, sizeNeeded); + } else { + work.scratchBuffer = itr->second; + } + } else { + work.scratchBuffer = nullptr; } - work.scratchBuffer = status.scratchBuffers[sizeNeeded]; work.sendBuff = sendBuff; work.recvBuff = recvBuff; work.sizePerMscclChunk = count * hostAlgo->sizeMultiplier / hostAlgo->nChunksPerLoop; // count is sum of all ranks in MSCCL kernel