diff --git a/benchmarks/vllm_cluster_bench.py b/benchmarks/vllm_cluster_bench.py index 555c070..26c2f62 100755 --- a/benchmarks/vllm_cluster_bench.py +++ b/benchmarks/vllm_cluster_bench.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import subprocess, time, json, sys, os, requests, argparse +import subprocess, time, json, sys, os, requests, argparse, re from pathlib import Path # ========================= @@ -47,15 +47,14 @@ MODELS_TO_RUN = models.MODELS_TO_RUN def log(msg): print(f"\n[CLUSTER-BENCH] {msg}") -def check_ray_status(): - """Checks if Ray cluster is active with at least 2 nodes.""" +def get_ray_nodes(): + """Returns a list of active Ray node IPs.""" try: res = subprocess.run(["ray", "status"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if res.returncode != 0: - return False + return [] - # Basic check for 2 nodes - active_nodes = 0 + nodes = [] in_active_section = False for line in res.stdout.splitlines(): if "Active:" in line: @@ -64,12 +63,22 @@ def check_ray_status(): if "Pending:" in line or "Recent failures:" in line: in_active_section = False - if in_active_section and line.strip().startswith("1 node_"): - active_nodes += 1 + if in_active_section: + # Look for "1 node_" pattern + # Existing logic checked for startswith("1 node_") + # We use regex to be robust and capture the IP + match = re.search(r"node_(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", line) + if match: + nodes.append(match.group(1)) - return active_nodes >= 2 + return nodes except: - return False + return [] + +def check_ray_status(): + """Checks if Ray cluster is active with at least 2 nodes.""" + nodes = get_ray_nodes() + return len(nodes) >= 2 def get_net_iface(ip_prefix="192.168.100"): """ @@ -95,16 +104,48 @@ def get_local_ip(iface): except: return "127.0.0.1" +def nuke_vllm_cache_on_node(ip, is_local=False): + """Clears vLLM cache on a specific node.""" + cmd_str = f"Locally" if is_local else f"on {ip}" + print(f"Clearing vLLM cache {cmd_str}...", end="", flush=True) + + try: + if is_local: + cache = Path.home() / ".cache" / "vllm" + if cache.exists(): + subprocess.run(["rm", "-rf", str(cache)], check=True) + cache.mkdir(parents=True, exist_ok=True) + else: + # Remote SSH + ssh_cmd = [ + "ssh", "-o", "StrictHostKeyChecking=no", ip, + "rm -rf ~/.cache/vllm && mkdir -p ~/.cache/vllm" + ] + subprocess.run(ssh_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + print(" Done.") + except Exception as e: + print(f" Failed ({e}).") + def nuke_vllm_cache(): - cache = Path.home() / ".cache" / "vllm" - if cache.exists(): - try: - print(f"Clearing vLLM cache...", end="", flush=True) - subprocess.run(["rm", "-rf", str(cache)], check=True) - cache.mkdir(parents=True, exist_ok=True) - print(" Done.") - time.sleep(2) - except: pass + """Clears vLLM cache on ALL cluster nodes.""" + nodes = get_ray_nodes() + rdma_iface = get_net_iface() + local_ip = get_local_ip(rdma_iface) + + # If no nodes found (unexpected if we are running bench), try just local + if not nodes: + nuke_vllm_cache_on_node(local_ip, is_local=True) + return + + for node_ip in nodes: + # Check if node is local + # Simple string match, but IPs might vary (localhost vs 192.168...) + # We trust get_local_ip returns the IP used in the cluster (192.168.100.x) + is_local = (node_ip == local_ip) or (node_ip == "127.0.0.1") + nuke_vllm_cache_on_node(node_ip, is_local) + + time.sleep(2) def get_dataset(): # Same as original @@ -161,9 +202,9 @@ def get_model_args(model): if config.get("trust_remote"): cmd.append("--trust-remote-code") - # Respect config for Eager Mode (Apple-to-Apples with TP=1) - if config.get("enforce_eager"): - cmd.append("--enforce-eager") + # ALWAYS Enforce Eager Mode for Cluster Benchmarks (TP=2) + # Distributed Graph Capture is unstable/prone to hangs on Strix Halo Cluster + cmd.append("--enforce-eager") return cmd