diff --git a/Dockerfile b/Dockerfile index 80c22bb..d45331d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,10 +62,18 @@ RUN echo "import sys, re" > patch_strix.py && \ echo "txt = p.read_text()" >> patch_strix.py && \ echo "header = 'import sys\nfrom unittest.mock import MagicMock\nsys.modules[\"amdsmi\"] = MagicMock()\n'" >> patch_strix.py && \ echo "txt = header + txt" >> patch_strix.py && \ - echo "txt = re.sub(r'device_type = .*', 'device_type = \"rocm\"', txt)" >> patch_strix.py && \ - echo "txt = re.sub(r'device_name = .*', 'device_name = \"gfx1151\"', txt)" >> patch_strix.py && \ + echo "txt = re.sub(r'device_type = .*', 'device_type = \"rocm\"', txt)" >> patch_strix.py && \ + echo "txt = re.sub(r'device_name = .*', 'device_name = \"gfx1151\"', txt)" >> patch_strix.py && \ echo "txt += '\n def get_device_name(self, device_id: int = 0) -> str:\n return \"AMD-gfx1151\"\n'" >> patch_strix.py && \ echo "p.write_text(txt)" >> patch_strix.py && \ + # Patch 3: quant_utils.cuh (Fix for Clang 22+ string comparison error) + echo "p = Path('csrc/quantization/w8a8/fp8/amd/quant_utils.cuh')" >> patch_strix.py && \ + echo "if p.exists():" >> patch_strix.py && \ + echo " txt = p.read_text()" >> patch_strix.py && \ + echo " txt = txt.replace('if (KV_DTYPE == \"auto\")', 'if (std::string(KV_DTYPE) == \"auto\")')" >> patch_strix.py && \ + echo " txt = txt.replace('if (KV_DTYPE == \"fp8\" || KV_DTYPE == \"fp8_e4m3\")', 'if (std::string(KV_DTYPE) == \"fp8\" || std::string(KV_DTYPE) == \"fp8_e4m3\")')" >> patch_strix.py && \ + echo " p.write_text(txt)" >> patch_strix.py && \ + # End of Patch 3 echo "print('Successfully patched vLLM for Strix Halo')" >> patch_strix.py && \ python patch_strix.py && \ sed -i 's/gfx1200;gfx1201/gfx1151/' CMakeLists.txt