Fix MSCCL scratch allocation (#1010)

Este commit está contenido en:
Ziyue Yang
2023-12-09 07:47:10 +08:00
cometido por GitHub
padre bb144dcd50
commit c002f20029
+23 -5
Ver fichero
@@ -419,12 +419,30 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
mscclWork work;
work.syncFlags = status.syncFlags;
size_t sizeNeeded = computeSizeNeeded(status.nBytes, hostAlgo->nScratchChunks, hostAlgo->nChunksPerLoop);
if (status.scratchBuffers.find(sizeNeeded) == status.scratchBuffers.end()) {
void *scratchBuffer = nullptr;
NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeNeeded, true));
status.scratchBuffers[sizeNeeded] = scratchBuffer;
if (sizeNeeded > 0) {
auto itr = status.scratchBuffers.lower_bound(sizeNeeded);
if (itr == status.scratchBuffers.end()) {
void *scratchBuffer = nullptr;
size_t sizeRounded = 1;
if (status.scratchBuffers.size() > 0) {
sizeRounded = status.scratchBuffers.rbegin()->first;
}
while (sizeRounded < sizeNeeded) {
if (sizeRounded >= sizeRounded * 2) {
WARN("MSCCL: Size of allocation for scratch buffer (%lu * 2) will wrap around", sizeRounded);
return ncclInvalidUsage;
}
sizeRounded *= 2;
}
NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeRounded, true));
work.scratchBuffer = status.scratchBuffers[sizeRounded] = scratchBuffer;
INFO(NCCL_INIT, "MSCCL: Allocated scratch buffer of size %lu on request (%lu)", sizeRounded, sizeNeeded);
} else {
work.scratchBuffer = itr->second;
}
} else {
work.scratchBuffer = nullptr;
}
work.scratchBuffer = status.scratchBuffers[sizeNeeded];
work.sendBuff = sendBuff;
work.recvBuff = recvBuff;
work.sizePerMscclChunk = count * hostAlgo->sizeMultiplier / hostAlgo->nChunksPerLoop; // count is sum of all ranks in MSCCL kernel