Use a map to host scratch buffers (#1004)

* Use a map to host scratch buffers

* Address review feedbacks. Deliberately keep mscclSetupScratch function.
此提交包含在:
Wen-Heng (Jack) Chung
2023-12-05 13:15:28 -06:00
提交者 GitHub
父節點 bc44e3faa7
當前提交 293f0fb752
共有 3 個檔案被更改,包括 16 行新增14 行删除
+1 -2
查看文件
@@ -205,8 +205,7 @@ struct mscclStatus {
std::map<mscclAlgoHandle_t, mscclAlgo *> hostAlgos;
std::map<mscclAlgoHandle_t, mscclAlgo *> devAlgos;
struct mscclFlag* syncFlags;
void *scratchBuffer;
uint64_t scratchBufferSize;
std::map<size_t, void *> scratchBuffers;
size_t nBytes;
int stepSize;
int chunkSteps;
+4 -5
查看文件
@@ -236,8 +236,6 @@ ncclResult_t mscclInit(ncclComm_t comm) {
return ncclSuccess;
}
status.scratchBuffer = nullptr;
status.scratchBufferSize = 0;
status.workIndex = 1;
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
status.lastStream = nullptr;
@@ -542,13 +540,14 @@ ncclResult_t mscclTeardown() {
for (auto &p : status.devAlgos) {
CUDACHECK(hipFree(p.second));
}
CUDACHECK(hipFree(status.scratchBuffer));
CUDACHECK(hipFree(status.syncFlags));
status.hostAlgos.clear();
status.devAlgos.clear();
status.freeAlgoHandles.clear();
status.scratchBuffer = nullptr;
status.scratchBufferSize = 0;
for (auto &p : status.scratchBuffers) {
CUDACHECK(hipFree(p.second));
}
status.scratchBuffers.clear();
status.connectedAlgos.clear();
if (status.mscclSchedulerPtr) {
NCCLCHECK(status.mscclSchedulerPtr->teardown());
+11 -7
查看文件
@@ -22,6 +22,10 @@ RCCL_PARAM(MscclEnableDoneEvent, "MSCCL_ENABLE_DONE_EVENT", 1);
RCCL_PARAM(MscclWorkFifoDepth, "MSCCL_WORK_FIFO_DEPTH", 64<<10);
static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nChunksPerLoop) {
return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop;
}
ncclResult_t mscclGetCaptureStatus(hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
@@ -70,12 +74,6 @@ ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
size_t sizeNeeded = (status.nBytes * (size_t)(hostAlgo->nScratchChunks)) / (size_t)(hostAlgo->nChunksPerLoop);
if (sizeNeeded > status.scratchBufferSize){
NCCLCHECK(ncclCudaFree(status.scratchBuffer));
NCCLCHECK(ncclCudaMalloc((char**)&status.scratchBuffer, sizeNeeded, true));
status.scratchBufferSize = sizeNeeded;
}
return ncclSuccess;
}
@@ -422,7 +420,13 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
mscclWork work;
work.syncFlags = status.syncFlags;
work.scratchBuffer = status.scratchBuffer;
size_t sizeNeeded = computeSizeNeeded(status.nBytes, hostAlgo->nScratchChunks, hostAlgo->nChunksPerLoop);
if (status.scratchBuffers.find(sizeNeeded) == status.scratchBuffers.end()) {
void *scratchBuffer = nullptr;
NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeNeeded, true));
status.scratchBuffers[sizeNeeded] = scratchBuffer;
}
work.scratchBuffer = status.scratchBuffers[sizeNeeded];
work.sendBuff = sendBuff;
work.recvBuff = recvBuff;
work.sizePerMscclChunk = count * hostAlgo->sizeMultiplier / hostAlgo->nChunksPerLoop; // count is sum of all ranks in MSCCL kernel