From 655742a3a6b1ba1430372f1c96fa1ad99abdf904 Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Thu, 14 Dec 2023 00:36:21 +0800 Subject: [PATCH] Fully disable MSCCL when machine is not matched (#1017) * Disable MSCCL algorithm meta loading when machine is not matched * fully disable init * fix potential segfault --- src/graph/connect.cc | 3 ++- src/include/msccl/msccl_lifecycle.h | 1 + src/init.cc | 6 ++++-- src/misc/msccl/msccl_lifecycle.cc | 11 ++++++++++- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/graph/connect.cc b/src/graph/connect.cc index 6ebd801379..80c84fe811 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -579,7 +579,8 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa } int minNchannels = ncclMinNchannels(); - if (mscclEnabled()) { + + if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) { int mscclNumChannelsRequired = 0; mscclSchedulerInit(comm, &mscclNumChannelsRequired); minNchannels = std::max(minNchannels, mscclNumChannelsRequired); diff --git a/src/include/msccl/msccl_lifecycle.h b/src/include/msccl/msccl_lifecycle.h index f8a20ac7b8..eac4f3a7ac 100644 --- a/src/include/msccl/msccl_lifecycle.h +++ b/src/include/msccl/msccl_lifecycle.h @@ -11,6 +11,7 @@ #include "msccl/msccl_struct.h" bool mscclEnabled(); +bool mscclForceEnabled(); void mscclSetIsCallerFlag(); void mscclClearIsCallerFlag(); diff --git a/src/init.cc b/src/init.cc index 36a8db183b..cdf0389b2d 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1568,7 +1568,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p // Call devCommSetup before the last barrier, making sure we don't have a thread running in front and starting to // launch NCCL kernels before all cuda mem allocation is complete. That could cause a deadlock. NCCLCHECKGOTO(devCommSetup(comm), ret, fail); - if (mscclEnabled()) { + + if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) { NCCLCHECK(mscclInit(comm)); mscclStatus& status = mscclGetStatus(); status.needsProxy |= mscclNeedsProxy; @@ -2141,6 +2142,7 @@ fail: static ncclResult_t commCleanup(ncclComm_t comm) { int savedDevice; int commDevice = comm->cudaDev; + bool mscclEnabledForTopo = comm->topo->mscclEnabled; CUDACHECK(cudaGetDevice(&savedDevice)); if (savedDevice != commDevice) { @@ -2164,7 +2166,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) { NCCLCHECK(NpKit::Shutdown()); #endif - if (mscclEnabled()) { + if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) { NCCLCHECK(mscclTeardown()); } diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index 3085403249..3ec91e77f5 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -36,6 +36,14 @@ bool mscclEnabled() { #endif } +bool mscclForceEnabled() { +#ifdef COMPILE_MSCCL_KERNEL + return rcclParamMscclForceEnabled(); +#else + return false; +#endif +} + void mscclSetIsCallerFlag() { mscclGetThreadLocalStatus().mscclIsCallerFlag = true; } @@ -309,7 +317,8 @@ static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* pa if (status.mscclSchedulerPtr) { NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p))); } else { - if (param->comm->topo->mscclEnabled || rcclParamMscclForceEnabled()) { + // Disable MSCCL algorithms if machine type is not matching + if (param->comm->topo->mscclEnabled || mscclForceEnabled()) { NCCLCHECK(mscclInternalSchedulerSelectAlgo(&(param->p))); } else { param->p.scheduled = false;