diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 5c19dc6..5fb99ef 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -85,6 +85,7 @@ struct ncclComm { std::unordered_map channelInInfos; std::unordered_map channelOutInfos; std::unordered_map channelScratchInfos; + std::unordered_map regHandles; std::unordered_map handleKeys; std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; @@ -616,6 +617,11 @@ NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, p->ipcHandle = ipcHandle; *handle = p; + auto regIt = comm->regHandles.find(buffKey); + if (regIt == comm->regHandles.end()) { + comm->regHandles[buffKey] = ipcHandle; + } + auto it = comm->handleKeys.find(*handle); if (it == comm->handleKeys.end()) { comm->handleKeys[*handle] = buffKey; @@ -642,6 +648,7 @@ NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { if (outIt != comm->channelOutInfos.end()) { comm->channelOutInfos.erase(outIt); } + comm->regHandles.erase(buffKey); comm->handleKeys.erase(handle); free(handle); } @@ -655,8 +662,8 @@ bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){ CUdeviceptr buffBasePtr; MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); channelKey buffKey{(void*)buffBasePtr, buffBytes}; - auto buffIt = comm->channelScratchInfos.find(buffKey); - bool registered = buffIt != comm->channelScratchInfos.end(); + auto buffIt = comm->regHandles.find(buffKey); + bool registered = buffIt != comm->regHandles.end(); return registered; } size_t