diff --git a/projects/rccl/CHANGELOG.md b/projects/rccl/CHANGELOG.md index f97d4fe066..590c99945c 100644 --- a/projects/rccl/CHANGELOG.md +++ b/projects/rccl/CHANGELOG.md @@ -8,6 +8,8 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https: * Resolved an issue when using more than 64 channels when multiple collectives are used in the same `ncclGroup()` call. * Fixed unit test failures in tests ending with `ManagedMem` and `ManagedMemGraph` suffixes. +* Fixed the known issue "When splitting a communicator using `ncclCommSplit` in some GPU configurations, MSCCL initialization can cause a segmentation fault." with a design change to use `comm` instead of `rank` for `mscclStatus`. The Global map for `comm` to `mscclStatus` is still not thread safe but should be explicitly handled by mutexes for read writes. This is tested for correctness, but there is a plan to use a thread-safe map data structure in upcoming changes. + ### Added @@ -304,4 +306,4 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https: ### Changed - Switched to hip-clang as default compiler ### Deprecated -- Deprecated hcc build +- Deprecated hcc build \ No newline at end of file diff --git a/projects/rccl/src/collectives.cc b/projects/rccl/src/collectives.cc index efe502b431..d51394ad49 100644 --- a/projects/rccl/src/collectives.cc +++ b/projects/rccl/src/collectives.cc @@ -96,7 +96,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen NCCLCHECK(Recorder::instance().record(rrAllGather, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream); @@ -124,7 +124,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou NCCLCHECK(Recorder::instance().record(rrAllReduce, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream); @@ -149,7 +149,7 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun NCCLCHECK(Recorder::instance().record(rrAllToAll, sendbuff, recvbuff, count, datatype, comm, stream)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream); @@ -195,7 +195,7 @@ ncclResult_t ncclAllToAllv_impl(const void *sendbuff, const size_t sendcounts[], NCCLCHECK(Recorder::instance().record(rrAllToAllv, sendbuff, recvbuff, 0, datatype, comm, stream, -1, sendcounts, sdispls, recvcounts, rdispls)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls, 0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream); @@ -241,7 +241,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou NCCLCHECK(Recorder::instance().record(rrBroadcast, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream); @@ -271,7 +271,7 @@ ncclResult_t ncclGather_impl(const void* sendbuff, void* recvbuff, size_t sendco NCCLCHECK(Recorder::instance().record(rrGather, sendbuff, recvbuff, sendcount, datatype, comm, stream, root)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream); @@ -310,7 +310,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count, NCCLCHECK(Recorder::instance().record(rrReduce, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, root, 0, op, mscclFuncReduce, comm, stream); @@ -337,7 +337,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t NCCLCHECK(Recorder::instance().record(rrReduceScatter, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream); @@ -360,7 +360,7 @@ ncclResult_t ncclScatter_impl(const void* sendbuff, void* recvbuff, size_t recvc NCCLCHECK(Recorder::instance().record(rrScatter, sendbuff, recvbuff, recvcount, datatype, comm, stream, root)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream); @@ -400,7 +400,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da NCCLCHECK(Recorder::instance().record(rrSend, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr, count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream); @@ -426,7 +426,7 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype NCCLCHECK(Recorder::instance().record(rrRecv, info)); } - if (mscclAvailable(comm->rank) && !mscclIsCaller()) { + if (mscclAvailable(comm) && !mscclIsCaller()) { return mscclEnqueueCheck( nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream); diff --git a/projects/rccl/src/include/api_trace.h b/projects/rccl/src/include/api_trace.h index 8329718671..a33d7e8d6a 100644 --- a/projects/rccl/src/include/api_trace.h +++ b/projects/rccl/src/include/api_trace.h @@ -139,7 +139,7 @@ typedef ncclResult_t (*ncclMemAlloc_fn_t)(void** ptr, size_t size); typedef ncclResult_t (*ncclMemFree_fn_t)(void* ptr); typedef ncclResult_t (*mscclLoadAlgo_fn_t)(const char* mscclAlgoFilePath, - mscclAlgoHandle_t* mscclAlgoHandle, int rank); + mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm); typedef ncclResult_t (*mscclRunAlgo_fn_t)( const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[], diff --git a/projects/rccl/src/include/msccl/msccl_lifecycle.h b/projects/rccl/src/include/msccl/msccl_lifecycle.h index 459a046613..b92ffbd7b7 100644 --- a/projects/rccl/src/include/msccl/msccl_lifecycle.h +++ b/projects/rccl/src/include/msccl/msccl_lifecycle.h @@ -17,7 +17,15 @@ void mscclSetIsCallerFlag(); void mscclClearIsCallerFlag(); bool mscclIsCaller(); -bool mscclAvailable(int rank = -1); +/** + * @brief mscclAvailable() is used to determine if msccl functionality is avaliable + * @param comm is an optional rccl communicator, if provided uses the mscclStatus + * from a global map mscclStatus> to determine if msccl is available. If not available + * in the map, this invocations inserts a new key value pair in the global map. + * If comm == nullptr, on the first invocation it initializes a static thread local variable + * mscclStatus and uses the same object in subsequent calls from same thread if comm is null ptr + */ +bool mscclAvailable(const ncclComm_t comm = nullptr); ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired); @@ -33,7 +41,7 @@ ncclResult_t mscclEnqueueCheck( ncclResult_t mscclGroupEnd(); -ncclResult_t mscclTeardown(int rank); +ncclResult_t mscclTeardown(const ncclComm_t comm); size_t mscclKernMaxLocalSize(); diff --git a/projects/rccl/src/include/msccl/msccl_setup.h b/projects/rccl/src/include/msccl/msccl_setup.h index b3f88e35ec..fb3161aee7 100644 --- a/projects/rccl/src/include/msccl/msccl_setup.h +++ b/projects/rccl/src/include/msccl/msccl_setup.h @@ -11,13 +11,13 @@ #include "comm.h" #include "msccl/msccl_struct.h" -ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream); +ncclResult_t mscclGetCaptureStatus(const ncclComm_t comm, hipStream_t stream); ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream); -ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream); +ncclResult_t mscclSetupSyncFlags(const ncclComm_t comm, hipStream_t stream); -ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm); +ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo,const ncclComm_t comm); ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType); diff --git a/projects/rccl/src/include/msccl/msccl_status.h b/projects/rccl/src/include/msccl/msccl_status.h index 077709ddc8..e55557a1e8 100644 --- a/projects/rccl/src/include/msccl/msccl_status.h +++ b/projects/rccl/src/include/msccl/msccl_status.h @@ -8,15 +8,15 @@ #include "msccl/msccl_struct.h" -bool mscclInitialized(int rank); +bool mscclInitialized(const ncclComm_t comm); -void mscclSetInitialized(int rank, bool initialized = true); +void mscclSetInitialized(const ncclComm_t comm, bool initialized = true); -void mscclRemoveRank(int rank); +void mscclRemoveRank(const ncclComm_t comm); -mscclStatus& mscclGetStatus(int rank); +mscclStatus& mscclGetStatus(const ncclComm_t comm); -mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank); +mscclSavedProxyArgs& mscclGetSavedProxyArgs(const ncclComm_t comm); mscclThreadLocalStatus& mscclGetThreadLocalStatus(); diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 99d5d3b173..6958394c8c 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -1745,7 +1745,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) { NCCLCHECK(mscclInit(comm)); - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); status.needsProxy |= mscclNeedsProxy; } @@ -2530,7 +2530,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) { NCCLCHECK(ncclTunerPluginUnload(comm)); } if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) { - NCCLCHECK(mscclTeardown(comm->rank)); + NCCLCHECK(mscclTeardown(comm)); } NCCLCHECK(commFree(comm)); diff --git a/projects/rccl/src/misc/api_trace.cc b/projects/rccl/src/misc/api_trace.cc index 92e09eb149..5ae708c0c1 100644 --- a/projects/rccl/src/misc/api_trace.cc +++ b/projects/rccl/src/misc/api_trace.cc @@ -135,7 +135,7 @@ ncclMemFree_impl(void* ptr); ncclResult_t mscclLoadAlgo_impl(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, - int rank); + const ncclComm_t comm); ncclResult_t mscclRunAlgo_impl(const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[], @@ -380,7 +380,7 @@ NCCL_API(ncclResult_t, ncclMemAlloc, void** ptr, size_t size); NCCL_API(ncclResult_t, ncclMemFree, void* ptr); NCCL_API(ncclResult_t, mscclLoadAlgo, const char* mscclAlgoFilePath, - mscclAlgoHandle_t* mscclAlgoHandle, int rank); + mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm); NCCL_API(ncclResult_t, mscclRunAlgo, const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[], void* recvBuff, const size_t recvCounts[], @@ -620,10 +620,10 @@ ncclMemFree(void* ptr) } ncclResult_t -mscclLoadAlgo(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, int rank) +mscclLoadAlgo(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm) { return ::rccl::RcclGetFunctionTable()->mscclLoadAlgo_fn(mscclAlgoFilePath, - mscclAlgoHandle, rank); + mscclAlgoHandle, comm); } ncclResult_t diff --git a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc index da4ad3d14c..15062f6acf 100644 --- a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc +++ b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc @@ -22,6 +22,7 @@ #include "msccl/msccl_setup.h" #include "msccl/msccl_status.h" +#include "rccl/rccl.h" #ifdef ENABLE_MSCCLPP #include "mscclpp/mscclpp_nccl.h" #endif @@ -59,8 +60,8 @@ bool mscclIsCaller() { return mscclGetThreadLocalStatus().mscclIsCallerFlag; } -bool mscclAvailable(int rank) { - return mscclEnabled() && mscclInitialized(rank); +bool mscclAvailable(const ncclComm_t comm) { + return mscclEnabled() && mscclInitialized(comm); } static bool allProcessHostsUnique(ncclComm_t comm) { @@ -118,7 +119,7 @@ static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-tes static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) { static thread_local bool mscclAlgoMetaLoaded = false; - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); int maxNchannels = *numChannelsRequired; *numChannelsRequired = 0; @@ -213,7 +214,7 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) { return ncclSuccess; } - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); bool useInternalScheduler = false; const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv); @@ -243,12 +244,14 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) { return ncclSuccess; } -ncclResult_t mscclInit(ncclComm_t comm) { +ncclResult_t mscclInit(const ncclComm_t comm) { { - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); // freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection - if (!mscclInitialized(comm->rank)) { + if (!mscclInitialized(comm)) { + static std::mutex initMutex; + std::lock_guard lock(initMutex); status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS); for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) { status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1; @@ -264,13 +267,16 @@ ncclResult_t mscclInit(ncclComm_t comm) { auto &m = status.algoMetas[i]; if (m.nRanks == comm->nRanks) { // Load algorithms - if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) { + mscclAlgoHandle_t mscclAlgoHandle; + { static std::mutex loadAlgoMutex; std::lock_guard lock(loadAlgoMutex); - NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank)); + if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()){ + NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm)); + } + // Connect algorithms + mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank]; } - // Connect algorithms - mscclAlgoHandle_t mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank]; if (status.connectedAlgos[comm].find(mscclAlgoHandle) == status.connectedAlgos[comm].end()) { NCCLCHECK(mscclSetupConnections(status.hostAlgos[mscclAlgoHandle], comm)); status.connectedAlgos[comm].insert(mscclAlgoHandle); @@ -278,17 +284,20 @@ ncclResult_t mscclInit(ncclComm_t comm) { } } } + { + static std::mutex mscclInitMutex; + std::lock_guard lock(mscclInitMutex); + if (mscclInitialized(comm)){ + return ncclSuccess; + } - if (mscclInitialized(comm->rank)) { - return ncclSuccess; + status.workIndex = 1; + NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS)); + status.lastStream = nullptr; + NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus))); + + mscclSetInitialized(comm); } - - status.workIndex = 1; - NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS)); - status.lastStream = nullptr; - NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus))); - - mscclSetInitialized(comm->rank); } INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize()); @@ -304,8 +313,8 @@ ncclResult_t mscclGroupStart() { return ncclSuccess; } -static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) { - mscclStatus& status = mscclGetStatus(rank); +static ncclResult_t mscclInternalSchedulerSelectAlgo(const ncclComm_t comm, struct mscclSchedulerParam* param) { + mscclStatus& status = mscclGetStatus(comm); param->scheduled = false; // Current MSCCL doesn't support pre/post op @@ -350,13 +359,13 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSched } static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) { - mscclStatus& status = mscclGetStatus(param->comm->rank); + mscclStatus& status = mscclGetStatus(param->comm); if (status.mscclSchedulerPtr) { NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p))); } else { // Disable MSCCL algorithms if machine type is not matching if (param->comm->topo->mscclEnabled || mscclForceEnabled()) { - NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p))); + NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm, &(param->p))); } else { param->p.scheduled = false; } @@ -520,7 +529,7 @@ ncclResult_t mscclEnqueueCheck( #ifdef ENABLE_MSCCLPP if (comm->mscclppCompatible) { INFO(NCCL_COLL, "MSCCL++: reading capture status"); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); + NCCLCHECK(mscclGetCaptureStatus(comm, stream)); const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff); const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff); @@ -566,7 +575,7 @@ ncclResult_t mscclEnqueueCheck( #ifdef ENABLE_MSCCLPP if (comm->mscclppCompatible) { INFO(NCCL_COLL, "MSCCL++: reading capture status"); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); + NCCLCHECK(mscclGetCaptureStatus(comm, stream)); const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff); const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff); @@ -631,8 +640,8 @@ ncclResult_t mscclGroupEnd() { return ncclSuccess; } -static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) { - mscclStatus& status = mscclGetStatus(rank); +static ncclResult_t mscclInternalUnloadAlgo(const ncclComm_t comm, mscclAlgoHandle_t mscclAlgoHandle) { + mscclStatus& status = mscclGetStatus(comm); free(status.hostAlgos[mscclAlgoHandle]); status.hostAlgos.erase(mscclAlgoHandle); @@ -649,12 +658,12 @@ static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlg return ncclSuccess; } -static ncclResult_t mscclInternalSchedulerTeardown(int rank) { +static ncclResult_t mscclInternalSchedulerTeardown(const ncclComm_t comm) { ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess; - mscclStatus& status = mscclGetStatus(rank); + mscclStatus& status = mscclGetStatus(comm); for (auto &m : status.rankToAlgoHandles) { for (auto &p : m) { - tmpRet = mscclInternalUnloadAlgo(rank, p.second); + tmpRet = mscclInternalUnloadAlgo(comm, p.second); if (ret == ncclSuccess) { ret = tmpRet; } @@ -665,13 +674,13 @@ static ncclResult_t mscclInternalSchedulerTeardown(int rank) { return ret; } -ncclResult_t mscclTeardown(int rank) { +ncclResult_t mscclTeardown(const ncclComm_t comm) { { - if (!mscclInitialized(rank)) { - mscclRemoveRank(rank); + if (!mscclInitialized(comm)) { + mscclRemoveRank(comm); return ncclSuccess; } - mscclStatus& status = mscclGetStatus(rank); + mscclStatus& status = mscclGetStatus(comm); for (auto &p : status.hostAlgos) { free(p.second); status.freeAlgoHandles.push_back(p.first); @@ -694,14 +703,14 @@ ncclResult_t mscclTeardown(int rank) { dlclose(status.mscclSchedulerLib); status.mscclSchedulerLib = nullptr; } else { - NCCLCHECK(mscclInternalSchedulerTeardown(rank)); + NCCLCHECK(mscclInternalSchedulerTeardown(comm)); } NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus))); for (auto &p : status.graphWorkFifoStatus) { NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second))); } - mscclSetInitialized(rank, false); - mscclRemoveRank(rank); + mscclSetInitialized(comm, false); + mscclRemoveRank(comm); } INFO(NCCL_INIT, "MSCCL: Teardown finished"); diff --git a/projects/rccl/src/misc/msccl/msccl_setup.cc b/projects/rccl/src/misc/msccl/msccl_setup.cc index bd5fd07aa7..ed5075d049 100644 --- a/projects/rccl/src/misc/msccl/msccl_setup.cc +++ b/projects/rccl/src/misc/msccl/msccl_setup.cc @@ -22,10 +22,10 @@ static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nC return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop; } -ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(rank); +ncclResult_t mscclGetCaptureStatus(const ncclComm_t comm, hipStream_t stream) { + mscclStatus& status = mscclGetStatus(comm); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); - mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(rank); + mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm); cudaStreamCaptureStatus captureStatus; unsigned long long captureId; CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &captureStatus, &captureId, &threadLocalStatus.graph, nullptr, nullptr)); @@ -47,7 +47,7 @@ ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) { } ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) { - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS; status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1; status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1; @@ -72,8 +72,8 @@ ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) { return ncclSuccess; } -ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(rank); +ncclResult_t mscclSetupSyncFlags(const ncclComm_t comm, hipStream_t stream) { + mscclStatus& status = mscclGetStatus(comm); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); if (threadLocalStatus.captureStatus == mscclNewCapture || status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) { @@ -85,7 +85,7 @@ ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) { } ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) { - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); // Check whether there are enough channels if (hostAlgo->nChannels > comm->nChannels) { @@ -122,7 +122,7 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) } static ncclResult_t mscclSetupProxyImpl(struct mscclAlgo* hostAlgo, ncclComm_t comm) { - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); struct ncclProxyOp proxyOp = {}; proxyOp.connIndex = 0; @@ -183,12 +183,12 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) { } ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); - mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm->rank); + mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm); if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) { INFO(NCCL_NET, "mscclSetupProxy: reading capture status"); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); + NCCLCHECK(mscclGetCaptureStatus(comm, stream)); } if (threadLocalStatus.captureStatus == mscclNoCapture) { INFO(NCCL_NET,"mscclSetupProxy: no capture\n"); @@ -412,7 +412,7 @@ RCCL_PARAM(MscclForceFullOps, "MSCCL_FORCE_FULLOPS", 0); ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count, ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo, ncclComm_t comm, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); if (status.lastStream != stream && status.lastStream != nullptr) { @@ -481,7 +481,7 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) { INFO(NCCL_NET, "MSCCL: reading capture status"); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); + NCCLCHECK(mscclGetCaptureStatus(comm, stream)); } mscclWorkFifoStatus* workFifoStatus = nullptr; if (threadLocalStatus.captureStatus == mscclNoCapture) { diff --git a/projects/rccl/src/misc/msccl/msccl_status.cc b/projects/rccl/src/misc/msccl/msccl_status.cc index f2b26663b7..497ec98b70 100644 --- a/projects/rccl/src/misc/msccl/msccl_status.cc +++ b/projects/rccl/src/misc/msccl/msccl_status.cc @@ -7,61 +7,66 @@ #include "msccl/msccl_struct.h" #include "debug.h" - +#include "comm.h" #include #include #include + using namespace std; struct mscclRankState { - int rank; bool initialized; mscclStatus status; mscclSavedProxyArgs savedProxyArgs; - mscclRankState() : rank(-1), initialized(false), status(), savedProxyArgs() {} + mscclRankState() : initialized(false), status(), savedProxyArgs() {} explicit mscclRankState(const mscclRankState&) = default; }; static mutex rankStatesMutex; -static unordered_map> rankStates; +/* + * @brief rankStates is intended to hold mscclRankState for each communicator in a rccl process. + * "rankStates" is not threadsafe, hence read/writes on this data strcutures need to be handled explicitly by + * block of code that is accessing the elements in this map using a lock guard or any mutual exclusion device. + */ +static unordered_map> rankStates; -static inline mscclRankState& mscclGetRankState(int rank) { - // In the unlikely case of negative rank, return a per-thread state - if (rank < 0) { +static inline mscclRankState& mscclGetRankState(const ncclComm_t comm) { + //the following condition comm == nullptr evaluates true when mscclAvailable() called with default params + if (comm == nullptr) { static thread_local shared_ptr threadRankState(new mscclRankState()); return *threadRankState; } lock_guard lock(rankStatesMutex); - auto rankStateIt = rankStates.find(rank); + auto rankStateIt = rankStates.find(comm); if (rankStateIt == rankStates.end()) { // Create a per rank threadRankState rather than per thread shared_ptr newthreadRankState(new mscclRankState()); - newthreadRankState->rank = rank; - rankStateIt = rankStates.insert(make_pair(rank, newthreadRankState)).first; + // newthreadRankState->rank = rank; + rankStateIt = rankStates.insert(make_pair(comm, newthreadRankState)).first; } return *(rankStateIt->second); } -bool mscclInitialized(int rank) { - return mscclGetRankState(rank).initialized; +bool mscclInitialized(const ncclComm_t comm) { + return mscclGetRankState(comm).initialized; } -void mscclSetInitialized(int rank, bool initialized) { - auto& state = mscclGetRankState(rank); +void mscclSetInitialized(const ncclComm_t comm, bool initialized) { + auto& state = mscclGetRankState(comm); assert(!initialized || !state.initialized); state.initialized = initialized; } -void mscclRemoveRank(int rank) { +void mscclRemoveRank(const ncclComm_t comm) { lock_guard lock(rankStatesMutex); - rankStates.erase(rank); + rankStates.erase(comm); } -mscclStatus& mscclGetStatus(int rank) { - return mscclGetRankState(rank).status; +mscclStatus& mscclGetStatus(const ncclComm_t comm) { + return mscclGetRankState(comm).status; } mscclThreadLocalStatus& mscclGetThreadLocalStatus() { @@ -69,6 +74,6 @@ mscclThreadLocalStatus& mscclGetThreadLocalStatus() { return threadLocalStatus; } -mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank) { - return mscclGetRankState(rank).savedProxyArgs; +mscclSavedProxyArgs& mscclGetSavedProxyArgs(const ncclComm_t comm) { + return mscclGetRankState(comm).savedProxyArgs; } diff --git a/projects/rccl/src/msccl.cc b/projects/rccl/src/msccl.cc index f19c6f863f..706c971808 100644 --- a/projects/rccl/src/msccl.cc +++ b/projects/rccl/src/msccl.cc @@ -14,10 +14,10 @@ using namespace rccl; -NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank); -ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) { +NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm); +ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm) { Recorder::instance().record("mscclLoadAlgo"); - mscclStatus& status = mscclGetStatus(rank); + mscclStatus& status = mscclGetStatus(comm); if (status.freeAlgoHandles.size() == 0) { WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS); @@ -28,7 +28,7 @@ ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t struct mscclAlgo* hostAlgo; NCCLCHECK(ncclCalloc(&hostAlgo, 1)); - NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, rank)); + NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, comm->rank)); status.hostAlgos[*mscclAlgoHandle] = hostAlgo; struct mscclAlgo* devAlgo; @@ -53,7 +53,7 @@ ncclResult_t mscclRunAlgo_impl( NVTX3_FUNC_WITH_PARAMS(MSCCL, NcclNvtxParamsMSCCL, NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(dataType), op, dataType)); - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm); struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle]; struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle]; @@ -65,13 +65,13 @@ ncclResult_t mscclRunAlgo_impl( CUDACHECK(hipSetDevice(comm->cudaDev)); - NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); + NCCLCHECK(mscclGetCaptureStatus(comm, stream)); NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType)); NCCLCHECK(mscclSetupScratch(hostAlgo, stream)); - NCCLCHECK(mscclSetupSyncFlags(comm->rank, stream)); + NCCLCHECK(mscclSetupSyncFlags(comm, stream)); NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream)); diff --git a/projects/rccl/src/nccl.h.in b/projects/rccl/src/nccl.h.in index e6c51e05db..3700667437 100644 --- a/projects/rccl/src/nccl.h.in +++ b/projects/rccl/src/nccl.h.in @@ -5,7 +5,6 @@ * * See LICENSE.txt for license information ************************************************************************/ - #ifndef NCCL_H_ #define NCCL_H_ @@ -764,10 +763,10 @@ typedef int mscclAlgoHandle_t; @param[in] mscclAlgoFilePath Path to MSCCL algorithm file @param[out] mscclAlgoHandle Returned handle to MSCCL algorithm - @param[in] rank Current rank */ -ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank); + @param[in] comm Current rank's communicator */ +ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm); /*! @cond include_hidden */ -ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank); +ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm); /*! @endcond */ /*! @brief MSCCL Run Algorithm