From 293f0fb752f524ee8e757b1f71062ab6d12640ca Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Tue, 5 Dec 2023 13:15:28 -0600 Subject: [PATCH] Use a map to host scratch buffers (#1004) * Use a map to host scratch buffers * Address review feedbacks. Deliberately keep mscclSetupScratch function. --- src/include/msccl/msccl_struct.h | 3 +-- src/misc/msccl/msccl_lifecycle.cc | 9 ++++----- src/misc/msccl/msccl_setup.cc | 18 +++++++++++------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/include/msccl/msccl_struct.h b/src/include/msccl/msccl_struct.h index 5e5560736b..b96f4d9b69 100644 --- a/src/include/msccl/msccl_struct.h +++ b/src/include/msccl/msccl_struct.h @@ -205,8 +205,7 @@ struct mscclStatus { std::map hostAlgos; std::map devAlgos; struct mscclFlag* syncFlags; - void *scratchBuffer; - uint64_t scratchBufferSize; + std::map scratchBuffers; size_t nBytes; int stepSize; int chunkSteps; diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index 8105dafc40..932fa2efed 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -236,8 +236,6 @@ ncclResult_t mscclInit(ncclComm_t comm) { return ncclSuccess; } - status.scratchBuffer = nullptr; - status.scratchBufferSize = 0; status.workIndex = 1; NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS)); status.lastStream = nullptr; @@ -542,13 +540,14 @@ ncclResult_t mscclTeardown() { for (auto &p : status.devAlgos) { CUDACHECK(hipFree(p.second)); } - CUDACHECK(hipFree(status.scratchBuffer)); CUDACHECK(hipFree(status.syncFlags)); status.hostAlgos.clear(); status.devAlgos.clear(); status.freeAlgoHandles.clear(); - status.scratchBuffer = nullptr; - status.scratchBufferSize = 0; + for (auto &p : status.scratchBuffers) { + CUDACHECK(hipFree(p.second)); + } + status.scratchBuffers.clear(); status.connectedAlgos.clear(); if (status.mscclSchedulerPtr) { NCCLCHECK(status.mscclSchedulerPtr->teardown()); diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index 0f25a84e83..d999cb1804 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -22,6 +22,10 @@ RCCL_PARAM(MscclEnableDoneEvent, "MSCCL_ENABLE_DONE_EVENT", 1); RCCL_PARAM(MscclWorkFifoDepth, "MSCCL_WORK_FIFO_DEPTH", 64<<10); +static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nChunksPerLoop) { + return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop; +} + ncclResult_t mscclGetCaptureStatus(hipStream_t stream) { mscclStatus& status = mscclGetStatus(); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); @@ -70,12 +74,6 @@ ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) { mscclStatus& status = mscclGetStatus(); - size_t sizeNeeded = (status.nBytes * (size_t)(hostAlgo->nScratchChunks)) / (size_t)(hostAlgo->nChunksPerLoop); - if (sizeNeeded > status.scratchBufferSize){ - NCCLCHECK(ncclCudaFree(status.scratchBuffer)); - NCCLCHECK(ncclCudaMalloc((char**)&status.scratchBuffer, sizeNeeded, true)); - status.scratchBufferSize = sizeNeeded; - } return ncclSuccess; } @@ -422,7 +420,13 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count mscclWork work; work.syncFlags = status.syncFlags; - work.scratchBuffer = status.scratchBuffer; + size_t sizeNeeded = computeSizeNeeded(status.nBytes, hostAlgo->nScratchChunks, hostAlgo->nChunksPerLoop); + if (status.scratchBuffers.find(sizeNeeded) == status.scratchBuffers.end()) { + void *scratchBuffer = nullptr; + NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeNeeded, true)); + status.scratchBuffers[sizeNeeded] = scratchBuffer; + } + work.scratchBuffer = status.scratchBuffers[sizeNeeded]; work.sendBuff = sendBuff; work.recvBuff = recvBuff; work.sizePerMscclChunk = count * hostAlgo->sizeMultiplier / hostAlgo->nChunksPerLoop; // count is sum of all ranks in MSCCL kernel