diff --git a/src/utils/specs.py b/src/utils/specs.py index bf2d49a7c8..696f850c0e 100644 --- a/src/utils/specs.py +++ b/src/utils/specs.py @@ -186,14 +186,23 @@ def get_machine_specs(devicenum): break if not rocmFound: - _rocm_path = os.getenv("ROCM_PATH", "/opt/rocm") - print("Error: Unable to detect a complete local ROCm installation.") - print( - "\nThe expected %s/.info/ versioning directory is missing. Please" - % _rocm_path - ) - print("ensure you have valid ROCm installation.") - sys.exit(1) + # check if ROCM_VER is supplied externally + ROCM_VER_USER = os.getenv("ROCM_VER") + if ROCM_VER_USER is not None: + print( + "Overriding missing ROCm version detection with ROCM_VER = %s" + % ROCM_VER_USER + ) + rocm_ver = ROCM_VER_USER + else: + _rocm_path = os.getenv("ROCM_PATH", "/opt/rocm") + print("Error: Unable to detect a complete local ROCm installation.") + print( + "\nThe expected %s/.info/ versioning directory is missing. Please" + % _rocm_path + ) + print("ensure you have valid ROCm installation.") + sys.exit(1) ( gpu_id,