diff --git a/projects/rccl/src/include/rccl_common.h b/projects/rccl/src/include/rccl_common.h index 0e6b19b107..f7110690dc 100644 --- a/projects/rccl/src/include/rccl_common.h +++ b/projects/rccl/src/include/rccl_common.h @@ -91,4 +91,6 @@ void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable); void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize); ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount); ncclResult_t commSetUnrollFactor(struct ncclComm* comm); +bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, const char* archName); +int parseFirmwareVersion(const char* command); #endif diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 795105ee7d..c428488294 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -130,7 +130,24 @@ ncclResult_t initGdrCopy() { static ncclResult_t initResult = ncclSuccess; static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; +ncclResult_t checkHsaEnvSetting() { + const char* hsaScratchEnv = getenv("HSA_NO_SCRATCH_RECLAIM"); + int hipRuntimeVersion = 0; + // hipVer is an integer e.g., 6.2.41133 -> 60241133 + CUDACHECK(hipRuntimeGetVersion(&hipRuntimeVersion)); + const int firmwareVersion = parseFirmwareVersion("amd-smi firmware"); + hipDeviceProp_t devProp; + // use GPU0 should be good enough + CUDACHECK(hipGetDeviceProperties(&devProp, 0)); + INFO(NCCL_INIT, "Hipruntime version: %d, firmware version: %d", hipRuntimeVersion, firmwareVersion); + if (!validHsaScratchEnvSetting(hsaScratchEnv, hipRuntimeVersion, firmwareVersion, devProp.gcnArchName)) { + WARN("HSA_NO_SCRATCH_RECLAIM=1 must be set to avoid RCCL perf hit, rocm ver:%d", hipRuntimeVersion); + return ncclSystemError; + } + return ncclSuccess; +} static void initOnceFunc() { + NCCLCHECKGOTO(checkHsaEnvSetting(), initResult, exit); initEnv(); initGdrCopy(); // Always initialize bootstrap network @@ -175,13 +192,6 @@ static ncclResult_t ncclInit() { WARN("Missing \"HSA_FORCE_FINE_GRAIN_PCIE=1\" from environment which can lead to low RCCL performance, system instablity or hang!"); #endif } - const char* hsaScratchEnv = getenv("HSA_NO_SCRATCH_RECLAIM"); - int hipRuntimeVersion = 0; - // hipVer is an integer e.g., 6.2.41133 -> 60241133 - CUDACHECK(hipRuntimeGetVersion(&hipRuntimeVersion)); - if ((!hsaScratchEnv || strcmp(hsaScratchEnv,"1") != 0) && hipRuntimeVersion < 60400000){ - WARN("HSA_NO_SCRATCH_RECLAIM=1 must be set to avoid RCCL perf hit for rocm older than 6.4,, rocm ver:%d", hipRuntimeVersion); - } pthread_once(&initOnceControl, initOnceFunc); return initResult; } diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index 716b564d84..ece9f62f86 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -183,7 +183,7 @@ void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) { const char *inputStr = getenv("NCCL_PXN_DISABLE"); const bool archGfx942 = IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942"); const bool archGfx950 = IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950"); - comm->enableCustColl = (archGfx942 || archGfx950) && (inputStr && !atoi(inputStr)); + comm->enableCustColl = (archGfx942 || archGfx950) && (inputStr && !atoi(inputStr)); if((!archGfx942 && !archGfx950) || inputStr) { rcclPxnDisable = pxnDisable = RCCL_VALUE_INVALID; @@ -242,3 +242,82 @@ ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll+1); return ncclSuccess; } + +std::string trimString(const std::string& s) { + int sz = s.size(); + int b = 0; + int e = sz - 1; + while (b < sz && isspace(s[b])) { + b++; + } + if (b >= sz) { + return ""; + } + + while (e >= b && e < sz && isspace(s[e])) { + e--; + } + if (b > e) { + return ""; + } + return s.substr(b, e - b + 1); +} + +std::vector splitString(const std::string& s, char delimiter) { + std::vector tokens; + std::stringstream ss(s); + std::string token; + + while (std::getline(ss, token, delimiter)) { + tokens.push_back(trimString(token)); + } + return tokens; +} + +int parseFirmwareVersionImpl(FILE* file) { + constexpr std::size_t MAX_LINE_SZ = 1024; + char line[MAX_LINE_SZ]; + bool found_pattern = false; + while (fgets(line, MAX_LINE_SZ, file)) { + auto parts = splitString(line, ':'); + if (parts == std::vector{"FW_ID", "CP_MEC1"}) { + if (!found_pattern) { + found_pattern = true; + } + continue; + } + + if (found_pattern && (parts[0] == "FW_VERSION")) { + return stoi(parts[1]) & 0x7ff; + } + } + return -1; +} + +int parseFirmwareVersion(const char* command) { + auto file = popen(command, "r"); + if (file == nullptr) { + return -1; + } + int version = -1; + try { + version = parseFirmwareVersionImpl(file); + } catch (const std::exception& ex) { + } + pclose(file); + return version; +} + +bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, char const* archName) { + bool hsaScratchEnvSet = (hsaScratchEnv && strcmp(hsaScratchEnv,"1") == 0); + if (hsaScratchEnvSet) { + return true; + } + if (IsArchMatch(archName, "gfx950")) { + return (hipRuntimeVersion >= 60443484 && firmwareVersion >= 24); + } + if (IsArchMatch(archName, "gfx942")) { + return (hipRuntimeVersion >= 60443484 && firmwareVersion >= 177); + } + return true; +} diff --git a/projects/rccl/test/RcclWrapTests.cpp b/projects/rccl/test/RcclWrapTests.cpp index cdb1662667..31780c9b07 100644 --- a/projects/rccl/test/RcclWrapTests.cpp +++ b/projects/rccl/test/RcclWrapTests.cpp @@ -333,6 +333,30 @@ TEST(Rcclwrap, RcclUpdateCollectiveProtocol_SimpleFallbackWhenNoRanges) { delete comm; } +TEST(Rcclwrap, validHsaScratchEnvSettingTest) { + // When HSA_NO_SCRATCH_RECLAIM is set, it is always valid + EXPECT_TRUE(validHsaScratchEnvSetting("1", 0, 0, "gfx950")); + + EXPECT_TRUE(validHsaScratchEnvSetting("1", 0, 0, "gfx942")); + + // When HSA_NO_SCRATCH_RECLAIM is not set, looking at hip version and firmware version + EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60443484, 24, "gfx950")); + + EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443483, 24, "gfx950")); + + EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443484, 23, "gfx950")); + + EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60443484, 177, "gfx942")); + + EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443484, 176, "gfx942")); + + EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443483, 177, "gfx942")); + + EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60443483, 0, "gfx000")); + + EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60300000, 0, "gfx000")); +} + TEST(Rcclwrap, RcclUpdateThreadThreshold_UserEnvSet) { const char *value = getenv("NCCL_THREAD_THRESHOLDS"); @@ -1687,4 +1711,4 @@ TEST(Rcclwrap, PXN_ZeroRanks_GFX950) { CleanupMockComm(mockComm); } -} // namespace RcclUnitTesting \ No newline at end of file +} // namespace RcclUnitTesting