diff --git a/src/include/mscclpp/mscclpp_nccl.h b/src/include/mscclpp/mscclpp_nccl.h index fbc1b87e04..c4dc7dfa0c 100644 --- a/src/include/mscclpp/mscclpp_nccl.h +++ b/src/include/mscclpp/mscclpp_nccl.h @@ -49,4 +49,6 @@ namespace std { bool operator ==(const ncclUniqueId& a, const ncclUniqueId& b); +bool mscclppCommCompatible(ncclComm_t comm); + #endif diff --git a/src/init.cc b/src/init.cc index 0c3297f0d3..c5beab45dd 100644 --- a/src/init.cc +++ b/src/init.cc @@ -189,25 +189,6 @@ ncclResult_t ncclGetUniqueId_impl(ncclUniqueId* out) { NCCLCHECK(PtrCheck(out, "GetUniqueId", "out")); ncclResult_t res = bootstrapGetUniqueId((struct ncclBootstrapHandle*)out); TRACE_CALL("ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out)); - if (rcclParamMscclppEnabled()) { -#ifdef ENABLE_MSCCLPP - NCCLCHECK(res); - int dev; - CUDACHECK(cudaGetDevice(&dev)); - hipDeviceProp_t devProp; - CUDACHECK(hipGetDeviceProperties(&devProp, dev)); - if (IsArchMatch(devProp.gcnArchName, "gfx94")) { - auto& mscclppUniqueId = mscclpp_uniqueIdMap[*out]; - res = mscclpp_ncclGetUniqueId(&mscclppUniqueId); - TRACE_CALL("mscclpp_ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(mscclppUniqueId)); - mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(*out); - } else { - WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName); - } -#else - WARN("MSCCL++: Feature not enabled. ENABLE_MSCCLPP must be defined at compile-time to enable this feature."); -#endif - } return res; } @@ -1982,27 +1963,43 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { auto& mscclppUniqueId = mscclpp_uniqueIdMap[origUniqueId]; mscclpp_uniqueIdMap[job->commId] = mscclppUniqueId; mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(job->commId); + ncclCommToUniqueIdMap[comm] = job->commId; } } else #endif if (rcclParamMscclppEnabled()) { #ifdef ENABLE_MSCCLPP - hipDeviceProp_t devProp; - CUDACHECK(hipGetDeviceProperties(&devProp, cudaDev)); - comm->mscclppCompatible = IsArchMatch(devProp.gcnArchName, "gfx94"); - if (comm->mscclppCompatible) { - auto& mscclppUniqueId = mscclpp_uniqueIdMap[job->commId]; - NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, &mscclppUniqueId, sizeof(mscclppUniqueId)), res, fail); - unsigned long long mscclppUniqueIdHash; (void)mscclppUniqueIdHash; - TRACE_CALL("bootstrapIntraNodeBroadcast(rank=%d, nranks=%d, root=%d, bcastData=hash:0x%llx)", comm->localRank, comm->localRanks, 0, (mscclppUniqueIdHash = (unsigned long long)hashUniqueId(mscclppUniqueId))); - comm->mscclpp_threshold = rcclParamMscclppThreshold(); - INFO(NCCL_INIT, "MSCCL++: Enabled! Msg size threshold=%zu", comm->mscclpp_threshold); - NCCLCHECKGOTO(mscclpp_ncclCommInitRank(&(comm->mscclpp_comm), job->nranks, mscclppUniqueId, job->myrank), res, fail); - TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank); - mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId; + if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled()) && mscclppCommCompatible(comm)) { + hipDeviceProp_t devProp; + CUDACHECK(hipGetDeviceProperties(&devProp, cudaDev)); + comm->mscclppCompatible = IsArchMatch(devProp.gcnArchName, "gfx94"); + if (comm->mscclppCompatible) { + bool mapContainsId = (mscclpp_uniqueIdMap.count(job->commId) > 0); + auto& mscclppUniqueId = mscclpp_uniqueIdMap[job->commId]; + if (comm->localRank == 0 && !mapContainsId) { + NCCLCHECKGOTO(mscclpp_ncclGetUniqueId(&mscclppUniqueId), res, fail); + TRACE_CALL("mscclpp_ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(mscclppUniqueId)); + } + + NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, &mscclppUniqueId, sizeof(mscclppUniqueId)), res, fail); + unsigned long long mscclppUniqueIdHash; (void)mscclppUniqueIdHash; + TRACE_CALL("bootstrapIntraNodeBroadcast(rank=%d, nranks=%d, root=%d, bcastData=hash:0x%llx)", comm->localRank, comm->localRanks, 0, (mscclppUniqueIdHash = (unsigned long long)hashUniqueId(mscclppUniqueId))); + mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(job->commId); + + comm->mscclpp_threshold = rcclParamMscclppThreshold(); + INFO(NCCL_INIT, "MSCCL++: Enabled! Msg size threshold=%zu", comm->mscclpp_threshold); + + NCCLCHECKGOTO(mscclpp_ncclCommInitRank(&(comm->mscclpp_comm), job->nranks, mscclppUniqueId, job->myrank), res, fail); + TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank); + mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId; + ncclCommToUniqueIdMap[comm] = job->commId; + } else { + WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName); + } } else { - WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName); + comm->mscclppCompatible = false; + WARN("MSCCL++: Cannot enable MSCCL++; environment is not MSCCL compatible"); } #else WARN("MSCCL++: Feature not enabled. ENABLE_MSCCLPP must be defined at compile-time to enable this feature."); diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index 2b179fccaf..d186e0c47f 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -84,6 +84,12 @@ static bool mscclCommCompatible(ncclComm_t comm) { return true; } +#ifdef ENABLE_MSCCLPP +bool mscclppCommCompatible(ncclComm_t comm) { + return mscclCommCompatible(comm); +} +#endif + const char *mscclFuncNames[] = { "mscclFuncReduce", "mscclFuncBroadcast",