diff --git a/projects/rccl/cmake/MSCCLPP.cmake b/projects/rccl/cmake/MSCCLPP.cmake index 491cdadc2f..075cc3e30a 100644 --- a/projects/rccl/cmake/MSCCLPP.cmake +++ b/projects/rccl/cmake/MSCCLPP.cmake @@ -95,6 +95,11 @@ if(ENABLE_MSCCLPP) WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + message(STATUS "Building mscclpp only for supported variants:gfx942,gfx950") mscclpp_cmake_arg(CMAKE_PREFIX_PATH) mscclpp_cmake_arg(CMAKE_INSTALL_RPATH_USE_LINK_PATH) @@ -121,23 +126,14 @@ if(ENABLE_MSCCLPP) find_package(mscclpp_nccl REQUIRED) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) @@ -147,10 +143,25 @@ if(ENABLE_MSCCLPP) ) execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + #endif() execute_process(COMMAND objcopy diff --git a/projects/rccl/ext-src/reg-fix.patch b/projects/rccl/ext-src/reg-fix.patch new file mode 100644 index 0000000000..5eda54bbc7 --- /dev/null +++ b/projects/rccl/ext-src/reg-fix.patch @@ -0,0 +1,43 @@ +diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu +index 5c19dc6..5fb99ef 100644 +--- a/apps/nccl/src/nccl.cu ++++ b/apps/nccl/src/nccl.cu +@@ -85,6 +85,7 @@ struct ncclComm { + std::unordered_map channelInInfos; + std::unordered_map channelOutInfos; + std::unordered_map channelScratchInfos; ++ std::unordered_map regHandles; + std::unordered_map handleKeys; + std::shared_ptr scratchBuff; + std::vector remoteScratchRegMemories; +@@ -616,6 +617,11 @@ NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, + p->ipcHandle = ipcHandle; + *handle = p; + ++ auto regIt = comm->regHandles.find(buffKey); ++ if (regIt == comm->regHandles.end()) { ++ comm->regHandles[buffKey] = ipcHandle; ++ } ++ + auto it = comm->handleKeys.find(*handle); + if (it == comm->handleKeys.end()) { + comm->handleKeys[*handle] = buffKey; +@@ -642,6 +648,7 @@ NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { + if (outIt != comm->channelOutInfos.end()) { + comm->channelOutInfos.erase(outIt); + } ++ comm->regHandles.erase(buffKey); + comm->handleKeys.erase(handle); + free(handle); + } +@@ -655,8 +662,8 @@ bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff){ + CUdeviceptr buffBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); + channelKey buffKey{(void*)buffBasePtr, buffBytes}; +- auto buffIt = comm->channelScratchInfos.find(buffKey); +- bool registered = buffIt != comm->channelScratchInfos.end(); ++ auto buffIt = comm->regHandles.find(buffKey); ++ bool registered = buffIt != comm->regHandles.end(); + return registered; + } + size_t diff --git a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc index 225ab8bfe0..e73d68b86b 100644 --- a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc +++ b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc @@ -519,33 +519,31 @@ ncclResult_t mscclEnqueueCheck( case mscclNoGroup: #ifdef ENABLE_MSCCLPP if (comm->mscclppCompatible) { - if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) { - INFO(NCCL_COLL, "MSCCL++: reading capture status"); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); - } + INFO(NCCL_COLL, "MSCCL++: reading capture status"); + NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff); const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff); const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture; - const bool buffsRegisteredNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered; + const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered; /* check if one rank per GPU and graph mode is enabled */ - if ((graphMode || buffsRegisteredNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { + if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { bool isManagedBuffer = false; if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(sendBuff))); if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(recvBuff))); if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ } else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", - "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d", + "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break; } else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", - "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d", + "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break; @@ -565,33 +563,31 @@ ncclResult_t mscclEnqueueCheck( case mscclGroupSupportedOp: #ifdef ENABLE_MSCCLPP if (comm->mscclppCompatible) { - if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) { - INFO(NCCL_COLL, "MSCCL++: reading capture status"); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); - } + INFO(NCCL_COLL, "MSCCL++: reading capture status"); + NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff); const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff); const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture; - const bool buffsRegisteredNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered; + const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered; /* check if one rank per GPU and graph mode is enabled */ - if ((graphMode || buffsRegisteredNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { + if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { bool isManagedBuffer = false; if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(sendBuff))); if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(recvBuff))); if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ } else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", - "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d", + "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break; } else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", - "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d" , + "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break;