misc/msccl: Read graph capture status for every collective call (#1576)

* misc/msccl: read graphCaptureStatus for every collective call

* fix a bug in checking whether UBR is enabled in MSCCLPP

* cmake: Fix patch reversal order

* misc/msccl: add logging

[ROCm/rccl commit: 23c0b7bd84]
Этот коммит содержится в:
Nusrat Islam
2025-02-28 17:16:07 -06:00
коммит произвёл GitHub
родитель 3be905ca83
Коммит f70f406463
3 изменённых файлов: 86 добавлений и 36 удалений
+27 -16
Просмотреть файл
@@ -95,6 +95,11 @@ if(ENABLE_MSCCLPP)
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
message(STATUS "Building mscclpp only for supported variants:gfx942,gfx950")
mscclpp_cmake_arg(CMAKE_PREFIX_PATH)
mscclpp_cmake_arg(CMAKE_INSTALL_RPATH_USE_LINK_PATH)
@@ -121,23 +126,14 @@ if(ENABLE_MSCCLPP)
find_package(mscclpp_nccl REQUIRED)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
@@ -147,10 +143,25 @@ if(ENABLE_MSCCLPP)
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
execute_process(
COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch
WORKING_DIRECTORY ${MSCCLPP_SOURCE}
)
#endif()
execute_process(COMMAND objcopy
+43
Просмотреть файл
@@ -0,0 +1,43 @@
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index 5c19dc6..5fb99ef 100644
--- a/apps/nccl/src/nccl.cu
+++ b/apps/nccl/src/nccl.cu
@@ -85,6 +85,7 @@ struct ncclComm {
std::unordered_map<channelKey, ChannelInfo> channelInInfos;
std::unordered_map<channelKey, ChannelInfo> channelOutInfos;
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
+ std::unordered_map<channelKey, cudaIpcMemHandle_t> regHandles;
std::unordered_map<void*, channelKey> handleKeys;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
@@ -616,6 +617,11 @@ NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size,
p->ipcHandle = ipcHandle;
*handle = p;
+ auto regIt = comm->regHandles.find(buffKey);
+ if (regIt == comm->regHandles.end()) {
+ comm->regHandles[buffKey] = ipcHandle;
+ }
+
auto it = comm->handleKeys.find(*handle);
if (it == comm->handleKeys.end()) {
comm->handleKeys[*handle] = buffKey;
@@ -642,6 +648,7 @@ NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) {
if (outIt != comm->channelOutInfos.end()) {
comm->channelOutInfos.erase(outIt);
}
+ comm->regHandles.erase(buffKey);
comm->handleKeys.erase(handle);
free(handle);
}
@@ -655,8 +662,8 @@ bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){
CUdeviceptr buffBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff));
channelKey buffKey{(void*)buffBasePtr, buffBytes};
- auto buffIt = comm->channelScratchInfos.find(buffKey);
- bool registered = buffIt != comm->channelScratchInfos.end();
+ auto buffIt = comm->regHandles.find(buffKey);
+ bool registered = buffIt != comm->regHandles.end();
return registered;
}
size_t
+16 -20
Просмотреть файл
@@ -519,33 +519,31 @@ ncclResult_t mscclEnqueueCheck(
case mscclNoGroup:
#ifdef ENABLE_MSCCLPP
if (comm->mscclppCompatible) {
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
}
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
const bool buffsRegisteredNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
/* check if one rank per GPU and graph mode is enabled */
if ((graphMode || buffsRegisteredNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
@@ -565,33 +563,31 @@ ncclResult_t mscclEnqueueCheck(
case mscclGroupSupportedOp:
#ifdef ENABLE_MSCCLPP
if (comm->mscclppCompatible) {
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
}
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture;
const bool buffsRegisteredNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered;
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
/* check if one rank per GPU and graph mode is enabled */
if ((graphMode || buffsRegisteredNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d" ,
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;