Use a map to host scratch buffers (#1004)
* Use a map to host scratch buffers * Address review feedbacks. Deliberately keep mscclSetupScratch function.
此提交包含在:
@@ -205,8 +205,7 @@ struct mscclStatus {
|
||||
std::map<mscclAlgoHandle_t, mscclAlgo *> hostAlgos;
|
||||
std::map<mscclAlgoHandle_t, mscclAlgo *> devAlgos;
|
||||
struct mscclFlag* syncFlags;
|
||||
void *scratchBuffer;
|
||||
uint64_t scratchBufferSize;
|
||||
std::map<size_t, void *> scratchBuffers;
|
||||
size_t nBytes;
|
||||
int stepSize;
|
||||
int chunkSteps;
|
||||
|
||||
@@ -236,8 +236,6 @@ ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
status.scratchBuffer = nullptr;
|
||||
status.scratchBufferSize = 0;
|
||||
status.workIndex = 1;
|
||||
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
|
||||
status.lastStream = nullptr;
|
||||
@@ -542,13 +540,14 @@ ncclResult_t mscclTeardown() {
|
||||
for (auto &p : status.devAlgos) {
|
||||
CUDACHECK(hipFree(p.second));
|
||||
}
|
||||
CUDACHECK(hipFree(status.scratchBuffer));
|
||||
CUDACHECK(hipFree(status.syncFlags));
|
||||
status.hostAlgos.clear();
|
||||
status.devAlgos.clear();
|
||||
status.freeAlgoHandles.clear();
|
||||
status.scratchBuffer = nullptr;
|
||||
status.scratchBufferSize = 0;
|
||||
for (auto &p : status.scratchBuffers) {
|
||||
CUDACHECK(hipFree(p.second));
|
||||
}
|
||||
status.scratchBuffers.clear();
|
||||
status.connectedAlgos.clear();
|
||||
if (status.mscclSchedulerPtr) {
|
||||
NCCLCHECK(status.mscclSchedulerPtr->teardown());
|
||||
|
||||
@@ -22,6 +22,10 @@ RCCL_PARAM(MscclEnableDoneEvent, "MSCCL_ENABLE_DONE_EVENT", 1);
|
||||
|
||||
RCCL_PARAM(MscclWorkFifoDepth, "MSCCL_WORK_FIFO_DEPTH", 64<<10);
|
||||
|
||||
static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nChunksPerLoop) {
|
||||
return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGetCaptureStatus(hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
@@ -70,12 +74,6 @@ ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t
|
||||
|
||||
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
size_t sizeNeeded = (status.nBytes * (size_t)(hostAlgo->nScratchChunks)) / (size_t)(hostAlgo->nChunksPerLoop);
|
||||
if (sizeNeeded > status.scratchBufferSize){
|
||||
NCCLCHECK(ncclCudaFree(status.scratchBuffer));
|
||||
NCCLCHECK(ncclCudaMalloc((char**)&status.scratchBuffer, sizeNeeded, true));
|
||||
status.scratchBufferSize = sizeNeeded;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -422,7 +420,13 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
|
||||
mscclWork work;
|
||||
work.syncFlags = status.syncFlags;
|
||||
work.scratchBuffer = status.scratchBuffer;
|
||||
size_t sizeNeeded = computeSizeNeeded(status.nBytes, hostAlgo->nScratchChunks, hostAlgo->nChunksPerLoop);
|
||||
if (status.scratchBuffers.find(sizeNeeded) == status.scratchBuffers.end()) {
|
||||
void *scratchBuffer = nullptr;
|
||||
NCCLCHECK(ncclCudaMalloc((char**)&scratchBuffer, sizeNeeded, true));
|
||||
status.scratchBuffers[sizeNeeded] = scratchBuffer;
|
||||
}
|
||||
work.scratchBuffer = status.scratchBuffers[sizeNeeded];
|
||||
work.sendBuff = sendBuff;
|
||||
work.recvBuff = recvBuff;
|
||||
work.sizePerMscclChunk = count * hostAlgo->sizeMultiplier / hostAlgo->nChunksPerLoop; // count is sum of all ranks in MSCCL kernel
|
||||
|
||||
新增問題並參考
封鎖使用者