diff --git a/src/omniperf_soc/soc_base.py b/src/omniperf_soc/soc_base.py index 197b107fb0..a576c4b789 100644 --- a/src/omniperf_soc/soc_base.py +++ b/src/omniperf_soc/soc_base.py @@ -29,7 +29,7 @@ import shutil import glob import re import numpy as np -from utils.utils import demarcate, console_debug, console_log +from utils.utils import demarcate, console_debug, console_log, console_error from pathlib import Path from omniperf_base import SUPPORTED_ARCHS @@ -178,14 +178,13 @@ class OmniSoC_Base: self._mspec.gpu_model = list(SUPPORTED_ARCHS[self._mspec.gpu_arch].values())[ 0 ][0] - if (self._mspec.gpu_arch == "gfx942") and ( - "MI300A" in "\n".join(self._mspec._rocminfo) - ): - self._mspec.gpu_model = "MI300A_A1" - if (self._mspec.gpu_arch == "gfx942") and ( - "MI300A" not in "\n".join(self._mspec._rocminfo) - ): - self._mspec.gpu_model = "MI300X_A1" + if (self._mspec.gpu_arch == "gfx942"): + if "MI300A" in "\n".join(self._mspec._rocminfo): + self._mspec.gpu_model = "MI300A_A1" + elif "MI300X" in "\n".join(self._mspec._rocminfo): + self._mspec.gpu_model = "MI300X_A1" + else: + console_error("Cannot parse MI300 details from rocminfo. Please verify output.") self._mspec.num_xcd = str( total_xcds(self._mspec.gpu_model, self._mspec.compute_partition)