diff --git a/src/include/rccl_common.h b/src/include/rccl_common.h index 542149f6b0..029180df6e 100644 --- a/src/include/rccl_common.h +++ b/src/include/rccl_common.h @@ -95,5 +95,5 @@ void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize); ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount); ncclResult_t commSetUnrollFactor(struct ncclComm* comm); bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, const char* archName); -int parseFirmwareVersion(const char* command); +int parseFirmwareVersion(); #endif diff --git a/src/init.cc b/src/init.cc index c428488294..727af9dbcd 100644 --- a/src/init.cc +++ b/src/init.cc @@ -131,14 +131,21 @@ static ncclResult_t initResult = ncclSuccess; static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; ncclResult_t checkHsaEnvSetting() { + // get user-specified value for `HSA_NO_SCRATCH_RECLAIM` const char* hsaScratchEnv = getenv("HSA_NO_SCRATCH_RECLAIM"); + int hipRuntimeVersion = 0; // hipVer is an integer e.g., 6.2.41133 -> 60241133 CUDACHECK(hipRuntimeGetVersion(&hipRuntimeVersion)); - const int firmwareVersion = parseFirmwareVersion("amd-smi firmware"); + + // using rocm-smi API to query FW version, instead of parsing CLI output + // will switch to amd-smi API soon + const int firmwareVersion = parseFirmwareVersion(); + hipDeviceProp_t devProp; // use GPU0 should be good enough CUDACHECK(hipGetDeviceProperties(&devProp, 0)); + INFO(NCCL_INIT, "Hipruntime version: %d, firmware version: %d", hipRuntimeVersion, firmwareVersion); if (!validHsaScratchEnvSetting(hsaScratchEnv, hipRuntimeVersion, firmwareVersion, devProp.gcnArchName)) { WARN("HSA_NO_SCRATCH_RECLAIM=1 must be set to avoid RCCL perf hit, rocm ver:%d", hipRuntimeVersion); diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index 55f6e40dc3..05a9cb3aed 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -24,6 +24,7 @@ THE SOFTWARE. #include "comm.h" #include "graph/topo.h" #include "enqueue.h" +#include "rocm_smi/rocm_smi.h" // Use this param to experiment pipelining new data types besides bfloat16 // Make sure you generate the device code with the new data type (i.e. in generate.py) @@ -342,37 +343,26 @@ std::vector splitString(const std::string& s, char delimiter) { return tokens; } -int parseFirmwareVersionImpl(FILE* file) { - constexpr std::size_t MAX_LINE_SZ = 1024; - char line[MAX_LINE_SZ]; - bool found_pattern = false; - while (fgets(line, MAX_LINE_SZ, file)) { - auto parts = splitString(line, ':'); - if (parts == std::vector{"FW_ID", "CP_MEC1"}) { - if (!found_pattern) { - found_pattern = true; - } - continue; - } +int parseFirmwareVersionImpl() { + uint64_t fw_version = -1; - if (found_pattern && (parts[0] == "FW_VERSION")) { - return stoi(parts[1]) & 0x7ff; - } - } - return -1; + // using rocm-smi APIs for now to query MEC FW version + // will switch to amd-smi APIs soon + rsmi_status_t ret; + ret = rsmi_init(0); + if (ret != RSMI_STATUS_SUCCESS) return -1; + ret = rsmi_dev_firmware_version_get(0, RSMI_FW_BLOCK_MEC, &fw_version); + if (ret != RSMI_STATUS_SUCCESS) return -1; + + return fw_version; } -int parseFirmwareVersion(const char* command) { - auto file = popen(command, "r"); - if (file == nullptr) { - return -1; - } +int parseFirmwareVersion() { int version = -1; try { - version = parseFirmwareVersionImpl(file); + version = parseFirmwareVersionImpl(); } catch (const std::exception& ex) { } - pclose(file); return version; }