From 1ddcb9a2027ff15d2d0a343c7213eb102916cfed Mon Sep 17 00:00:00 2001 From: Donato Capitella Date: Mon, 2 Feb 2026 15:40:16 +0000 Subject: [PATCH] feat: Configure ROCm attention via `--attention-backend` CLI argument, disable the Ray dashboard, and make eager mode configurable for cluster benchmarks. --- benchmarks/run_vllm_bench.py | 15 ++++++++++----- benchmarks/vllm_cluster_bench.py | 12 +++++------- scripts/cluster_manager.py | 4 ++-- scripts/start_vllm.py | 4 +--- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/benchmarks/run_vllm_bench.py b/benchmarks/run_vllm_bench.py index b1dfc32..cf7bc27 100644 --- a/benchmarks/run_vllm_bench.py +++ b/benchmarks/run_vllm_bench.py @@ -146,6 +146,10 @@ def run_throughput(model, tp_size, backend_name="Default", output_dir=RESULTS_DI ]) cmd.extend(dataset_args) + # Force Attention Backend via CLI if ROCm-Attn + if backend_name == "ROCm-Attn": + cmd.extend(["--attention-backend", "ROCM_ATTN"]) + # ENV Setup: Global + Model Specific env = os.environ.copy() @@ -209,10 +213,11 @@ if __name__ == "__main__": # 1. Default (Triton) run_throughput(m, tp, "Default", RESULTS_DIR) - # 2. ROCm Attention - run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm", { - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1", - "VLLM_USE_TRITON_FLASH_ATTN": "0" - }) + # 2. ROCm Attention + # We force this via CLI argument --attention-backend ROCM_ATTN below + # No specific env vars needed if forcing backend. + rocm_env = {} + print(f"[DEBUG] Forcing ROCm Env: {rocm_env} + CLI: --attention-backend ROCM_ATTN") + run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm", rocm_env) print_summary(valid_tp_args) diff --git a/benchmarks/vllm_cluster_bench.py b/benchmarks/vllm_cluster_bench.py index b9876a2..f3b0c0a 100755 --- a/benchmarks/vllm_cluster_bench.py +++ b/benchmarks/vllm_cluster_bench.py @@ -158,9 +158,7 @@ def get_model_args(model): if config.get("trust_remote"): cmd.append("--trust-remote-code") - # ALWAYS Enforce Eager Mode for Cluster Benchmarks (TP=2) - # Distributed Graph Capture is unstable/prone to hangs on Strix Halo Cluster - cmd.append("--enforce-eager") + if config.get("enforce_eager"): cmd.append("--enforce-eager") return cmd @@ -194,6 +192,9 @@ def run_bench_set(model, backend_name, output_dir, extra_env=None): ]) cmd.extend(dataset_args) + if backend_name == "ROCm-Attn": + cmd.extend(["--attention-backend", "ROCM_ATTN"]) + env = get_cluster_env() # Model specific envs @@ -227,10 +228,7 @@ def run_cluster_throughput(model): model, "ROCm-Attn", "benchmark_results_rocm", - extra_env={ - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1", - "VLLM_USE_TRITON_FLASH_ATTN": "0" - } + extra_env={} ) diff --git a/scripts/cluster_manager.py b/scripts/cluster_manager.py index db249c8..45a4b9e 100644 --- a/scripts/cluster_manager.py +++ b/scripts/cluster_manager.py @@ -59,7 +59,7 @@ def setup_worker_node(worker_ip, head_ip): export NCCL_IB_TIMEOUT=23 export NCCL_IB_RETRY_CNT=7 echo "Starting Ray Worker on {worker_ip} connecting to {head_ip}..." - ray start --address='{head_ip}:6379' --num-gpus=1 --num-cpus=8 --disable-usage-stats + ray start --address='{head_ip}:6379' --num-gpus=1 --num-cpus=8 --disable-usage-stats --include-dashboard=false """ print(f"Setting up Worker Node ({worker_ip})...") @@ -97,7 +97,7 @@ def setup_head_node(head_ip): export NCCL_IB_TIMEOUT=23 export NCCL_IB_RETRY_CNT=7 echo "Starting Ray Head on {head_ip}..." - ray start --head --port=6379 --node-ip-address={head_ip} --num-gpus=1 --num-cpus=8 --disable-usage-stats + ray start --head --port=6379 --node-ip-address={head_ip} --num-gpus=1 --num-cpus=8 --disable-usage-stats --include-dashboard=false """ try: diff --git a/scripts/start_vllm.py b/scripts/start_vllm.py index 864c95b..0c33b2f 100644 --- a/scripts/start_vllm.py +++ b/scripts/start_vllm.py @@ -309,9 +309,7 @@ def configure_and_launch(model_idx, gpu_count): env.update(config.get("env", {})) if use_rocm_attn: - env["VLLM_V1_USE_PREFILL_DECODE_ATTENTION"] = "1" - env["VLLM_USE_TRITON_FLASH_ATTN"] = "0" - # Optional: Explicitly mention these in print + cmd.extend(["--attention-backend", "ROCM_ATTN"]) print("\n" + "="*60)