acb0d614a5
* Add functional tests for CSV Tuner Plugin
* Updated directory structure
* Updated and renamed directories
* Updated csv conf files
* Added tests for ext-profiler
* Updated readme
* Updated readme
[ROCm/rccl commit: c7f400dbff]
224 строки
8.1 KiB
Python
224 строки
8.1 KiB
Python
# *************************************************************************
|
|
# * Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
|
# *
|
|
# * See LICENSE.txt for license information
|
|
# ************************************************************************
|
|
|
|
import os
|
|
import pytest
|
|
import subprocess
|
|
import re
|
|
from types import SimpleNamespace
|
|
import json
|
|
import glob
|
|
|
|
WORKDIR = os.getcwd()
|
|
|
|
RCCL_INSTALL_DIR = "path/to/rccl"
|
|
OMPI_INSTALL_DIR = "path/to/ompi/install"
|
|
RCCL_TESTS_DIR = "path/to/rccl-tests"
|
|
|
|
# Ext-Tuner Paths
|
|
PLUGIN_DIR = f"{RCCL_INSTALL_DIR}/ext-tuner/example"
|
|
PLUGIN_SO = f"{PLUGIN_DIR}/libnccl-tuner-example.so"
|
|
|
|
# Ext-Profiler Paths
|
|
PROFILER_DIR = f"{RCCL_INSTALL_DIR}/ext-profiler/example"
|
|
PROFILER_SO = f"{PROFILER_DIR}/librccl-profiler.so"
|
|
|
|
# CSV Configs
|
|
VALID_CONFIG_WITH_WILDCARDS = os.path.join(WORKDIR, "assets/csv_confs/valid_config_with_wildcards.conf")
|
|
VALID_CONFIG_WITHOUT_WILDCARDS = os.path.join(WORKDIR, "assets/csv_confs/valid_config_without_wildcards.conf")
|
|
NO_MATCHING_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/no_matching_config.conf")
|
|
INCORRECT_VALUES_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/incorrect_values_config.conf")
|
|
UNSUPPORTED_ALGO_PROTO_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/unsupported_algo_proto_config.conf")
|
|
SINGLENODE_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/singlenode_config.conf")
|
|
MULTINODE_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/multinode_config.conf")
|
|
|
|
LOGDIR = os.path.join(WORKDIR, "logs")
|
|
os.makedirs(LOGDIR, exist_ok=True)
|
|
|
|
PROFILER_DUMP_DIR = os.path.join(WORKDIR, "profiler_dumps")
|
|
os.makedirs(PROFILER_DUMP_DIR, exist_ok=True)
|
|
|
|
# Helper Functions
|
|
def get_avg_bus_bandwidth(log_content: str):
|
|
"""Extract average bus bandwidth from RCCL test log"""
|
|
pattern = r'#\s*Avg bus bandwidth\s*:\s*([\d.]+)'
|
|
match = re.search(pattern, log_content, re.IGNORECASE)
|
|
return float(match.group(1)) if match else None
|
|
|
|
def check_node_interface(node: str, interface: str) -> bool:
|
|
"""Check if a node has the specified interface with an IP address"""
|
|
try:
|
|
cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no",
|
|
node, f"ip addr show {interface} | grep 'inet ' | wc -l"]
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
|
return result.returncode == 0 and int(result.stdout.strip()) > 0
|
|
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError):
|
|
return False
|
|
|
|
def find_common_interface(nodelist):
|
|
"""Find a common network interface across all nodes"""
|
|
interfaces_to_check = ["eth0", "eth1"]
|
|
|
|
for interface in interfaces_to_check:
|
|
all_nodes_have_interface = True
|
|
for node in nodelist:
|
|
if not check_node_interface(node, interface):
|
|
all_nodes_have_interface = False
|
|
break
|
|
if all_nodes_have_interface:
|
|
return interface
|
|
return None
|
|
|
|
def get_available_nodes():
|
|
"""Get available nodes from SLURM environment"""
|
|
try:
|
|
# Get available nodes
|
|
result = subprocess.run(
|
|
["scontrol", "show", "hostnames"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True
|
|
)
|
|
nodelist = result.stdout.strip().split('\n')
|
|
nodelist = [node.strip() for node in nodelist if node.strip()]
|
|
|
|
return nodelist
|
|
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
return []
|
|
|
|
def validate_json_trace(trace_file):
|
|
"""Validate that a trace file is valid JSON and follows chrome trace format"""
|
|
|
|
if not os.path.exists(trace_file):
|
|
return False, "File does not exist"
|
|
|
|
try:
|
|
with open(trace_file, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Basic validation for chrome trace format
|
|
if not isinstance(data, list):
|
|
return False, "Trace must be a JSON array"
|
|
|
|
# Check for at least one valid event
|
|
if len(data) > 1: # More than just the closing empty object
|
|
# Validate some events have required fields
|
|
valid_events = 0
|
|
for event in data:
|
|
if isinstance(event, dict) and 'name' in event and 'ph' in event:
|
|
valid_events += 1
|
|
|
|
if valid_events == 0:
|
|
return False, "No valid trace events found"
|
|
|
|
return True, f"Valid JSON trace with {len(data)} entries"
|
|
except json.JSONDecodeError as e:
|
|
return False, f"Invalid JSON: {e}"
|
|
except Exception as e:
|
|
return False, f"Error reading file: {e}"
|
|
|
|
def check_event_in_log(log_file, event_string):
|
|
"""Check if a specific event string appears in the log file"""
|
|
if not os.path.exists(log_file):
|
|
return False
|
|
|
|
with open(log_file, 'r') as f:
|
|
content = f.read()
|
|
return event_string in content
|
|
|
|
def count_events_in_trace(trace_file, event_name=None, category=None):
|
|
"""Count events in a trace file, optionally filtered by name or category"""
|
|
|
|
if not os.path.exists(trace_file):
|
|
return 0
|
|
|
|
try:
|
|
with open(trace_file, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
count = 0
|
|
for event in data:
|
|
if not isinstance(event, dict):
|
|
continue
|
|
|
|
match = True
|
|
if event_name and event.get('name') != event_name:
|
|
match = False
|
|
if category and event.get('cat') != category:
|
|
match = False
|
|
|
|
if match and 'name' in event: # Valid event
|
|
count += 1
|
|
|
|
return count
|
|
except:
|
|
return 0
|
|
|
|
# Pytest Fixture
|
|
@pytest.fixture(scope="session")
|
|
def paths():
|
|
return SimpleNamespace(
|
|
# Paths
|
|
WORKDIR=WORKDIR,
|
|
RCCL_INSTALL_DIR=RCCL_INSTALL_DIR,
|
|
OMPI_INSTALL_DIR=OMPI_INSTALL_DIR,
|
|
PLUGIN_DIR=PLUGIN_DIR,
|
|
PLUGIN_SO=PLUGIN_SO,
|
|
RCCL_TESTS_DIR=RCCL_TESTS_DIR,
|
|
PROFILER_DIR=PROFILER_DIR,
|
|
PROFILER_SO=PROFILER_SO,
|
|
# CSV Configs
|
|
VALID_CONFIG_WITH_WILDCARDS=VALID_CONFIG_WITH_WILDCARDS,
|
|
VALID_CONFIG_WITHOUT_WILDCARDS=VALID_CONFIG_WITHOUT_WILDCARDS,
|
|
NO_MATCHING_CONFIG=NO_MATCHING_CONFIG,
|
|
INCORRECT_VALUES_CONFIG=INCORRECT_VALUES_CONFIG,
|
|
UNSUPPORTED_ALGO_PROTO_CONFIG=UNSUPPORTED_ALGO_PROTO_CONFIG,
|
|
SINGLENODE_CONFIG=SINGLENODE_CONFIG,
|
|
MULTINODE_CONFIG=MULTINODE_CONFIG,
|
|
LOGDIR=LOGDIR,
|
|
PROFILER_DUMP_DIR=PROFILER_DUMP_DIR,
|
|
# Helper Functions for Ext-Tuner
|
|
get_avg_bus_bandwidth=get_avg_bus_bandwidth,
|
|
check_node_interface=check_node_interface,
|
|
find_common_interface=find_common_interface,
|
|
get_available_nodes=get_available_nodes,
|
|
# Helper Functions for Ext-Profiler
|
|
validate_json_trace=validate_json_trace,
|
|
check_event_in_log=check_event_in_log,
|
|
count_events_in_trace=count_events_in_trace,
|
|
)
|
|
|
|
def pytest_runtest_setup(item):
|
|
"""Check plugin availability before running each test"""
|
|
# Check for ext_tuner marker
|
|
if item.get_closest_marker("ext_tuner"):
|
|
if not os.path.exists(PLUGIN_SO):
|
|
pytest.skip(f"Tuner plugin library not found at: {PLUGIN_SO}")
|
|
|
|
# Check for ext_profiler marker
|
|
if item.get_closest_marker("ext_profiler"):
|
|
if not os.path.exists(PROFILER_SO):
|
|
pytest.skip(f"Profiler plugin library not found at: {PROFILER_SO}")
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def clear_profiler_dump(request):
|
|
"""Automatically clear profiler dump folder once before ext_profiler tests"""
|
|
# Check if any test in the session has ext_profiler marker
|
|
has_profiler_tests = any(
|
|
item.get_closest_marker("ext_profiler")
|
|
for item in request.session.items
|
|
)
|
|
|
|
if has_profiler_tests:
|
|
# Clear all JSON files in the profiler dump directory (including subdirectories)
|
|
pattern = os.path.join(PROFILER_DUMP_DIR, "**", "*.json")
|
|
for trace_file in glob.glob(pattern, recursive=True):
|
|
try:
|
|
os.remove(trace_file)
|
|
except OSError:
|
|
pass # Ignore errors if file doesn't exist or can't be removed
|