diff --git a/Dockerfile b/Dockerfile index a6714d3..777abff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -129,6 +129,7 @@ COPY scripts/start_vllm.py /opt/start-vllm COPY scripts/start_vllm_cluster.py /opt/start-vllm-cluster COPY scripts/cluster_manager.py /opt/cluster_manager.py COPY scripts/models.py /opt/models.py + COPY benchmarks/max_context_results.json /opt/max_context_results.json COPY benchmarks/run_vllm_bench.py /opt/run_vllm_bench.py COPY benchmarks/vllm_cluster_bench.py /opt/vllm_cluster_bench.py diff --git a/benchmarks/run_vllm_bench.py b/benchmarks/run_vllm_bench.py index a2e2f7d..b1dfc32 100644 --- a/benchmarks/run_vllm_bench.py +++ b/benchmarks/run_vllm_bench.py @@ -181,7 +181,7 @@ def print_summary(tps): # ROCm try: - p2 = Path("benchmark_results_rocm_attn/benchmark_results") / f"{msafe}_tp{tp}_throughput.json" + p2 = Path("benchmark_results_rocm") / f"{msafe}_tp{tp}_throughput.json" d2 = json.loads(p2.read_text()) val2 = f"{d2.get('tokens_per_second', 0):.1f}" except: val2 = "N/A" @@ -210,7 +210,7 @@ if __name__ == "__main__": run_throughput(m, tp, "Default", RESULTS_DIR) # 2. ROCm Attention - run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm_attn/benchmark_results", { + run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm", { "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1", "VLLM_USE_TRITON_FLASH_ATTN": "0" }) diff --git a/benchmarks/vllm_cluster_bench.py b/benchmarks/vllm_cluster_bench.py index 755bb6c..be18a37 100755 --- a/benchmarks/vllm_cluster_bench.py +++ b/benchmarks/vllm_cluster_bench.py @@ -16,7 +16,7 @@ OFF_NUM_PROMPTS = 200 OFF_FORCED_OUTPUT = "512" DEFAULT_BATCH_TOKENS = "8192" -RESULTS_DIR = Path("cluster_benchmark_results") +RESULTS_DIR = Path("benchmark_results") RESULTS_DIR.mkdir(exist_ok=True) # Reuse the model table from the main benchmark script @@ -93,7 +93,8 @@ def get_local_ip(iface): return cluster_manager.get_local_ip(iface) def nuke_vllm_cache(): - cluster_manager.nuke_vllm_cache_cluster() + # We use explicit IPs because ray status might return Hex IDs which we can't SSH to. + cluster_manager.nuke_vllm_cache_cluster(nodes=[HEAD_IP, WORKER_IP]) def get_dataset(): @@ -223,7 +224,7 @@ def run_cluster_throughput(model): run_bench_set( model, "ROCm-Attn", - "benchmark_results_rocm_attn/benchmark_results", + "benchmark_results_rocm", extra_env={ "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1", "VLLM_USE_TRITON_FLASH_ATTN": "0" @@ -247,7 +248,7 @@ def print_summary(): # ROCm try: - p2 = Path("benchmark_results_rocm_attn/benchmark_results") / f"{msafe}_cluster_tp{CLUSTER_TP}_throughput.json" + p2 = Path("benchmark_results_rocm") / f"{msafe}_cluster_tp{CLUSTER_TP}_throughput.json" d2 = json.loads(p2.read_text()) val2 = f"{d2.get('tokens_per_second', 0):.1f}" except: val2 = "N/A" diff --git a/docs/parse_results.py b/docs/parse_results.py index 771c65f..ed5a138 100644 --- a/docs/parse_results.py +++ b/docs/parse_results.py @@ -10,7 +10,7 @@ from pathlib import Path SCRIPT_DIR = Path(__file__).parent.resolve() BENCHMARK_SOURCES = { "Triton": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results", - "ROCm": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results_rocm_attn" / "benchmark_results" + "ROCm": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results_rocm" } OUTPUT_FILE = SCRIPT_DIR / "results.json" diff --git a/scripts/cluster_manager.py b/scripts/cluster_manager.py index a8af722..db249c8 100644 --- a/scripts/cluster_manager.py +++ b/scripts/cluster_manager.py @@ -126,9 +126,12 @@ def get_ray_nodes(): in_active_section = False if in_active_section: - match = re.search(r"node_(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", line) + # Match "1 node_" + # We relax regex to accept hex IDs or IPs + match = re.search(r"node_([a-zA-Z0-9\.\-_]+)", line) if match: nodes.append(match.group(1)) + return nodes except: @@ -179,15 +182,19 @@ def nuke_vllm_cache_on_node(ip, is_local=False): except Exception as e: print(f" Failed ({e}).") -def nuke_vllm_cache_cluster(): - """Clears vLLM cache on ALL cluster nodes.""" - nodes = get_ray_nodes() - # Assuming we are running on Head, which is one of the nodes. - # We need to detect which IP is "local" - # Or just run 'ray stop' first? - # The requirement is often to clear cache BEFORE start or between runs. - # If ray is down, 'get_ray_nodes' returns empty. - # So this is best used when cluster is UP. +def nuke_vllm_cache_cluster(nodes=None): + """ + Clears vLLM cache on cluster nodes. + If 'nodes' (list of IPs) is provided, uses those. + Otherwise attempts to discover from ray status (which may fail if status shows Hex IDs and not IPs). + """ + if nodes is None: + nodes = get_ray_nodes() + + # Check if nodes look like IPs before trying SSH + # If we only have Hex IDs, we can't SSH unless we map them. + # For now, we filter for things that look like IPs if we are relying on discovery + # But if user passed explicit list, we assume they are IPs. rdma_iface = get_net_iface() local_ip = get_local_ip(rdma_iface) @@ -197,8 +204,22 @@ def nuke_vllm_cache_cluster(): nuke_vllm_cache_on_node(local_ip, is_local=True) return + import re + ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$") + for node_ip in nodes: + # If discovered node is NOT an IP (e.g. Hex ID), we warn and skip remote nuke + # unless it is '127.0.0.1' or we can determine it is local. + + is_ip = ip_pattern.match(node_ip) or node_ip == "localhost" + + if not is_ip: + # Maybe it's a Hex ID. We can't SSH to a Hex ID. + print(f"Skipping cache clear on '{node_ip}' (Not an IP address).") + continue + is_local = (node_ip == local_ip) or (node_ip == "127.0.0.1") nuke_vllm_cache_on_node(node_ip, is_local) time.sleep(2) +