文件
Nusrat Islam f70f406463 misc/msccl: Read graph capture status for every collective call (#1576)
* misc/msccl: read graphCaptureStatus for every collective call

* fix a bug in checking whether UBR is enabled in MSCCLPP

* cmake: Fix patch reversal order

* misc/msccl: add logging

[ROCm/rccl commit: 23c0b7bd84]
2025-02-28 17:16:07 -06:00

44 行
1.8 KiB
Diff

diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index 5c19dc6..5fb99ef 100644
--- a/apps/nccl/src/nccl.cu
+++ b/apps/nccl/src/nccl.cu
@@ -85,6 +85,7 @@ struct ncclComm {
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
+ std::unordered_map<channelKey, cudaIpcMemHandle_t> regHandles;
std::unordered_map<void*, channelKey> handleKeys;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
@@ -616,6 +617,11 @@ NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size,
p->ipcHandle = ipcHandle;
*handle = p;
+ auto regIt = comm->regHandles.find(buffKey);
+ if (regIt == comm->regHandles.end()) {
+ comm->regHandles[buffKey] = ipcHandle;
+ }
+
auto it = comm->handleKeys.find(*handle);
if (it == comm->handleKeys.end()) {
comm->handleKeys[*handle] = buffKey;
@@ -642,6 +648,7 @@ NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
if (outIt != comm->channelOutInfos.end()) {
comm->channelOutInfos.erase(outIt);
}
+ comm->regHandles.erase(buffKey);
comm->handleKeys.erase(handle);
free(handle);
}
@@ -655,8 +662,8 @@ bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){
CUdeviceptr buffBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
channelKey buffKey{(void*)buffBasePtr, buffBytes};
- auto buffIt = comm->channelScratchInfos.find(buffKey);
- bool registered = buffIt != comm->channelScratchInfos.end();
+ auto buffIt = comm->regHandles.find(buffKey);
+ bool registered = buffIt != comm->regHandles.end();
return registered;
}
size_t