[rocm_regression] Return errors when HSA_NO_SCRATCH_RECLAIM=1 even for rocm>=6.4.0 (#1867)

* [rocm_regression] Return errors when HSA_NO_SCRATCH_RECLAIM=1 even for rocm >= 6.4.0
* [rocm_regression] Check firmware version
* [rocm_regression] Resolve review comments
* [rocm_regression] Move hsa env checking into init once func
* [rocm_regression] Prevent hot fix version in firmware
* [rocm_regression] Improve unit tests

[ROCm/rccl commit: 361d596229]
Tento commit je obsažen v:
ycui1984
2025-08-29 09:18:23 -07:00
odevzdal GitHub
rodič 21278d2073
revize 1999f2eba8
4 změnil soubory, kde provedl 124 přidání a 9 odebrání
+2
Zobrazit soubor
@@ -91,4 +91,6 @@ void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable);
void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize);
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
ncclResult_t commSetUnrollFactor(struct ncclComm* comm);
bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, const char* archName);
int parseFirmwareVersion(const char* command);
#endif
+17 -7
Zobrazit soubor
@@ -130,7 +130,24 @@ ncclResult_t initGdrCopy() {
static ncclResult_t initResult = ncclSuccess;
static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT;
ncclResult_t checkHsaEnvSetting() {
const char* hsaScratchEnv = getenv("HSA_NO_SCRATCH_RECLAIM");
int hipRuntimeVersion = 0;
// hipVer is an integer e.g., 6.2.41133 -> 60241133
CUDACHECK(hipRuntimeGetVersion(&hipRuntimeVersion));
const int firmwareVersion = parseFirmwareVersion("amd-smi firmware");
hipDeviceProp_t devProp;
// use GPU0 should be good enough
CUDACHECK(hipGetDeviceProperties(&devProp, 0));
INFO(NCCL_INIT, "Hipruntime version: %d, firmware version: %d", hipRuntimeVersion, firmwareVersion);
if (!validHsaScratchEnvSetting(hsaScratchEnv, hipRuntimeVersion, firmwareVersion, devProp.gcnArchName)) {
WARN("HSA_NO_SCRATCH_RECLAIM=1 must be set to avoid RCCL perf hit, rocm ver:%d", hipRuntimeVersion);
return ncclSystemError;
}
return ncclSuccess;
}
static void initOnceFunc() {
NCCLCHECKGOTO(checkHsaEnvSetting(), initResult, exit);
initEnv();
initGdrCopy();
// Always initialize bootstrap network
@@ -175,13 +192,6 @@ static ncclResult_t ncclInit() {
WARN("Missing \"HSA_FORCE_FINE_GRAIN_PCIE=1\" from environment which can lead to low RCCL performance, system instablity or hang!");
#endif
}
const char* hsaScratchEnv = getenv("HSA_NO_SCRATCH_RECLAIM");
int hipRuntimeVersion = 0;
// hipVer is an integer e.g., 6.2.41133 -> 60241133
CUDACHECK(hipRuntimeGetVersion(&hipRuntimeVersion));
if ((!hsaScratchEnv || strcmp(hsaScratchEnv,"1") != 0) && hipRuntimeVersion < 60400000){
WARN("HSA_NO_SCRATCH_RECLAIM=1 must be set to avoid RCCL perf hit for rocm older than 6.4,, rocm ver:%d", hipRuntimeVersion);
}
pthread_once(&initOnceControl, initOnceFunc);
return initResult;
}
+80 -1
Zobrazit soubor
@@ -183,7 +183,7 @@ void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) {
const char *inputStr = getenv("NCCL_PXN_DISABLE");
const bool archGfx942 = IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942");
const bool archGfx950 = IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950");
comm->enableCustColl = (archGfx942 || archGfx950) && (inputStr && !atoi(inputStr));
comm->enableCustColl = (archGfx942 || archGfx950) && (inputStr && !atoi(inputStr));
if((!archGfx942 && !archGfx950) || inputStr) {
rcclPxnDisable = pxnDisable = RCCL_VALUE_INVALID;
@@ -242,3 +242,82 @@ ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll+1);
return ncclSuccess;
}
std::string trimString(const std::string& s) {
int sz = s.size();
int b = 0;
int e = sz - 1;
while (b < sz && isspace(s[b])) {
b++;
}
if (b >= sz) {
return "";
}
while (e >= b && e < sz && isspace(s[e])) {
e--;
}
if (b > e) {
return "";
}
return s.substr(b, e - b + 1);
}
std::vector<std::string> splitString(const std::string& s, char delimiter) {
std::vector<std::string> tokens;
std::stringstream ss(s);
std::string token;
while (std::getline(ss, token, delimiter)) {
tokens.push_back(trimString(token));
}
return tokens;
}
int parseFirmwareVersionImpl(FILE* file) {
constexpr std::size_t MAX_LINE_SZ = 1024;
char line[MAX_LINE_SZ];
bool found_pattern = false;
while (fgets(line, MAX_LINE_SZ, file)) {
auto parts = splitString(line, ':');
if (parts == std::vector<std::string>{"FW_ID", "CP_MEC1"}) {
if (!found_pattern) {
found_pattern = true;
}
continue;
}
if (found_pattern && (parts[0] == "FW_VERSION")) {
return stoi(parts[1]) & 0x7ff;
}
}
return -1;
}
int parseFirmwareVersion(const char* command) {
auto file = popen(command, "r");
if (file == nullptr) {
return -1;
}
int version = -1;
try {
version = parseFirmwareVersionImpl(file);
} catch (const std::exception& ex) {
}
pclose(file);
return version;
}
bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, char const* archName) {
bool hsaScratchEnvSet = (hsaScratchEnv && strcmp(hsaScratchEnv,"1") == 0);
if (hsaScratchEnvSet) {
return true;
}
if (IsArchMatch(archName, "gfx950")) {
return (hipRuntimeVersion >= 60443484 && firmwareVersion >= 24);
}
if (IsArchMatch(archName, "gfx942")) {
return (hipRuntimeVersion >= 60443484 && firmwareVersion >= 177);
}
return true;
}
+25 -1
Zobrazit soubor
@@ -333,6 +333,30 @@ TEST(Rcclwrap, RcclUpdateCollectiveProtocol_SimpleFallbackWhenNoRanges) {
delete comm;
}
TEST(Rcclwrap, validHsaScratchEnvSettingTest) {
// When HSA_NO_SCRATCH_RECLAIM is set, it is always valid
EXPECT_TRUE(validHsaScratchEnvSetting("1", 0, 0, "gfx950"));
EXPECT_TRUE(validHsaScratchEnvSetting("1", 0, 0, "gfx942"));
// When HSA_NO_SCRATCH_RECLAIM is not set, looking at hip version and firmware version
EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60443484, 24, "gfx950"));
EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443483, 24, "gfx950"));
EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443484, 23, "gfx950"));
EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60443484, 177, "gfx942"));
EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443484, 176, "gfx942"));
EXPECT_FALSE(validHsaScratchEnvSetting(nullptr, 60443483, 177, "gfx942"));
EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60443483, 0, "gfx000"));
EXPECT_TRUE(validHsaScratchEnvSetting(nullptr, 60300000, 0, "gfx000"));
}
TEST(Rcclwrap, RcclUpdateThreadThreshold_UserEnvSet) {
const char *value = getenv("NCCL_THREAD_THRESHOLDS");
@@ -1687,4 +1711,4 @@ TEST(Rcclwrap, PXN_ZeroRanks_GFX950) {
CleanupMockComm(mockComm);
}
} // namespace RcclUnitTesting
} // namespace RcclUnitTesting