Files
Kapil S. Pawar acb0d614a5 Functional Tests for Ext-Profiler Plugin (#2007)
* Add functional tests for CSV Tuner Plugin

* Updated directory structure

* Updated and renamed directories

* Updated csv conf files

* Added tests for ext-profiler

* Updated readme

* Updated readme

[ROCm/rccl commit: c7f400dbff]
2025-11-18 11:20:39 -06:00

224 строки
8.1 KiB
Python

# *************************************************************************
# * Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
# *
# * See LICENSE.txt for license information
# ************************************************************************
import os
import pytest
import subprocess
import re
from types import SimpleNamespace
import json
import glob
WORKDIR = os.getcwd()
RCCL_INSTALL_DIR = "path/to/rccl"
OMPI_INSTALL_DIR = "path/to/ompi/install"
RCCL_TESTS_DIR = "path/to/rccl-tests"
# Ext-Tuner Paths
PLUGIN_DIR = f"{RCCL_INSTALL_DIR}/ext-tuner/example"
PLUGIN_SO = f"{PLUGIN_DIR}/libnccl-tuner-example.so"
# Ext-Profiler Paths
PROFILER_DIR = f"{RCCL_INSTALL_DIR}/ext-profiler/example"
PROFILER_SO = f"{PROFILER_DIR}/librccl-profiler.so"
# CSV Configs
VALID_CONFIG_WITH_WILDCARDS = os.path.join(WORKDIR, "assets/csv_confs/valid_config_with_wildcards.conf")
VALID_CONFIG_WITHOUT_WILDCARDS = os.path.join(WORKDIR, "assets/csv_confs/valid_config_without_wildcards.conf")
NO_MATCHING_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/no_matching_config.conf")
INCORRECT_VALUES_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/incorrect_values_config.conf")
UNSUPPORTED_ALGO_PROTO_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/unsupported_algo_proto_config.conf")
SINGLENODE_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/singlenode_config.conf")
MULTINODE_CONFIG = os.path.join(WORKDIR, "assets/csv_confs/multinode_config.conf")
LOGDIR = os.path.join(WORKDIR, "logs")
os.makedirs(LOGDIR, exist_ok=True)
PROFILER_DUMP_DIR = os.path.join(WORKDIR, "profiler_dumps")
os.makedirs(PROFILER_DUMP_DIR, exist_ok=True)
# Helper Functions
def get_avg_bus_bandwidth(log_content: str):
"""Extract average bus bandwidth from RCCL test log"""
pattern = r'#\s*Avg bus bandwidth\s*:\s*([\d.]+)'
match = re.search(pattern, log_content, re.IGNORECASE)
return float(match.group(1)) if match else None
def check_node_interface(node: str, interface: str) -> bool:
"""Check if a node has the specified interface with an IP address"""
try:
cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no",
node, f"ip addr show {interface} | grep 'inet ' | wc -l"]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
return result.returncode == 0 and int(result.stdout.strip()) > 0
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError):
return False
def find_common_interface(nodelist):
"""Find a common network interface across all nodes"""
interfaces_to_check = ["eth0", "eth1"]
for interface in interfaces_to_check:
all_nodes_have_interface = True
for node in nodelist:
if not check_node_interface(node, interface):
all_nodes_have_interface = False
break
if all_nodes_have_interface:
return interface
return None
def get_available_nodes():
"""Get available nodes from SLURM environment"""
try:
# Get available nodes
result = subprocess.run(
["scontrol", "show", "hostnames"],
capture_output=True,
text=True,
check=True
)
nodelist = result.stdout.strip().split('\n')
nodelist = [node.strip() for node in nodelist if node.strip()]
return nodelist
except (subprocess.CalledProcessError, FileNotFoundError):
return []
def validate_json_trace(trace_file):
"""Validate that a trace file is valid JSON and follows chrome trace format"""
if not os.path.exists(trace_file):
return False, "File does not exist"
try:
with open(trace_file, 'r') as f:
data = json.load(f)
# Basic validation for chrome trace format
if not isinstance(data, list):
return False, "Trace must be a JSON array"
# Check for at least one valid event
if len(data) > 1: # More than just the closing empty object
# Validate some events have required fields
valid_events = 0
for event in data:
if isinstance(event, dict) and 'name' in event and 'ph' in event:
valid_events += 1
if valid_events == 0:
return False, "No valid trace events found"
return True, f"Valid JSON trace with {len(data)} entries"
except json.JSONDecodeError as e:
return False, f"Invalid JSON: {e}"
except Exception as e:
return False, f"Error reading file: {e}"
def check_event_in_log(log_file, event_string):
"""Check if a specific event string appears in the log file"""
if not os.path.exists(log_file):
return False
with open(log_file, 'r') as f:
content = f.read()
return event_string in content
def count_events_in_trace(trace_file, event_name=None, category=None):
"""Count events in a trace file, optionally filtered by name or category"""
if not os.path.exists(trace_file):
return 0
try:
with open(trace_file, 'r') as f:
data = json.load(f)
count = 0
for event in data:
if not isinstance(event, dict):
continue
match = True
if event_name and event.get('name') != event_name:
match = False
if category and event.get('cat') != category:
match = False
if match and 'name' in event: # Valid event
count += 1
return count
except:
return 0
# Pytest Fixture
@pytest.fixture(scope="session")
def paths():
return SimpleNamespace(
# Paths
WORKDIR=WORKDIR,
RCCL_INSTALL_DIR=RCCL_INSTALL_DIR,
OMPI_INSTALL_DIR=OMPI_INSTALL_DIR,
PLUGIN_DIR=PLUGIN_DIR,
PLUGIN_SO=PLUGIN_SO,
RCCL_TESTS_DIR=RCCL_TESTS_DIR,
PROFILER_DIR=PROFILER_DIR,
PROFILER_SO=PROFILER_SO,
# CSV Configs
VALID_CONFIG_WITH_WILDCARDS=VALID_CONFIG_WITH_WILDCARDS,
VALID_CONFIG_WITHOUT_WILDCARDS=VALID_CONFIG_WITHOUT_WILDCARDS,
NO_MATCHING_CONFIG=NO_MATCHING_CONFIG,
INCORRECT_VALUES_CONFIG=INCORRECT_VALUES_CONFIG,
UNSUPPORTED_ALGO_PROTO_CONFIG=UNSUPPORTED_ALGO_PROTO_CONFIG,
SINGLENODE_CONFIG=SINGLENODE_CONFIG,
MULTINODE_CONFIG=MULTINODE_CONFIG,
LOGDIR=LOGDIR,
PROFILER_DUMP_DIR=PROFILER_DUMP_DIR,
# Helper Functions for Ext-Tuner
get_avg_bus_bandwidth=get_avg_bus_bandwidth,
check_node_interface=check_node_interface,
find_common_interface=find_common_interface,
get_available_nodes=get_available_nodes,
# Helper Functions for Ext-Profiler
validate_json_trace=validate_json_trace,
check_event_in_log=check_event_in_log,
count_events_in_trace=count_events_in_trace,
)
def pytest_runtest_setup(item):
"""Check plugin availability before running each test"""
# Check for ext_tuner marker
if item.get_closest_marker("ext_tuner"):
if not os.path.exists(PLUGIN_SO):
pytest.skip(f"Tuner plugin library not found at: {PLUGIN_SO}")
# Check for ext_profiler marker
if item.get_closest_marker("ext_profiler"):
if not os.path.exists(PROFILER_SO):
pytest.skip(f"Profiler plugin library not found at: {PROFILER_SO}")
@pytest.fixture(scope="session", autouse=True)
def clear_profiler_dump(request):
"""Automatically clear profiler dump folder once before ext_profiler tests"""
# Check if any test in the session has ext_profiler marker
has_profiler_tests = any(
item.get_closest_marker("ext_profiler")
for item in request.session.items
)
if has_profiler_tests:
# Clear all JSON files in the profiler dump directory (including subdirectories)
pattern = os.path.join(PROFILER_DUMP_DIR, "**", "*.json")
for trace_file in glob.glob(pattern, recursive=True):
try:
os.remove(trace_file)
except OSError:
pass # Ignore errors if file doesn't exist or can't be removed