diff --git a/Dockerfile b/Dockerfile index 63d626d..50620b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -46,44 +46,8 @@ RUN git clone https://github.com/vllm-project/vllm.git /opt/vllm WORKDIR /opt/vllm # --- PATCHING --- -RUN echo "import sys, re" > patch_strix.py && \ - echo "from pathlib import Path" >> patch_strix.py && \ - # Patch 1: __init__.py - echo "p = Path('vllm/platforms/__init__.py')" >> patch_strix.py && \ - echo "txt = p.read_text()" >> patch_strix.py && \ - echo "txt = txt.replace('import amdsmi', '# import amdsmi')" >> patch_strix.py && \ - echo "txt = re.sub(r'is_rocm = .*', 'is_rocm = True', txt)" >> patch_strix.py && \ - echo "txt = re.sub(r'if len\(amdsmi\.amdsmi_get_processor_handles\(\)\) > 0:', 'if True:', txt)" >> patch_strix.py && \ - echo "txt = txt.replace('amdsmi.amdsmi_init()', 'pass')" >> patch_strix.py && \ - echo "txt = txt.replace('amdsmi.amdsmi_shut_down()', 'pass')" >> patch_strix.py && \ - echo "p.write_text(txt)" >> patch_strix.py && \ - # Patch 2: rocm.py - echo "p = Path('vllm/platforms/rocm.py')" >> patch_strix.py && \ - echo "txt = p.read_text()" >> patch_strix.py && \ - echo "header = 'import sys\nfrom unittest.mock import MagicMock\nsys.modules[\"amdsmi\"] = MagicMock()\n'" >> patch_strix.py && \ - echo "txt = header + txt" >> patch_strix.py && \ - echo "txt = re.sub(r'device_type = .*', 'device_type = \"rocm\"', txt)" >> patch_strix.py && \ - echo "txt = re.sub(r'device_name = .*', 'device_name = \"gfx1151\"', txt)" >> patch_strix.py && \ - echo "txt += '\n def get_device_name(self, device_id: int = 0) -> str:\n return \"AMD-gfx1151\"\n'" >> patch_strix.py && \ - echo "p.write_text(txt)" >> patch_strix.py && \ - # Patch 3: Fix C10_HIP_CHECK undeclared identifier by patching source files - echo "import glob" >> patch_strix.py && \ - echo "for f in glob.glob('csrc/mamba/mamba_ssm/selective_scan_*.cu') + glob.glob('csrc/mamba/mamba_ssm/selective_scan_*.hip'):" >> patch_strix.py && \ - echo " p_f = Path(f)" >> patch_strix.py && \ - echo " if p_f.exists():" >> patch_strix.py && \ - echo " txt = p_f.read_text()" >> patch_strix.py && \ - echo " macro_def = '''" >> patch_strix.py && \ - echo "#ifndef C10_HIP_CHECK" >> patch_strix.py && \ - echo "#define C10_HIP_CHECK(error) do { if (error != hipSuccess) { abort(); } } while(0)" >> patch_strix.py && \ - echo "#endif" >> patch_strix.py && \ - echo "#ifndef C10_CUDA_CHECK" >> patch_strix.py && \ - echo "#define C10_CUDA_CHECK(error) do { if (error != cudaSuccess) { abort(); } } while(0)" >> patch_strix.py && \ - echo "#endif" >> patch_strix.py && \ - echo "'''" >> patch_strix.py && \ - echo " p_f.write_text(macro_def + '\\n' + txt)" >> patch_strix.py && \ - # ----------- - echo "print('Successfully patched vLLM for Strix Halo')" >> patch_strix.py && \ - python patch_strix.py && \ +COPY scripts/patch_strix.py /opt/vllm/patch_strix.py +RUN python /opt/vllm/patch_strix.py && \ sed -i 's/gfx1200;gfx1201/gfx1151/' CMakeLists.txt # 7. Build vLLM (Wheel Method) with CLANG Host Compiler diff --git a/scripts/patch_strix.py b/scripts/patch_strix.py new file mode 100644 index 0000000..ec8b5cb --- /dev/null +++ b/scripts/patch_strix.py @@ -0,0 +1,60 @@ +import sys +import re +import glob +from pathlib import Path +from unittest.mock import MagicMock + +def patch_vllm(): + print("Applying Strix Halo patches to vLLM...") + + # Patch 1: vllm/platforms/__init__.py + p_init = Path('vllm/platforms/__init__.py') + if p_init.exists(): + txt = p_init.read_text() + txt = txt.replace('import amdsmi', '# import amdsmi') + txt = re.sub(r'is_rocm = .*', 'is_rocm = True', txt) + txt = re.sub(r'if len\(amdsmi\.amdsmi_get_processor_handles\(\)\) > 0:', 'if True:', txt) + txt = txt.replace('amdsmi.amdsmi_init()', 'pass') + txt = txt.replace('amdsmi.amdsmi_shut_down()', 'pass') + p_init.write_text(txt) + print(" -> Patched vllm/platforms/__init__.py") + + # Patch 2: vllm/platforms/rocm.py + p_rocm = Path('vllm/platforms/rocm.py') + if p_rocm.exists(): + txt = p_rocm.read_text() + header = 'import sys\nfrom unittest.mock import MagicMock\nsys.modules["amdsmi"] = MagicMock()\n' + txt = header + txt + txt = re.sub(r'device_type = .*', 'device_type = "rocm"', txt) + txt = re.sub(r'device_name = .*', 'device_name = "gfx1151"', txt) + txt += '\n def get_device_name(self, device_id: int = 0) -> str:\n return "AMD-gfx1151"\n' + p_rocm.write_text(txt) + print(" -> Patched vllm/platforms/rocm.py") + + # Patch 3: CUDA/HIP Macro injections for PyTorch Nightly Compatibility + macro_def = """ +#ifndef C10_HIP_CHECK +#define C10_HIP_CHECK(error) do { if (error != hipSuccess) { abort(); } } while(0) +#endif +#ifndef C10_CUDA_CHECK +#define C10_CUDA_CHECK(error) do { if (error != cudaSuccess) { abort(); } } while(0) +#endif +#define getCurrentHIPStreamMasqueradingAsCUDA getCurrentCUDAStream +""" + # Apply to all .cu and .hip files in csrc + csrc_files = glob.glob('csrc/**/*.cu', recursive=True) + glob.glob('csrc/**/*.hip', recursive=True) + patched_csrc_count = 0 + for f in csrc_files: + p_f = Path(f) + if p_f.exists(): + txt = p_f.read_text() + # Only prepend if not already patched to avoid duplicate macros + if "getCurrentHIPStreamMasqueradingAsCUDA getCurrentCUDAStream" not in txt: + p_f.write_text(macro_def + '\n' + txt) + patched_csrc_count += 1 + + print(f" -> Patched {patched_csrc_count} C/C++ source files with missing macros.") + print("Successfully patched vLLM for Strix Halo.") + +if __name__ == "__main__": + patch_vllm()