feat: Update ROCm benchmark result paths, improve cluster node discovery and cache clearing, and refine cluster benchmark result directory.
This commit is contained in:
@@ -129,6 +129,7 @@ COPY scripts/start_vllm.py /opt/start-vllm
|
||||
COPY scripts/start_vllm_cluster.py /opt/start-vllm-cluster
|
||||
COPY scripts/cluster_manager.py /opt/cluster_manager.py
|
||||
COPY scripts/models.py /opt/models.py
|
||||
|
||||
COPY benchmarks/max_context_results.json /opt/max_context_results.json
|
||||
COPY benchmarks/run_vllm_bench.py /opt/run_vllm_bench.py
|
||||
COPY benchmarks/vllm_cluster_bench.py /opt/vllm_cluster_bench.py
|
||||
|
||||
@@ -181,7 +181,7 @@ def print_summary(tps):
|
||||
|
||||
# ROCm
|
||||
try:
|
||||
p2 = Path("benchmark_results_rocm_attn/benchmark_results") / f"{msafe}_tp{tp}_throughput.json"
|
||||
p2 = Path("benchmark_results_rocm") / f"{msafe}_tp{tp}_throughput.json"
|
||||
d2 = json.loads(p2.read_text())
|
||||
val2 = f"{d2.get('tokens_per_second', 0):.1f}"
|
||||
except: val2 = "N/A"
|
||||
@@ -210,7 +210,7 @@ if __name__ == "__main__":
|
||||
run_throughput(m, tp, "Default", RESULTS_DIR)
|
||||
|
||||
# 2. ROCm Attention
|
||||
run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm_attn/benchmark_results", {
|
||||
run_throughput(m, tp, "ROCm-Attn", "benchmark_results_rocm", {
|
||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1",
|
||||
"VLLM_USE_TRITON_FLASH_ATTN": "0"
|
||||
})
|
||||
|
||||
@@ -16,7 +16,7 @@ OFF_NUM_PROMPTS = 200
|
||||
OFF_FORCED_OUTPUT = "512"
|
||||
DEFAULT_BATCH_TOKENS = "8192"
|
||||
|
||||
RESULTS_DIR = Path("cluster_benchmark_results")
|
||||
RESULTS_DIR = Path("benchmark_results")
|
||||
RESULTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Reuse the model table from the main benchmark script
|
||||
@@ -93,7 +93,8 @@ def get_local_ip(iface):
|
||||
return cluster_manager.get_local_ip(iface)
|
||||
|
||||
def nuke_vllm_cache():
|
||||
cluster_manager.nuke_vllm_cache_cluster()
|
||||
# We use explicit IPs because ray status might return Hex IDs which we can't SSH to.
|
||||
cluster_manager.nuke_vllm_cache_cluster(nodes=[HEAD_IP, WORKER_IP])
|
||||
|
||||
|
||||
def get_dataset():
|
||||
@@ -223,7 +224,7 @@ def run_cluster_throughput(model):
|
||||
run_bench_set(
|
||||
model,
|
||||
"ROCm-Attn",
|
||||
"benchmark_results_rocm_attn/benchmark_results",
|
||||
"benchmark_results_rocm",
|
||||
extra_env={
|
||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1",
|
||||
"VLLM_USE_TRITON_FLASH_ATTN": "0"
|
||||
@@ -247,7 +248,7 @@ def print_summary():
|
||||
|
||||
# ROCm
|
||||
try:
|
||||
p2 = Path("benchmark_results_rocm_attn/benchmark_results") / f"{msafe}_cluster_tp{CLUSTER_TP}_throughput.json"
|
||||
p2 = Path("benchmark_results_rocm") / f"{msafe}_cluster_tp{CLUSTER_TP}_throughput.json"
|
||||
d2 = json.loads(p2.read_text())
|
||||
val2 = f"{d2.get('tokens_per_second', 0):.1f}"
|
||||
except: val2 = "N/A"
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
SCRIPT_DIR = Path(__file__).parent.resolve()
|
||||
BENCHMARK_SOURCES = {
|
||||
"Triton": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results",
|
||||
"ROCm": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results_rocm_attn" / "benchmark_results"
|
||||
"ROCm": SCRIPT_DIR.parent / "benchmarks" / "benchmark_results_rocm"
|
||||
}
|
||||
OUTPUT_FILE = SCRIPT_DIR / "results.json"
|
||||
|
||||
|
||||
+31
-10
@@ -126,9 +126,12 @@ def get_ray_nodes():
|
||||
in_active_section = False
|
||||
|
||||
if in_active_section:
|
||||
match = re.search(r"node_(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", line)
|
||||
# Match "1 node_<ID_OR_IP>"
|
||||
# We relax regex to accept hex IDs or IPs
|
||||
match = re.search(r"node_([a-zA-Z0-9\.\-_]+)", line)
|
||||
if match:
|
||||
nodes.append(match.group(1))
|
||||
|
||||
|
||||
return nodes
|
||||
except:
|
||||
@@ -179,15 +182,19 @@ def nuke_vllm_cache_on_node(ip, is_local=False):
|
||||
except Exception as e:
|
||||
print(f" Failed ({e}).")
|
||||
|
||||
def nuke_vllm_cache_cluster():
|
||||
"""Clears vLLM cache on ALL cluster nodes."""
|
||||
nodes = get_ray_nodes()
|
||||
# Assuming we are running on Head, which is one of the nodes.
|
||||
# We need to detect which IP is "local"
|
||||
# Or just run 'ray stop' first?
|
||||
# The requirement is often to clear cache BEFORE start or between runs.
|
||||
# If ray is down, 'get_ray_nodes' returns empty.
|
||||
# So this is best used when cluster is UP.
|
||||
def nuke_vllm_cache_cluster(nodes=None):
|
||||
"""
|
||||
Clears vLLM cache on cluster nodes.
|
||||
If 'nodes' (list of IPs) is provided, uses those.
|
||||
Otherwise attempts to discover from ray status (which may fail if status shows Hex IDs and not IPs).
|
||||
"""
|
||||
if nodes is None:
|
||||
nodes = get_ray_nodes()
|
||||
|
||||
# Check if nodes look like IPs before trying SSH
|
||||
# If we only have Hex IDs, we can't SSH unless we map them.
|
||||
# For now, we filter for things that look like IPs if we are relying on discovery
|
||||
# But if user passed explicit list, we assume they are IPs.
|
||||
|
||||
rdma_iface = get_net_iface()
|
||||
local_ip = get_local_ip(rdma_iface)
|
||||
@@ -197,8 +204,22 @@ def nuke_vllm_cache_cluster():
|
||||
nuke_vllm_cache_on_node(local_ip, is_local=True)
|
||||
return
|
||||
|
||||
import re
|
||||
ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
|
||||
|
||||
for node_ip in nodes:
|
||||
# If discovered node is NOT an IP (e.g. Hex ID), we warn and skip remote nuke
|
||||
# unless it is '127.0.0.1' or we can determine it is local.
|
||||
|
||||
is_ip = ip_pattern.match(node_ip) or node_ip == "localhost"
|
||||
|
||||
if not is_ip:
|
||||
# Maybe it's a Hex ID. We can't SSH to a Hex ID.
|
||||
print(f"Skipping cache clear on '{node_ip}' (Not an IP address).")
|
||||
continue
|
||||
|
||||
is_local = (node_ip == local_ip) or (node_ip == "127.0.0.1")
|
||||
nuke_vllm_cache_on_node(node_ip, is_local)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user