Fully disable MSCCL when machine is not matched (#1017)

* Disable MSCCL algorithm meta loading when machine is not matched

* fully disable init

* fix potential segfault

[ROCm/rccl commit: 655742a3a6]
This commit is contained in:
Ziyue Yang
2023-12-14 00:36:21 +08:00
gecommit door GitHub
bovenliggende 918ce6c2e2
commit e4b63a8ba0
4 gewijzigde bestanden met toevoegingen van 17 en 4 verwijderingen
+2 -1
Bestand weergeven
@@ -579,7 +579,8 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa
}
int minNchannels = ncclMinNchannels();
if (mscclEnabled()) {
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
int mscclNumChannelsRequired = 0;
mscclSchedulerInit(comm, &mscclNumChannelsRequired);
minNchannels = std::max(minNchannels, mscclNumChannelsRequired);
@@ -11,6 +11,7 @@
#include "msccl/msccl_struct.h"
bool mscclEnabled();
bool mscclForceEnabled();
void mscclSetIsCallerFlag();
void mscclClearIsCallerFlag();
+4 -2
Bestand weergeven
@@ -1568,7 +1568,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
// Call devCommSetup before the last barrier, making sure we don't have a thread running in front and starting to
// launch NCCL kernels before all cuda mem allocation is complete. That could cause a deadlock.
NCCLCHECKGOTO(devCommSetup(comm), ret, fail);
if (mscclEnabled()) {
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
NCCLCHECK(mscclInit(comm));
mscclStatus& status = mscclGetStatus();
status.needsProxy |= mscclNeedsProxy;
@@ -2141,6 +2142,7 @@ fail:
static ncclResult_t commCleanup(ncclComm_t comm) {
int savedDevice;
int commDevice = comm->cudaDev;
bool mscclEnabledForTopo = comm->topo->mscclEnabled;
CUDACHECK(cudaGetDevice(&savedDevice));
if (savedDevice != commDevice) {
@@ -2164,7 +2166,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) {
NCCLCHECK(NpKit::Shutdown());
#endif
if (mscclEnabled()) {
if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) {
NCCLCHECK(mscclTeardown());
}
@@ -36,6 +36,14 @@ bool mscclEnabled() {
#endif
}
bool mscclForceEnabled() {
#ifdef COMPILE_MSCCL_KERNEL
return rcclParamMscclForceEnabled();
#else
return false;
#endif
}
void mscclSetIsCallerFlag() {
mscclGetThreadLocalStatus().mscclIsCallerFlag = true;
}
@@ -309,7 +317,8 @@ static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* pa
if (status.mscclSchedulerPtr) {
NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p)));
} else {
if (param->comm->topo->mscclEnabled || rcclParamMscclForceEnabled()) {
// Disable MSCCL algorithms if machine type is not matching
if (param->comm->topo->mscclEnabled || mscclForceEnabled()) {
NCCLCHECK(mscclInternalSchedulerSelectAlgo(&(param->p)));
} else {
param->p.scheduled = false;