feat: Update ROCm benchmark result paths, improve cluster node discovery and cache clearing, and refine cluster benchmark result directory.

This commit is contained in:
Donato Capitella
2026-02-02 07:35:50 +00:00
parent c587981d73
commit 6f118ff936
5 changed files with 40 additions and 17 deletions
+1
View File
@@ -129,6 +129,7 @@ COPY scripts/start_vllm.py /opt/start-vllm
COPY scripts/start_vllm_cluster.py /opt/start-vllm-cluster COPY scripts/start_vllm_cluster.py /opt/start-vllm-cluster
COPY scripts/cluster_manager.py /opt/cluster_manager.py COPY scripts/cluster_manager.py /opt/cluster_manager.py
COPY scripts/models.py /opt/models.py COPY scripts/models.py /opt/models.py
COPY benchmarks/max_context_results.json /opt/max_context_results.json COPY benchmarks/max_context_results.json /opt/max_context_results.json
COPY benchmarks/run_vllm_bench.py /opt/run_vllm_bench.py COPY benchmarks/run_vllm_bench.py /opt/run_vllm_bench.py
COPY benchmarks/vllm_cluster_bench.py /opt/vllm_cluster_bench.py COPY benchmarks/vllm_cluster_bench.py /opt/vllm_cluster_bench.py
+2 -2
View File
@@ -181,7 +181,7 @@ def print_summary(tps):
# ROCm # ROCm
try: try:
p2 = Path("benchmark_results_rocm_attn/benchmark_results") / f"{msafe}_tp{tp}_throughput.json" p2 = Path("benchmark_results_rocm") / f"{msafe}_tp{tp}_throughput.json"
d2 = json.loads(p2.read_text()) d2 = json.loads(p2.read_text())
val2 = f"{d2.get('tokens_per_second', 0):.1f}" val2 = f"{d2.get('tokens_per_second', 0):.1f}"
except: val2 = "N/A" except: val2 = "N/A"
@@ -210,7 +210,7 @@ if __name__ == "__main__":
run_throughput(m, tp, "Default", RESULTS_DIR) run_throughput(m, tp, "Default", RESULTS_DIR)
# 2. ROCm Attention # 2. ROCm Attention
run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm_attn/benchmark_results", { run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm", {
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1",
"VLLM_USE_TRITON_FLASH_ATTN": "0" "VLLM_USE_TRITON_FLASH_ATTN": "0"
}) })
+5 -4
View File
@@ -16,7 +16,7 @@ OFF_NUM_PROMPTS = 200
OFF_FORCED_OUTPUT = "512" OFF_FORCED_OUTPUT = "512"
DEFAULT_BATCH_TOKENS = "8192" DEFAULT_BATCH_TOKENS = "8192"
RESULTS_DIR = Path("cluster_benchmark_results") RESULTS_DIR = Path("benchmark_results")
RESULTS_DIR.mkdir(exist_ok=True) RESULTS_DIR.mkdir(exist_ok=True)
# Reuse the model table from the main benchmark script # Reuse the model table from the main benchmark script
@@ -93,7 +93,8 @@ def get_local_ip(iface):
return cluster_manager.get_local_ip(iface) return cluster_manager.get_local_ip(iface)
def nuke_vllm_cache(): def nuke_vllm_cache():
cluster_manager.nuke_vllm_cache_cluster() # We use explicit IPs because ray status might return Hex IDs which we can't SSH to.
cluster_manager.nuke_vllm_cache_cluster(nodes=[HEAD_IP, WORKER_IP])
def get_dataset(): def get_dataset():
@@ -223,7 +224,7 @@ def run_cluster_throughput(model):
run_bench_set( run_bench_set(
model, model,
"ROCm-Attn", "ROCm-Attn",
"benchmark_results_rocm_attn/benchmark_results", "benchmark_results_rocm",
extra_env={ extra_env={
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1",
"VLLM_USE_TRITON_FLASH_ATTN": "0" "VLLM_USE_TRITON_FLASH_ATTN": "0"
@@ -247,7 +248,7 @@ def print_summary():
# ROCm # ROCm
try: try:
p2 = Path("benchmark_results_rocm_attn/benchmark_results") / f"{msafe}_cluster_tp{CLUSTER_TP}_throughput.json" p2 = Path("benchmark_results_rocm") / f"{msafe}_cluster_tp{CLUSTER_TP}_throughput.json"
d2 = json.loads(p2.read_text()) d2 = json.loads(p2.read_text())
val2 = f"{d2.get('tokens_per_second', 0):.1f}" val2 = f"{d2.get('tokens_per_second', 0):.1f}"
except: val2 = "N/A" except: val2 = "N/A"
+1 -1
View File
@@ -10,7 +10,7 @@ from pathlib import Path
SCRIPT_DIR = Path(__file__).parent.resolve() SCRIPT_DIR = Path(__file__).parent.resolve()
BENCHMARK_SOURCES = { BENCHMARK_SOURCES = {
"Triton": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results", "Triton": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results",
"ROCm": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results_rocm_attn" / "benchmark_results" "ROCm": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results_rocm"
} }
OUTPUT_FILE = SCRIPT_DIR / "results.json" OUTPUT_FILE = SCRIPT_DIR / "results.json"
+31 -10
View File
@@ -126,9 +126,12 @@ def get_ray_nodes():
in_active_section = False in_active_section = False
if in_active_section: if in_active_section:
match = re.search(r"node_(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", line) # Match "1 node_<ID_OR_IP>"
# We relax regex to accept hex IDs or IPs
match = re.search(r"node_([a-zA-Z0-9\.\-_]+)", line)
if match: if match:
nodes.append(match.group(1)) nodes.append(match.group(1))
return nodes return nodes
except: except:
@@ -179,15 +182,19 @@ def nuke_vllm_cache_on_node(ip, is_local=False):
except Exception as e: except Exception as e:
print(f" Failed ({e}).") print(f" Failed ({e}).")
def nuke_vllm_cache_cluster(): def nuke_vllm_cache_cluster(nodes=None):
"""Clears vLLM cache on ALL cluster nodes.""" """
nodes = get_ray_nodes() Clears vLLM cache on cluster nodes.
# Assuming we are running on Head, which is one of the nodes. If 'nodes' (list of IPs) is provided, uses those.
# We need to detect which IP is "local" Otherwise attempts to discover from ray status (which may fail if status shows Hex IDs and not IPs).
# Or just run 'ray stop' first? """
# The requirement is often to clear cache BEFORE start or between runs. if nodes is None:
# If ray is down, 'get_ray_nodes' returns empty. nodes = get_ray_nodes()
# So this is best used when cluster is UP.
# Check if nodes look like IPs before trying SSH
# If we only have Hex IDs, we can't SSH unless we map them.
# For now, we filter for things that look like IPs if we are relying on discovery
# But if user passed explicit list, we assume they are IPs.
rdma_iface = get_net_iface() rdma_iface = get_net_iface()
local_ip = get_local_ip(rdma_iface) local_ip = get_local_ip(rdma_iface)
@@ -197,8 +204,22 @@ def nuke_vllm_cache_cluster():
nuke_vllm_cache_on_node(local_ip, is_local=True) nuke_vllm_cache_on_node(local_ip, is_local=True)
return return
import re
ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
for node_ip in nodes: for node_ip in nodes:
# If discovered node is NOT an IP (e.g. Hex ID), we warn and skip remote nuke
# unless it is '127.0.0.1' or we can determine it is local.
is_ip = ip_pattern.match(node_ip) or node_ip == "localhost"
if not is_ip:
# Maybe it's a Hex ID. We can't SSH to a Hex ID.
print(f"Skipping cache clear on '{node_ip}' (Not an IP address).")
continue
is_local = (node_ip == local_ip) or (node_ip == "127.0.0.1") is_local = (node_ip == local_ip) or (node_ip == "127.0.0.1")
nuke_vllm_cache_on_node(node_ip, is_local) nuke_vllm_cache_on_node(node_ip, is_local)
time.sleep(2) time.sleep(2)