f70f406463
* misc/msccl: read graphCaptureStatus for every collective call
* fix a bug in checking whether UBR is enabled in MSCCLPP
* cmake: Fix patch reversal order
* misc/msccl: add logging
[ROCm/rccl commit: 23c0b7bd84]
44 rader
1.8 KiB
Diff
44 rader
1.8 KiB
Diff
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
|
|
index 5c19dc6..5fb99ef 100644
|
|
--- a/apps/nccl/src/nccl.cu
|
|
+++ b/apps/nccl/src/nccl.cu
|
|
@@ -85,6 +85,7 @@ struct ncclComm {
|
|
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
|
|
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
|
|
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
|
|
+ std::unordered_map<channelKey, cudaIpcMemHandle_t> regHandles;
|
|
std::unordered_map<void*, channelKey> handleKeys;
|
|
std::shared_ptr<char> scratchBuff;
|
|
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
|
|
@@ -616,6 +617,11 @@ NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size,
|
|
p->ipcHandle = ipcHandle;
|
|
*handle = p;
|
|
|
|
+ auto regIt = comm->regHandles.find(buffKey);
|
|
+ if (regIt == comm->regHandles.end()) {
|
|
+ comm->regHandles[buffKey] = ipcHandle;
|
|
+ }
|
|
+
|
|
auto it = comm->handleKeys.find(*handle);
|
|
if (it == comm->handleKeys.end()) {
|
|
comm->handleKeys[*handle] = buffKey;
|
|
@@ -642,6 +648,7 @@ NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
|
|
if (outIt != comm->channelOutInfos.end()) {
|
|
comm->channelOutInfos.erase(outIt);
|
|
}
|
|
+ comm->regHandles.erase(buffKey);
|
|
comm->handleKeys.erase(handle);
|
|
free(handle);
|
|
}
|
|
@@ -655,8 +662,8 @@ bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){
|
|
CUdeviceptr buffBasePtr;
|
|
MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
|
|
channelKey buffKey{(void*)buffBasePtr, buffBytes};
|
|
- auto buffIt = comm->channelScratchInfos.find(buffKey);
|
|
- bool registered = buffIt != comm->channelScratchInfos.end();
|
|
+ auto buffIt = comm->regHandles.find(buffKey);
|
|
+ bool registered = buffIt != comm->regHandles.end();
|
|
return registered;
|
|
}
|
|
size_t
|