From a3dcfa8cb2ad5a407defb5436f7daad60eec0052 Mon Sep 17 00:00:00 2001 From: JoseSantosAMD Date: Wed, 25 Oct 2023 15:26:22 -0500 Subject: [PATCH] Fix dispatch notation Signed-off-by: JoseSantosAMD --- tests/test_profile_general.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_profile_general.py b/tests/test_profile_general.py index e41a94715f..7e4f540582 100644 --- a/tests/test_profile_general.py +++ b/tests/test_profile_general.py @@ -1378,8 +1378,8 @@ def test_dispatch_0(): for file in files_in_workload: if file.endswith(".csv"): file_dict[file] = pd.read_csv(workload_1 + "/" + file) - if not "sysinfo" in file: - assert len(file_dict[file].index) > 3 + if not "sysinfo" in file and not "roofline" in file: + assert len(file_dict[file].index) == 1 if soc == "mi200": print(sorted(list(file_dict.keys()))) assert sorted(list(file_dict.keys())) == ALL_CSVS_MI200 @@ -1405,8 +1405,7 @@ def test_dispatch_0_1(): "--path", workload_1, "--dispatch", - "0", - "1", + "0:2", "--", ] + app_1, @@ -1423,8 +1422,8 @@ def test_dispatch_0_1(): for file in files_in_workload: if file.endswith(".csv"): file_dict[file] = pd.read_csv(workload_1 + "/" + file) - if not "sysinfo" in file: - assert len(file_dict[file].index) > 3 + if not "sysinfo" in file and not "roofline" in file: + assert len(file_dict[file].index) == 2 if soc == "mi200": print(sorted(list(file_dict.keys()))) assert sorted(list(file_dict.keys())) == ALL_CSVS_MI200 @@ -1467,8 +1466,8 @@ def test_dispatch_2(): for file in files_in_workload: if file.endswith(".csv"): file_dict[file] = pd.read_csv(workload_1 + "/" + file) - if not "sysinfo" in file: - assert len(file_dict[file].index) > 3 + if not "sysinfo" in file and not "roofline" in file: + assert len(file_dict[file].index) == 1 if soc == "mi200": print(sorted(list(file_dict.keys()))) assert sorted(list(file_dict.keys())) == ALL_CSVS_MI200 @@ -1864,7 +1863,10 @@ def test_device_0(): if file.endswith(".csv"): file_dict[file] = pd.read_csv(workload_1 + "/" + file) if not "sysinfo" in file: - assert len(file_dict[file].index) > 3 + if "roofline" in file: + assert len(file_dict[file].index) + else: + assert len(file_dict[file].index) > 3 if soc == "mi200": print(sorted(list(file_dict.keys()))) assert sorted(list(file_dict.keys())) == ALL_CSVS_MI200