Moved mscclpp_ncclGetUniqueId call into ncclCommInitRankFunc (#1332)

* Moved call to `mscclpp_ncclGetUniqueId` into `ncclCommInitRankFunc` to avoid setting up transport early in environments where MSCCL++ isn't valid.

* Checking `mscclEnabled` for the process and the topology to gate MSCCL++.

* Allowed `mscclForceEnable` to enable MSCCL++.
Этот коммит содержится в:
corey-derochie-amd
2024-09-16 16:41:40 -06:00
коммит произвёл GitHub
родитель ad94c651ad
Коммит 853a0586b4
3 изменённых файлов: 38 добавлений и 33 удалений
+2
Просмотреть файл
@@ -49,4 +49,6 @@ namespace std {
bool operator ==(const ncclUniqueId& a, const ncclUniqueId& b);
bool mscclppCommCompatible(ncclComm_t comm);
#endif
+30 -33
Просмотреть файл
@@ -189,25 +189,6 @@ ncclResult_t ncclGetUniqueId_impl(ncclUniqueId* out) {
NCCLCHECK(PtrCheck(out, "GetUniqueId", "out"));
ncclResult_t res = bootstrapGetUniqueId((struct ncclBootstrapHandle*)out);
TRACE_CALL("ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(*out));
if (rcclParamMscclppEnabled()) {
#ifdef ENABLE_MSCCLPP
NCCLCHECK(res);
int dev;
CUDACHECK(cudaGetDevice(&dev));
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, dev));
if (IsArchMatch(devProp.gcnArchName, "gfx94")) {
auto& mscclppUniqueId = mscclpp_uniqueIdMap[*out];
res = mscclpp_ncclGetUniqueId(&mscclppUniqueId);
TRACE_CALL("mscclpp_ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(mscclppUniqueId));
mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(*out);
} else {
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
}
#else
WARN("MSCCL++: Feature not enabled. ENABLE_MSCCLPP must be defined at compile-time to enable this feature.");
#endif
}
return res;
}
@@ -1982,27 +1963,43 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
auto& mscclppUniqueId = mscclpp_uniqueIdMap[origUniqueId];
mscclpp_uniqueIdMap[job->commId] = mscclppUniqueId;
mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(job->commId);
ncclCommToUniqueIdMap[comm] = job->commId;
}
}
else
#endif
if (rcclParamMscclppEnabled()) {
#ifdef ENABLE_MSCCLPP
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, cudaDev));
comm->mscclppCompatible = IsArchMatch(devProp.gcnArchName, "gfx94");
if (comm->mscclppCompatible) {
auto& mscclppUniqueId = mscclpp_uniqueIdMap[job->commId];
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, &mscclppUniqueId, sizeof(mscclppUniqueId)), res, fail);
unsigned long long mscclppUniqueIdHash; (void)mscclppUniqueIdHash;
TRACE_CALL("bootstrapIntraNodeBroadcast(rank=%d, nranks=%d, root=%d, bcastData=hash:0x%llx)", comm->localRank, comm->localRanks, 0, (mscclppUniqueIdHash = (unsigned long long)hashUniqueId(mscclppUniqueId)));
comm->mscclpp_threshold = rcclParamMscclppThreshold();
INFO(NCCL_INIT, "MSCCL++: Enabled! Msg size threshold=%zu", comm->mscclpp_threshold);
NCCLCHECKGOTO(mscclpp_ncclCommInitRank(&(comm->mscclpp_comm), job->nranks, mscclppUniqueId, job->myrank), res, fail);
TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank);
mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId;
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled()) && mscclppCommCompatible(comm)) {
hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, cudaDev));
comm->mscclppCompatible = IsArchMatch(devProp.gcnArchName, "gfx94");
if (comm->mscclppCompatible) {
bool mapContainsId = (mscclpp_uniqueIdMap.count(job->commId) > 0);
auto& mscclppUniqueId = mscclpp_uniqueIdMap[job->commId];
if (comm->localRank == 0 && !mapContainsId) {
NCCLCHECKGOTO(mscclpp_ncclGetUniqueId(&mscclppUniqueId), res, fail);
TRACE_CALL("mscclpp_ncclGetUniqueId(0x%llx)", (unsigned long long)hashUniqueId(mscclppUniqueId));
}
NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, 0, &mscclppUniqueId, sizeof(mscclppUniqueId)), res, fail);
unsigned long long mscclppUniqueIdHash; (void)mscclppUniqueIdHash;
TRACE_CALL("bootstrapIntraNodeBroadcast(rank=%d, nranks=%d, root=%d, bcastData=hash:0x%llx)", comm->localRank, comm->localRanks, 0, (mscclppUniqueIdHash = (unsigned long long)hashUniqueId(mscclppUniqueId)));
mscclpp_uniqueIdReverseMap[mscclppUniqueId].insert(job->commId);
comm->mscclpp_threshold = rcclParamMscclppThreshold();
INFO(NCCL_INIT, "MSCCL++: Enabled! Msg size threshold=%zu", comm->mscclpp_threshold);
NCCLCHECKGOTO(mscclpp_ncclCommInitRank(&(comm->mscclpp_comm), job->nranks, mscclppUniqueId, job->myrank), res, fail);
TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank);
mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId;
ncclCommToUniqueIdMap[comm] = job->commId;
} else {
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
}
} else {
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
comm->mscclppCompatible = false;
WARN("MSCCL++: Cannot enable MSCCL++; environment is not MSCCL compatible");
}
#else
WARN("MSCCL++: Feature not enabled. ENABLE_MSCCLPP must be defined at compile-time to enable this feature.");
+6
Просмотреть файл
@@ -84,6 +84,12 @@ static bool mscclCommCompatible(ncclComm_t comm) {
return true;
}
#ifdef ENABLE_MSCCLPP
bool mscclppCommCompatible(ncclComm_t comm) {
return mscclCommCompatible(comm);
}
#endif
const char *mscclFuncNames[] = {
"mscclFuncReduce",
"mscclFuncBroadcast",