Moved mscclpp_ncclGetUniqueId call into ncclCommInitRankFunc (#1332)
* Moved call to `mscclpp_ncclGetUniqueId` into `ncclCommInitRankFunc` to avoid setting up transport early in environments where MSCCL++ isn't valid. * Checking `mscclEnabled` for the process and the topology to gate MSCCL++. * Allowed `mscclForceEnable` to enable MSCCL++.
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
ad94c651ad
Коммит
853a0586b4
@@ -49,4 +49,6 @@ namespace std {
|
||||
|
||||
bool operator ==(const ncclUniqueId& a, const ncclUniqueId& b);
|
||||
|
||||
bool mscclppCommCompatible(ncclComm_t comm);
|
||||
|
||||
#endif
|
||||
|
||||
+30
-33
@@ -189,25 +189,6 @@ ncclResult_t ncclGetUniqueId_impl(ncclUniqueId* out) {
|
||||
NCCLCHECK(PtrCheck(out, "GetUniqueId", "out"));
|
||||
ncclResult_t res = bootstrapGetUniqueId((struct ncclBootstrapHandle*)out);
|
||||
TRACE_CALL("ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out));
|
||||
if (rcclParamMscclppEnabled()) {
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
NCCLCHECK(res);
|
||||
int dev;
|
||||
CUDACHECK(cudaGetDevice(&dev));
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, dev));
|
||||
if (IsArchMatch(devProp.gcnArchName, "gfx94")) {
|
||||
auto& mscclppUniqueId = mscclpp_uniqueIdMap[*out];
|
||||
res = mscclpp_ncclGetUniqueId(&mscclppUniqueId);
|
||||
TRACE_CALL("mscclpp_ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(mscclppUniqueId));
|
||||
mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(*out);
|
||||
} else {
|
||||
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
|
||||
}
|
||||
#else
|
||||
WARN("MSCCL++: Feature not enabled. ENABLE_MSCCLPP must be defined at compile-time to enable this feature.");
|
||||
#endif
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1982,27 +1963,43 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
auto& mscclppUniqueId = mscclpp_uniqueIdMap[origUniqueId];
|
||||
mscclpp_uniqueIdMap[job->commId] = mscclppUniqueId;
|
||||
mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(job->commId);
|
||||
ncclCommToUniqueIdMap[comm] = job->commId;
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
if (rcclParamMscclppEnabled()) {
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, cudaDev));
|
||||
comm->mscclppCompatible = IsArchMatch(devProp.gcnArchName, "gfx94");
|
||||
if (comm->mscclppCompatible) {
|
||||
auto& mscclppUniqueId = mscclpp_uniqueIdMap[job->commId];
|
||||
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, &mscclppUniqueId, sizeof(mscclppUniqueId)), res, fail);
|
||||
unsigned long long mscclppUniqueIdHash; (void)mscclppUniqueIdHash;
|
||||
TRACE_CALL("bootstrapIntraNodeBroadcast(rank=%d, nranks=%d, root=%d, bcastData=hash:0x%llx)", comm->localRank, comm->localRanks, 0, (mscclppUniqueIdHash = (unsigned long long)hashUniqueId(mscclppUniqueId)));
|
||||
comm->mscclpp_threshold = rcclParamMscclppThreshold();
|
||||
INFO(NCCL_INIT, "MSCCL++: Enabled! Msg size threshold=%zu", comm->mscclpp_threshold);
|
||||
NCCLCHECKGOTO(mscclpp_ncclCommInitRank(&(comm->mscclpp_comm), job->nranks, mscclppUniqueId, job->myrank), res, fail);
|
||||
TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank);
|
||||
mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId;
|
||||
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled()) && mscclppCommCompatible(comm)) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, cudaDev));
|
||||
comm->mscclppCompatible = IsArchMatch(devProp.gcnArchName, "gfx94");
|
||||
if (comm->mscclppCompatible) {
|
||||
bool mapContainsId = (mscclpp_uniqueIdMap.count(job->commId) > 0);
|
||||
auto& mscclppUniqueId = mscclpp_uniqueIdMap[job->commId];
|
||||
if (comm->localRank == 0 && !mapContainsId) {
|
||||
NCCLCHECKGOTO(mscclpp_ncclGetUniqueId(&mscclppUniqueId), res, fail);
|
||||
TRACE_CALL("mscclpp_ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(mscclppUniqueId));
|
||||
}
|
||||
|
||||
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, &mscclppUniqueId, sizeof(mscclppUniqueId)), res, fail);
|
||||
unsigned long long mscclppUniqueIdHash; (void)mscclppUniqueIdHash;
|
||||
TRACE_CALL("bootstrapIntraNodeBroadcast(rank=%d, nranks=%d, root=%d, bcastData=hash:0x%llx)", comm->localRank, comm->localRanks, 0, (mscclppUniqueIdHash = (unsigned long long)hashUniqueId(mscclppUniqueId)));
|
||||
mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(job->commId);
|
||||
|
||||
comm->mscclpp_threshold = rcclParamMscclppThreshold();
|
||||
INFO(NCCL_INIT, "MSCCL++: Enabled! Msg size threshold=%zu", comm->mscclpp_threshold);
|
||||
|
||||
NCCLCHECKGOTO(mscclpp_ncclCommInitRank(&(comm->mscclpp_comm), job->nranks, mscclppUniqueId, job->myrank), res, fail);
|
||||
TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank);
|
||||
mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId;
|
||||
ncclCommToUniqueIdMap[comm] = job->commId;
|
||||
} else {
|
||||
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
|
||||
}
|
||||
} else {
|
||||
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
|
||||
comm->mscclppCompatible = false;
|
||||
WARN("MSCCL++: Cannot enable MSCCL++; environment is not MSCCL compatible");
|
||||
}
|
||||
#else
|
||||
WARN("MSCCL++: Feature not enabled. ENABLE_MSCCLPP must be defined at compile-time to enable this feature.");
|
||||
|
||||
@@ -84,6 +84,12 @@ static bool mscclCommCompatible(ncclComm_t comm) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
bool mscclppCommCompatible(ncclComm_t comm) {
|
||||
return mscclCommCompatible(comm);
|
||||
}
|
||||
#endif
|
||||
|
||||
const char *mscclFuncNames[] = {
|
||||
"mscclFuncReduce",
|
||||
"mscclFuncBroadcast",
|
||||
|
||||
Ссылка в новой задаче
Block a user