Fix MSCCL scratch allocation (#1010)
Este commit está contenido en:
@@ -419,12 +419,30 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
mscclWork work;
|
||||
work.syncFlags = status.syncFlags;
|
||||
size_t sizeNeeded = computeSizeNeeded(status.nBytes, hostAlgo->nScratchChunks, hostAlgo->nChunksPerLoop);
|
||||
if (status.scratchBuffers.find(sizeNeeded) == status.scratchBuffers.end()) {
|
||||
void *scratchBuffer = nullptr;
|
||||
NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeNeeded, true));
|
||||
status.scratchBuffers[sizeNeeded] = scratchBuffer;
|
||||
if (sizeNeeded > 0) {
|
||||
auto itr = status.scratchBuffers.lower_bound(sizeNeeded);
|
||||
if (itr == status.scratchBuffers.end()) {
|
||||
void *scratchBuffer = nullptr;
|
||||
size_t sizeRounded = 1;
|
||||
if (status.scratchBuffers.size() > 0) {
|
||||
sizeRounded = status.scratchBuffers.rbegin()->first;
|
||||
}
|
||||
while (sizeRounded < sizeNeeded) {
|
||||
if (sizeRounded >= sizeRounded * 2) {
|
||||
WARN("MSCCL: Size of allocation for scratch buffer (%lu * 2) will wrap around", sizeRounded);
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
sizeRounded *= 2;
|
||||
}
|
||||
NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeRounded, true));
|
||||
work.scratchBuffer = status.scratchBuffers[sizeRounded] = scratchBuffer;
|
||||
INFO(NCCL_INIT, "MSCCL: Allocated scratch buffer of size %lu on request (%lu)", sizeRounded, sizeNeeded);
|
||||
} else {
|
||||
work.scratchBuffer = itr->second;
|
||||
}
|
||||
} else {
|
||||
work.scratchBuffer = nullptr;
|
||||
}
|
||||
work.scratchBuffer = status.scratchBuffers[sizeNeeded];
|
||||
work.sendBuff = sendBuff;
|
||||
work.recvBuff = recvBuff;
|
||||
work.sizePerMscclChunk = count * hostAlgo->sizeMultiplier / hostAlgo->nChunksPerLoop; // count is sum of all ranks in MSCCL kernel
|
||||
|
||||
Referencia en una nueva incidencia
Block a user