From e9b6bbca8a806e33ef07fc3217de6811525cce3d Mon Sep 17 00:00:00 2001 From: Nusrat Islam Date: Tue, 14 Jan 2025 08:20:24 -0600 Subject: [PATCH] Add MSCCLPP user buffer registration APIs and integrate with RCCL (#1477) * ext-src: add MSCCLPP memory registration APIs * update mem-reg patch with mscclpp helper routine to check if buffer is registered * RCCL integration of MSCCL++ user-buffer registration APIs * only include mscclpp_nccl header if ENABLE_MSCCLPP is defined * ext-src: update mscclpp mem-reg patch * add helper routine to patch * check handle before MSCCL++ deregister * fix typo to replace send buff with recv buff * in case of no mscclpp registration, dduring deRegister call, ont fall back to rccl deRegister which will return an error * Apply suggestions from code review Whitespace suggestions and reducing diffs to avoid future merge conflicts Co-authored-by: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com> * rename helper functions and change their return type * set RCCL user-buffer registration to occur if attempting MSCCL++ registration with a buffer in managed memory --------- Co-authored-by: isaki001 Co-authored-by: isaki001 <36317038+isaki001@users.noreply.github.com> Co-authored-by: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com> --- cmake/MSCCLPP.cmake | 26 +++-- ext-src/mem-reg.patch | 147 +++++++++++++++++++++++++ src/include/mscclpp/mscclpp_nccl.h | 8 ++ src/misc/msccl/msccl_lifecycle.cc | 14 ++- src/misc/mscclpp/mscclpp_nccl_syms.txt | 2 + src/register.cc | 27 +++++ 6 files changed, 215 insertions(+), 9 deletions(-) create mode 100644 ext-src/mem-reg.patch diff --git a/cmake/MSCCLPP.cmake b/cmake/MSCCLPP.cmake index e38922d5e7..48c1e22842 100644 --- a/cmake/MSCCLPP.cmake +++ b/cmake/MSCCLPP.cmake @@ -68,14 +68,20 @@ if(ENABLE_MSCCLPP) WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + message(STATUS "Building mscclpp only for gfx942.") mscclpp_cmake_arg(CMAKE_PREFIX_PATH) @@ -102,17 +108,23 @@ if(ENABLE_MSCCLPP) find_package(mscclpp_nccl REQUIRED) execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + execute_process( COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) endif() execute_process(COMMAND objcopy diff --git a/ext-src/mem-reg.patch b/ext-src/mem-reg.patch new file mode 100644 index 0000000000..f95b116b9a --- /dev/null +++ b/ext-src/mem-reg.patch @@ -0,0 +1,147 @@ +diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h +index 7f50792..b8b146d 100644 +--- a/apps/nccl/include/nccl.h ++++ b/apps/nccl/include/nccl.h +@@ -344,6 +344,13 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun + ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, + ncclComm_t comm, cudaStream_t stream); + ++/* ++ * Register/Deregister ++ */ ++ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle); ++ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle); ++bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count); ++size_t mscclpp_BufferSize(ncclComm_t comm, void* handle); + /* + * Send + * +diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu +index a697be2..1d4af61 100644 +--- a/apps/nccl/src/nccl.cu ++++ b/apps/nccl/src/nccl.cu +@@ -65,6 +65,7 @@ struct ncclComm { + std::unordered_map channelInInfos; + std::unordered_map channelOutInfos; + std::unordered_map channelScratchInfos; ++ std::unordered_map handleKeys; + std::shared_ptr scratchBuff; + std::vector remoteScratchRegMemories; + +@@ -73,6 +74,11 @@ struct ncclComm { + uint32_t buffFlag; + }; + ++struct handleInfo { ++ void * buff; ++ cudaIpcMemHandle_t ipcHandle; ++}; ++ + static size_t ncclTypeSize(ncclDataType_t type) { + switch (type) { + case ncclInt8: +@@ -577,6 +583,104 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t + return ncclSuccess; + } + ++NCCL_API ncclResult_t ncclCommRegister(ncclComm_t comm, void* buff, size_t size, void** handle) { ++ size_t buffBytes = size; ++ CUdeviceptr buffBasePtr; ++ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); ++ ++ int rank = comm->comm->bootstrap()->getRank(); ++ channelKey buffKey{(void*)buffBasePtr, buffBytes}; ++ ++ std::vector remoteMemories; ++ ++ // Creating the channels ++ auto buffIt = comm->channelScratchInfos.find(buffKey); ++ if (buffIt == comm->channelScratchInfos.end()) { ++ std::vector channels = ++ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)buffBasePtr)); ++ ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)}; ++ buffIt = comm->channelScratchInfos.emplace(buffKey, channelInfo).first; ++ } ++ auto sendIt = comm->channelInInfos.find(buffKey); ++ if (sendIt == comm->channelInInfos.end()) { ++ std::vector channels = ++ setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)buffBasePtr)); ++ ++ remoteMemories = ++ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); ++ std::vector channels1 = ++ setupSmChannels(comm, remoteMemories, const_cast((void*)buffBasePtr)); ++ ++ ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)}; ++ sendIt = comm->channelInInfos.emplace(buffKey, channelInfo).first; ++ } ++ auto recvIt = comm->channelOutInfos.find(buffKey); ++ if (recvIt == comm->channelOutInfos.end()) { ++ remoteMemories = ++ setupRemoteMemories(comm->comm, rank, (void*)buffBasePtr, buffBytes, mscclpp::Transport::CudaIpc); ++ std::vector outChannels = ++ setupSmChannels(comm, remoteMemories, const_cast((void*)buffBasePtr)); ++ ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)}; ++ recvIt = comm->channelOutInfos.emplace(buffKey, channelInfo).first; ++ } ++ ++ cudaIpcMemHandle_t ipcHandle; ++ MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&ipcHandle, buffBasePtr)); ++ ++ struct handleInfo *p = (struct handleInfo *) malloc(sizeof(struct handleInfo)); ++ p->buff = buffBasePtr; ++ p->ipcHandle = ipcHandle; ++ *handle = p; ++ ++ auto it = comm->handleKeys.find(*handle); ++ if (it == comm->handleKeys.end()) { ++ comm->handleKeys[*handle] = buffKey; ++ } ++ ++ return ncclSuccess; ++} ++ ++NCCL_API ncclResult_t ncclCommDeregister(ncclComm_t comm, void* handle) { ++ if (comm && handle) { ++ channelKey buffKey = comm->handleKeys[handle]; ++ ++ auto scratchIt = comm->channelScratchInfos.find(buffKey); ++ if (scratchIt != comm->channelScratchInfos.end()) { ++ comm->channelScratchInfos.erase(scratchIt); ++ } ++ ++ auto inIt = comm->channelInInfos.find(buffKey); ++ if (inIt != comm->channelInInfos.end()) { ++ comm->channelInInfos.erase(inIt); ++ } ++ ++ auto outIt = comm->channelOutInfos.find(buffKey); ++ if (outIt != comm->channelOutInfos.end()) { ++ comm->channelOutInfos.erase(outIt); ++ } ++ ++ free(handle); ++ } ++ return ncclSuccess; ++} ++ ++bool mscclpp_BuffIsRegistered(ncclComm_t comm, const void* buff, size_t count){ ++ size_t buffBytes; ++ CUdeviceptr buffBasePtr; ++ MSCCLPP_CUTHROW(cuMemGetAddressRange(&buffBasePtr, &buffBytes, (CUdeviceptr)buff)); ++ channelKey buffKey{(void*)buffBasePtr, buffBytes}; ++ auto buffIt = comm->channelScratchInfos.find(buffKey); ++ bool registered = buffIt != comm->channelScratchInfos.end(); ++ return registered; ++} ++size_t ++mscclpp_BufferSize(ncclComm_t comm, void* handle){ ++ if (!(comm && handle)){ ++ return 0; ++ } ++ auto buffKeyIt = comm->handleKeys.find(handle); ++ return buffKeyIt != comm->handleKeys.end() ? buffKeyIt->second.bytes : 0; ++} + NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { + // TODO: implement this function + return ncclInternalError; diff --git a/src/include/mscclpp/mscclpp_nccl.h b/src/include/mscclpp/mscclpp_nccl.h index c4dc7dfa0c..760405e499 100644 --- a/src/include/mscclpp/mscclpp_nccl.h +++ b/src/include/mscclpp/mscclpp_nccl.h @@ -38,6 +38,14 @@ extern "C" { /* See ncclAllGather. */ ncclResult_t mscclpp_ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, mscclppComm_t comm, hipStream_t stream); + + ncclResult_t mscclpp_ncclCommRegister(mscclppComm_t comm, void* buff, size_t size, void** handle); + + ncclResult_t mscclpp_ncclCommDeregister(mscclppComm_t comm, void* handle); + + bool mscclpp_BuffIsRegistered(mscclppComm_t comm, const void* buff, size_t count); + + size_t mscclpp_BufferSize(mscclppComm_t comm, void* handle); } namespace std { diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index 532171bf0a..66048767a2 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -524,8 +524,13 @@ ncclResult_t mscclEnqueueCheck( NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); } + const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff, count); + const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff, count); + const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture; + const bool buffsRegistedNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered; + /* check if one rank per GPU and graph mode is enabled */ - if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { + if ((graphMode || buffsRegistedNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { bool isManagedBuffer = false; if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(sendBuff))); if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(recvBuff))); @@ -565,8 +570,13 @@ ncclResult_t mscclEnqueueCheck( NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); } + const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff, count); + const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff, count); + const bool graphMode = threadLocalStatus.captureStatus != mscclNoCapture; + const bool buffsRegistedNonGraphMode = !graphMode && sendBuffRegistered && recvBuffRegistered; + /* check if one rank per GPU and graph mode is enabled */ - if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { + if ((graphMode || buffsRegistedNonGraphMode) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { bool isManagedBuffer = false; if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(sendBuff))); if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(recvBuff))); diff --git a/src/misc/mscclpp/mscclpp_nccl_syms.txt b/src/misc/mscclpp/mscclpp_nccl_syms.txt index dd1cfa8fe0..bb2bd858dc 100644 --- a/src/misc/mscclpp/mscclpp_nccl_syms.txt +++ b/src/misc/mscclpp/mscclpp_nccl_syms.txt @@ -30,3 +30,5 @@ ncclRedOpDestroy mscclpp_ncclRedOpDestroy ncclReduce mscclpp_ncclReduce ncclReduceScatter mscclpp_ncclReduceScatter ncclSend mscclpp_ncclSend +ncclCommRegister mscclpp_ncclCommRegister +ncclCommDeregister mscclpp_ncclCommDeregister diff --git a/src/register.cc b/src/register.cc index 1020e9dde2..c5c6afd73e 100644 --- a/src/register.cc +++ b/src/register.cc @@ -10,6 +10,9 @@ #include "net.h" #include "register.h" #include "api_trace.h" +#ifdef ENABLE_MSCCLPP +#include "mscclpp/mscclpp_nccl.h" +#endif ncclResult_t ncclNetDeregister(struct ncclComm* comm, struct ncclReg* reg) { struct ncclRegCache* cache = &comm->regCache; @@ -155,12 +158,36 @@ NCCL_API(ncclResult_t, ncclCommRegister, const ncclComm_t comm, void* buff, size ncclResult_t ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** handle) { NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); if (comm->checkPointers) NCCLCHECK(CudaPtrCheck(buff, comm, "buff", "ncclCommRegister")); + #ifdef ENABLE_MSCCLPP + if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold){ + bool isManagedBuffer = false; + CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(buff))); + if(!isManagedBuffer){ + INFO(NCCL_INIT, "MSCCL++: ncclCommRegister"); + NCCLCHECK(mscclpp_ncclCommRegister(comm->mscclpp_comm, buff, size, handle)); + return ncclSuccess; + } + else{ + WARN("MSCCL++: Cannot register user-buffers on managed memory. RCCL user-buffer registration will occur."); + } + } + #endif + INFO(NCCL_INIT, "RCCL: ncclCommRegister"); NCCLCHECK(ncclRegister(comm, buff, size, handle)); return ncclSuccess; } NCCL_API(ncclResult_t, ncclCommDeregister, const ncclComm_t comm, void* handle); ncclResult_t ncclCommDeregister_impl(const ncclComm_t comm, void* handle) { + + #ifdef ENABLE_MSCCLPP + const size_t size = mscclpp_BufferSize(comm->mscclpp_comm, handle); + if (comm->mscclCompatible && size > 0 && (size & 31) == 0 && size <= comm->mscclpp_threshold) { + NCCLCHECK(mscclpp_ncclCommDeregister(comm->mscclpp_comm, handle)); + return ncclSuccess; + } + #endif + NCCLCHECK(CommCheck(comm, "ncclCommRegister", "comm")); struct ncclReg* reg = (struct ncclReg*)handle; struct ncclRegCache* cache = &comm->regCache;