feat: Override _get_gcn_arch function to return "gfx1151" and rename the original implementation to _old_get_gcn_arch.
This commit is contained in:
@@ -25,6 +25,7 @@ def patch_vllm():
|
|||||||
txt = p_rocm.read_text()
|
txt = p_rocm.read_text()
|
||||||
header = 'import sys\nfrom unittest.mock import MagicMock\nsys.modules["amdsmi"] = MagicMock()\n'
|
header = 'import sys\nfrom unittest.mock import MagicMock\nsys.modules["amdsmi"] = MagicMock()\n'
|
||||||
txt = header + txt
|
txt = header + txt
|
||||||
|
txt = txt.replace('def _get_gcn_arch() -> str:', 'def _get_gcn_arch() -> str:\n return "gfx1151"\n\ndef _old_get_gcn_arch() -> str:')
|
||||||
txt = re.sub(r'device_type = .*', 'device_type = "rocm"', txt)
|
txt = re.sub(r'device_type = .*', 'device_type = "rocm"', txt)
|
||||||
txt = re.sub(r'device_name = .*', 'device_name = "gfx1151"', txt)
|
txt = re.sub(r'device_name = .*', 'device_name = "gfx1151"', txt)
|
||||||
txt += '\n def get_device_name(self, device_id: int = 0) -> str:\n return "AMD-gfx1151"\n'
|
txt += '\n def get_device_name(self, device_id: int = 0) -> str:\n return "AMD-gfx1151"\n'
|
||||||
|
|||||||
Reference in New Issue
Block a user