diff --git a/scripts/patch_strix.py b/scripts/patch_strix.py index ec8b5cb..94c8c33 100644 --- a/scripts/patch_strix.py +++ b/scripts/patch_strix.py @@ -39,7 +39,6 @@ def patch_vllm(): #ifndef C10_CUDA_CHECK #define C10_CUDA_CHECK(error) do { if (error != cudaSuccess) { abort(); } } while(0) #endif -#define getCurrentHIPStreamMasqueradingAsCUDA getCurrentCUDAStream """ # Apply to all .cu and .hip files in csrc csrc_files = glob.glob('csrc/**/*.cu', recursive=True) + glob.glob('csrc/**/*.hip', recursive=True) @@ -49,7 +48,7 @@ def patch_vllm(): if p_f.exists(): txt = p_f.read_text() # Only prepend if not already patched to avoid duplicate macros - if "getCurrentHIPStreamMasqueradingAsCUDA getCurrentCUDAStream" not in txt: + if "C10_CUDA_CHECK" not in txt: p_f.write_text(macro_def + '\n' + txt) patched_csrc_count += 1