diff --git a/Dockerfile b/Dockerfile index df1c427..3db7668 100644 --- a/Dockerfile +++ b/Dockerfile @@ -46,8 +46,11 @@ RUN git clone https://github.com/vllm-project/vllm.git /opt/vllm WORKDIR /opt/vllm # --- PATCHING --- +# https://github.com/kyuz0/amd-strix-halo-vllm-toolboxes/issues/28 COPY scripts/patch_strix.py /opt/vllm/patch_strix.py +COPY scripts/fix_block_size.py /opt/vllm/fix_block_size.py RUN python /opt/vllm/patch_strix.py && \ + python /opt/vllm/fix_block_size.py && \ sed -i 's/gfx1200;gfx1201/gfx1151/' CMakeLists.txt # 7. Build vLLM (Wheel Method) with CLANG Host Compiler diff --git a/scripts/fix_block_size.py b/scripts/fix_block_size.py new file mode 100644 index 0000000..fd82761 --- /dev/null +++ b/scripts/fix_block_size.py @@ -0,0 +1,50 @@ +import sys +import re +from pathlib import Path + +def patch_block_size(): + # Fix for: https://github.com/kyuz0/amd-strix-halo-vllm-toolboxes/issues/28 + + # If a command line argument is provided, use that as the file path, + # otherwise default to the relative path in the source tree. + target = sys.argv[1] if len(sys.argv) > 1 else 'vllm/v1/attention/backend.py' + p = Path(target) + + # Some installations might have it in site-packages, search for it if absolute path isn't given + if not p.exists(): + # Let's see if we can find it in site-packages if we just run it inside a container + try: + import vllm + p = Path(vllm.__path__[0]) / 'v1' / 'attention' / 'backend.py' + except ImportError: + pass + + if not p.exists(): + print(f"File not found: {p}") + sys.exit(1) + + txt, count = re.subn( + r'(\s*valid_sizes\s*=\s*get_args\(BlockSize\)\s*\n\s*if\s*block_size\s*not\s*in\s*valid_sizes:\s*\n\s*return\s*False\s*\n*)', + '\n', + p.read_text(), + flags=re.MULTILINE + ) + + if count > 0: + p.write_text(txt) + print(f" -> Patched {p} (removed {count} block size checks)") + else: + print(f" -> Warning: Block size check not found in {p}. Regex might be outdated or already patched.") + # Try a simpler replace just in case get_args is imported differently + txt, count2 = re.subn( + r'(\s*if\s*block_size\s*not\s*in\s*valid_sizes:\s*\n\s*return\s*False\s*\n*)', + '\n', + txt, + flags=re.MULTILINE + ) + if count2 > 0: + p.write_text(txt) + print(f" -> Patched {p} via fallback regex") + +if __name__ == "__main__": + patch_block_size()