diff --git a/scripts/patch_strix.py b/scripts/patch_strix.py index 94c8c33..056d23f 100644 --- a/scripts/patch_strix.py +++ b/scripts/patch_strix.py @@ -25,6 +25,7 @@ def patch_vllm(): txt = p_rocm.read_text() header = 'import sys\nfrom unittest.mock import MagicMock\nsys.modules["amdsmi"] = MagicMock()\n' txt = header + txt + txt = txt.replace('def _get_gcn_arch() -> str:', 'def _get_gcn_arch() -> str:\n return "gfx1151"\n\ndef _old_get_gcn_arch() -> str:') txt = re.sub(r'device_type = .*', 'device_type = "rocm"', txt) txt = re.sub(r'device_name = .*', 'device_name = "gfx1151"', txt) txt += '\n def get_device_name(self, device_id: int = 0) -> str:\n return "AMD-gfx1151"\n'