feat: Improve Ray node detection, enable cluster-wide vLLM cache clearing, and enforce eager mode for benchmarks.
Tá an tiomantas seo le fáil i:
@@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import subprocess, time, json, sys, os, requests, argparse
|
import subprocess, time, json, sys, os, requests, argparse, re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
@@ -47,15 +47,14 @@ MODELS_TO_RUN = models.MODELS_TO_RUN
|
|||||||
|
|
||||||
def log(msg): print(f"\n[CLUSTER-BENCH] {msg}")
|
def log(msg): print(f"\n[CLUSTER-BENCH] {msg}")
|
||||||
|
|
||||||
def check_ray_status():
|
def get_ray_nodes():
|
||||||
"""Checks if Ray cluster is active with at least 2 nodes."""
|
"""Returns a list of active Ray node IPs."""
|
||||||
try:
|
try:
|
||||||
res = subprocess.run(["ray", "status"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
res = subprocess.run(["ray", "status"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
if res.returncode != 0:
|
if res.returncode != 0:
|
||||||
return False
|
return []
|
||||||
|
|
||||||
# Basic check for 2 nodes
|
nodes = []
|
||||||
active_nodes = 0
|
|
||||||
in_active_section = False
|
in_active_section = False
|
||||||
for line in res.stdout.splitlines():
|
for line in res.stdout.splitlines():
|
||||||
if "Active:" in line:
|
if "Active:" in line:
|
||||||
@@ -64,12 +63,22 @@ def check_ray_status():
|
|||||||
if "Pending:" in line or "Recent failures:" in line:
|
if "Pending:" in line or "Recent failures:" in line:
|
||||||
in_active_section = False
|
in_active_section = False
|
||||||
|
|
||||||
if in_active_section and line.strip().startswith("1 node_"):
|
if in_active_section:
|
||||||
active_nodes += 1
|
# Look for "1 node_<IP>" pattern
|
||||||
|
# Existing logic checked for startswith("1 node_")
|
||||||
|
# We use regex to be robust and capture the IP
|
||||||
|
match = re.search(r"node_(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", line)
|
||||||
|
if match:
|
||||||
|
nodes.append(match.group(1))
|
||||||
|
|
||||||
return active_nodes >= 2
|
return nodes
|
||||||
except:
|
except:
|
||||||
return False
|
return []
|
||||||
|
|
||||||
|
def check_ray_status():
|
||||||
|
"""Checks if Ray cluster is active with at least 2 nodes."""
|
||||||
|
nodes = get_ray_nodes()
|
||||||
|
return len(nodes) >= 2
|
||||||
|
|
||||||
def get_net_iface(ip_prefix="192.168.100"):
|
def get_net_iface(ip_prefix="192.168.100"):
|
||||||
"""
|
"""
|
||||||
@@ -95,16 +104,48 @@ def get_local_ip(iface):
|
|||||||
except:
|
except:
|
||||||
return "127.0.0.1"
|
return "127.0.0.1"
|
||||||
|
|
||||||
|
def nuke_vllm_cache_on_node(ip, is_local=False):
|
||||||
|
"""Clears vLLM cache on a specific node."""
|
||||||
|
cmd_str = f"Locally" if is_local else f"on {ip}"
|
||||||
|
print(f"Clearing vLLM cache {cmd_str}...", end="", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_local:
|
||||||
|
cache = Path.home() / ".cache" / "vllm"
|
||||||
|
if cache.exists():
|
||||||
|
subprocess.run(["rm", "-rf", str(cache)], check=True)
|
||||||
|
cache.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
# Remote SSH
|
||||||
|
ssh_cmd = [
|
||||||
|
"ssh", "-o", "StrictHostKeyChecking=no", ip,
|
||||||
|
"rm -rf ~/.cache/vllm && mkdir -p ~/.cache/vllm"
|
||||||
|
]
|
||||||
|
subprocess.run(ssh_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
print(" Done.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Failed ({e}).")
|
||||||
|
|
||||||
def nuke_vllm_cache():
|
def nuke_vllm_cache():
|
||||||
cache = Path.home() / ".cache" / "vllm"
|
"""Clears vLLM cache on ALL cluster nodes."""
|
||||||
if cache.exists():
|
nodes = get_ray_nodes()
|
||||||
try:
|
rdma_iface = get_net_iface()
|
||||||
print(f"Clearing vLLM cache...", end="", flush=True)
|
local_ip = get_local_ip(rdma_iface)
|
||||||
subprocess.run(["rm", "-rf", str(cache)], check=True)
|
|
||||||
cache.mkdir(parents=True, exist_ok=True)
|
# If no nodes found (unexpected if we are running bench), try just local
|
||||||
print(" Done.")
|
if not nodes:
|
||||||
time.sleep(2)
|
nuke_vllm_cache_on_node(local_ip, is_local=True)
|
||||||
except: pass
|
return
|
||||||
|
|
||||||
|
for node_ip in nodes:
|
||||||
|
# Check if node is local
|
||||||
|
# Simple string match, but IPs might vary (localhost vs 192.168...)
|
||||||
|
# We trust get_local_ip returns the IP used in the cluster (192.168.100.x)
|
||||||
|
is_local = (node_ip == local_ip) or (node_ip == "127.0.0.1")
|
||||||
|
nuke_vllm_cache_on_node(node_ip, is_local)
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
def get_dataset():
|
def get_dataset():
|
||||||
# Same as original
|
# Same as original
|
||||||
@@ -161,9 +202,9 @@ def get_model_args(model):
|
|||||||
|
|
||||||
if config.get("trust_remote"): cmd.append("--trust-remote-code")
|
if config.get("trust_remote"): cmd.append("--trust-remote-code")
|
||||||
|
|
||||||
# Respect config for Eager Mode (Apple-to-Apples with TP=1)
|
# ALWAYS Enforce Eager Mode for Cluster Benchmarks (TP=2)
|
||||||
if config.get("enforce_eager"):
|
# Distributed Graph Capture is unstable/prone to hangs on Strix Halo Cluster
|
||||||
cmd.append("--enforce-eager")
|
cmd.append("--enforce-eager")
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|||||||
Tagairt in Eagrán Nua
Cuir bac ar úsáideoir