Enable MPI support to execute MPI specific unit/functional tests (#1996)
* Added MPI support to execute unit/functional tests
Update node and process validation
Updated node detection count and modified validation method
Update validation logic to include max procs and nodes
* Address review comments
* Fix warnings
* Added a new NET transport test and clean up
* Added MPI test logging mechanism
* Decoupled GTest framework
* Added Net IB functional tests
* Updated with resource guards
* Added NET IB tests and refactored code
* Update P2pWorkflow test
* Update documentation
* Add MPI_TESTS_ENABLED guard to the file
* Fix Shm and NetIB tests
* Applied refactoring and cleanup
* Replaced BufferGuard with AutoGuard
* Modified test debug logging
* Use macro to reduce NcclTypeTraits code duplication
- Replace repetitive template specializations with a single
DEFINE_NCCL_TYPE_TRAIT macro
- Use stringification operator (#) to auto-generate type name strings
- Add #undef to keep macro from polluting namespace
- Makes adding new type mappings trivial
* Unify buffer initialization with generic pattern function
- Remove initializeBufferWithCustomPattern
- Make initializeBufferWithPattern generic with PatternFunc template param
- Now single function handles all patterns via lambda injection
- Updated all test files to use lambdas for pattern generation
- Pattern logic now visible at call site (self-documenting)
* Unify buffer verification with pluggable pattern function
- Remove verifyBufferWithCustomCheck
- Make verifyBufferData generic with PatternFunc template param
- Single function handles all verification patterns via lambda injection
- Updated all test files to use lambdas
- Better defaults: num_samples=0 means verify all elements
- Pattern logic now visible at call site (self-documenting)
* Docs: Add DeviceBufferHelpers section to MPITestRunner.md
- Document new refactored buffer initialization/verification API
- Explain pluggable pattern functions with lambda examples
- Show type mapping and automatic float/int comparison
- Include migration guide from old API to new unified functions
- Demonstrate best practices with real-world examples
- Reference recent refactoring commits (macro-based type traits)
* Docs: Update documentation and examples
- Update on DeviceBufferHelpers
- Update examples using DeviceBufferHelpers methods, e.g. data verification
* Address review comment.
- Replace manual pattern generation loop with initializeBufferWithPattern call
- Use downloadBuffer to get host copy instead of manual hipMemcpy
* Remove non-existent dependency
* Remove duplicate testcase
* Code cleanup in test files
* Moved common constants to base class
[ROCm/rccl commit: 29e1567b95]
This commit is contained in:
@@ -6,8 +6,12 @@ cmake_minimum_required(VERSION 3.16)
|
||||
if(BUILD_TESTS)
|
||||
|
||||
option(OPENMP_TESTS_ENABLED "Enable OpenMP for unit tests" OFF)
|
||||
option(ENABLE_MPI_TESTS "Enable MPI-based tests" OFF)
|
||||
|
||||
message("Building rccl unit tests (Installed in /test/rccl-UnitTests)")
|
||||
if(ENABLE_MPI_TESTS)
|
||||
message("MPI-based tests are enabled")
|
||||
endif()
|
||||
|
||||
if (ENABLE_CODE_COVERAGE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-instr-generate -fcoverage-mapping")
|
||||
@@ -34,6 +38,48 @@ if(BUILD_TESTS)
|
||||
find_package(OpenMP REQUIRED)
|
||||
endif()
|
||||
|
||||
# MPI configuration
|
||||
if(ENABLE_MPI_TESTS)
|
||||
# Set default MPI path, allow user to override
|
||||
if(NOT DEFINED MPI_PATH)
|
||||
set(MPI_PATH "/opt/ompi" CACHE PATH "Path to MPI installation")
|
||||
endif()
|
||||
|
||||
# Verify MPI path exists
|
||||
if(NOT EXISTS ${MPI_PATH})
|
||||
message(WARNING "MPI_PATH does not exist: ${MPI_PATH}")
|
||||
message(WARNING "Please set MPI_PATH to your MPI installation directory")
|
||||
message(FATAL_ERROR "MPI installation not found")
|
||||
endif()
|
||||
|
||||
message(STATUS "Using MPI installation at: ${MPI_PATH}")
|
||||
|
||||
# Find required MPI library
|
||||
find_library(MPI_LIBRARY
|
||||
NAMES mpi
|
||||
PATHS ${MPI_PATH}/lib ${MPI_PATH}/lib64
|
||||
NO_DEFAULT_PATH
|
||||
REQUIRED
|
||||
)
|
||||
|
||||
if(NOT MPI_LIBRARY)
|
||||
message(FATAL_ERROR "Could not find MPI library (libmpi.so) in ${MPI_PATH}/lib or ${MPI_PATH}/lib64")
|
||||
endif()
|
||||
|
||||
# Set up MPI variables
|
||||
set(MPI_CXX_LIBRARIES ${MPI_LIBRARY})
|
||||
set(MPI_CXX_INCLUDE_DIRS ${MPI_PATH}/include)
|
||||
set(MPI_CXX_LINK_FLAGS "-L${MPI_PATH}/lib -Wl,-rpath,${MPI_PATH}/lib")
|
||||
set(MPIEXEC_EXECUTABLE ${MPI_PATH}/bin/mpirun CACHE FILEPATH "MPI executable")
|
||||
|
||||
# Add link directories for MPI
|
||||
link_directories(${MPI_PATH}/lib)
|
||||
|
||||
message(STATUS "MPI library: ${MPI_CXX_LIBRARIES}")
|
||||
message(STATUS "MPI include: ${MPI_CXX_INCLUDE_DIRS}")
|
||||
message(STATUS "MPI executable: ${MPIEXEC_EXECUTABLE}")
|
||||
endif()
|
||||
|
||||
include_directories(${GTEST_INCLUDE_DIRS} ./common)
|
||||
|
||||
# Common include directories
|
||||
@@ -48,6 +94,11 @@ if(BUILD_TESTS)
|
||||
${ROCM_PATH}
|
||||
)
|
||||
|
||||
# Add MPI include directories if MPI tests are enabled
|
||||
if(ENABLE_MPI_TESTS AND MPI_CXX_INCLUDE_DIRS)
|
||||
list(APPEND RCCL_COMMON_INCLUDE_DIRS ${MPI_CXX_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
# Common compile definitions
|
||||
set(RCCL_COMMON_COMPILE_DEFS ROCM_PATH="${ROCM_PATH}")
|
||||
if(LL128_ENABLED)
|
||||
@@ -56,6 +107,9 @@ if(BUILD_TESTS)
|
||||
if(OPENMP_TESTS_ENABLED)
|
||||
list(APPEND RCCL_COMMON_COMPILE_DEFS ENABLE_OPENMP)
|
||||
endif()
|
||||
if(ENABLE_MPI_TESTS)
|
||||
list(APPEND RCCL_COMMON_COMPILE_DEFS MPI_TESTS_ENABLED)
|
||||
endif()
|
||||
list(APPEND RCCL_COMMON_COMPILE_DEFS __HIP_PLATFORM_AMD__)
|
||||
|
||||
# Common link libraries
|
||||
@@ -69,6 +123,9 @@ if(BUILD_TESTS)
|
||||
if(OPENMP_TESTS_ENABLED)
|
||||
list(APPEND RCCL_COMMON_LINK_LIBS "${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
if(ENABLE_MPI_TESTS AND MPI_CXX_LIBRARIES)
|
||||
list(APPEND RCCL_COMMON_LINK_LIBS ${MPI_CXX_LIBRARIES})
|
||||
endif()
|
||||
|
||||
# Get the compile definitions from the main rccl target
|
||||
# These helps to keep the test compile definitions in sync with the main rccl target
|
||||
@@ -129,7 +186,7 @@ if(BUILD_TESTS)
|
||||
|
||||
# Create rccl-UnitTests binary
|
||||
add_executable(rccl-UnitTests ${TEST_SOURCE_FILES})
|
||||
|
||||
|
||||
# Create rccl-UnitTestsFixtures binary if ROCm version is 4.6.0 or greater
|
||||
# and build type is Debug
|
||||
if (ROCM_VERSION VERSION_GREATER_EQUAL "60400" AND CMAKE_BUILD_TYPE MATCHES "Debug")
|
||||
@@ -154,12 +211,46 @@ if(BUILD_TESTS)
|
||||
)
|
||||
|
||||
add_executable(rccl-UnitTestsFixtures ${TEST_FIXTURE_SOURCE_FILES})
|
||||
|
||||
# Create separate MPI test binary if MPI tests are enabled
|
||||
if(ENABLE_MPI_TESTS)
|
||||
# Define MPI test source files
|
||||
set(MPI_TEST_SOURCE_FILES
|
||||
common/main_mpi.cpp
|
||||
common/MPIHelpers.cpp
|
||||
common/MPITestCore.cpp
|
||||
common/MPIEnvironment.cpp
|
||||
common/TestChecks.cpp
|
||||
transport/TransportMPIBase.cpp
|
||||
transport/P2pMPITests.cpp
|
||||
transport/NetMPITests.cpp
|
||||
transport/ShmMPITests.cpp
|
||||
transport/NetIbMPITests.cpp
|
||||
)
|
||||
|
||||
# Create the MPI test executable
|
||||
add_executable(rccl-UnitTestsMPI ${MPI_TEST_SOURCE_FILES})
|
||||
|
||||
# Add to test executables list for proper linking
|
||||
list(APPEND RCCL_TEST_EXECUTABLES rccl-UnitTestsMPI)
|
||||
|
||||
endif()
|
||||
endif()
|
||||
|
||||
foreach(test_executable IN LISTS RCCL_TEST_EXECUTABLES)
|
||||
target_include_directories(${test_executable} PRIVATE ${RCCL_COMMON_INCLUDE_DIRS})
|
||||
target_compile_definitions(${test_executable} PRIVATE ${RCCL_COMMON_COMPILE_DEFS})
|
||||
target_link_libraries(${test_executable} PRIVATE ${RCCL_COMMON_LINK_LIBS})
|
||||
|
||||
# Add MPI-specific configuration if MPI tests are enabled
|
||||
if(ENABLE_MPI_TESTS)
|
||||
if(MPI_CXX_COMPILE_FLAGS)
|
||||
target_compile_options(${test_executable} PRIVATE ${MPI_CXX_COMPILE_FLAGS})
|
||||
endif()
|
||||
if(MPI_CXX_LINK_FLAGS)
|
||||
set_target_properties(${test_executable} PROPERTIES LINK_FLAGS "${MPI_CXX_LINK_FLAGS}")
|
||||
endif()
|
||||
endif()
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_libraries(${test_executable} PRIVATE rccl)
|
||||
if(${HOST_OS_ID} STREQUAL "debian")
|
||||
@@ -176,4 +267,3 @@ if(BUILD_TESTS)
|
||||
endforeach()
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# RCCL Test Suite
|
||||
|
||||
Testing infrastructure for ROCm Communication Collectives Library (RCCL).
|
||||
|
||||
## Table of Contents
|
||||
- [Overview](#overview)
|
||||
- [Testing Frameworks](#testing-frameworks)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The RCCL test suite provides following frameworks along with the existing rccl-UnitTests TestBed framework:
|
||||
|
||||
## Testing Frameworks
|
||||
|
||||
Following are two new complementary testing frameworks for different testing needs:
|
||||
|
||||
### 1. Process Isolated Test Runner
|
||||
Run tests in isolated processes with clean environment settings.
|
||||
|
||||
📄 **[Full Documentation](common/ProcessIsolatedTestRunner.md)**
|
||||
|
||||
### 2. MPI Test Runner
|
||||
Base class for multi-process distributed tests using MPI.
|
||||
|
||||
📄 **[Full Documentation](common/MPITestRunner.md)**
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "nccl.h"
|
||||
#include <cmath>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
/**
|
||||
* @file DeviceBufferHelpers.hpp
|
||||
* @brief Template-based device buffer utilities for RCCL tests
|
||||
*
|
||||
* Provides type-safe, reusable functions for device buffer operations:
|
||||
* - Initialization with test patterns (Host -> Device)
|
||||
* - Host <-> Device transfers
|
||||
* - Data verification (Device -> Host)
|
||||
* - NCCL datatype mapping
|
||||
*
|
||||
* NOTE: All functions expect DEVICE memory pointers allocated with hipMalloc().
|
||||
* For host memory operations, use direct CPU operations instead.
|
||||
*/
|
||||
|
||||
namespace RCCLTestHelpers
|
||||
{
|
||||
|
||||
// ============================================================================
|
||||
// NCCL Datatype Mapping
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Maps C++ types to NCCL data types at compile time
|
||||
* @tparam T C++ data type
|
||||
*/
|
||||
template<typename T>
|
||||
struct NcclTypeTraits;
|
||||
|
||||
/**
|
||||
* @brief Macro to define NcclTypeTraits specializations
|
||||
*
|
||||
* ncclDataType_t mapping and the string name using the stringification
|
||||
* operator (#) for each supported type.
|
||||
*
|
||||
* @param cpp_type The C++ type (e.g., uint64_t, float)
|
||||
* @param nccl_type The corresponding NCCL type (e.g., ncclUint64, ncclFloat)
|
||||
*/
|
||||
#define DEFINE_NCCL_TYPE_TRAIT(cpp_type, nccl_type) \
|
||||
template<> \
|
||||
struct NcclTypeTraits<cpp_type> \
|
||||
{ \
|
||||
static constexpr ncclDataType_t value = nccl_type; \
|
||||
static constexpr const char* name = #cpp_type; \
|
||||
}
|
||||
|
||||
// Define all supported type mappings
|
||||
DEFINE_NCCL_TYPE_TRAIT(float, ncclFloat);
|
||||
DEFINE_NCCL_TYPE_TRAIT(double, ncclDouble);
|
||||
DEFINE_NCCL_TYPE_TRAIT(int8_t, ncclInt8);
|
||||
DEFINE_NCCL_TYPE_TRAIT(uint8_t, ncclUint8);
|
||||
DEFINE_NCCL_TYPE_TRAIT(int32_t, ncclInt32);
|
||||
DEFINE_NCCL_TYPE_TRAIT(uint32_t, ncclUint32);
|
||||
DEFINE_NCCL_TYPE_TRAIT(int64_t, ncclInt64);
|
||||
DEFINE_NCCL_TYPE_TRAIT(uint64_t, ncclUint64);
|
||||
|
||||
// Undefine macro to avoid polluting namespace
|
||||
#undef DEFINE_NCCL_TYPE_TRAIT
|
||||
|
||||
/**
|
||||
* @brief Helper function to get NCCL datatype for a C++ type
|
||||
* @tparam T C++ data type
|
||||
* @return Corresponding ncclDataType_t
|
||||
*/
|
||||
template<typename T>
|
||||
constexpr ncclDataType_t getNcclDataType()
|
||||
{
|
||||
return NcclTypeTraits<T>::value;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to get type name string
|
||||
* @tparam T C++ data type
|
||||
* @return Type name as string
|
||||
*/
|
||||
template<typename T>
|
||||
constexpr const char* getTypeName()
|
||||
{
|
||||
return NcclTypeTraits<T>::name;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Device Buffer Initialization
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Initialize device buffer with pattern function
|
||||
*
|
||||
* Generic function that allows any pattern generation via lambda or function pointer.
|
||||
*
|
||||
* Example usage:
|
||||
* @code
|
||||
* // Rank-based pattern: rank * multiplier + index
|
||||
* initializeBufferWithPattern<float>(buffer, size,
|
||||
* [rank, multiplier](size_t i) { return rank * multiplier + i; });
|
||||
*
|
||||
* // Constant value pattern
|
||||
* initializeBufferWithPattern<int>(buffer, size,
|
||||
* [](size_t i) { return 42; });
|
||||
*
|
||||
* // Custom pattern
|
||||
* initializeBufferWithPattern<double>(buffer, size,
|
||||
* [](size_t i) { return std::sin(i * 0.1); });
|
||||
* @endcode
|
||||
*
|
||||
* @tparam T Element type (float, int, etc.)
|
||||
* @tparam PatternFunc Callable type (lambda, function pointer, functor)
|
||||
* @param device_buffer Device memory pointer (from hipMalloc)
|
||||
* @param num_elements Number of elements
|
||||
* @param pattern_func Function that generates value for each index: T pattern_func(size_t index)
|
||||
* @return hipError_t from hipMemcpy, or hipSuccess
|
||||
*/
|
||||
template<typename T, typename PatternFunc>
|
||||
hipError_t initializeBufferWithPattern(void* device_buffer,
|
||||
size_t num_elements,
|
||||
PatternFunc pattern_func)
|
||||
{
|
||||
if(!device_buffer || num_elements == 0)
|
||||
{
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
std::vector<T> host_data(num_elements);
|
||||
for(size_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
host_data[i] = pattern_func(i);
|
||||
}
|
||||
|
||||
return hipMemcpy(device_buffer,
|
||||
host_data.data(),
|
||||
num_elements * sizeof(T),
|
||||
hipMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Zero-initialize device buffer
|
||||
*
|
||||
* @tparam T Element type
|
||||
* @param device_buffer Device memory pointer (from hipMalloc)
|
||||
* @param num_elements Number of elements
|
||||
* @return hipError_t from hipMemset
|
||||
*/
|
||||
template<typename T>
|
||||
hipError_t zeroInitializeBuffer(void* device_buffer, size_t num_elements)
|
||||
{
|
||||
if(!device_buffer || num_elements == 0)
|
||||
{
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
return hipMemset(device_buffer, 0, num_elements * sizeof(T));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Device Buffer Verification
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Verify device buffer data with pattern function
|
||||
*
|
||||
* Generic function that allows any verification pattern via lambda or function pointer.
|
||||
* Downloads data from device and verifies elements against expected values.
|
||||
* Uses appropriate comparison for floating-point vs integer types.
|
||||
*
|
||||
* Example usage:
|
||||
* @code
|
||||
* // Rank-based pattern verification: rank * multiplier + index
|
||||
* verifyBufferData<float>(buffer, size,
|
||||
* [rank, multiplier](size_t i) { return rank * multiplier + i; },
|
||||
* num_samples, tolerance);
|
||||
*
|
||||
* // Constant value verification
|
||||
* verifyBufferData<int>(buffer, size,
|
||||
* [](size_t i) { return 42; });
|
||||
*
|
||||
* // Custom pattern verification
|
||||
* verifyBufferData<double>(buffer, size,
|
||||
* [](size_t i) { return std::sin(i * 0.1); },
|
||||
* size, 1e-6); // verify all elements with tighter tolerance
|
||||
* @endcode
|
||||
*
|
||||
* @tparam T Element type
|
||||
* @tparam PatternFunc Callable type (lambda, function pointer, functor)
|
||||
* @param device_buffer Device memory pointer (from hipMalloc)
|
||||
* @param num_elements Total number of elements in buffer
|
||||
* @param pattern_func Function that generates expected value for each index: T pattern_func(size_t index)
|
||||
* @param num_samples Number of elements to verify (default: all, capped at num_elements)
|
||||
* @param tolerance Tolerance for floating-point comparison (default: 1e-5, ignored for integer types)
|
||||
* @param[out] first_error_index If verification fails, set to index of first mismatch
|
||||
* @param[out] expected_value If verification fails, set to expected value
|
||||
* @param[out] actual_value If verification fails, set to actual value
|
||||
* @return true if all samples match, false otherwise
|
||||
*/
|
||||
template<typename T, typename PatternFunc>
|
||||
bool verifyBufferData(const void* device_buffer,
|
||||
size_t num_elements,
|
||||
PatternFunc pattern_func,
|
||||
size_t num_samples = 0, // 0 means verify all
|
||||
double tolerance = 1e-5,
|
||||
size_t* first_error_index = nullptr,
|
||||
T* expected_value = nullptr,
|
||||
T* actual_value = nullptr)
|
||||
{
|
||||
if(!device_buffer || num_elements == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Default to verifying all elements if num_samples is 0
|
||||
if(num_samples == 0)
|
||||
{
|
||||
num_samples = num_elements;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Cap num_samples at num_elements
|
||||
num_samples = std::min(num_samples, num_elements);
|
||||
}
|
||||
|
||||
// Download data from device
|
||||
std::vector<T> host_data(num_elements);
|
||||
hipError_t err = hipMemcpy(host_data.data(),
|
||||
device_buffer,
|
||||
num_elements * sizeof(T),
|
||||
hipMemcpyDeviceToHost);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify samples
|
||||
for(size_t i = 0; i < num_samples; i++)
|
||||
{
|
||||
T expected = pattern_func(i);
|
||||
T actual = host_data[i];
|
||||
|
||||
bool matches = false;
|
||||
|
||||
// Use appropriate comparison based on type
|
||||
if constexpr(std::is_floating_point_v<T>)
|
||||
{
|
||||
// Floating-point: use tolerance-based comparison
|
||||
matches = (std::abs(actual - expected) <= tolerance);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Integer: exact comparison
|
||||
matches = (actual == expected);
|
||||
}
|
||||
|
||||
if(!matches)
|
||||
{
|
||||
// Record error details
|
||||
if(first_error_index)
|
||||
*first_error_index = i;
|
||||
if(expected_value)
|
||||
*expected_value = expected;
|
||||
if(actual_value)
|
||||
*actual_value = actual;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Combined Operations
|
||||
// ============================================================================
|
||||
|
||||
// Forward declaration for downloadBuffer (used in allocateAndInitialize)
|
||||
template<typename T>
|
||||
std::pair<hipError_t, std::vector<T>> downloadBuffer(const void* device_buffer, size_t num_elements);
|
||||
|
||||
/**
|
||||
* @brief Allocate, initialize, and return RAII-guarded device buffers
|
||||
*
|
||||
* Convenience function that combines allocation and initialization.
|
||||
* Returns host vector for later verification if needed.
|
||||
*
|
||||
* @tparam T Element type
|
||||
* @param[out] device_buffer Pointer to receive device buffer address
|
||||
* @param num_elements Number of elements
|
||||
* @param rank MPI rank for pattern generation
|
||||
* @param multiplier Pattern multiplier
|
||||
* @return std::pair<hipError_t, std::vector<T>> - error code and host data copy
|
||||
*/
|
||||
template<typename T>
|
||||
std::pair<hipError_t, std::vector<T>> allocateAndInitialize(void** device_buffer,
|
||||
size_t num_elements,
|
||||
int rank,
|
||||
int multiplier = 1000)
|
||||
{
|
||||
if(!device_buffer)
|
||||
{
|
||||
return {hipErrorInvalidValue, {}};
|
||||
}
|
||||
|
||||
// Allocate device memory
|
||||
hipError_t err = hipMalloc(device_buffer, num_elements * sizeof(T));
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
return {err, {}};
|
||||
}
|
||||
|
||||
// Initialize using generic pattern function
|
||||
err = initializeBufferWithPattern<T>(
|
||||
*device_buffer, num_elements,
|
||||
[rank, multiplier](size_t i) { return static_cast<T>(rank * multiplier + i); });
|
||||
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
return {err, {}};
|
||||
}
|
||||
|
||||
// Download and return host copy for verification
|
||||
return downloadBuffer<T>(*device_buffer, num_elements);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy data from one device buffer to another
|
||||
*
|
||||
* @tparam T Element type (used for size calculation)
|
||||
* @param dst Destination device buffer (from hipMalloc)
|
||||
* @param src Source device buffer (from hipMalloc)
|
||||
* @param num_elements Number of elements to copy
|
||||
* @return hipError_t from hipMemcpy
|
||||
*/
|
||||
template<typename T>
|
||||
hipError_t copyDeviceBuffer(void* dst, const void* src, size_t num_elements)
|
||||
{
|
||||
if(!dst || !src || num_elements == 0)
|
||||
{
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
return hipMemcpy(dst, src, num_elements * sizeof(T), hipMemcpyDeviceToDevice);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Download device buffer to host vector
|
||||
*
|
||||
* @tparam T Element type
|
||||
* @param device_buffer Device memory pointer (from hipMalloc)
|
||||
* @param num_elements Number of elements
|
||||
* @return std::pair<hipError_t, std::vector<T>> - error code and host data
|
||||
*/
|
||||
template<typename T>
|
||||
std::pair<hipError_t, std::vector<T>> downloadBuffer(const void* device_buffer, size_t num_elements)
|
||||
{
|
||||
std::vector<T> host_data(num_elements);
|
||||
|
||||
if(!device_buffer || num_elements == 0)
|
||||
{
|
||||
return {hipErrorInvalidValue, {}};
|
||||
}
|
||||
|
||||
hipError_t err = hipMemcpy(host_data.data(),
|
||||
device_buffer,
|
||||
num_elements * sizeof(T),
|
||||
hipMemcpyDeviceToHost);
|
||||
|
||||
return {err, std::move(host_data)};
|
||||
}
|
||||
|
||||
} // namespace RCCLTestHelpers
|
||||
|
||||
|
||||
@@ -0,0 +1,361 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file MPIEnvironment.cpp
|
||||
* @brief Implementation of global MPI environment for RCCL testing
|
||||
*/
|
||||
|
||||
#include "MPIEnvironment.hpp"
|
||||
#include "MPITestBase.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
/**
|
||||
* @brief Initialize the global test environment
|
||||
*
|
||||
* Performs one-time setup for the entire test suite:
|
||||
* - Initializes MPI with thread support
|
||||
* - Sets up GPU devices for each rank
|
||||
*
|
||||
* @note Called automatically by Google Test framework before any tests run
|
||||
*/
|
||||
void MPIEnvironment::SetUp()
|
||||
{
|
||||
// One-time initialization (MPI_Init can only be called once)
|
||||
initialize_mpi();
|
||||
initialize_devices();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initialize MPI with multi-threading support
|
||||
*
|
||||
* Calls MPI_Init_thread() with MPI_THREAD_MULTIPLE to support concurrent
|
||||
* MPI operations. Sets world_rank and world_size for use by all tests.
|
||||
*
|
||||
* Idempotent - safe to call multiple times (uses mpi_initialized flag).
|
||||
* Typically called from main_mpi.cpp, but provides fallback initialization.
|
||||
*/
|
||||
void MPIEnvironment::initialize_mpi()
|
||||
{
|
||||
if(mpi_initialized)
|
||||
{
|
||||
// Already initialized in main_mpi.cpp
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("MPI already initialized - skipping re-initialization");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// This path should not be reached when using main_mpi.cpp
|
||||
// but kept for compatibility with other test mains
|
||||
auto provided = int{};
|
||||
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &world_rank));
|
||||
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size));
|
||||
|
||||
mpi_initialized = true;
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("MPI initialized - World size: %d, Thread support: %d", world_size, provided);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initialize GPU devices and assign one GPU per MPI rank
|
||||
*
|
||||
* Performs comprehensive GPU setup:
|
||||
* 1. Queries number of available GPUs
|
||||
* 2. Validates sufficient GPUs for world_size
|
||||
* 3. Assigns GPU ID = rank (rank-based assignment)
|
||||
* 4. Resets HIP context for clean state
|
||||
* 5. Sets active device
|
||||
* 6. Verifies device assignment
|
||||
* 7. Synchronizes all ranks
|
||||
*
|
||||
* @note Requires at least world_size GPUs
|
||||
* @note Sets retCode=1 on error (insufficient GPUs, assignment failure)
|
||||
* @note Idempotent - safe to call multiple times (uses devices_initialized flag)
|
||||
*/
|
||||
void MPIEnvironment::initialize_devices()
|
||||
{
|
||||
if(devices_initialized)
|
||||
{
|
||||
return; // Already initialized
|
||||
}
|
||||
|
||||
auto numDevices = int{};
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipGetDeviceCount(&numDevices));
|
||||
|
||||
// Calculate local rank (rank within this node) for multi-node support
|
||||
// Split MPI_COMM_WORLD by node using MPI_Comm_split_type
|
||||
MPI_Comm node_comm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD,
|
||||
MPI_COMM_TYPE_SHARED,
|
||||
world_rank,
|
||||
MPI_INFO_NULL,
|
||||
&node_comm);
|
||||
|
||||
int local_rank, local_size;
|
||||
MPI_Comm_rank(node_comm, &local_rank);
|
||||
MPI_Comm_size(node_comm, &local_size);
|
||||
|
||||
// Cache multi-node detection result ONCE during initialization
|
||||
// local_size < world_size means we have multiple nodes
|
||||
cached_multi_node_result = (local_size < world_size) ? 1 : 0;
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Detected %d GPU(s) for %d MPI rank(s)", numDevices, world_size);
|
||||
TEST_INFO("Local configuration: %d ranks per node", local_size);
|
||||
TEST_INFO("Multi-node configuration: %s",
|
||||
cached_multi_node_result ? "YES (multiple nodes)" : "NO (single node)");
|
||||
}
|
||||
|
||||
// Check if we have enough GPUs for ranks on THIS node
|
||||
if(numDevices < local_size)
|
||||
{
|
||||
TEST_ABORT(
|
||||
"ERROR: (local rank %d): Only %d GPUs available on this node for %d local ranks. "
|
||||
"RCCL requires unique GPUs per rank on each node. "
|
||||
"Please run with fewer ranks per node (e.g., --ntasks-per-node=%d) "
|
||||
"or ensure more GPUs are available.",
|
||||
local_rank,
|
||||
numDevices,
|
||||
local_size,
|
||||
numDevices);
|
||||
retCode = 1;
|
||||
devices_initialized = true;
|
||||
MPI_Comm_free(&node_comm);
|
||||
return;
|
||||
}
|
||||
|
||||
// Use LOCAL rank for device assignment (not global rank)
|
||||
// This ensures ranks 0-7 on each node use GPUs 0-7
|
||||
const auto assigned_device = local_rank;
|
||||
|
||||
// Validate device assignment
|
||||
if(assigned_device < 0 || assigned_device >= numDevices)
|
||||
{
|
||||
TEST_ABORT(
|
||||
"ERROR: (local rank %d): Invalid device assignment! assigned_device=%d, numDevices=%d",
|
||||
local_rank,
|
||||
assigned_device,
|
||||
numDevices);
|
||||
retCode = 1;
|
||||
devices_initialized = true;
|
||||
MPI_Comm_free(&node_comm);
|
||||
return;
|
||||
}
|
||||
|
||||
// Complete HIP context reset and isolation
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipDeviceReset());
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipSetDevice(assigned_device));
|
||||
|
||||
// Force HIP context creation and synchronization
|
||||
auto prop = hipDeviceProp_t{};
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipGetDeviceProperties(&prop, assigned_device));
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipDeviceSynchronize());
|
||||
|
||||
// Verify device assignment
|
||||
auto current_device = int{};
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipGetDevice(¤t_device));
|
||||
if(current_device != assigned_device)
|
||||
{
|
||||
TEST_ABORT("ERROR: (local rank %d) device assignment failed! Expected %d, got %d",
|
||||
local_rank,
|
||||
assigned_device,
|
||||
current_device);
|
||||
retCode = 1;
|
||||
MPI_Comm_free(&node_comm);
|
||||
return;
|
||||
}
|
||||
|
||||
// Print device info (only from rank 0 to reduce output)
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("(local rank %d): Device assignment: global rank %d -> GPU %d",
|
||||
local_rank,
|
||||
world_rank,
|
||||
assigned_device);
|
||||
TEST_INFO("PCI Bus ID = 0x%x, Device Name = %s", prop.pciBusID, prop.name);
|
||||
TEST_INFO("Total GPUs available per node: %d", numDevices);
|
||||
TEST_INFO("Multi-node: Each node's local ranks (0-%d) mapped to GPUs (0-%d)",
|
||||
local_size - 1,
|
||||
numDevices - 1);
|
||||
}
|
||||
|
||||
// Clean up node communicator
|
||||
MPI_Comm_free(&node_comm);
|
||||
|
||||
// Ensure all ranks have set their devices before proceeding
|
||||
MPICHECK(MPI_Barrier(MPI_COMM_WORLD));
|
||||
|
||||
devices_initialized = true;
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Device initialization completed");
|
||||
TEST_INFO("Each test will create its own NCCL communicator for isolation");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Tear down the global test environment
|
||||
*
|
||||
* Ensures all ranks have completed their tests before cleanup:
|
||||
* 1. Synchronizes all ranks with MPI_Barrier
|
||||
* 2. Calls cleanup_mpi() to finalize MPI
|
||||
*
|
||||
* @note Critical synchronization point - ensures all test cleanup is complete
|
||||
* @note Called automatically by Google Test framework after all tests complete
|
||||
*/
|
||||
void MPIEnvironment::TearDown()
|
||||
{
|
||||
// CRITICAL: Handle the case where ranks are out of sync due to test failures
|
||||
//
|
||||
// Problem: If rank 0 fails with ASSERT/FAIL, it immediately goes to TearDown()
|
||||
// while rank 1 is still in the test body. This causes deadlock when rank 0
|
||||
// tries to do MPI collectives (like Allreduce) while rank 1 is doing different
|
||||
// MPI collectives (like Bcast in createTestCommunicator).
|
||||
//
|
||||
// Use MPI_Ibarrier (non-blocking) with a timeout to detect if ranks
|
||||
// are out of sync, then force cleanup with MPI_Abort if necessary.
|
||||
|
||||
// Try a non-blocking barrier to check if all ranks are ready
|
||||
MPI_Request barrier_req;
|
||||
int barrier_result = MPI_Ibarrier(MPI_COMM_WORLD, &barrier_req);
|
||||
|
||||
if(barrier_result == MPI_SUCCESS)
|
||||
{
|
||||
// Wait for barrier with a timeout (1 second)
|
||||
int flag = 0;
|
||||
auto timeout_start = std::chrono::steady_clock::now();
|
||||
const auto timeout_duration = std::chrono::seconds(1);
|
||||
|
||||
while(!flag)
|
||||
{
|
||||
MPI_Test(&barrier_req, &flag, MPI_STATUS_IGNORE);
|
||||
|
||||
if(!flag)
|
||||
{
|
||||
// Check if timeout exceeded
|
||||
auto elapsed = std::chrono::steady_clock::now() - timeout_start;
|
||||
if(elapsed > timeout_duration)
|
||||
{
|
||||
// Timeout - ranks are out of sync!
|
||||
std::fprintf(
|
||||
stderr,
|
||||
"Rank %d: TIMEOUT in TearDown barrier - ranks out of sync, forcing abort\n",
|
||||
world_rank);
|
||||
std::fflush(stderr);
|
||||
|
||||
// Cancel the barrier request
|
||||
MPI_Cancel(&barrier_req);
|
||||
MPI_Request_free(&barrier_req);
|
||||
|
||||
// Force abort - can't safely continue
|
||||
MPI_Abort(MPI_COMM_WORLD, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
// Sleep briefly to avoid busy-waiting
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
}
|
||||
|
||||
// Barrier completed - all ranks are synchronized
|
||||
// Now safe to do collective operations
|
||||
|
||||
// Check if ANY rank had a failure
|
||||
int local_failed = (retCode != 0) ? 1 : 0;
|
||||
int global_failed = 0;
|
||||
MPI_Allreduce(&local_failed, &global_failed, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD);
|
||||
|
||||
// Update retCode to reflect global failure status
|
||||
if(global_failed > 0)
|
||||
{
|
||||
retCode = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// MPI_Ibarrier failed - something is very wrong
|
||||
std::fprintf(stderr,
|
||||
"Rank %d: MPI_Ibarrier failed in TearDown, forcing abort\n",
|
||||
world_rank);
|
||||
std::fflush(stderr);
|
||||
MPI_Abort(MPI_COMM_WORLD, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
cleanup_mpi();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clean up MPI resources and finalize
|
||||
*
|
||||
* Performs coordinated cleanup across all ranks:
|
||||
* 1. Guards against multiple cleanup attempts
|
||||
* 2. Synchronizes all ranks
|
||||
* 3. Aggregates test results using MPI_Allreduce
|
||||
* 4. Prints final results from rank 0
|
||||
* 5. Calls MPI_Finalize()
|
||||
* 6. Resets initialization flags
|
||||
*
|
||||
* Uses context-aware error handling:
|
||||
* - MPI_Barrier/Allreduce: MPICHECK with rank (aborts on error)
|
||||
* - MPI_Finalize: MPICHECK with rank and true flag (exits on error)
|
||||
*
|
||||
* @note Uses static guard to prevent multiple cleanup attempts
|
||||
* @note Safe to call from signal handlers or error paths
|
||||
* @note All ranks must call this function for proper finalization
|
||||
*/
|
||||
void MPIEnvironment::cleanup_mpi()
|
||||
{
|
||||
// Use static guard to prevent multiple cleanup attempts
|
||||
static bool cleanup_in_progress_or_done = false;
|
||||
|
||||
if(cleanup_in_progress_or_done)
|
||||
{
|
||||
return; // Already cleaned up or currently cleaning up
|
||||
}
|
||||
|
||||
if(!mpi_initialized)
|
||||
{
|
||||
return; // Never initialized
|
||||
}
|
||||
|
||||
cleanup_in_progress_or_done = true;
|
||||
|
||||
// Synchronize all ranks before MPI finalization
|
||||
MPICHECK(MPI_Barrier(MPI_COMM_WORLD), world_rank);
|
||||
|
||||
MPICHECK(MPI_Finalize(), world_rank, true);
|
||||
|
||||
mpi_initialized = false;
|
||||
devices_initialized = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Accessor function to get cached multi-node detection result
|
||||
*
|
||||
* This function is defined here to avoid circular dependency between
|
||||
* TestChecks.hpp and MPIEnvironment.hpp.
|
||||
*
|
||||
* @return The cached multi-node result: -1 (not computed), 0 (single node), 1 (multi-node)
|
||||
*/
|
||||
int getMPIEnvironmentCachedMultiNodeResult()
|
||||
{
|
||||
return MPIEnvironment::cached_multi_node_result;
|
||||
}
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
@@ -0,0 +1,149 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file MPIEnvironment.hpp
|
||||
* @brief Global MPI environment and error checking macros for RCCL testing
|
||||
*
|
||||
* Provides a Google Test Environment for managing MPI initialization/finalization
|
||||
* and error checking macros for MPI, NCCL, and HIP operations in tests.
|
||||
*/
|
||||
|
||||
#ifndef RCCL_MPI_ENVIRONMENT_HPP
|
||||
#define RCCL_MPI_ENVIRONMENT_HPP
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
// Conditionally include MPI headers for MPI-based tests
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include "rccl/rccl.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "TestChecks.hpp"
|
||||
#include "ResourceGuards.hpp"
|
||||
|
||||
/**
|
||||
* @class MPIEnvironment
|
||||
* @brief Google Test Environment for global MPI setup and teardown
|
||||
*
|
||||
* Manages the global MPI state for all MPI-based tests:
|
||||
* - One-time MPI initialization (MPI_Init_thread)
|
||||
* - GPU device initialization and assignment
|
||||
* - MPI finalization and result aggregation across ranks
|
||||
*
|
||||
* @note MPI_Init can only be called once, so this uses static flags
|
||||
* @note Each MPI rank is assigned to a unique GPU
|
||||
* @see MPITestBase for test-level functionality
|
||||
*/
|
||||
class MPIEnvironment : public ::testing::Environment
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Current MPI rank in MPI_COMM_WORLD
|
||||
*
|
||||
* Valid after MPI initialization. Each rank corresponds to one GPU.
|
||||
*/
|
||||
inline static int world_rank{0};
|
||||
|
||||
/**
|
||||
* @brief Total number of MPI processes in MPI_COMM_WORLD
|
||||
*
|
||||
* Valid after MPI initialization. Must not exceed number of available GPUs.
|
||||
*/
|
||||
inline static int world_size{0};
|
||||
|
||||
/**
|
||||
* @brief Aggregated return code for test results
|
||||
*
|
||||
* Set to non-zero on test failure. Aggregated across all ranks during cleanup.
|
||||
*/
|
||||
inline static int retCode{0};
|
||||
|
||||
/**
|
||||
* @brief Flag indicating MPI has been initialized
|
||||
*
|
||||
* Prevents multiple MPI_Init calls (only allowed once per process).
|
||||
*/
|
||||
inline static bool mpi_initialized{false};
|
||||
|
||||
/**
|
||||
* @brief Cached result of multi-node detection
|
||||
*
|
||||
* Computed once during SetUp() using MPI_Comm_split_type().
|
||||
* -1 = not computed, 0 = single node, 1 = multi-node
|
||||
*
|
||||
* @note MUST be initialized before any TEST_* macros are called
|
||||
* @note Prevents nested MPI collective operations in isMultiNodeTest()
|
||||
*/
|
||||
inline static int cached_multi_node_result{-1};
|
||||
|
||||
/**
|
||||
* @brief Flag indicating GPU devices have been initialized
|
||||
*
|
||||
* Prevents redundant device setup across multiple test runs.
|
||||
*/
|
||||
inline static bool devices_initialized{false};
|
||||
|
||||
/**
|
||||
* @brief Initialize MPI with thread support
|
||||
*
|
||||
* Calls MPI_Init_thread() with MPI_THREAD_MULTIPLE support and sets
|
||||
* world_rank and world_size. Safe to call multiple times (idempotent).
|
||||
*
|
||||
* @note Should be called before any MPI operations
|
||||
* @see mpi_initialized flag
|
||||
*/
|
||||
static void initialize_mpi();
|
||||
|
||||
/**
|
||||
* @brief Initialize and assign GPU devices to MPI ranks
|
||||
*
|
||||
* Performs the following:
|
||||
* 1. Queries available GPU count
|
||||
* 2. Validates sufficient GPUs for all ranks
|
||||
* 3. Assigns one GPU per rank (rank N → GPU N)
|
||||
* 4. Resets and sets HIP device context
|
||||
* 5. Synchronizes all ranks
|
||||
*
|
||||
* @note Requires world_size ≤ number of available GPUs
|
||||
* @see devices_initialized flag
|
||||
*/
|
||||
static void initialize_devices();
|
||||
|
||||
/**
|
||||
* @brief Clean up MPI resources and finalize
|
||||
*
|
||||
* Performs the following cleanup:
|
||||
* 1. Synchronizes all ranks with MPI_Barrier
|
||||
* 2. Aggregates test results across ranks with MPI_Allreduce
|
||||
* 3. Prints final results from rank 0
|
||||
* 4. Calls MPI_Finalize()
|
||||
*
|
||||
* @note Uses static guard to prevent multiple cleanup attempts
|
||||
* @note Safe to call from signal handlers or error paths
|
||||
*/
|
||||
static void cleanup_mpi();
|
||||
|
||||
/**
|
||||
* @brief Google Test SetUp hook - called once before all tests
|
||||
*
|
||||
* Initializes MPI and GPU devices for the entire test suite.
|
||||
*/
|
||||
void SetUp() override;
|
||||
|
||||
/**
|
||||
* @brief Google Test TearDown hook - called once after all tests
|
||||
*
|
||||
* Synchronizes all ranks and calls cleanup_mpi() to finalize MPI.
|
||||
*/
|
||||
void TearDown() override;
|
||||
};
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // RCCL_MPI_ENVIRONMENT_HPP
|
||||
@@ -0,0 +1,371 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "MPIHelpers.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include "MPITestCore.hpp"
|
||||
#include "MPIEnvironment.hpp"
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <fcntl.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <mpi.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace MPIHelpers
|
||||
{
|
||||
|
||||
// ============================================================================
|
||||
// FileDescriptor Implementation
|
||||
// ============================================================================
|
||||
|
||||
FileDescriptor::FileDescriptor(int fd) noexcept : fd_(fd) {}
|
||||
|
||||
FileDescriptor::~FileDescriptor()
|
||||
{
|
||||
if(fd_ >= 0)
|
||||
{
|
||||
::close(fd_);
|
||||
}
|
||||
}
|
||||
|
||||
FileDescriptor::FileDescriptor(FileDescriptor&& other) noexcept : fd_(other.fd_)
|
||||
{
|
||||
other.fd_ = -1;
|
||||
}
|
||||
|
||||
FileDescriptor& FileDescriptor::operator=(FileDescriptor&& other) noexcept
|
||||
{
|
||||
if(this != &other)
|
||||
{
|
||||
if(fd_ >= 0)
|
||||
{
|
||||
::close(fd_);
|
||||
}
|
||||
fd_ = other.fd_;
|
||||
other.fd_ = -1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
int FileDescriptor::get() const noexcept
|
||||
{
|
||||
return fd_;
|
||||
}
|
||||
|
||||
bool FileDescriptor::is_valid() const noexcept
|
||||
{
|
||||
return fd_ >= 0;
|
||||
}
|
||||
|
||||
int FileDescriptor::release() noexcept
|
||||
{
|
||||
const auto fd = fd_;
|
||||
fd_ = -1;
|
||||
return fd;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TeeThread Implementation
|
||||
// ============================================================================
|
||||
|
||||
TeeThread::TeeThread(int read_fd, int console_fd, int log_fd)
|
||||
: read_fd_(read_fd), console_fd_(console_fd), log_fd_(log_fd), running_(true)
|
||||
{
|
||||
thread_ = std::thread([this]() { this->tee_loop(); });
|
||||
}
|
||||
|
||||
TeeThread::~TeeThread()
|
||||
{
|
||||
running_ = false;
|
||||
if(thread_.joinable())
|
||||
{
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void TeeThread::tee_loop()
|
||||
{
|
||||
std::array<char, 4096> buffer;
|
||||
while(running_)
|
||||
{
|
||||
const auto bytes_read = ::read(read_fd_, buffer.data(), buffer.size());
|
||||
if(bytes_read <= 0)
|
||||
{
|
||||
if(bytes_read == 0 || errno != EINTR)
|
||||
{
|
||||
break; // EOF or error
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Write to console
|
||||
[[maybe_unused]] auto console_written = ::write(console_fd_, buffer.data(), bytes_read);
|
||||
|
||||
// Write to log file
|
||||
[[maybe_unused]] auto log_written = ::write(log_fd_, buffer.data(), bytes_read);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MPI Initialization
|
||||
// ============================================================================
|
||||
|
||||
MPIContext initializeMPI(int* argc, char*** argv)
|
||||
{
|
||||
MPIContext ctx;
|
||||
|
||||
// Initialize MPI with thread support
|
||||
MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, &ctx.thread_support);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &ctx.world_rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &ctx.world_size);
|
||||
|
||||
// Update global environment
|
||||
MPIEnvironment::world_rank = ctx.world_rank;
|
||||
MPIEnvironment::world_size = ctx.world_size;
|
||||
MPIEnvironment::mpi_initialized = true;
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GPU Setup
|
||||
// ============================================================================
|
||||
|
||||
void setupGPU(int world_rank)
|
||||
{
|
||||
int device_count = 0;
|
||||
hipGetDeviceCount(&device_count);
|
||||
|
||||
if(device_count > 0)
|
||||
{
|
||||
// Use MPI_COMM_TYPE_SHARED to detect local ranks on same node
|
||||
MPI_Comm node_comm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &node_comm);
|
||||
|
||||
int local_rank, local_size;
|
||||
MPI_Comm_rank(node_comm, &local_rank);
|
||||
MPI_Comm_size(node_comm, &local_size);
|
||||
|
||||
// Cache multi-node detection result for isMultiNodeTest()
|
||||
int world_size;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
MPIEnvironment::cached_multi_node_result = (local_size < world_size) ? 1 : 0;
|
||||
|
||||
// Assign GPU in round-robin fashion
|
||||
int device_id = local_rank % device_count;
|
||||
hipSetDevice(device_id);
|
||||
|
||||
MPI_Comm_free(&node_comm);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Per-Rank Logging
|
||||
// ============================================================================
|
||||
|
||||
std::optional<RankLogConfig> setupRankLogging(int rank)
|
||||
{
|
||||
const auto* env_value = std::getenv("RCCL_MPI_LOG_ALL_RANKS");
|
||||
const bool per_rank_logging_enabled = (env_value && std::string(env_value) == "1");
|
||||
|
||||
RankLogConfig config;
|
||||
config.logging_enabled = per_rank_logging_enabled;
|
||||
config.is_rank_zero = (rank == 0);
|
||||
|
||||
// Non-zero ranks: Always redirect output (either to log file or /dev/null)
|
||||
if(rank != 0)
|
||||
{
|
||||
// Save original stdout/stderr
|
||||
config.saved_stdout = FileDescriptor{::dup(STDOUT_FILENO)};
|
||||
config.saved_stderr = FileDescriptor{::dup(STDERR_FILENO)};
|
||||
|
||||
if(!config.saved_stdout->is_valid() || !config.saved_stderr->is_valid())
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to duplicate stdout/stderr", rank);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if(per_rank_logging_enabled)
|
||||
{
|
||||
// Per-rank logging enabled: Redirect to log file
|
||||
const auto log_filename
|
||||
= std::string{"rccl_test_rank_"} + std::to_string(rank) + ".log";
|
||||
|
||||
const auto log_fd = ::open(log_filename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
|
||||
if(log_fd < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to create log file: %s", rank, log_filename.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
config.log_fd = FileDescriptor{log_fd};
|
||||
|
||||
// Redirect stdout/stderr to log file
|
||||
if(::dup2(log_fd, STDOUT_FILENO) < 0 || ::dup2(log_fd, STDERR_FILENO) < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to redirect to log file", rank);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Debug: Write initial marker to log file (AFTER redirection)
|
||||
TEST_INFO("===== LOG FILE FOR RANK %d =====", rank);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Default: Suppress all output by redirecting to /dev/null
|
||||
const auto null_fd = ::open("/dev/null", O_WRONLY);
|
||||
if(null_fd < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to open /dev/null", rank);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Redirect stdout/stderr to /dev/null
|
||||
if(::dup2(null_fd, STDOUT_FILENO) < 0 || ::dup2(null_fd, STDERR_FILENO) < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to redirect to /dev/null", rank);
|
||||
::close(null_fd);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
::close(null_fd);
|
||||
}
|
||||
|
||||
// Disable buffering for immediate output
|
||||
std::setvbuf(stdout, nullptr, _IONBF, 0);
|
||||
std::setvbuf(stderr, nullptr, _IONBF, 0);
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
// Rank 0: Only redirect if per-rank logging is enabled (for tee functionality)
|
||||
if(!per_rank_logging_enabled)
|
||||
{
|
||||
return std::nullopt; // Rank 0 outputs to console normally
|
||||
}
|
||||
|
||||
// Create log file for rank 0
|
||||
const auto log_filename = std::string{"rccl_test_rank_"} + std::to_string(rank) + ".log";
|
||||
|
||||
// Debug: Print to stderr BEFORE creating log file
|
||||
TEST_TRACE("Rank %d (rank 0 tee mode) opening log file: %s", rank, log_filename.c_str());
|
||||
|
||||
const auto log_fd = ::open(log_filename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
|
||||
if(log_fd < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to create log file: %s", rank, log_filename.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
config.log_fd = FileDescriptor{log_fd};
|
||||
|
||||
// Debug: Write initial marker directly to log file (BEFORE redirection)
|
||||
const char* marker = "===== LOG FILE FOR RANK 0 (TEE MODE) =====\n";
|
||||
[[maybe_unused]] auto written = ::write(log_fd, marker, std::strlen(marker));
|
||||
|
||||
// Rank 0 with per-rank logging: Output to BOTH console AND log file (tee behavior)
|
||||
// Print banner before redirection
|
||||
TEST_INFO("Per-Rank Logging ENABLED (RCCL_MPI_LOG_ALL_RANKS=1)");
|
||||
TEST_INFO("Rank 0 : Output to BOTH console AND %s", log_filename.c_str());
|
||||
TEST_INFO("Ranks 1-N : Output redirected to rccl_test_rank_<N>.log");
|
||||
TEST_INFO("Location : Log files created in current working directory");
|
||||
|
||||
// Save original stdout/stderr for tee thread
|
||||
config.saved_stdout = FileDescriptor{::dup(STDOUT_FILENO)};
|
||||
config.saved_stderr = FileDescriptor{::dup(STDERR_FILENO)};
|
||||
|
||||
if(!config.saved_stdout->is_valid() || !config.saved_stderr->is_valid())
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to duplicate stdout/stderr", rank);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Create pipes for tee functionality
|
||||
int pipe_fds[2];
|
||||
if(::pipe(pipe_fds) < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to create pipe", rank);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
config.pipe_read_fd = FileDescriptor{pipe_fds[0]};
|
||||
config.pipe_write_fd = FileDescriptor{pipe_fds[1]};
|
||||
|
||||
// Start tee thread to duplicate output to both console and log file
|
||||
try
|
||||
{
|
||||
config.tee_thread = std::make_unique<TeeThread>(config.pipe_read_fd->get(),
|
||||
config.saved_stdout->get(),
|
||||
log_fd);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to start tee thread: %s", rank, e.what());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Redirect stdout/stderr to the pipe write end
|
||||
if(::dup2(config.pipe_write_fd->get(), STDOUT_FILENO) < 0
|
||||
|| ::dup2(config.pipe_write_fd->get(), STDERR_FILENO) < 0)
|
||||
{
|
||||
TEST_WARN("Rank %d: Failed to redirect to pipe", rank);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Disable buffering for immediate output
|
||||
std::setvbuf(stdout, nullptr, _IONBF, 0);
|
||||
std::setvbuf(stderr, nullptr, _IONBF, 0);
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
void restoreRankLogging(RankLogConfig& config)
|
||||
{
|
||||
// Only restore if we actually redirected (have saved stdout/stderr)
|
||||
if(!config.saved_stdout || !config.saved_stdout->is_valid())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Flush any pending output
|
||||
std::fflush(stdout);
|
||||
std::fflush(stderr);
|
||||
|
||||
// CRITICAL: Restore stdout/stderr BEFORE closing pipe
|
||||
// The tee thread won't get EOF until ALL write ends are closed
|
||||
if(config.saved_stdout && config.saved_stdout->is_valid())
|
||||
{
|
||||
::dup2(config.saved_stdout->get(), STDOUT_FILENO);
|
||||
}
|
||||
|
||||
if(config.saved_stderr && config.saved_stderr->is_valid())
|
||||
{
|
||||
::dup2(config.saved_stderr->get(), STDERR_FILENO);
|
||||
}
|
||||
|
||||
if(config.is_rank_zero && config.tee_thread)
|
||||
{
|
||||
// For rank 0 with per-rank logging: Stop the tee thread
|
||||
// Close the pipe write end to signal EOF to the tee thread
|
||||
config.pipe_write_fd.reset();
|
||||
|
||||
// Wait for tee thread to finish processing
|
||||
config.tee_thread.reset();
|
||||
|
||||
// Close pipe read end
|
||||
config.pipe_read_fd.reset();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MPIHelpers
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
@@ -0,0 +1,187 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file MPIHelpers.hpp
|
||||
* @brief Shared MPI utility functions for both GTest and standalone tests
|
||||
*
|
||||
* Provides common functionality for MPI test initialization, GPU setup,
|
||||
* and per-rank logging that can be used by both GTest-based tests and
|
||||
* standalone tests (performance benchmarks, etc.).
|
||||
*/
|
||||
|
||||
#ifndef MPI_HELPERS_HPP
|
||||
#define MPI_HELPERS_HPP
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
/**
|
||||
* @namespace MPIHelpers
|
||||
* @brief Shared MPI utilities for test infrastructure
|
||||
*/
|
||||
namespace MPIHelpers
|
||||
{
|
||||
|
||||
/**
|
||||
* @struct MPIContext
|
||||
* @brief MPI environment context information
|
||||
*/
|
||||
struct MPIContext
|
||||
{
|
||||
int world_rank; ///< MPI rank in MPI_COMM_WORLD
|
||||
int world_size; ///< Total number of MPI processes
|
||||
int thread_support; ///< MPI thread support level provided
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Initialize MPI with thread support
|
||||
*
|
||||
* Initializes MPI with MPI_THREAD_MULTIPLE support and returns context info.
|
||||
*
|
||||
* @param argc Pointer to argc from main()
|
||||
* @param argv Pointer to argv from main()
|
||||
* @return MPIContext with rank, size, and thread support info
|
||||
*
|
||||
* @note Must be called before any other MPI operations
|
||||
* @note Automatically sets MPIEnvironment static variables
|
||||
*/
|
||||
MPIContext initializeMPI(int* argc, char*** argv);
|
||||
|
||||
/**
|
||||
* @brief Setup GPU device for this MPI rank
|
||||
*
|
||||
* Assigns GPU device based on local rank (ranks on same node).
|
||||
* Uses MPI_COMM_TYPE_SHARED to detect node topology and assigns
|
||||
* GPUs in round-robin fashion.
|
||||
*
|
||||
* @param world_rank MPI rank in MPI_COMM_WORLD
|
||||
*
|
||||
* @note Handles multiple ranks per node automatically
|
||||
* @note Uses hipSetDevice() to assign GPU
|
||||
*/
|
||||
void setupGPU(int world_rank);
|
||||
|
||||
/**
|
||||
* @class FileDescriptor
|
||||
* @brief RAII wrapper for POSIX file descriptors
|
||||
*
|
||||
* Automatically closes file descriptor on destruction.
|
||||
* Move-only semantics prevent accidental duplication.
|
||||
*/
|
||||
class FileDescriptor
|
||||
{
|
||||
public:
|
||||
explicit FileDescriptor(int fd = -1) noexcept;
|
||||
~FileDescriptor();
|
||||
|
||||
// Move-only semantics
|
||||
FileDescriptor(FileDescriptor&& other) noexcept;
|
||||
FileDescriptor& operator=(FileDescriptor&& other) noexcept;
|
||||
|
||||
// Delete copy operations
|
||||
FileDescriptor(const FileDescriptor&) = delete;
|
||||
FileDescriptor& operator=(const FileDescriptor&) = delete;
|
||||
|
||||
[[nodiscard]] int get() const noexcept;
|
||||
[[nodiscard]] bool is_valid() const noexcept;
|
||||
int release() noexcept;
|
||||
|
||||
private:
|
||||
int fd_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class TeeThread
|
||||
* @brief Thread for duplicating output to console and log file
|
||||
*
|
||||
* Used by rank 0 when per-rank logging is enabled to send output
|
||||
* to both console and log file simultaneously.
|
||||
*/
|
||||
class TeeThread
|
||||
{
|
||||
public:
|
||||
TeeThread(int read_fd, int console_fd, int log_fd);
|
||||
~TeeThread();
|
||||
|
||||
// Delete copy/move operations
|
||||
TeeThread(const TeeThread&) = delete;
|
||||
TeeThread& operator=(const TeeThread&) = delete;
|
||||
TeeThread(TeeThread&&) = delete;
|
||||
TeeThread& operator=(TeeThread&&) = delete;
|
||||
|
||||
private:
|
||||
void tee_loop();
|
||||
|
||||
int read_fd_;
|
||||
int console_fd_;
|
||||
int log_fd_;
|
||||
std::atomic<bool> running_;
|
||||
std::thread thread_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct RankLogConfig
|
||||
* @brief Per-rank logging configuration and state
|
||||
*
|
||||
* Manages file descriptors and threads for per-rank logging when
|
||||
* RCCL_MPI_LOG_ALL_RANKS=1 environment variable is set.
|
||||
*/
|
||||
struct RankLogConfig
|
||||
{
|
||||
std::optional<FileDescriptor> log_fd; ///< Log file descriptor
|
||||
std::optional<FileDescriptor> saved_stdout; ///< Saved stdout for restoration
|
||||
std::optional<FileDescriptor> saved_stderr; ///< Saved stderr for restoration
|
||||
std::optional<FileDescriptor> pipe_read_fd; ///< Pipe read end (rank 0 only)
|
||||
std::optional<FileDescriptor> pipe_write_fd; ///< Pipe write end (rank 0 only)
|
||||
std::unique_ptr<TeeThread> tee_thread; ///< Tee thread (rank 0 only)
|
||||
bool logging_enabled{false}; ///< Is per-rank logging enabled?
|
||||
bool is_rank_zero{false}; ///< Is this rank 0?
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Setup per-rank logging if RCCL_MPI_LOG_ALL_RANKS=1
|
||||
*
|
||||
* Configures output redirection for MPI ranks:
|
||||
* - Rank 0: Output to BOTH console AND log file (tee behavior)
|
||||
* - Rank 1-N: Output redirected to rccl_test_rank_<N>.log
|
||||
*
|
||||
* If RCCL_MPI_LOG_ALL_RANKS is not set:
|
||||
* - Rank 0: Normal console output
|
||||
* - Rank 1-N: Output suppressed (redirected to /dev/null)
|
||||
*
|
||||
* @param rank MPI rank in MPI_COMM_WORLD
|
||||
* @return Optional RankLogConfig if logging was configured, std::nullopt otherwise
|
||||
*
|
||||
* @note Call before any test output
|
||||
* @note Must call restoreRankLogging() at end to cleanup
|
||||
*/
|
||||
std::optional<RankLogConfig> setupRankLogging(int rank);
|
||||
|
||||
/**
|
||||
* @brief Restore original stdout/stderr after per-rank logging
|
||||
*
|
||||
* Cleans up per-rank logging configuration and restores original
|
||||
* stdout/stderr file descriptors.
|
||||
*
|
||||
* @param config RankLogConfig to cleanup
|
||||
*
|
||||
* @note Safe to call multiple times
|
||||
* @note Flushes pending output before restoration
|
||||
*/
|
||||
void restoreRankLogging(RankLogConfig& config);
|
||||
|
||||
} // namespace MPIHelpers
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // MPI_HELPERS_HPP
|
||||
@@ -0,0 +1,226 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file MPIStandaloneTest.hpp
|
||||
* @brief Standalone (non-GTest) adapter for MPI tests
|
||||
*
|
||||
* Provides infrastructure for writing standalone MPI tests without Google Test.
|
||||
* Ideal for performance benchmarks, low-level API tests, and production utilities.
|
||||
*
|
||||
* @see MPITestCore for the base framework-agnostic functionality
|
||||
* @see MPITestBase for GTest integration
|
||||
*/
|
||||
|
||||
#ifndef MPI_STANDALONE_TEST_HPP
|
||||
#define MPI_STANDALONE_TEST_HPP
|
||||
|
||||
#include "MPITestCore.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
/**
|
||||
* @class MPIStandaloneTest
|
||||
* @brief Standalone test adapter for MPI tests (no GTest dependency)
|
||||
*
|
||||
* Provides a simple base class for standalone MPI tests that don't require
|
||||
* Google Test framework. Useful for:
|
||||
* - Performance benchmarks (bandwidth, latency)
|
||||
* - Low-level API testing
|
||||
* - Production utilities
|
||||
* - Custom test harnesses
|
||||
*
|
||||
* **Key Features:**
|
||||
* - No GTest dependency
|
||||
* - Simple run() interface
|
||||
* - Automatic resource cleanup via RAII
|
||||
* - Same validation and setup as GTest tests
|
||||
* - Return code-based error reporting
|
||||
*
|
||||
* **Usage Pattern:**
|
||||
* @code
|
||||
* class MyBandwidthTest : public MPIStandaloneTest {
|
||||
* public:
|
||||
* int run() override {
|
||||
* // Validate prerequisites
|
||||
* if (!validateTestPrerequisites(2)) {
|
||||
* if (MPIEnvironment::world_rank == 0) {
|
||||
* printf("SKIP: Need at least 2 processes\n");
|
||||
* }
|
||||
* return 0; // Skip (not an error)
|
||||
* }
|
||||
*
|
||||
* // Setup communicator
|
||||
* if (createTestCommunicator() != ncclSuccess) {
|
||||
* if (MPIEnvironment::world_rank == 0) {
|
||||
* fprintf(stderr, "ERROR: Failed to create communicator\n");
|
||||
* }
|
||||
* return 1; // Error
|
||||
* }
|
||||
*
|
||||
* // Run test logic
|
||||
* ncclComm_t comm = getActiveCommunicator();
|
||||
* hipStream_t stream = getActiveStream();
|
||||
*
|
||||
* // Your test code here...
|
||||
*
|
||||
* return 0; // Success
|
||||
* }
|
||||
* };
|
||||
*
|
||||
* int main(int argc, char** argv) {
|
||||
* MPI_Init(&argc, &argv);
|
||||
*
|
||||
* MyBandwidthTest test;
|
||||
* int result = test.run();
|
||||
* test.cleanup(); // Explicit cleanup
|
||||
*
|
||||
* MPI_Finalize();
|
||||
* return result;
|
||||
* }
|
||||
* @endcode
|
||||
*
|
||||
* **RAII Wrapper Alternative:**
|
||||
* @code
|
||||
* int main(int argc, char** argv) {
|
||||
* MPI_Init(&argc, &argv);
|
||||
*
|
||||
* int result = 0;
|
||||
* {
|
||||
* MPIStandaloneTestRAII test;
|
||||
* MyBandwidthTest bandwidth_test;
|
||||
* result = bandwidth_test.run();
|
||||
* // Automatic cleanup when test goes out of scope
|
||||
* }
|
||||
*
|
||||
* MPI_Finalize();
|
||||
* return result;
|
||||
* }
|
||||
* @endcode
|
||||
*
|
||||
* @note For GTest-based tests, use MPITestBase instead
|
||||
*/
|
||||
class MPIStandaloneTest : public MPITestCore
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Virtual destructor for proper cleanup
|
||||
*/
|
||||
virtual ~MPIStandaloneTest() = default;
|
||||
|
||||
/**
|
||||
* @brief Main test execution method - override this
|
||||
*
|
||||
* Override this method to implement your test logic.
|
||||
*
|
||||
* @return 0 for success/skip, non-zero for error
|
||||
*
|
||||
* @par Return Codes:
|
||||
* - 0: Success or test skipped (validation failed)
|
||||
* - 1: Generic error
|
||||
* - Other: Custom error codes
|
||||
*/
|
||||
virtual int run() = 0;
|
||||
|
||||
/**
|
||||
* @brief Explicit cleanup method
|
||||
*
|
||||
* Call this after run() completes to ensure proper resource cleanup.
|
||||
* Alternatively, use MPIStandaloneTestRAII for automatic cleanup.
|
||||
*/
|
||||
void cleanup()
|
||||
{
|
||||
cleanupTestCommunicator();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Setup hook (optional)
|
||||
*
|
||||
* Override this to perform custom setup before run().
|
||||
* Default implementation does nothing.
|
||||
*/
|
||||
void setUp() override
|
||||
{
|
||||
SetUp();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Teardown hook (optional)
|
||||
*
|
||||
* Override this to perform custom cleanup after run().
|
||||
* Default implementation calls cleanupTestCommunicator().
|
||||
*/
|
||||
void tearDown() override
|
||||
{
|
||||
TearDown();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MPIStandaloneTestRAII
|
||||
* @brief RAII wrapper for automatic MPIStandaloneTest cleanup
|
||||
*
|
||||
* Provides scope-based automatic cleanup for MPIStandaloneTest.
|
||||
* Useful for ensuring cleanup even with early returns or exceptions.
|
||||
*
|
||||
* @par Example:
|
||||
* @code
|
||||
* int main(int argc, char** argv) {
|
||||
* MPI_Init(&argc, &argv);
|
||||
*
|
||||
* int result = 0;
|
||||
* {
|
||||
* MPIStandaloneTestRAII raii_wrapper;
|
||||
* MyTest test;
|
||||
* result = test.run();
|
||||
* // Automatic cleanup when raii_wrapper goes out of scope
|
||||
* }
|
||||
*
|
||||
* MPI_Finalize();
|
||||
* return result;
|
||||
* }
|
||||
* @endcode
|
||||
*/
|
||||
class MPIStandaloneTestRAII
|
||||
{
|
||||
private:
|
||||
MPIStandaloneTest* test_ = nullptr;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor - registers test for cleanup
|
||||
* @param test Pointer to test instance (optional)
|
||||
*/
|
||||
explicit MPIStandaloneTestRAII(MPIStandaloneTest* test = nullptr) : test_(test) {}
|
||||
|
||||
/**
|
||||
* @brief Destructor - performs automatic cleanup
|
||||
*/
|
||||
~MPIStandaloneTestRAII()
|
||||
{
|
||||
if(test_)
|
||||
{
|
||||
test_->cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set test instance to manage
|
||||
* @param test Pointer to test instance
|
||||
*/
|
||||
void setTest(MPIStandaloneTest* test)
|
||||
{
|
||||
test_ = test;
|
||||
}
|
||||
|
||||
// Delete copy constructor and assignment operator
|
||||
MPIStandaloneTestRAII(const MPIStandaloneTestRAII&) = delete;
|
||||
MPIStandaloneTestRAII& operator=(const MPIStandaloneTestRAII&) = delete;
|
||||
};
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // MPI_STANDALONE_TEST_HPP
|
||||
@@ -0,0 +1,114 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file MPITestBase.hpp
|
||||
* @brief Base class infrastructure for MPI-based RCCL testing
|
||||
*
|
||||
* Provides a common test base class for writing multi-process distributed tests
|
||||
* using MPI and RCCL. Handles communicator creation, process validation, and
|
||||
* resource cleanup automatically.
|
||||
*
|
||||
* @see MPITestBase for the main base class
|
||||
* @see MPIEnvironment for global MPI setup
|
||||
*/
|
||||
|
||||
#ifndef MPI_TEST_BASE_HPP
|
||||
#define MPI_TEST_BASE_HPP
|
||||
|
||||
#include "MPITestCore.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
#include "MPIEnvironment.hpp"
|
||||
#include "TestChecks.hpp"
|
||||
#include "rccl/rccl.h"
|
||||
#include "utils.h" // For getHostName() from RCCL
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
#include <string>
|
||||
|
||||
/**
|
||||
* @class MPITestBase
|
||||
* @brief Google Test adapter for MPI tests
|
||||
*
|
||||
* Integrates MPITestCore with Google Test framework for seamless MPI testing.
|
||||
* Inherits from both ::testing::Test (for GTest integration) and MPITestCore
|
||||
* (for MPI/RCCL functionality).
|
||||
*
|
||||
* **Features:**
|
||||
* - Process count validation (minimum processes, power-of-two requirements)
|
||||
* - Node count validation (single-node vs multi-node)
|
||||
* - Test-specific RCCL communicator creation and management
|
||||
* - HIP stream management for each test
|
||||
* - Automatic resource cleanup via GTest TearDown
|
||||
*
|
||||
* **Usage Example:**
|
||||
* @code
|
||||
* class MyMPITest : public MPITestBase {};
|
||||
*
|
||||
* TEST_F(MyMPITest, BasicAllReduce) {
|
||||
* if (!validateTestPrerequisites(2)) {
|
||||
* GTEST_SKIP() << "Need at least 2 processes";
|
||||
* }
|
||||
* ASSERT_EQ(ncclSuccess, createTestCommunicator());
|
||||
*
|
||||
* ncclComm_t comm = getActiveCommunicator();
|
||||
* hipStream_t stream = getActiveStream();
|
||||
*
|
||||
* // Your test logic here...
|
||||
* // Cleanup happens automatically in TearDown()
|
||||
* }
|
||||
* @endcode
|
||||
*
|
||||
* @note For standalone tests without GTest, use MPIStandaloneTest instead
|
||||
* @see MPITestCore for the base framework-agnostic functionality
|
||||
* @see MPIEnvironment for global MPI initialization
|
||||
*/
|
||||
/**
|
||||
* @brief Google Test adapter for MPI tests
|
||||
*
|
||||
* Integrates MPITestCore with Google Test framework by inheriting from both
|
||||
* ::testing::Test and MPITestCore.
|
||||
*
|
||||
* @note For standalone tests (without GTest), use MPIStandaloneTest instead
|
||||
*/
|
||||
class MPITestBase
|
||||
: public ::testing::Test
|
||||
, public MPITestCore
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Google Test SetUp hook - initializes test resources
|
||||
*
|
||||
* Automatically called before each test runs. Calls initializeTest()
|
||||
* from MPITestCore for any custom initialization.
|
||||
*
|
||||
* @note No ambiguity with MPITestCore::initializeTest() - different names
|
||||
*/
|
||||
void SetUp() override
|
||||
{
|
||||
initializeTest();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Google Test TearDown hook - ensures cleanup of test resources
|
||||
*
|
||||
* Automatically called after each test completes. Calls cleanupTest()
|
||||
* from MPITestCore to ensure proper resource cleanup.
|
||||
*/
|
||||
void TearDown() override
|
||||
{
|
||||
cleanupTest();
|
||||
}
|
||||
};
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // MPI_TEST_BASE_HPP
|
||||
@@ -0,0 +1,412 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "MPITestCore.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
#include "ResourceGuards.hpp"
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
// Import commonly used guards into local scope
|
||||
using RCCLTestGuards::makeScopeGuard;
|
||||
|
||||
// Detect the number of unique nodes
|
||||
int MPITestConstants::detectNodeCount()
|
||||
{
|
||||
int world_rank = 0;
|
||||
int world_size = 0;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("=== MPI Process Distribution ===");
|
||||
TEST_INFO("Total processes: %d", world_size);
|
||||
}
|
||||
|
||||
// Special case: single process is always single node
|
||||
if(world_size == 1)
|
||||
{
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Detected nodes: 1");
|
||||
TEST_INFO("================================");
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Use MPI_COMM_TYPE_SHARED to split by node
|
||||
MPI_Comm node_comm;
|
||||
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &node_comm);
|
||||
|
||||
int node_rank = 0;
|
||||
int node_size = 0;
|
||||
MPI_Comm_rank(node_comm, &node_rank);
|
||||
MPI_Comm_size(node_comm, &node_size);
|
||||
|
||||
// Gather node sizes to rank 0
|
||||
std::vector<int> all_node_sizes;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
all_node_sizes.resize(world_size);
|
||||
}
|
||||
MPI_Gather(&node_size, 1, MPI_INT, all_node_sizes.data(), 1, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
|
||||
// Rank 0 analyzes distribution
|
||||
int num_nodes = 0;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
std::vector<int> node_counts; // ranks per node
|
||||
std::vector<int> node_first_rank; // first rank on each node
|
||||
|
||||
for(int r = 0; r < world_size; r++)
|
||||
{
|
||||
bool found = false;
|
||||
for(size_t n = 0; n < node_counts.size(); n++)
|
||||
{
|
||||
// Same node if same node_size and rank is within that node
|
||||
if(all_node_sizes[r] == all_node_sizes[node_first_rank[n]])
|
||||
{
|
||||
// Check if this rank belongs to this node group
|
||||
int local_rank = r - node_first_rank[n];
|
||||
if(local_rank >= 0 && local_rank < node_counts[n])
|
||||
{
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!found)
|
||||
{
|
||||
node_first_rank.push_back(r);
|
||||
node_counts.push_back(all_node_sizes[r]);
|
||||
}
|
||||
}
|
||||
|
||||
num_nodes = static_cast<int>(node_counts.size());
|
||||
|
||||
TEST_INFO("Detected nodes: %d", num_nodes);
|
||||
TEST_INFO("");
|
||||
|
||||
// Get hostnames for display
|
||||
char hostname[MPI_MAX_PROCESSOR_NAME];
|
||||
int hostname_len;
|
||||
MPI_Get_processor_name(hostname, &hostname_len);
|
||||
|
||||
for(size_t n = 0; n < node_counts.size(); n++)
|
||||
{
|
||||
int first_rank = node_first_rank[n];
|
||||
TEST_INFO("Node %zu: %d rank(s) starting at rank %d", n, node_counts[n], first_rank);
|
||||
|
||||
// Print ranks on this node - build string first for TEST_INFO
|
||||
std::string ranks_str = " Ranks: ";
|
||||
for(int r = first_rank; r < first_rank + node_counts[n]; r++)
|
||||
{
|
||||
ranks_str += std::to_string(r);
|
||||
if(r < first_rank + node_counts[n] - 1)
|
||||
ranks_str += ", ";
|
||||
}
|
||||
TEST_INFO("%s", ranks_str.c_str());
|
||||
}
|
||||
TEST_INFO("================================");
|
||||
}
|
||||
|
||||
// Broadcast node count to all ranks
|
||||
MPI_Bcast(&num_nodes, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
|
||||
MPI_Comm_free(&node_comm);
|
||||
|
||||
return num_nodes;
|
||||
}
|
||||
|
||||
// Validate test prerequisites
|
||||
bool MPITestCore::validateTestPrerequisites(
|
||||
int min_processes, int max_processes, bool require_power_of_two, int min_nodes, int max_nodes)
|
||||
{
|
||||
int world_rank = MPIEnvironment::world_rank;
|
||||
int world_size = MPIEnvironment::world_size;
|
||||
|
||||
// Always detect nodes and display process distribution
|
||||
// This provides valuable information for all tests
|
||||
int actual_nodes = MPITestConstants::detectNodeCount();
|
||||
|
||||
bool validation_passed = true;
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("=== Test Requirements ===");
|
||||
if(min_processes == max_processes)
|
||||
{
|
||||
TEST_INFO("Processes: exactly %d", min_processes);
|
||||
}
|
||||
else if(max_processes == MPITestConstants::kNoProcessLimit)
|
||||
{
|
||||
TEST_INFO("Processes: at least %d", min_processes);
|
||||
}
|
||||
else
|
||||
{
|
||||
TEST_INFO("Processes: between %d and %d", min_processes, max_processes);
|
||||
}
|
||||
|
||||
if(require_power_of_two)
|
||||
{
|
||||
TEST_INFO("Power-of-two: required");
|
||||
}
|
||||
|
||||
if(min_nodes > 1 || max_nodes > 0)
|
||||
{
|
||||
if(min_nodes == max_nodes)
|
||||
{
|
||||
TEST_INFO("Nodes: exactly %d", min_nodes);
|
||||
}
|
||||
else if(max_nodes == MPITestConstants::kNoNodeLimit)
|
||||
{
|
||||
TEST_INFO("Nodes: at least %d", min_nodes);
|
||||
}
|
||||
else
|
||||
{
|
||||
TEST_INFO("Nodes: between %d and %d", min_nodes, max_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_INFO("");
|
||||
TEST_INFO("=== Current Environment ===");
|
||||
TEST_INFO("Processes: %d", world_size);
|
||||
TEST_INFO("Nodes: %d", actual_nodes);
|
||||
if(require_power_of_two)
|
||||
{
|
||||
TEST_INFO("Power-of-two: %s",
|
||||
MPITestConstants::isPowerOfTwo(world_size) ? "yes" : "no");
|
||||
}
|
||||
TEST_INFO("===========================");
|
||||
TEST_INFO("");
|
||||
}
|
||||
|
||||
// Validate process count
|
||||
if(world_size < min_processes)
|
||||
{
|
||||
validation_passed = false;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
printf("Error: REQUIREMENT NOT MET: Need at least %d processes, got %d\n",
|
||||
min_processes,
|
||||
world_size);
|
||||
printf(" For test details, set: NCCL_DEBUG=INFO\n");
|
||||
}
|
||||
}
|
||||
|
||||
if(max_processes != MPITestConstants::kNoProcessLimit && world_size > max_processes)
|
||||
{
|
||||
validation_passed = false;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
printf("Error: REQUIREMENT NOT MET: Need at most %d processes, got %d\n",
|
||||
max_processes,
|
||||
world_size);
|
||||
printf(" For test details, set: NCCL_DEBUG=INFO\n");
|
||||
}
|
||||
}
|
||||
|
||||
if(require_power_of_two && !MPITestConstants::isPowerOfTwo(world_size))
|
||||
{
|
||||
validation_passed = false;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
printf("Error: REQUIREMENT NOT MET: Need power-of-two processes, got %d (not power of "
|
||||
"2)\n",
|
||||
world_size);
|
||||
printf(" For test details, set: NCCL_DEBUG=INFO\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Validate node count
|
||||
if(min_nodes > 1 || max_nodes > 0)
|
||||
{
|
||||
if(actual_nodes < min_nodes)
|
||||
{
|
||||
validation_passed = false;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
printf("Error: REQUIREMENT NOT MET: Need at least %d node(s), detected %d nodes\n",
|
||||
min_nodes,
|
||||
actual_nodes);
|
||||
printf(" For test details, set: NCCL_DEBUG=INFO\n");
|
||||
}
|
||||
}
|
||||
|
||||
if(max_nodes != MPITestConstants::kNoNodeLimit && actual_nodes > max_nodes)
|
||||
{
|
||||
validation_passed = false;
|
||||
if(world_rank == 0)
|
||||
{
|
||||
printf("Error: REQUIREMENT NOT MET: Need at most %d node(s), detected %d nodes\n",
|
||||
max_nodes,
|
||||
actual_nodes);
|
||||
printf(" For test details, set: NCCL_DEBUG=INFO\n");
|
||||
if(max_nodes == 1)
|
||||
{
|
||||
printf(" This test requires single-node execution\n");
|
||||
printf(" To run on single node, allocate all processes on the same host\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
if(validation_passed)
|
||||
{
|
||||
TEST_INFO("All requirements met - test will run");
|
||||
}
|
||||
else
|
||||
{
|
||||
TEST_INFO("===========================");
|
||||
TEST_INFO("");
|
||||
}
|
||||
}
|
||||
|
||||
return validation_passed;
|
||||
}
|
||||
|
||||
// Create test communicator
|
||||
ncclResult_t MPITestCore::createTestCommunicator()
|
||||
{
|
||||
int world_rank = MPIEnvironment::world_rank;
|
||||
int world_size = MPIEnvironment::world_size;
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Creating test-specific communicator");
|
||||
}
|
||||
|
||||
// Rank 0 generates unique ID
|
||||
if(world_rank == 0)
|
||||
{
|
||||
RCCL_TEST_CHECK(ncclGetUniqueId(&nccl_id_));
|
||||
}
|
||||
|
||||
// Broadcast ID to all ranks
|
||||
MPI_Bcast(&nccl_id_, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
// Initialize NCCL communicator with automatic cleanup on error
|
||||
RCCL_TEST_CHECK(ncclGroupStart());
|
||||
|
||||
// RAII guard: Automatically calls ncclGroupEnd() if subsequent operations fail
|
||||
auto group_guard = makeScopeGuard([]() { (void)ncclGroupEnd(); });
|
||||
|
||||
RCCL_TEST_CHECK(ncclCommInitRank(&test_comm_, world_size, nccl_id_, world_rank));
|
||||
|
||||
// RAII guard: Automatically destroys test_comm_ if subsequent operations fail
|
||||
auto comm_guard = makeScopeGuard(
|
||||
[this]()
|
||||
{
|
||||
if(test_comm_)
|
||||
{
|
||||
(void)ncclCommDestroy(test_comm_);
|
||||
test_comm_ = nullptr;
|
||||
}
|
||||
});
|
||||
|
||||
RCCL_TEST_CHECK(ncclGroupEnd());
|
||||
group_guard.dismiss(); // ncclGroupEnd succeeded, don't call it again
|
||||
|
||||
// Create HIP stream - if this fails, comm_guard automatically cleans up test_comm_
|
||||
HIP_TEST_CHECK(hipStreamCreate(&test_stream_));
|
||||
|
||||
// RAII guard: Automatically destroys test_stream_ if subsequent operations fail
|
||||
auto stream_guard = makeScopeGuard(
|
||||
[this]()
|
||||
{
|
||||
if(test_stream_)
|
||||
{
|
||||
(void)hipStreamDestroy(test_stream_);
|
||||
test_stream_ = nullptr;
|
||||
}
|
||||
});
|
||||
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
if(world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Test-specific communicator created successfully");
|
||||
}
|
||||
|
||||
// Success! Dismiss guards so resources aren't destroyed
|
||||
comm_guard.dismiss();
|
||||
stream_guard.dismiss();
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Get active communicator
|
||||
ncclComm_t MPITestCore::getActiveCommunicator()
|
||||
{
|
||||
return test_comm_;
|
||||
}
|
||||
|
||||
// Get active stream
|
||||
hipStream_t MPITestCore::getActiveStream()
|
||||
{
|
||||
return test_stream_;
|
||||
}
|
||||
|
||||
// Cleanup test communicator
|
||||
ncclResult_t MPITestCore::cleanupTestCommunicator()
|
||||
{
|
||||
if(!test_comm_ && !test_stream_)
|
||||
{
|
||||
return ncclSuccess; // Already cleaned up or never created
|
||||
}
|
||||
|
||||
int world_rank = MPIEnvironment::world_rank;
|
||||
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
// RAII guard: Ensure test_comm_ is destroyed even if stream cleanup fails
|
||||
auto comm_guard = makeScopeGuard(
|
||||
[this, world_rank]()
|
||||
{
|
||||
if(test_comm_)
|
||||
{
|
||||
ncclResult_t result = ncclCommDestroy(test_comm_);
|
||||
if(result != ncclSuccess)
|
||||
{
|
||||
TEST_WARN("Rank %d: ncclCommDestroy failed during cleanup: %s",
|
||||
world_rank,
|
||||
ncclGetErrorString(result));
|
||||
}
|
||||
test_comm_ = nullptr;
|
||||
}
|
||||
});
|
||||
|
||||
// RAII guard: Ensure test_stream_ is destroyed
|
||||
auto stream_guard = makeScopeGuard(
|
||||
[this, world_rank]()
|
||||
{
|
||||
if(test_stream_)
|
||||
{
|
||||
hipError_t hip_result = hipStreamDestroy(test_stream_);
|
||||
if(hip_result != hipSuccess)
|
||||
{
|
||||
TEST_WARN("Rank %d: hipStreamDestroy failed during cleanup: %s",
|
||||
world_rank,
|
||||
hipGetErrorString(hip_result));
|
||||
}
|
||||
test_stream_ = nullptr;
|
||||
}
|
||||
});
|
||||
|
||||
// Guards will automatically clean up when going out of scope
|
||||
// Even if an exception were thrown (though we don't use exceptions)
|
||||
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
@@ -0,0 +1,260 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file MPITestCore.hpp
|
||||
* @brief Framework-agnostic MPI test infrastructure
|
||||
*
|
||||
* Provides core MPI test functionality independent of any testing framework.
|
||||
* Can be used with Google Test, standalone tests, performance benchmarks, etc.
|
||||
*
|
||||
* @see MPITestBase for GTest integration
|
||||
* @see MPIStandaloneTest for standalone usage
|
||||
*/
|
||||
|
||||
#ifndef MPI_TEST_CORE_HPP
|
||||
#define MPI_TEST_CORE_HPP
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
#include "MPIEnvironment.hpp"
|
||||
#include "rccl/rccl.h"
|
||||
#include "utils.h" // For getHostName() from RCCL
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
#include <string>
|
||||
|
||||
/**
|
||||
* @namespace MPITestConstants
|
||||
* @brief Constants and helper functions for MPI test configuration
|
||||
*/
|
||||
namespace MPITestConstants
|
||||
{
|
||||
/**
|
||||
* @brief Minimum number of processes typically required for MPI tests
|
||||
*/
|
||||
constexpr int kMinProcessesForMPI = 2;
|
||||
|
||||
/**
|
||||
* @brief Flag to indicate power-of-two process count is required
|
||||
*/
|
||||
constexpr bool kRequirePowerOfTwo = true;
|
||||
|
||||
/**
|
||||
* @brief Flag to indicate power-of-two process count is not required
|
||||
*/
|
||||
constexpr bool kNoPowerOfTwoRequired = false;
|
||||
|
||||
/**
|
||||
* @brief Value indicating no upper limit on process count
|
||||
*/
|
||||
constexpr int kNoProcessLimit = 0;
|
||||
|
||||
/**
|
||||
* @brief Value indicating single-node only execution required
|
||||
*/
|
||||
constexpr int kRequireSingleNode = 1;
|
||||
|
||||
/**
|
||||
* @brief Value indicating no node limit (multi-node capable)
|
||||
*/
|
||||
constexpr int kNoNodeLimit = 0;
|
||||
|
||||
/**
|
||||
* @brief Check if a number is a power of two
|
||||
* @param n The number to check
|
||||
* @return true if n is a power of two, false otherwise
|
||||
*/
|
||||
inline bool isPowerOfTwo(int n)
|
||||
{
|
||||
return n > 0 && (n & (n - 1)) == 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Detect the number of unique nodes in the MPI configuration
|
||||
* @return Number of unique nodes
|
||||
*/
|
||||
int detectNodeCount();
|
||||
|
||||
} // namespace MPITestConstants
|
||||
|
||||
/**
|
||||
* @class MPITestCore
|
||||
* @brief Framework-agnostic base class for MPI tests
|
||||
*
|
||||
* Provides core MPI test infrastructure without dependency on any testing framework.
|
||||
* Supports both GTest-based tests (via MPITestBase) and standalone tests.
|
||||
*
|
||||
* **Key Features:**
|
||||
* - Framework-agnostic design
|
||||
* - Automatic RCCL communicator management
|
||||
* - Process and node count validation
|
||||
* - HIP stream lifecycle management
|
||||
* - Clean resource cleanup
|
||||
*
|
||||
* **Usage:**
|
||||
* - For GTest: Use MPITestBase (inherits from MPITestCore)
|
||||
* - For Standalone: Use MPIStandaloneTest (inherits from MPITestCore)
|
||||
* - For Custom: Inherit from MPITestCore directly
|
||||
*
|
||||
* @par Example (Standalone):
|
||||
* @code
|
||||
* class MyPerfTest : public MPITestCore {
|
||||
* public:
|
||||
* int run() {
|
||||
* if (!validateTestPrerequisites(2)) {
|
||||
* return 1; // Skip
|
||||
* }
|
||||
* if (createTestCommunicator() != ncclSuccess) {
|
||||
* return 1; // Error
|
||||
* }
|
||||
* // Test logic...
|
||||
* return 0; // Success
|
||||
* }
|
||||
* };
|
||||
* @endcode
|
||||
*/
|
||||
class MPITestCore
|
||||
{
|
||||
protected:
|
||||
/**
|
||||
* @brief Test-specific NCCL communicator handle
|
||||
*
|
||||
* Created by createTestCommunicator(), destroyed in cleanup.
|
||||
* Access via getActiveCommunicator().
|
||||
*/
|
||||
ncclComm_t test_comm_ = nullptr;
|
||||
|
||||
/**
|
||||
* @brief Test-specific HIP stream handle
|
||||
*
|
||||
* Created with the communicator, destroyed in cleanup.
|
||||
* Access via getActiveStream().
|
||||
*/
|
||||
hipStream_t test_stream_ = nullptr;
|
||||
|
||||
/**
|
||||
* @brief NCCL unique ID for communicator initialization
|
||||
*
|
||||
* Generated on rank 0 and broadcast to all ranks.
|
||||
*/
|
||||
ncclUniqueId nccl_id_ = {};
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Virtual destructor for proper cleanup
|
||||
*/
|
||||
virtual ~MPITestCore() = default;
|
||||
|
||||
/**
|
||||
* @brief Validate test prerequisites (process count, node count)
|
||||
*
|
||||
* Checks if the current MPI environment meets the test's requirements.
|
||||
* Displays what the test requires and whether the environment satisfies those requirements.
|
||||
* Returns true if all requirements met, false otherwise.
|
||||
*
|
||||
* Parameters are organized by category:
|
||||
* - Process requirements: min_processes, max_processes, require_power_of_two
|
||||
* - Node requirements: min_nodes, max_nodes
|
||||
*
|
||||
* @param min_processes Minimum number of MPI processes required (default: 1)
|
||||
* @param max_processes Maximum number of MPI processes allowed (0 = no limit) (default: 0)
|
||||
* @param require_power_of_two If true, world size must be a power of 2 (default: false)
|
||||
* @param min_nodes Minimum number of nodes required (default: 1)
|
||||
* @param max_nodes Maximum number of nodes allowed (0 = no limit) (default: 0)
|
||||
*
|
||||
* @return true if all requirements are met, false otherwise
|
||||
*/
|
||||
bool validateTestPrerequisites(int min_processes = 1,
|
||||
int max_processes = MPITestConstants::kNoProcessLimit,
|
||||
bool require_power_of_two = false,
|
||||
int min_nodes = 1,
|
||||
int max_nodes = MPITestConstants::kNoNodeLimit);
|
||||
|
||||
/**
|
||||
* @brief Create a test-specific RCCL communicator and HIP stream
|
||||
*
|
||||
* Creates isolated RCCL communicator and HIP stream for this test.
|
||||
* Uses ncclGroupStart/End for proper initialization and MPI barriers
|
||||
* for synchronization across all ranks.
|
||||
*
|
||||
* @return ncclSuccess on success, or NCCL error code on failure
|
||||
*
|
||||
* @note This function is idempotent - calling it multiple times is safe
|
||||
* @note Communicator is automatically destroyed in cleanup
|
||||
*/
|
||||
virtual ncclResult_t createTestCommunicator();
|
||||
|
||||
/**
|
||||
* @brief Get the active NCCL communicator for this test
|
||||
*
|
||||
* Returns the test-specific communicator. Returns nullptr if createTestCommunicator()
|
||||
* has not been called first.
|
||||
*
|
||||
* @return The active NCCL communicator handle, or nullptr if not created
|
||||
*
|
||||
* @note Always call createTestCommunicator() before this method
|
||||
*/
|
||||
virtual ncclComm_t getActiveCommunicator();
|
||||
|
||||
/**
|
||||
* @brief Get the active HIP stream for this test
|
||||
*
|
||||
* Returns the test-specific HIP stream. Returns nullptr if createTestCommunicator()
|
||||
* has not been called first.
|
||||
*
|
||||
* @return The active HIP stream handle, or nullptr if not created
|
||||
*
|
||||
* @note Always call createTestCommunicator() before this method
|
||||
*/
|
||||
virtual hipStream_t getActiveStream();
|
||||
|
||||
/**
|
||||
* @brief Cleanup test-specific NCCL communicator and HIP stream
|
||||
*
|
||||
* Destroys the test communicator and stream with proper MPI synchronization.
|
||||
* Safe to call multiple times or if resources were never created.
|
||||
*
|
||||
* @return ncclResult_t - ncclSuccess on success, error code on failure
|
||||
* Returns ncclUnhandledCudaError if HIP cleanup fails
|
||||
*
|
||||
* @note For GTest: This is automatically called by cleanupTest()
|
||||
* @note For Standalone: Call this explicitly or use RAII wrapper
|
||||
* @note Errors are logged but cleanup continues for all resources
|
||||
*/
|
||||
virtual ncclResult_t cleanupTestCommunicator();
|
||||
|
||||
/**
|
||||
* @brief Initialize test resources before test execution
|
||||
*
|
||||
* Override this to perform custom initialization. Default implementation does nothing.
|
||||
* This method is framework-agnostic and not tied to GTest's lifecycle.
|
||||
* For standalone tests, call this explicitly if needed.
|
||||
*/
|
||||
virtual void initializeTest() {}
|
||||
|
||||
/**
|
||||
* @brief Cleanup test resources after test execution
|
||||
*
|
||||
* Override this to perform custom cleanup. Default implementation calls
|
||||
* cleanupTestCommunicator() to destroy NCCL communicator and HIP stream.
|
||||
* This method is framework-agnostic and not tied to GTest's lifecycle.
|
||||
* For standalone tests, call this explicitly if needed.
|
||||
*
|
||||
* @note For GTest tests: Errors are logged but don't fail the test (cleanup phase)
|
||||
* @note For standalone tests: Check return value and handle errors appropriately
|
||||
*/
|
||||
virtual void cleanupTest()
|
||||
{
|
||||
(void)cleanupTestCommunicator();
|
||||
}
|
||||
};
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // MPI_TEST_CORE_HPP
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,455 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "nccl.h"
|
||||
#include "net.h"
|
||||
#include "transport.h"
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <utility>
|
||||
|
||||
/**
|
||||
* @file ResourceGuards.hpp
|
||||
* @brief Comprehensive RAII resource guards for automatic cleanup in tests
|
||||
*
|
||||
* Provides all RAII guard types for automatic resource management:
|
||||
* - ScopeGuard: Generic cleanup for any action (with lambdas)
|
||||
* - AutoGuard: Typed guards for resources with simple cleanup functions
|
||||
* - ResourceGuard: Typed guards for resources with stateful deleters
|
||||
* - Specialized guards: NcclRegHandleGuard, etc.
|
||||
*
|
||||
* Guards ensure cleanup even when ASSERT_* fails in tests.
|
||||
* See MPITestRunner.md for detailed usage documentation.
|
||||
*/
|
||||
|
||||
namespace RCCLTestGuards
|
||||
{
|
||||
|
||||
// ============================================================================
|
||||
// ScopeGuard - Generic cleanup for arbitrary actions
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @class ScopeGuard
|
||||
* @brief Generic RAII scope guard for custom cleanup logic
|
||||
*
|
||||
* Executes a cleanup function on scope exit (normal return, early return, or exception).
|
||||
* Useful for resources that don't have dedicated RAII guards or for one-off cleanup needs.
|
||||
*
|
||||
* @par Example:
|
||||
* @code
|
||||
* void* buffer = nullptr;
|
||||
* hipMalloc(&buffer, size);
|
||||
* auto guard = makeScopeGuard([&]() { if(buffer) hipFree(buffer); });
|
||||
* // Automatic cleanup on scope exit
|
||||
* @endcode
|
||||
*
|
||||
* @tparam Func Callable type (lambda, function pointer, functor)
|
||||
*/
|
||||
template<typename Func>
|
||||
class ScopeGuard
|
||||
{
|
||||
Func cleanup_; ///< Cleanup function to execute on scope exit
|
||||
bool dismissed_; ///< If true, skip cleanup (for ownership transfer)
|
||||
|
||||
public:
|
||||
explicit ScopeGuard(Func f) noexcept : cleanup_(std::move(f)), dismissed_(false) {}
|
||||
|
||||
~ScopeGuard() noexcept
|
||||
{
|
||||
if(!dismissed_)
|
||||
{
|
||||
cleanup_();
|
||||
}
|
||||
}
|
||||
|
||||
void dismiss() noexcept { dismissed_ = true; }
|
||||
void restore() noexcept { dismissed_ = false; }
|
||||
|
||||
ScopeGuard(ScopeGuard&& other) noexcept
|
||||
: cleanup_(std::move(other.cleanup_)), dismissed_(other.dismissed_)
|
||||
{
|
||||
other.dismissed_ = true;
|
||||
}
|
||||
|
||||
ScopeGuard& operator=(ScopeGuard&& other) noexcept
|
||||
{
|
||||
if(this != &other)
|
||||
{
|
||||
if(!dismissed_)
|
||||
{
|
||||
cleanup_();
|
||||
}
|
||||
cleanup_ = std::move(other.cleanup_);
|
||||
dismissed_ = other.dismissed_;
|
||||
other.dismissed_ = true;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
ScopeGuard(const ScopeGuard&) = delete;
|
||||
ScopeGuard& operator=(const ScopeGuard&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Factory function to create ScopeGuard with type deduction
|
||||
*
|
||||
* @par Example:
|
||||
* @code
|
||||
* auto guard = makeScopeGuard([&]() { cleanup(); });
|
||||
* @endcode
|
||||
*/
|
||||
template<typename Func>
|
||||
ScopeGuard<Func> makeScopeGuard(Func f)
|
||||
{
|
||||
return ScopeGuard<Func>(std::move(f));
|
||||
}
|
||||
|
||||
/**
|
||||
* @def SCOPE_EXIT
|
||||
* @brief Convenience macro for creating anonymous scope guards
|
||||
*
|
||||
* @par Example:
|
||||
* @code
|
||||
* void* buffer = nullptr;
|
||||
* hipMalloc(&buffer, size);
|
||||
* SCOPE_EXIT(if(buffer) hipFree(buffer));
|
||||
* @endcode
|
||||
*/
|
||||
#define SCOPE_EXIT_CONCAT_IMPL(a, b) a##b
|
||||
#define SCOPE_EXIT_CONCAT(a, b) SCOPE_EXIT_CONCAT_IMPL(a, b)
|
||||
#define SCOPE_EXIT(code) \
|
||||
auto SCOPE_EXIT_CONCAT(scope_guard_, __LINE__) = RCCLTestGuards::makeScopeGuard([&]() { code; })
|
||||
|
||||
// ============================================================================
|
||||
// AutoGuard & ResourceGuard - Typed resource management
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @class AutoGuard
|
||||
* @brief Modern RAII guard using non-type template parameter for deleter
|
||||
*
|
||||
* Uses C++17's auto template parameters to directly reference cleanup functions,
|
||||
* eliminating the need for deleter functors in simple cases.
|
||||
*
|
||||
* @tparam T Resource handle type
|
||||
* @tparam DeleterFunc Function pointer for cleanup (auto-deduced)
|
||||
*/
|
||||
template<typename T, auto DeleterFunc>
|
||||
class AutoGuard
|
||||
{
|
||||
private:
|
||||
T resource_;
|
||||
bool dismissed_;
|
||||
|
||||
public:
|
||||
explicit AutoGuard(T resource = T{}) : resource_(resource), dismissed_(false) {}
|
||||
|
||||
~AutoGuard()
|
||||
{
|
||||
if(!dismissed_ && resource_)
|
||||
{
|
||||
DeleterFunc(resource_);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the resource handle
|
||||
T get() const
|
||||
{
|
||||
return resource_;
|
||||
}
|
||||
// Get pointer to resource handle (for API calls)
|
||||
T* ptr()
|
||||
{
|
||||
return &resource_;
|
||||
}
|
||||
// Set the resource handle
|
||||
void set(T resource)
|
||||
{
|
||||
resource_ = resource;
|
||||
}
|
||||
// Dismiss the guard (prevent cleanup)
|
||||
void dismiss()
|
||||
{
|
||||
dismissed_ = true;
|
||||
}
|
||||
|
||||
// Release ownership (prevent cleanup)
|
||||
T release()
|
||||
{
|
||||
dismissed_ = true;
|
||||
return resource_;
|
||||
}
|
||||
|
||||
AutoGuard(const AutoGuard&) = delete;
|
||||
AutoGuard& operator=(const AutoGuard&) = delete;
|
||||
|
||||
AutoGuard(AutoGuard&& other) noexcept : resource_(other.resource_), dismissed_(other.dismissed_)
|
||||
{
|
||||
other.dismissed_ = true;
|
||||
}
|
||||
|
||||
AutoGuard& operator=(AutoGuard&& other) noexcept
|
||||
{
|
||||
if(this != &other)
|
||||
{
|
||||
if(!dismissed_ && resource_)
|
||||
{
|
||||
DeleterFunc(resource_);
|
||||
}
|
||||
resource_ = other.resource_;
|
||||
dismissed_ = other.dismissed_;
|
||||
other.dismissed_ = true;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class ResourceGuard
|
||||
* @brief Generic RAII guard template for resources with complex cleanup
|
||||
*
|
||||
* Uses a functor-based deleter for stateful deleters requiring additional context.
|
||||
* For simple cleanup functions, prefer AutoGuard<T, func> instead.
|
||||
*
|
||||
* @tparam T Resource handle type
|
||||
* @tparam Deleter Functor type for cleanup
|
||||
*/
|
||||
template<typename T, typename Deleter>
|
||||
class ResourceGuard
|
||||
{
|
||||
private:
|
||||
T resource_;
|
||||
Deleter deleter_;
|
||||
bool owns_;
|
||||
|
||||
public:
|
||||
// Construct a resource guard
|
||||
// @param resource Resource handle (can be nullptr/0)
|
||||
// @param deleter Cleanup function/functor
|
||||
explicit ResourceGuard(T resource = T{}, Deleter deleter = Deleter{})
|
||||
: resource_(resource), deleter_(std::move(deleter)), owns_(true)
|
||||
{}
|
||||
|
||||
// Destructor - automatically cleans up resource
|
||||
~ResourceGuard()
|
||||
{
|
||||
if(owns_ && resource_)
|
||||
{
|
||||
deleter_(resource_);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the resource handle
|
||||
T get() const
|
||||
{
|
||||
return resource_;
|
||||
}
|
||||
// Get pointer to resource handle (for API calls)
|
||||
T* ptr()
|
||||
{
|
||||
return &resource_;
|
||||
}
|
||||
// Set the resource handle
|
||||
void set(T resource)
|
||||
{
|
||||
resource_ = resource;
|
||||
}
|
||||
|
||||
// Reset the resource handle
|
||||
// @param resource New resource handle (can be nullptr/0)
|
||||
void reset(T resource = T{})
|
||||
{
|
||||
if(owns_ && resource_ && resource_ != resource)
|
||||
{
|
||||
deleter_(resource_);
|
||||
}
|
||||
resource_ = resource;
|
||||
owns_ = true;
|
||||
}
|
||||
|
||||
T release()
|
||||
{
|
||||
owns_ = false;
|
||||
return resource_;
|
||||
}
|
||||
|
||||
ResourceGuard(const ResourceGuard&) = delete;
|
||||
ResourceGuard& operator=(const ResourceGuard&) = delete;
|
||||
|
||||
ResourceGuard(ResourceGuard&& other) noexcept
|
||||
: resource_(other.resource_), deleter_(std::move(other.deleter_)), owns_(other.owns_)
|
||||
{
|
||||
other.owns_ = false;
|
||||
}
|
||||
|
||||
ResourceGuard& operator=(ResourceGuard&& other) noexcept
|
||||
{
|
||||
if(this != &other)
|
||||
{
|
||||
// Clean up current resource
|
||||
if(owns_ && resource_)
|
||||
{
|
||||
deleter_(resource_);
|
||||
}
|
||||
// Take ownership of other's resource
|
||||
resource_ = other.resource_;
|
||||
deleter_ = std::move(other.deleter_);
|
||||
owns_ = other.owns_;
|
||||
other.owns_ = false;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
// Note: Simple stateless deleters are replaced by wrapper functions + AutoGuard.
|
||||
// Only stateful deleters that need additional context are kept here.
|
||||
// Common deleters (NCCL-specific, used across many tests)
|
||||
struct NcclRegHandleDeleter
|
||||
{
|
||||
ncclComm_t comm;
|
||||
explicit NcclRegHandleDeleter(ncclComm_t c = nullptr) : comm(c) {}
|
||||
void operator()(void* reg_handle) const
|
||||
{
|
||||
if(reg_handle && comm)
|
||||
{
|
||||
ncclCommDeregister(comm, reg_handle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Wrapper functions for AutoGuard (void-returning cleanup functions)
|
||||
inline void hipFreeWrapper(void* ptr)
|
||||
{
|
||||
if(ptr)
|
||||
{
|
||||
hipError_t err = hipFree(ptr);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
fprintf(stderr,
|
||||
"WARNING: hipFree failed in destructor: %s (ptr=%p)\n",
|
||||
hipGetErrorString(err),
|
||||
ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void hipStreamDestroyWrapper(hipStream_t stream)
|
||||
{
|
||||
if(stream)
|
||||
{
|
||||
hipError_t err = hipStreamDestroy(stream);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
fprintf(stderr,
|
||||
"WARNING: hipStreamDestroy failed in destructor: %s (stream=%p)\n",
|
||||
hipGetErrorString(err),
|
||||
static_cast<void*>(stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void hipEventDestroyWrapper(hipEvent_t event)
|
||||
{
|
||||
if(event)
|
||||
{
|
||||
hipError_t err = hipEventDestroy(event);
|
||||
if(err != hipSuccess)
|
||||
{
|
||||
fprintf(stderr,
|
||||
"WARNING: hipEventDestroy failed in destructor: %s (event=%p)\n",
|
||||
hipGetErrorString(err),
|
||||
static_cast<void*>(event));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ncclCommDestroyWrapper(ncclComm_t comm)
|
||||
{
|
||||
if(comm)
|
||||
{
|
||||
ncclResult_t result = ncclCommDestroy(comm);
|
||||
if(result != ncclSuccess)
|
||||
{
|
||||
fprintf(stderr,
|
||||
"WARNING: ncclCommDestroy failed in destructor: %s (comm=%p)\n",
|
||||
ncclGetErrorString(result),
|
||||
static_cast<void*>(comm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void freeWrapper(void* ptr)
|
||||
{
|
||||
if(ptr)
|
||||
free(ptr);
|
||||
}
|
||||
|
||||
// Type aliases for AutoGuard-based guards
|
||||
using HostBufferAutoGuard = AutoGuard<void*, freeWrapper>;
|
||||
using DeviceBufferAutoGuard = AutoGuard<void*, hipFreeWrapper>;
|
||||
using HipStreamAutoGuard = AutoGuard<hipStream_t, hipStreamDestroyWrapper>;
|
||||
using HipEventAutoGuard = AutoGuard<hipEvent_t, hipEventDestroyWrapper>;
|
||||
using NcclCommAutoGuard = AutoGuard<ncclComm_t, ncclCommDestroyWrapper>;
|
||||
|
||||
// Type aliases for ResourceGuard-based guards (common/NCCL-specific)
|
||||
using NcclRegHandleGuard = ResourceGuard<void*, NcclRegHandleDeleter>;
|
||||
|
||||
// Factory methods for ResourceGuard
|
||||
template<typename T, typename Deleter>
|
||||
inline auto makeGuard(T resource, Deleter deleter) -> ResourceGuard<T, Deleter>
|
||||
{
|
||||
return ResourceGuard<T, Deleter>(resource, std::move(deleter));
|
||||
}
|
||||
|
||||
inline NcclRegHandleGuard makeRegHandleGuard(void* handle, ncclComm_t comm)
|
||||
{
|
||||
return NcclRegHandleGuard(handle, NcclRegHandleDeleter(comm));
|
||||
}
|
||||
|
||||
template<typename T, typename Deleter>
|
||||
inline auto makeCustomGuard(T resource, Deleter deleter) -> ResourceGuard<T, Deleter>
|
||||
{
|
||||
return ResourceGuard<T, Deleter>(resource, std::move(deleter));
|
||||
}
|
||||
|
||||
// Factory methods for AutoGuard
|
||||
template<typename T, auto DeleterFunc>
|
||||
inline AutoGuard<T, DeleterFunc> makeAutoGuard(T resource)
|
||||
{
|
||||
return AutoGuard<T, DeleterFunc>(resource);
|
||||
}
|
||||
|
||||
inline HostBufferAutoGuard makeHostBufferAutoGuard(void* buffer)
|
||||
{
|
||||
return HostBufferAutoGuard(buffer);
|
||||
}
|
||||
|
||||
inline DeviceBufferAutoGuard makeDeviceBufferAutoGuard(void* buffer)
|
||||
{
|
||||
return DeviceBufferAutoGuard(buffer);
|
||||
}
|
||||
|
||||
inline HipStreamAutoGuard makeStreamAutoGuard(hipStream_t stream)
|
||||
{
|
||||
return HipStreamAutoGuard(stream);
|
||||
}
|
||||
|
||||
inline HipEventAutoGuard makeEventAutoGuard(hipEvent_t event)
|
||||
{
|
||||
return HipEventAutoGuard(event);
|
||||
}
|
||||
|
||||
inline NcclCommAutoGuard makeCommAutoGuard(ncclComm_t comm)
|
||||
{
|
||||
return NcclCommAutoGuard(comm);
|
||||
}
|
||||
|
||||
} // namespace RCCLTestGuards
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file TestChecks.cpp
|
||||
* @brief Implementation file for TestChecks.hpp
|
||||
*
|
||||
* Provides definitions for variables used by test logging macros.
|
||||
*/
|
||||
|
||||
#include "TestChecks.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
// Define and initialize rcclTestDebugLevel for TEST_* macros
|
||||
|
||||
// This matches RCCL's debug level parsing logic from src/debug.cc
|
||||
// Values correspond to ncclDebugLogLevel enum in nccl_common.h:
|
||||
// - -1 = Uninitialized (treated as ERROR level)
|
||||
// - 0 = NCCL_LOG_NONE
|
||||
// - 1 = NCCL_LOG_VERSION
|
||||
// - 2 = NCCL_LOG_WARN
|
||||
// - 3 = NCCL_LOG_INFO
|
||||
// - 4 = NCCL_LOG_ABORT
|
||||
// - 5 = NCCL_LOG_TRACE
|
||||
int rcclTestDebugLevel = []() -> int {
|
||||
const char* env = std::getenv("NCCL_DEBUG");
|
||||
|
||||
// Default to ERROR level if not set (matches RCCL behavior)
|
||||
if (!env) return -1;
|
||||
|
||||
// Match RCCL's case-insensitive string comparison
|
||||
if (strcasecmp(env, "NONE") == 0) return 0;
|
||||
if (strcasecmp(env, "VERSION") == 0) return 1;
|
||||
if (strcasecmp(env, "WARN") == 0) return 2;
|
||||
if (strcasecmp(env, "INFO") == 0) return 3;
|
||||
if (strcasecmp(env, "ABORT") == 0) return 4;
|
||||
if (strcasecmp(env, "TRACE") == 0) return 5;
|
||||
|
||||
// Unknown value, default to ERROR level
|
||||
return -1;
|
||||
}();
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
|
||||
@@ -0,0 +1,604 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file TestChecks.hpp
|
||||
* @brief Centralized test error checking and logging macros
|
||||
*
|
||||
* Provides all test-related macros for error checking, logging, and assertions:
|
||||
* - MPI error checking (MPICHECK with 3 overload variants)
|
||||
* - NCCL error checking (RCCL_TEST_CHECK, RCCL_TEST_CHECK_GTEST_FAIL)
|
||||
* - HIP error checking (HIPCHECK, HIP_TEST_CHECK, HIP_TEST_CHECK_GTEST_FAIL)
|
||||
* - MPI-aware assertions (ASSERT_MPI_*)
|
||||
* - Debug logging (TEST_WARN, TEST_INFO, TEST_ABORT, TEST_TRACE)
|
||||
*/
|
||||
|
||||
#ifndef RCCL_TEST_CHECKS_HPP
|
||||
#define RCCL_TEST_CHECKS_HPP
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <mpi.h>
|
||||
#include <rccl/rccl.h>
|
||||
#include "utils.h" // For RCCL's getHostName utility
|
||||
|
||||
// Forward declaration of MPIEnvironment class (defined in MPIEnvironment.hpp)
|
||||
class MPIEnvironment;
|
||||
|
||||
// Forward declarations for helper functions
|
||||
extern int rcclTestDebugLevel;
|
||||
inline int getTestDebugLevel();
|
||||
inline int getTestMpiRank();
|
||||
inline const char* getTestHostname();
|
||||
inline bool isMultiNodeTest();
|
||||
|
||||
// Helper function implementations
|
||||
inline int getTestDebugLevel()
|
||||
{
|
||||
return rcclTestDebugLevel;
|
||||
}
|
||||
|
||||
inline int getTestMpiRank()
|
||||
{
|
||||
int rank = -1;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
return rank;
|
||||
}
|
||||
|
||||
inline const char* getTestHostname()
|
||||
{
|
||||
static char hostname[256] = {0};
|
||||
static bool initialized = false;
|
||||
|
||||
if(!initialized)
|
||||
{
|
||||
// Use RCCL's getHostName utility to get short hostname (delimited by '.')
|
||||
if(getHostName(hostname, sizeof(hostname), '.') != ncclSuccess)
|
||||
{
|
||||
strncpy(hostname, "unknown", sizeof(hostname) - 1);
|
||||
}
|
||||
initialized = true;
|
||||
}
|
||||
return hostname;
|
||||
}
|
||||
|
||||
// Forward declaration of helper function to access MPIEnvironment state
|
||||
// (Defined in MPIEnvironment.cpp to avoid circular dependency)
|
||||
int getMPIEnvironmentCachedMultiNodeResult();
|
||||
|
||||
inline bool isMultiNodeTest()
|
||||
{
|
||||
// Return cached result from global environment
|
||||
// If not yet computed (== -1), assume single node to be safe
|
||||
return getMPIEnvironmentCachedMultiNodeResult() == 1;
|
||||
}
|
||||
|
||||
// NCCL Error Checking Macros
|
||||
|
||||
/**
|
||||
* @def RCCL_TEST_CHECK
|
||||
* @brief NCCL error checking macro for test infrastructure code
|
||||
*
|
||||
* Checks NCCL function calls and returns error code if failed.
|
||||
* Use in test setup/teardown and infrastructure code that returns ncclResult_t.
|
||||
*
|
||||
* Behavior:
|
||||
* - Checks NCCL function result
|
||||
* - Logs error to stderr
|
||||
* - Returns the error code to caller
|
||||
*
|
||||
* @note For GTest test bodies, use RCCL_TEST_CHECK_GTEST_FAIL instead
|
||||
*/
|
||||
#define RCCL_TEST_CHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
ncclResult_t res = cmd; \
|
||||
if(res != ncclSuccess && res != ncclInProgress) \
|
||||
{ \
|
||||
fprintf(stderr, \
|
||||
"RCCL Error at %s:%d - %s\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
ncclGetErrorString(res)); \
|
||||
return res; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def RCCL_TEST_CHECK_GTEST_FAIL
|
||||
* @brief RCCL error checking macro for GTest test bodies
|
||||
*
|
||||
* Checks NCCL function calls and fails the test if an error occurs.
|
||||
* Use in TEST_F/TEST_P test bodies.
|
||||
*
|
||||
* Behavior:
|
||||
* - Checks NCCL function result
|
||||
* - Prints error to stdout
|
||||
* - Calls FAIL() to mark test as failed
|
||||
*
|
||||
* @note For infrastructure code (setup/teardown), use RCCL_TEST_CHECK instead
|
||||
*/
|
||||
#define RCCL_TEST_CHECK_GTEST_FAIL(cmd) \
|
||||
do \
|
||||
{ \
|
||||
ncclResult_t res = cmd; \
|
||||
if(res != ncclSuccess) \
|
||||
{ \
|
||||
printf("RCCL Error at %s:%d - %s\n", __FILE__, __LINE__, ncclGetErrorString(res)); \
|
||||
FAIL() << "RCCL Error: " << ncclGetErrorString(res); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
// HIP Error Checking Macros
|
||||
|
||||
/**
|
||||
* @def HIP_TEST_CHECK
|
||||
* @brief HIP error checking macro for test infrastructure code
|
||||
*
|
||||
* Checks HIP function calls and returns ncclUnhandledCudaError if failed.
|
||||
* Use in test setup/teardown and infrastructure code that returns ncclResult_t.
|
||||
*
|
||||
* Behavior:
|
||||
* - Checks HIP function result
|
||||
* - Logs error to stderr
|
||||
* - Returns ncclUnhandledCudaError to caller
|
||||
*
|
||||
* @note Requires enclosing function to return ncclResult_t
|
||||
* @note For test bodies, use HIP_TEST_CHECK_GTEST_FAIL instead
|
||||
*/
|
||||
#define HIP_TEST_CHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
hipError_t err = cmd; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
fprintf(stderr, \
|
||||
"HIP Error at %s:%d - %s (hipError_t=%d)\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
hipGetErrorString(err), \
|
||||
static_cast<int>(err)); \
|
||||
return ncclUnhandledCudaError; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def HIPCHECK
|
||||
* @brief HIP error checking macro (library-style)
|
||||
*
|
||||
* Similar to RCCL library's CUDACHECK macro. Returns ncclUnhandledCudaError on error.
|
||||
* Use in any code that returns ncclResult_t.
|
||||
*
|
||||
* Behavior:
|
||||
* - Checks HIP function result
|
||||
* - Logs error to stderr
|
||||
* - Returns ncclUnhandledCudaError to caller
|
||||
*
|
||||
* @note Requires enclosing function to return ncclResult_t
|
||||
* @note For GTest test bodies, use HIP_TEST_CHECK_GTEST_FAIL instead
|
||||
* @note This mirrors the library's CUDACHECK behavior
|
||||
*/
|
||||
#ifndef HIPCHECK
|
||||
#define HIPCHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
hipError_t err = cmd; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
fprintf(stderr, \
|
||||
"HIP Error at %s:%d - %s (hipError_t=%d)\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
hipGetErrorString(err), \
|
||||
static_cast<int>(err)); \
|
||||
return ncclUnhandledCudaError; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
#endif // HIPCHECK
|
||||
|
||||
/**
|
||||
* @def HIP_TEST_CHECK_GTEST_FAIL
|
||||
* @brief HIP error checking for GTest test bodies
|
||||
*
|
||||
* Checks HIP function calls and fails the test if an error occurs.
|
||||
* Use in TEST_F/TEST_P test bodies.
|
||||
*
|
||||
* Behavior:
|
||||
* - Checks HIP function result
|
||||
* - Prints error to stdout
|
||||
* - Calls FAIL() to mark test as failed
|
||||
*
|
||||
* @note For infrastructure code, use HIPCHECK or HIP_TEST_CHECK instead
|
||||
*/
|
||||
#define HIP_TEST_CHECK_GTEST_FAIL(cmd) \
|
||||
do \
|
||||
{ \
|
||||
hipError_t err = cmd; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
printf("HIP Error at %s:%d - %s\n", __FILE__, __LINE__, hipGetErrorString(err)); \
|
||||
FAIL() << "HIP Error: " << hipGetErrorString(err); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
// Debug Logging Macros (TEST_*)
|
||||
|
||||
/**
|
||||
* @def TEST_WARN
|
||||
* @brief Warning-level logging macro
|
||||
*
|
||||
* Prints warning messages when NCCL_DEBUG=WARN or higher.
|
||||
* Automatically includes rank and hostname prefixes.
|
||||
*/
|
||||
#define TEST_WARN(...) \
|
||||
do \
|
||||
{ \
|
||||
if(getTestDebugLevel() >= 2) \
|
||||
{ \
|
||||
int rank = getTestMpiRank(); \
|
||||
if(isMultiNodeTest()) \
|
||||
{ \
|
||||
printf("%s:[%d] TEST WARN ", getTestHostname(), rank); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
printf("[%d] TEST WARN ", rank); \
|
||||
} \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def TEST_INFO
|
||||
* @brief Info-level logging macro
|
||||
*
|
||||
* Prints informational messages when NCCL_DEBUG=INFO or higher.
|
||||
* Automatically includes rank and hostname prefixes.
|
||||
*/
|
||||
#define TEST_INFO(...) \
|
||||
do \
|
||||
{ \
|
||||
if(getTestDebugLevel() >= 3) \
|
||||
{ \
|
||||
int rank = getTestMpiRank(); \
|
||||
if(isMultiNodeTest()) \
|
||||
{ \
|
||||
printf("%s:[%d] TEST INFO ", getTestHostname(), rank); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
printf("[%d] TEST INFO ", rank); \
|
||||
} \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def TEST_ABORT
|
||||
* @brief Abort-level logging macro
|
||||
*
|
||||
* Prints abort-level messages when NCCL_DEBUG=ABORT or higher.
|
||||
* Automatically includes rank and hostname prefixes.
|
||||
*/
|
||||
#define TEST_ABORT(...) \
|
||||
do \
|
||||
{ \
|
||||
if(getTestDebugLevel() >= 4) \
|
||||
{ \
|
||||
int rank = getTestMpiRank(); \
|
||||
if(isMultiNodeTest()) \
|
||||
{ \
|
||||
printf("%s:[%d] TEST ABORT ", getTestHostname(), rank); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
printf("[%d] TEST ABORT ", rank); \
|
||||
} \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def TEST_TRACE
|
||||
* @brief Trace-level logging macro
|
||||
*
|
||||
* Prints trace messages when NCCL_DEBUG=TRACE.
|
||||
* Automatically includes rank and hostname prefixes.
|
||||
*/
|
||||
#define TEST_TRACE(...) \
|
||||
do \
|
||||
{ \
|
||||
if(getTestDebugLevel() >= 5) \
|
||||
{ \
|
||||
int rank = getTestMpiRank(); \
|
||||
if(isMultiNodeTest()) \
|
||||
{ \
|
||||
printf("%s:[%d] TEST TRACE ", getTestHostname(), rank); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
printf("[%d] TEST TRACE ", rank); \
|
||||
} \
|
||||
printf(__VA_ARGS__); \
|
||||
printf("\n"); \
|
||||
fflush(stdout); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
// MPI-Aware Assertion Macros (ASSERT_MPI_*)
|
||||
|
||||
/**
|
||||
* @def ASSERT_MPI_TRUE
|
||||
* @brief MPI-aware version of ASSERT_TRUE
|
||||
*
|
||||
* Checks condition on all ranks. If ANY rank fails, ALL ranks skip together
|
||||
* to prevent deadlock. This is critical for MPI tests where collective
|
||||
* operations require all ranks to participate.
|
||||
*
|
||||
* Behavior:
|
||||
* - Evaluates condition on each rank
|
||||
* - Uses MPI_Allreduce to check if all ranks passed
|
||||
* - If any rank fails, all ranks call GTEST_SKIP() together
|
||||
*
|
||||
* @param condition The condition to test
|
||||
*/
|
||||
#define ASSERT_MPI_TRUE(condition) \
|
||||
do \
|
||||
{ \
|
||||
bool _local_pass = static_cast<bool>(condition); \
|
||||
int _local_status = _local_pass ? 1 : 0; \
|
||||
int _global_status = 0; \
|
||||
MPI_Allreduce(&_local_status, &_global_status, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); \
|
||||
\
|
||||
if(_global_status == 0) \
|
||||
{ \
|
||||
/* At least one rank failed */ \
|
||||
if(!_local_pass) \
|
||||
{ \
|
||||
/* This rank failed - show the actual error */ \
|
||||
EXPECT_TRUE(condition) \
|
||||
<< "Rank " << MPIEnvironment::world_rank << " failed assertion"; \
|
||||
} \
|
||||
/* All ranks skip together */ \
|
||||
GTEST_SKIP() \
|
||||
<< "Rank " << MPIEnvironment::world_rank \
|
||||
<< ": Skipping test due to failure on at least one rank (synchronized exit)"; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def ASSERT_MPI_FALSE
|
||||
* @brief MPI-aware version of ASSERT_FALSE
|
||||
*/
|
||||
#define ASSERT_MPI_FALSE(condition) ASSERT_MPI_TRUE(!(condition))
|
||||
|
||||
/**
|
||||
* @def ASSERT_MPI_EQ
|
||||
* @brief MPI-aware version of ASSERT_EQ
|
||||
*
|
||||
* Checks if val1 == val2 on all ranks. If ANY rank fails,
|
||||
* ALL ranks skip together to prevent deadlock.
|
||||
*
|
||||
* @param val1 First value
|
||||
* @param val2 Second value
|
||||
*/
|
||||
#define ASSERT_MPI_EQ(val1, val2) \
|
||||
do \
|
||||
{ \
|
||||
auto _v1 = (val1); \
|
||||
auto _v2 = (val2); \
|
||||
bool _local_pass = (_v1 == _v2); \
|
||||
int _local_status = _local_pass ? 1 : 0; \
|
||||
int _global_status = 0; \
|
||||
MPI_Allreduce(&_local_status, &_global_status, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); \
|
||||
\
|
||||
if(_global_status == 0) \
|
||||
{ \
|
||||
if(!_local_pass) \
|
||||
{ \
|
||||
EXPECT_EQ(_v1, _v2) \
|
||||
<< "Rank " << MPIEnvironment::world_rank << " failed assertion"; \
|
||||
} \
|
||||
GTEST_SKIP() \
|
||||
<< "Rank " << MPIEnvironment::world_rank \
|
||||
<< ": Skipping test due to failure on at least one rank (synchronized exit)"; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def ASSERT_MPI_NE
|
||||
* @brief MPI-aware version of ASSERT_NE
|
||||
*
|
||||
* @param val1 First value
|
||||
* @param val2 Second value
|
||||
*/
|
||||
#define ASSERT_MPI_NE(val1, val2) \
|
||||
do \
|
||||
{ \
|
||||
auto _v1 = (val1); \
|
||||
auto _v2 = (val2); \
|
||||
bool _local_pass = (_v1 != _v2); \
|
||||
int _local_status = _local_pass ? 1 : 0; \
|
||||
int _global_status = 0; \
|
||||
MPI_Allreduce(&_local_status, &_global_status, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); \
|
||||
\
|
||||
if(_global_status == 0) \
|
||||
{ \
|
||||
if(!_local_pass) \
|
||||
{ \
|
||||
EXPECT_NE(_v1, _v2) \
|
||||
<< "Rank " << MPIEnvironment::world_rank << " failed assertion"; \
|
||||
} \
|
||||
GTEST_SKIP() \
|
||||
<< "Rank " << MPIEnvironment::world_rank \
|
||||
<< ": Skipping test due to failure on at least one rank (synchronized exit)"; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def ASSERT_MPI_SUCCESS
|
||||
* @brief MPI-aware assertion for MPI operations
|
||||
*
|
||||
* Checks if MPI operation succeeded on all ranks. If ANY rank fails,
|
||||
* ALL ranks skip together. Provides better error messages for MPI operations.
|
||||
*
|
||||
* @param expr Expression that returns an MPI error code
|
||||
*/
|
||||
#define ASSERT_MPI_SUCCESS(expr) \
|
||||
do \
|
||||
{ \
|
||||
int _result = (expr); \
|
||||
bool _local_pass = (_result == MPI_SUCCESS); \
|
||||
int _local_status = _local_pass ? 1 : 0; \
|
||||
int _global_status = 0; \
|
||||
MPI_Allreduce(&_local_status, &_global_status, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); \
|
||||
\
|
||||
if(_global_status == 0) \
|
||||
{ \
|
||||
if(!_local_pass) \
|
||||
{ \
|
||||
char _error_string[MPI_MAX_ERROR_STRING]; \
|
||||
int _len; \
|
||||
MPI_Error_string(_result, _error_string, &_len); \
|
||||
EXPECT_EQ(_result, MPI_SUCCESS) << "Rank " << MPIEnvironment::world_rank \
|
||||
<< " failed MPI operation: " << _error_string; \
|
||||
} \
|
||||
GTEST_SKIP() << "Rank " << MPIEnvironment::world_rank \
|
||||
<< ": Skipping test due to MPI failure on at least one rank"; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
// MPI Error Checking Macros (MPICHECK)
|
||||
|
||||
/**
|
||||
* @def MPICHECK
|
||||
* @brief Context-aware MPI error checking macro with overloaded behavior
|
||||
*
|
||||
* Provides three usage modes depending on context:
|
||||
*
|
||||
* @par Usage Modes:
|
||||
* - `MPICHECK(cmd)` - Normal test code: Fails test with FAIL() on error
|
||||
* - `MPICHECK(cmd, rank)` - Cleanup code: Calls MPI_Abort() on error
|
||||
* - `MPICHECK(cmd, rank, true)` - MPI_Finalize: Calls std::exit() on error
|
||||
*
|
||||
* @par Example:
|
||||
* @code
|
||||
* // In test body
|
||||
* MPICHECK(MPI_Barrier(MPI_COMM_WORLD));
|
||||
*
|
||||
* // In cleanup code
|
||||
* MPICHECK(MPI_Barrier(MPI_COMM_WORLD), world_rank);
|
||||
*
|
||||
* // During finalization
|
||||
* MPICHECK(MPI_Finalize(), world_rank, true);
|
||||
* @endcode
|
||||
*
|
||||
* @note Prints detailed error message including file, line, and MPI error string
|
||||
*/
|
||||
|
||||
// Helper macros for argument counting
|
||||
#define MPICHECK_GET_MACRO(_1, _2, _3, NAME, ...) NAME
|
||||
#define MPICHECK(...) \
|
||||
MPICHECK_GET_MACRO(__VA_ARGS__, MPICHECK_3, MPICHECK_2, MPICHECK_1)(__VA_ARGS__)
|
||||
|
||||
/**
|
||||
* @def MPICHECK_1
|
||||
* @brief 1-argument version: Normal test code (uses FAIL())
|
||||
* @hideinitializer
|
||||
*/
|
||||
#define MPICHECK_1(cmd) \
|
||||
do \
|
||||
{ \
|
||||
int err = cmd; \
|
||||
if(err != MPI_SUCCESS) \
|
||||
{ \
|
||||
char error_string[MPI_MAX_ERROR_STRING]; \
|
||||
int length; \
|
||||
MPI_Error_string(err, error_string, &length); \
|
||||
printf("MPI Error at %s:%d - %s\n", __FILE__, __LINE__, error_string); \
|
||||
FAIL() << "MPI Error: " << error_string; \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def MPICHECK_2
|
||||
* @brief 2-argument version: Cleanup code (uses MPI_Abort())
|
||||
* @hideinitializer
|
||||
*/
|
||||
#define MPICHECK_2(cmd, rank) \
|
||||
do \
|
||||
{ \
|
||||
int err = cmd; \
|
||||
if(err != MPI_SUCCESS) \
|
||||
{ \
|
||||
char error_string[MPI_MAX_ERROR_STRING]; \
|
||||
int length; \
|
||||
MPI_Error_string(err, error_string, &length); \
|
||||
std::fprintf(stderr, \
|
||||
"Rank %d: MPI Error at %s:%d - %s\n", \
|
||||
rank, \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
error_string); \
|
||||
std::fflush(stderr); \
|
||||
MPI_Abort(MPI_COMM_WORLD, err); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
/**
|
||||
* @def MPICHECK_3
|
||||
* @brief 3-argument version: MPI_Finalize (uses std::exit())
|
||||
* @hideinitializer
|
||||
*/
|
||||
#define MPICHECK_3(cmd, rank, is_finalize) \
|
||||
do \
|
||||
{ \
|
||||
int err = cmd; \
|
||||
if(err != MPI_SUCCESS) \
|
||||
{ \
|
||||
char error_string[MPI_MAX_ERROR_STRING]; \
|
||||
int length; \
|
||||
MPI_Error_string(err, error_string, &length); \
|
||||
std::fprintf(stderr, \
|
||||
"Rank %d: MPI_Finalize Error at %s:%d - %s\n", \
|
||||
rank, \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
error_string); \
|
||||
std::fflush(stderr); \
|
||||
std::exit(err); \
|
||||
} \
|
||||
} \
|
||||
while(0)
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // RCCL_TEST_CHECKS_HPP
|
||||
@@ -0,0 +1,85 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
/**
|
||||
* @file main_mpi.cpp
|
||||
* @brief Main entry point for Google Test-based MPI tests
|
||||
*
|
||||
* This file provides the main() function for running GTest-based MPI tests.
|
||||
* For standalone tests (performance benchmarks, etc.), each test should have
|
||||
* its own main() function and use MPIHelpers for common functionality.
|
||||
*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
#include "MPIHelpers.hpp"
|
||||
#include "MPITestBase.hpp"
|
||||
#include "MPIEnvironment.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// Initialize MPI using shared helper
|
||||
auto mpi_ctx = MPIHelpers::initializeMPI(&argc, &argv);
|
||||
|
||||
const auto world_rank = mpi_ctx.world_rank;
|
||||
const auto world_size = mpi_ctx.world_size;
|
||||
|
||||
// Setup per-rank logging using shared helper
|
||||
auto rank_log_config = MPIHelpers::setupRankLogging(world_rank);
|
||||
const auto per_rank_logging_enabled = rank_log_config && rank_log_config->logging_enabled;
|
||||
|
||||
// Print initialization message
|
||||
if(world_rank == 0 && !per_rank_logging_enabled)
|
||||
{
|
||||
TEST_INFO("MPI initialized - World size: %d, Thread support: %d",
|
||||
world_size,
|
||||
mpi_ctx.thread_support);
|
||||
}
|
||||
|
||||
// Initialize Google Test
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
// Suppress GTest output for non-zero ranks (unless per-rank logging is enabled)
|
||||
// This is done by deleting GTest listeners for non-zero ranks
|
||||
// Note: stdout/stderr are already redirected for non-zero ranks by setupRankLogging
|
||||
if(world_rank != 0 && !per_rank_logging_enabled)
|
||||
{
|
||||
auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
|
||||
delete listeners.Release(listeners.default_result_printer());
|
||||
delete listeners.Release(listeners.default_xml_generator());
|
||||
}
|
||||
|
||||
// Set up the RCCL MPI environment for all tests
|
||||
::testing::AddGlobalTestEnvironment(new MPIEnvironment());
|
||||
|
||||
// Run all tests
|
||||
const auto ret_code = RUN_ALL_TESTS();
|
||||
|
||||
// Restore original output if per-rank logging was enabled
|
||||
if(rank_log_config)
|
||||
{
|
||||
MPIHelpers::restoreRankLogging(*rank_log_config);
|
||||
}
|
||||
|
||||
// MPI_Finalize called by MPIEnvironment destructor
|
||||
return ret_code;
|
||||
}
|
||||
|
||||
#else // MPI_TESTS_ENABLED not defined
|
||||
|
||||
int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
|
||||
{
|
||||
std::fprintf(stderr,
|
||||
"ERROR: MPI tests are not enabled. Please build with ENABLE_MPI_TESTS=ON\n");
|
||||
std::fprintf(stderr, "Usage: cmake -DENABLE_MPI_TESTS=ON -DMPI_PATH=/path/to/mpi ..\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,522 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "DeviceBufferHelpers.hpp"
|
||||
#include "TestChecks.hpp"
|
||||
#include "TransportMPIBase.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
// Import MPI test constants
|
||||
using namespace MPITestConstants;
|
||||
using namespace RCCLTestHelpers;
|
||||
|
||||
// NET-specific RAII deleters
|
||||
namespace RCCLTestGuards
|
||||
{
|
||||
|
||||
struct NetMHandleDeleter
|
||||
{
|
||||
ncclNet_t* net;
|
||||
void* comm;
|
||||
NetMHandleDeleter(ncclNet_t* n = nullptr, void* c = nullptr) : net(n), comm(c) {}
|
||||
void operator()(void* mhandle) const
|
||||
{
|
||||
if(mhandle && comm && net)
|
||||
{
|
||||
net->deregMr(comm, mhandle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct NetSendCommDeleter
|
||||
{
|
||||
ncclNet_t* net;
|
||||
explicit NetSendCommDeleter(ncclNet_t* n = nullptr) : net(n) {}
|
||||
void operator()(void* comm) const
|
||||
{
|
||||
if(comm && net)
|
||||
net->closeSend(comm);
|
||||
}
|
||||
};
|
||||
|
||||
struct NetRecvCommDeleter
|
||||
{
|
||||
ncclNet_t* net;
|
||||
explicit NetRecvCommDeleter(ncclNet_t* n = nullptr) : net(n) {}
|
||||
void operator()(void* comm) const
|
||||
{
|
||||
if(comm && net)
|
||||
net->closeRecv(comm);
|
||||
}
|
||||
};
|
||||
|
||||
struct NetListenCommDeleter
|
||||
{
|
||||
ncclNet_t* net;
|
||||
explicit NetListenCommDeleter(ncclNet_t* n = nullptr) : net(n) {}
|
||||
void operator()(void* comm) const
|
||||
{
|
||||
if(comm && net)
|
||||
net->closeListen(comm);
|
||||
}
|
||||
};
|
||||
|
||||
using NetMHandleGuard = ResourceGuard<void*, NetMHandleDeleter>;
|
||||
using NetSendCommGuard = ResourceGuard<void*, NetSendCommDeleter>;
|
||||
using NetRecvCommGuard = ResourceGuard<void*, NetRecvCommDeleter>;
|
||||
using NetListenCommGuard = ResourceGuard<void*, NetListenCommDeleter>;
|
||||
|
||||
class NetConnectionGuard
|
||||
{
|
||||
private:
|
||||
ncclNet_t* net_;
|
||||
void* listen_comm_;
|
||||
void* send_comm_;
|
||||
void* recv_comm_;
|
||||
|
||||
public:
|
||||
explicit NetConnectionGuard(ncclNet_t* net)
|
||||
: net_(net), listen_comm_(nullptr), send_comm_(nullptr), recv_comm_(nullptr)
|
||||
{}
|
||||
|
||||
~NetConnectionGuard()
|
||||
{
|
||||
if(recv_comm_ && net_)
|
||||
net_->closeRecv(recv_comm_);
|
||||
if(send_comm_ && net_)
|
||||
net_->closeSend(send_comm_);
|
||||
if(listen_comm_ && net_)
|
||||
net_->closeListen(listen_comm_);
|
||||
}
|
||||
|
||||
void setListenComm(void* comm)
|
||||
{
|
||||
listen_comm_ = comm;
|
||||
}
|
||||
void setSendComm(void* comm)
|
||||
{
|
||||
send_comm_ = comm;
|
||||
}
|
||||
void setRecvComm(void* comm)
|
||||
{
|
||||
recv_comm_ = comm;
|
||||
}
|
||||
|
||||
void* getListenComm() const
|
||||
{
|
||||
return listen_comm_;
|
||||
}
|
||||
void* getSendComm() const
|
||||
{
|
||||
return send_comm_;
|
||||
}
|
||||
void* getRecvComm() const
|
||||
{
|
||||
return recv_comm_;
|
||||
}
|
||||
|
||||
NetConnectionGuard(const NetConnectionGuard&) = delete;
|
||||
NetConnectionGuard& operator=(const NetConnectionGuard&) = delete;
|
||||
NetConnectionGuard(NetConnectionGuard&&) = delete;
|
||||
NetConnectionGuard& operator=(NetConnectionGuard&&) = delete;
|
||||
};
|
||||
|
||||
inline NetMHandleGuard makeNetMHandleGuard(void* mhandle, ncclNet_t* net, void* comm)
|
||||
{
|
||||
return NetMHandleGuard(mhandle, NetMHandleDeleter(net, comm));
|
||||
}
|
||||
|
||||
inline NetSendCommGuard makeNetSendCommGuard(void* comm, ncclNet_t* net)
|
||||
{
|
||||
return NetSendCommGuard(comm, NetSendCommDeleter(net));
|
||||
}
|
||||
|
||||
inline NetRecvCommGuard makeNetRecvCommGuard(void* comm, ncclNet_t* net)
|
||||
{
|
||||
return NetRecvCommGuard(comm, NetRecvCommDeleter(net));
|
||||
}
|
||||
|
||||
inline NetListenCommGuard makeNetListenCommGuard(void* comm, ncclNet_t* net)
|
||||
{
|
||||
return NetListenCommGuard(comm, NetListenCommDeleter(net));
|
||||
}
|
||||
|
||||
} // namespace RCCLTestGuards
|
||||
|
||||
namespace
|
||||
{
|
||||
// Buffer size constants
|
||||
inline constexpr size_t kTestBufferSize = 16384;
|
||||
|
||||
// NET transport test requirements
|
||||
inline constexpr int kMinNodesForNET = 2; // NET transport requires at least 2 nodes
|
||||
inline constexpr int kExactRanksForNET = 2; // NET transport tests use exactly 2 ranks (1 per node)
|
||||
|
||||
// Test pattern generation constants
|
||||
inline constexpr int kDefaultPatternMultiplier = 100; // For NET transport patterns
|
||||
inline constexpr int kByteValueModulo = 256; // For uint8_t wraparound
|
||||
|
||||
} // namespace
|
||||
|
||||
class NetTransportMPITest : public TransportTestBase
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
TransportTestBase::SetUp();
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("NetTransport SetUp completed");
|
||||
}
|
||||
}
|
||||
|
||||
void TearDown() override
|
||||
{
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("NetTransport TearDown completed");
|
||||
}
|
||||
TransportTestBase::TearDown();
|
||||
}
|
||||
|
||||
public:
|
||||
// Test ncclNetGraphRegisterBuffer
|
||||
void testNetGraphRegisterBuffer()
|
||||
{
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Testing ncclNetGraphRegisterBuffer...");
|
||||
}
|
||||
|
||||
// Verify communicator is ready
|
||||
ASSERT_NE(comm_handle, nullptr) << "Rank " << config.world_rank << ": comm_handle is null";
|
||||
|
||||
// Allocate and automatically guard buffers
|
||||
void* send_buffer = nullptr;
|
||||
void* recv_buffer = nullptr;
|
||||
allocateAndInitBuffersGuarded(&send_buffer, &recv_buffer, kTestBufferSize, kTestBufferSize);
|
||||
|
||||
// Register and automatically guard handles
|
||||
void* send_reg_handle = nullptr;
|
||||
void* recv_reg_handle = nullptr;
|
||||
preRegisterBuffersGuarded(send_buffer,
|
||||
recv_buffer,
|
||||
kTestBufferSize,
|
||||
kTestBufferSize,
|
||||
&send_reg_handle,
|
||||
&recv_reg_handle);
|
||||
|
||||
// Test ncclNetGraphRegisterBuffer
|
||||
int net_reg_flag{};
|
||||
void* net_handle{};
|
||||
ncclIntruQueue<ncclCommCallback, &ncclCommCallback::next> cleanup_queue{};
|
||||
int n_cleanup_elts{};
|
||||
|
||||
ncclConnector* send_conn_array[1] = {&send_connector};
|
||||
|
||||
auto nccl_result
|
||||
= ncclNetGraphRegisterBuffer(reinterpret_cast<ncclComm*>(getActiveCommunicator()),
|
||||
send_buffer,
|
||||
kTestBufferSize,
|
||||
send_conn_array,
|
||||
1,
|
||||
&net_reg_flag,
|
||||
&net_handle,
|
||||
&cleanup_queue,
|
||||
&n_cleanup_elts);
|
||||
|
||||
EXPECT_EQ(ncclSuccess, nccl_result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": ncclNetGraphRegisterBuffer failed: " << ncclGetErrorString(nccl_result);
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO(" ncclNetGraphRegisterBuffer returned: %s",
|
||||
ncclGetErrorString(nccl_result));
|
||||
TEST_INFO(" Registration flag: %d", net_reg_flag);
|
||||
TEST_INFO(" Handle: %p", net_handle);
|
||||
TEST_INFO(" Cleanup queue elements: %d", n_cleanup_elts);
|
||||
}
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("ncclNetGraphRegisterBuffer test completed");
|
||||
}
|
||||
}
|
||||
|
||||
// Test ncclNetLocalRegisterBuffer
|
||||
void testNetLocalRegisterBuffer()
|
||||
{
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Testing ncclNetLocalRegisterBuffer...");
|
||||
TEST_INFO("This API internally calls ncclNetLocalRegisterBuffer "
|
||||
"and ncclNetLocalRegisterBuffer");
|
||||
}
|
||||
|
||||
// Verify communicator is ready (NCCL has already initialized NET transport)
|
||||
ASSERT_NE(comm_handle, nullptr) << "Rank " << config.world_rank << ": comm_handle is null";
|
||||
|
||||
// Allocate and automatically guard buffers
|
||||
void* send_buffer = nullptr;
|
||||
void* recv_buffer = nullptr;
|
||||
allocateAndInitBuffersGuarded(&send_buffer, &recv_buffer, kTestBufferSize, kTestBufferSize);
|
||||
|
||||
// Register and automatically guard handles
|
||||
void* send_reg_handle = nullptr;
|
||||
void* recv_reg_handle = nullptr;
|
||||
preRegisterBuffersGuarded(send_buffer,
|
||||
recv_buffer,
|
||||
kTestBufferSize,
|
||||
kTestBufferSize,
|
||||
&send_reg_handle,
|
||||
&recv_reg_handle);
|
||||
|
||||
// Test ncclNetLocalRegisterBuffer
|
||||
int net_reg_flag{};
|
||||
void* net_handle{};
|
||||
|
||||
ncclConnector* send_conn_array[1] = {&send_connector};
|
||||
|
||||
auto nccl_result
|
||||
= ncclNetLocalRegisterBuffer(reinterpret_cast<ncclComm*>(getActiveCommunicator()),
|
||||
send_buffer,
|
||||
kTestBufferSize,
|
||||
send_conn_array,
|
||||
1, // nPeers
|
||||
&net_reg_flag,
|
||||
&net_handle);
|
||||
|
||||
EXPECT_EQ(ncclSuccess, nccl_result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": ncclNetLocalRegisterBuffer failed: " << ncclGetErrorString(nccl_result);
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO(" ncclNetLocalRegisterBuffer returned: %s",
|
||||
ncclGetErrorString(nccl_result));
|
||||
TEST_INFO(" Registration flag: %d", net_reg_flag);
|
||||
TEST_INFO(" Handle: %p", net_handle);
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple buffer sizes with actual data transfer
|
||||
void testMultipleBufferSizes()
|
||||
{
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Testing multiple buffer sizes (aligned and unaligned) with NET "
|
||||
"transport and data transfer...");
|
||||
}
|
||||
|
||||
// Verify communicator is ready
|
||||
ASSERT_NE(comm_handle, nullptr) << "Rank " << config.world_rank << ": comm_handle is null";
|
||||
|
||||
// Test both aligned and unaligned buffer sizes to validate edge cases
|
||||
std::vector<size_t> sizes = {
|
||||
// Small sizes (including unaligned)
|
||||
1, // Minimum size
|
||||
3, // Unaligned (not power of 2)
|
||||
7, // Unaligned
|
||||
15, // Unaligned
|
||||
63, // Unaligned
|
||||
|
||||
// Medium sizes (mix of aligned and unaligned)
|
||||
1024, // 1KB (aligned)
|
||||
1025, // 1KB + 1 (unaligned)
|
||||
1536, // 1.5KB (unaligned)
|
||||
4096, // 4KB (aligned)
|
||||
4097, // 4KB + 1 (unaligned)
|
||||
5000, // Unaligned
|
||||
16384, // 16KB (aligned)
|
||||
16385, // 16KB + 1 (unaligned)
|
||||
|
||||
// Large sizes (mix of aligned and unaligned)
|
||||
65536, // 64KB (aligned)
|
||||
65537, // 64KB + 1 (unaligned)
|
||||
100000, // ~98KB (unaligned)
|
||||
262144, // 256KB (aligned)
|
||||
262145, // 256KB + 1 (unaligned)
|
||||
500000, // ~488KB (unaligned)
|
||||
1048576, // 1MB (aligned)
|
||||
1048577, // 1MB + 1 (unaligned)
|
||||
4 * 1024 * 1024, // 4MB (aligned)
|
||||
4 * 1024 * 1024 + 1 // 4MB + 1 (unaligned)
|
||||
};
|
||||
|
||||
int peer_rank = (config.world_rank == 0) ? 1 : 0;
|
||||
hipStream_t stream = getActiveStream();
|
||||
ASSERT_NE(stream, nullptr) << "Rank " << config.world_rank << ": Stream is null";
|
||||
|
||||
for(size_t size : sizes)
|
||||
{
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO(" Testing size: %zu bytes with data transfer", size);
|
||||
}
|
||||
|
||||
// Allocate buffers with local guards (per-iteration cleanup)
|
||||
void* send_buffer = nullptr;
|
||||
void* recv_buffer = nullptr;
|
||||
auto [sendGuard, recvGuard]
|
||||
= allocateAndInitBuffersGuarded(&send_buffer, &recv_buffer, size, size, false);
|
||||
|
||||
ASSERT_NE(send_buffer, nullptr) << "Rank " << config.world_rank
|
||||
<< ": Send buffer allocation failed for size " << size;
|
||||
ASSERT_NE(recv_buffer, nullptr) << "Rank " << config.world_rank
|
||||
<< ": Recv buffer allocation failed for size " << size;
|
||||
|
||||
// Initialize send buffer with rank and size-specific pattern
|
||||
uint8_t* send_data = static_cast<uint8_t*>(send_buffer);
|
||||
for(size_t i = 0; i < size; i++)
|
||||
{
|
||||
send_data[i] = static_cast<uint8_t>(
|
||||
(config.world_rank * kDefaultPatternMultiplier + i) % kByteValueModulo);
|
||||
}
|
||||
|
||||
// Initialize recv buffer with invalid pattern
|
||||
uint8_t* recv_data = static_cast<uint8_t*>(recv_buffer);
|
||||
for(size_t i = 0; i < size; i++)
|
||||
{
|
||||
recv_data[i] = 0xFF; // Invalid pattern to detect transfer
|
||||
}
|
||||
|
||||
// Perform actual data transfer using NCCL Send/Recv
|
||||
// Use ASSERT_MPI_SUCCESS to ensure both ranks synchronize on NCCL errors
|
||||
ASSERT_MPI_SUCCESS(ncclGroupStart());
|
||||
|
||||
ASSERT_MPI_SUCCESS(
|
||||
ncclSend(send_buffer, size, ncclInt8, peer_rank, getActiveCommunicator(), stream));
|
||||
|
||||
ASSERT_MPI_SUCCESS(
|
||||
ncclRecv(recv_buffer, size, ncclInt8, peer_rank, getActiveCommunicator(), stream));
|
||||
|
||||
ASSERT_MPI_SUCCESS(ncclGroupEnd());
|
||||
|
||||
// Wait for transfer to complete
|
||||
// Use ASSERT_MPI_EQ to ensure both ranks synchronize on HIP errors
|
||||
ASSERT_MPI_EQ(hipSuccess, hipStreamSynchronize(stream));
|
||||
|
||||
// Verify received data matches peer's send pattern
|
||||
int errors = 0;
|
||||
const int max_errors_to_print = 5;
|
||||
for(size_t i = 0; i < size && errors < max_errors_to_print; i++)
|
||||
{
|
||||
uint8_t expected = static_cast<uint8_t>((peer_rank * kDefaultPatternMultiplier + i)
|
||||
% kByteValueModulo);
|
||||
if(recv_data[i] != expected)
|
||||
{
|
||||
TEST_WARN("Size %zu - Data mismatch at index %zu: expected %u, got %u",
|
||||
size,
|
||||
i,
|
||||
expected,
|
||||
recv_data[i]);
|
||||
errors++;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(0, errors) << "Rank " << config.world_rank
|
||||
<< ": Found data mismatches for buffer size " << size;
|
||||
|
||||
if(config.world_rank == 0 && errors == 0)
|
||||
{
|
||||
TEST_INFO(" Size %zu - Data transfer successful and verified", size);
|
||||
}
|
||||
|
||||
// Resource Guards will automatically cleanup at end of loop iteration
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Test cases
|
||||
TEST_F(NetTransportMPITest, NetGraphRegisterBufferTest)
|
||||
{
|
||||
// NET transport tests require exactly 2 ranks on 2 nodes (1 rank per node)
|
||||
if(!validateTestPrerequisites(kExactRanksForNET,
|
||||
kExactRanksForNET,
|
||||
kNoPowerOfTwoRequired,
|
||||
kMinNodesForNET,
|
||||
kMinNodesForNET))
|
||||
{
|
||||
GTEST_SKIP() << "NET transport test requires exactly " << kExactRanksForNET << " ranks on "
|
||||
<< kMinNodesForNET << " nodes (1 rank per node)";
|
||||
}
|
||||
|
||||
// Create test-specific communicator
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Starting ncclNetGraphRegisterBuffer test (multi-node)");
|
||||
}
|
||||
|
||||
testNetGraphRegisterBuffer();
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("ncclNetGraphRegisterBuffer test completed successfully");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(NetTransportMPITest, NetLocalRegisterBufferTest)
|
||||
{
|
||||
// NET transport tests require exactly 2 ranks on 2 nodes (1 rank per node)
|
||||
if(!validateTestPrerequisites(kExactRanksForNET,
|
||||
kExactRanksForNET,
|
||||
kNoPowerOfTwoRequired,
|
||||
kMinNodesForNET,
|
||||
kMinNodesForNET))
|
||||
{
|
||||
GTEST_SKIP() << "NET transport test requires exactly " << kExactRanksForNET << " ranks on "
|
||||
<< kMinNodesForNET << " nodes (1 rank per node)";
|
||||
}
|
||||
|
||||
// Create test-specific communicator
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Starting ncclNetLocalRegisterBuffer test (multi-node)");
|
||||
}
|
||||
|
||||
testNetLocalRegisterBuffer();
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("ncclNetLocalRegisterBuffer test completed successfully");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(NetTransportMPITest, MultipleBufferSizesTest)
|
||||
{
|
||||
// NET transport tests require exactly 2 ranks on 2 nodes (1 rank per node)
|
||||
if(!validateTestPrerequisites(kExactRanksForNET,
|
||||
kExactRanksForNET,
|
||||
kNoPowerOfTwoRequired,
|
||||
kMinNodesForNET,
|
||||
kMinNodesForNET))
|
||||
{
|
||||
GTEST_SKIP() << "NET transport test requires exactly " << kExactRanksForNET << " ranks on "
|
||||
<< kMinNodesForNET << " nodes (1 rank per node)";
|
||||
}
|
||||
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Starting multiple buffer sizes test (multi-node)");
|
||||
}
|
||||
|
||||
testMultipleBufferSizes();
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Multiple buffer sizes test completed successfully");
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,979 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "DeviceBufferHelpers.hpp"
|
||||
#include "TestChecks.hpp"
|
||||
#include "ResourceGuards.hpp"
|
||||
#include "TransportMPIBase.hpp"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
// Import MPI test constants
|
||||
using namespace MPITestConstants;
|
||||
using namespace RCCLTestGuards;
|
||||
using namespace RCCLTestHelpers;
|
||||
using namespace TransportTestConstants;
|
||||
|
||||
// SHM-specific test configuration
|
||||
struct ShmTestConfig
|
||||
{
|
||||
bool is_sender{false};
|
||||
void* send_buffer{nullptr};
|
||||
void* recv_buffer{nullptr};
|
||||
size_t buffer_size{0};
|
||||
};
|
||||
|
||||
class ShmMPITest : public TransportTestBase
|
||||
{
|
||||
protected:
|
||||
ShmTestConfig shm_config;
|
||||
|
||||
// Test data buffers
|
||||
std::vector<uint32_t> host_send_data;
|
||||
std::vector<uint32_t> host_recv_data;
|
||||
|
||||
// Connection info structures for setup/connect phases
|
||||
ncclConnect send_connect_info{};
|
||||
ncclConnect recv_connect_info{};
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
// Call base class SetUp first
|
||||
TransportTestBase::SetUp();
|
||||
|
||||
// Switch to SHM transport
|
||||
setTransportType(TransportType::SHM);
|
||||
|
||||
// Set up SHM-specific test configuration
|
||||
shm_config.is_sender = (config.world_rank == 0);
|
||||
shm_config.buffer_size = kDefaultBufferSize;
|
||||
|
||||
// Allocate and initialize send buffer with test pattern
|
||||
constexpr size_t num_elements = kDefaultBufferSize / sizeof(float);
|
||||
auto [send_err, _] = allocateAndInitialize<float>(&shm_config.send_buffer,
|
||||
num_elements,
|
||||
config.world_rank);
|
||||
EXPECT_EQ(hipSuccess, send_err)
|
||||
<< "Rank " << config.world_rank << ": Failed to allocate/initialize send buffer";
|
||||
|
||||
// Allocate and zero-initialize receive buffer
|
||||
hipError_t hip_result = hipMalloc(&shm_config.recv_buffer, shm_config.buffer_size);
|
||||
EXPECT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Failed to allocate recv buffer";
|
||||
|
||||
hip_result = zeroInitializeBuffer<float>(shm_config.recv_buffer, num_elements);
|
||||
EXPECT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Failed to zero-initialize recv buffer";
|
||||
|
||||
// Synchronize default stream to ensure all buffer operations complete
|
||||
// Note: This is called in SetUp() before test starts, so we use the default stream (0)
|
||||
// Using config.stream here causes "invalid resource handle" as it's not yet initialized
|
||||
EXPECT_EQ(hipSuccess, hipStreamSynchronize(0))
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": Failed to synchronize default stream after buffer initialization";
|
||||
}
|
||||
|
||||
void TearDown() override
|
||||
{
|
||||
// Cleanup SHM-specific test resources
|
||||
if(shm_config.send_buffer)
|
||||
{
|
||||
(void)hipFree(shm_config.send_buffer);
|
||||
shm_config.send_buffer = nullptr;
|
||||
}
|
||||
if(shm_config.recv_buffer)
|
||||
{
|
||||
(void)hipFree(shm_config.recv_buffer);
|
||||
shm_config.recv_buffer = nullptr;
|
||||
}
|
||||
|
||||
// Call base class TearDown
|
||||
TransportTestBase::TearDown();
|
||||
}
|
||||
|
||||
public:
|
||||
// Test SHM capability detection (same-host communication)
|
||||
void testShmCanConnect()
|
||||
{
|
||||
// Validate preconditions
|
||||
ASSERT_NE(nullptr, comm_handle)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": comm_handle is null - NCCL communicator not initialized";
|
||||
ASSERT_NE(nullptr, local_peer_info)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": local_peer_info is null - peer information not initialized";
|
||||
ASSERT_NE(nullptr, remote_peer_info)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": remote_peer_info is null - peer information not initialized";
|
||||
|
||||
int can_connect = 0;
|
||||
const auto result = shmTransport.canConnect(&can_connect,
|
||||
comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result) << "Rank " << config.world_rank
|
||||
<< ": shmCanConnect failed: " << ncclGetErrorString(result);
|
||||
|
||||
// Synchronize the stream to ensure all operations complete
|
||||
ASSERT_EQ(hipSuccess, syncStream(config.stream, config.world_rank))
|
||||
<< "Rank " << config.world_rank << ": Stream synchronization failed";
|
||||
}
|
||||
|
||||
// Test SHM setup phase
|
||||
void testShmSetup()
|
||||
{
|
||||
// Call setup() and save the connect_info to class members for later MPI exchange
|
||||
const auto result = shm_config.is_sender
|
||||
? shmTransport.send.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&send_connect_info, // Save to class member
|
||||
&send_connector,
|
||||
0,
|
||||
0)
|
||||
: shmTransport.recv.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&recv_connect_info, // Save to class member
|
||||
&recv_connector,
|
||||
0,
|
||||
0);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank << ": " << (shm_config.is_sender ? "Send" : "Recv")
|
||||
<< " setup failed: " << ncclGetErrorString(result);
|
||||
}
|
||||
|
||||
// Test SHM connection phase
|
||||
void testShmConnect()
|
||||
{
|
||||
// Validate preconditions
|
||||
ASSERT_NE(nullptr, comm_handle) << "Rank " << config.world_rank << ": comm_handle is null";
|
||||
ASSERT_NE(nullptr, local_peer_info)
|
||||
<< "Rank " << config.world_rank << ": local_peer_info is null";
|
||||
ASSERT_NE(nullptr, remote_peer_info)
|
||||
<< "Rank " << config.world_rank << ": remote_peer_info is null";
|
||||
|
||||
// NOTE: setup() was already called in testShmSetup() and saved connect_info to class members
|
||||
// This method only does MPI exchange of connect_info and then calls connect()
|
||||
|
||||
if(shm_config.is_sender)
|
||||
{
|
||||
// Exchange connect info with receiver using MPI
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Send(&send_connect_info, // Use class member from testShmSetup()
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD))
|
||||
<< "Rank " << config.world_rank << ": MPI_Send failed";
|
||||
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Recv(&recv_connect_info, // Receive into class member
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE))
|
||||
<< "Rank " << config.world_rank << ": MPI_Recv failed";
|
||||
|
||||
// Perform the actual connection using the received info
|
||||
auto result = shmTransport.send.connect(comm_handle,
|
||||
&recv_connect_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
&send_connector);
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": Send connect failed: " << ncclGetErrorString(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Exchange connect info with sender using MPI
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Recv(&send_connect_info, // Receive into class member
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE))
|
||||
<< "Rank " << config.world_rank << ": MPI_Recv failed";
|
||||
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Send(&recv_connect_info, // Use class member from testShmSetup()
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD))
|
||||
<< "Rank " << config.world_rank << ": MPI_Send failed";
|
||||
|
||||
// Perform the actual connection using the received info
|
||||
auto result = shmTransport.recv.connect(comm_handle,
|
||||
&send_connect_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
&recv_connector);
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": Recv connect failed: " << ncclGetErrorString(result);
|
||||
}
|
||||
|
||||
// Synchronize the stream to ensure all RCCL operations complete
|
||||
ASSERT_EQ(hipSuccess, syncStream(config.stream, config.world_rank))
|
||||
<< "Rank " << config.world_rank << ": Stream synchronization failed";
|
||||
}
|
||||
|
||||
// Test actual data transfer through SHM
|
||||
void testShmDataTransfer()
|
||||
{
|
||||
// Initialize host data vectors
|
||||
const size_t num_elements = shm_config.buffer_size / sizeof(uint32_t);
|
||||
host_recv_data.resize(num_elements);
|
||||
host_send_data.resize(num_elements);
|
||||
|
||||
// Use RCCL point-to-point operations to validate SHM transport
|
||||
const size_t count = shm_config.buffer_size / sizeof(float);
|
||||
const auto result = shm_config.is_sender ? ncclSend(shm_config.send_buffer,
|
||||
count,
|
||||
ncclFloat,
|
||||
config.peer_rank,
|
||||
config.nccl_comm,
|
||||
config.stream)
|
||||
: ncclRecv(shm_config.recv_buffer,
|
||||
count,
|
||||
ncclFloat,
|
||||
config.peer_rank,
|
||||
config.nccl_comm,
|
||||
config.stream);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank << ": RCCL " << (shm_config.is_sender ? "Send" : "Recv")
|
||||
<< " failed: " << ncclGetErrorString(result);
|
||||
|
||||
ASSERT_EQ(hipSuccess, syncStream(config.stream, config.world_rank))
|
||||
<< "Rank " << config.world_rank << ": Stream synchronization failed";
|
||||
|
||||
// Only validate data on the receiver side
|
||||
if(!shm_config.is_sender)
|
||||
{
|
||||
ASSERT_FALSE(host_recv_data.empty())
|
||||
<< "Rank " << config.world_rank << ": host_recv_data is empty";
|
||||
ASSERT_NE(nullptr, shm_config.recv_buffer)
|
||||
<< "Rank " << config.world_rank << ": recv_buffer is null";
|
||||
|
||||
ASSERT_EQ(hipSuccess,
|
||||
hipMemcpy(host_recv_data.data(),
|
||||
shm_config.recv_buffer,
|
||||
shm_config.buffer_size,
|
||||
hipMemcpyDeviceToHost))
|
||||
<< "Rank " << config.world_rank << ": hipMemcpy DeviceToHost failed";
|
||||
|
||||
// Validate received data - should match sender's original pattern
|
||||
const size_t validation_count = std::min(num_elements, kMaxValidationElements);
|
||||
for(size_t i = 0; i < validation_count; i++)
|
||||
{
|
||||
const float expected_float
|
||||
= static_cast<float>(config.peer_rank * kDefaultPatternMultiplier + i);
|
||||
const uint32_t expected_value = *reinterpret_cast<const uint32_t*>(&expected_float);
|
||||
|
||||
EXPECT_EQ(expected_value, host_recv_data[i])
|
||||
<< "Rank " << config.world_rank << ": Data mismatch at index " << i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test resource cleanup
|
||||
void testShmCleanup()
|
||||
{
|
||||
// Ensure all stream operations complete before validation
|
||||
[[maybe_unused]] auto err = syncStream(config.stream, config.world_rank);
|
||||
// Don't return error on sync failure - continue with validation
|
||||
|
||||
// Validate that connector resources are still valid at this point
|
||||
// The actual cleanup will be handled by base class TearDown()
|
||||
auto* connector = shm_config.is_sender ? &send_connector : &recv_connector;
|
||||
|
||||
EXPECT_NE(nullptr, connector)
|
||||
<< "Rank " << config.world_rank << ": Connector pointer is null";
|
||||
|
||||
if(connector)
|
||||
{
|
||||
EXPECT_NE(nullptr, connector->transportResources)
|
||||
<< "Rank " << config.world_rank << ": " << (shm_config.is_sender ? "Send" : "Recv")
|
||||
<< " connector transport resources are null (premature cleanup)";
|
||||
|
||||
if(config.world_rank == 0 && connector->transportResources)
|
||||
{
|
||||
TEST_INFO("Connector resources validated - still active (will be freed by base class)");
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Connectors will be automatically freed by base class TearDown()
|
||||
// Device sync + connector cleanup happens BEFORE buffers are freed, which is critical for CE memcpy
|
||||
}
|
||||
|
||||
// Test SHM with memcpy mode enabled (CE - Copy Engine)
|
||||
// This test uses the transport API directly to ensure SHM methods are called
|
||||
void testShmWithMemcpy()
|
||||
{
|
||||
// Check if NCCL_SHM_USE_CUDA_MEMCPY is set externally
|
||||
const char* shm_memcpy_env = getenv("NCCL_SHM_USE_CUDA_MEMCPY");
|
||||
if(!shm_memcpy_env || strcmp(shm_memcpy_env, "1") != 0)
|
||||
{
|
||||
if(MPIEnvironment::world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Skipping CE memcpy test - NCCL_SHM_USE_CUDA_MEMCPY not set to '1'");
|
||||
TEST_INFO("To enable this test, set: export NCCL_SHM_USE_CUDA_MEMCPY=1");
|
||||
} // Skip test gracefully
|
||||
}
|
||||
|
||||
// Validate preconditions
|
||||
ASSERT_NE(nullptr, comm_handle) << "Rank " << config.world_rank << ": comm_handle is null";
|
||||
ASSERT_NE(nullptr, local_peer_info)
|
||||
<< "Rank " << config.world_rank << ": local_peer_info is null";
|
||||
ASSERT_NE(nullptr, remote_peer_info)
|
||||
<< "Rank " << config.world_rank << ": remote_peer_info is null";
|
||||
|
||||
// Step 1: Test shmCanConnect with CE memcpy enabled
|
||||
int can_connect = 0;
|
||||
ncclResult_t result = shmTransport.canConnect(&can_connect,
|
||||
comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result) << "Rank " << config.world_rank
|
||||
<< ": shmCanConnect failed: " << ncclGetErrorString(result);
|
||||
|
||||
ASSERT_EQ(1, can_connect)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": SHM cannot connect - test skipped but connection was expected";
|
||||
|
||||
// Step 2: Test SHM setup with CE memcpy enabled
|
||||
|
||||
ncclConnect send_connect_info{};
|
||||
ncclConnect recv_connect_info{};
|
||||
|
||||
if(shm_config.is_sender)
|
||||
{
|
||||
result = shmTransport.send.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&send_connect_info,
|
||||
&send_connector,
|
||||
0,
|
||||
0);
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": SHM send setup with CE memcpy failed: " << ncclGetErrorString(result);
|
||||
|
||||
// Exchange connect info with receiver
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Send(&send_connect_info,
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD));
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Recv(&recv_connect_info,
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE));
|
||||
}
|
||||
else
|
||||
{
|
||||
result = shmTransport.recv.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&recv_connect_info,
|
||||
&recv_connector,
|
||||
0,
|
||||
0);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": SHM recv setup with CE memcpy failed: " << ncclGetErrorString(result);
|
||||
|
||||
// Exchange connect info with sender
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Recv(&send_connect_info,
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD,
|
||||
MPI_STATUS_IGNORE));
|
||||
ASSERT_EQ(MPI_SUCCESS,
|
||||
MPI_Send(&recv_connect_info,
|
||||
sizeof(ncclConnect),
|
||||
MPI_BYTE,
|
||||
config.peer_rank,
|
||||
0,
|
||||
MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
// Step 3: Test SHM connect with CE memcpy
|
||||
|
||||
if(shm_config.is_sender)
|
||||
{
|
||||
result = shmTransport.send.connect(comm_handle,
|
||||
&recv_connect_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
&send_connector);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": SHM send connect with CE memcpy failed: " << ncclGetErrorString(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = shmTransport.recv.connect(comm_handle,
|
||||
&send_connect_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
&recv_connector);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": SHM recv connect with CE memcpy failed: " << ncclGetErrorString(result);
|
||||
}
|
||||
|
||||
// Step 4: Send large buffer with CE memcpy and validate
|
||||
const size_t buffer_size = kCEMemcpyBufferSize;
|
||||
const size_t num_elements = buffer_size / sizeof(float);
|
||||
void* send_buffer = nullptr;
|
||||
void* recv_buffer = nullptr;
|
||||
|
||||
hipError_t hip_result = hipMalloc(&send_buffer, buffer_size);
|
||||
ASSERT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Failed to allocate send buffer";
|
||||
auto sendBufferGuard = makeDeviceBufferAutoGuard(send_buffer);
|
||||
|
||||
hip_result = hipMalloc(&recv_buffer, buffer_size);
|
||||
ASSERT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Failed to allocate recv buffer";
|
||||
auto recvBufferGuard = makeDeviceBufferAutoGuard(recv_buffer);
|
||||
|
||||
// Initialize send buffer with unique pattern
|
||||
hip_result = initializeBufferWithPattern<float>(
|
||||
send_buffer,
|
||||
num_elements,
|
||||
[rank = config.world_rank](size_t i)
|
||||
{ return static_cast<float>(rank * kLargePatternMultiplier + (i % kPatternModulo)); });
|
||||
ASSERT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Failed to initialize send buffer";
|
||||
|
||||
hip_result = hipMemset(recv_buffer, 0, buffer_size);
|
||||
ASSERT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Failed to zero recv buffer";
|
||||
|
||||
// Synchronize stream before transfer
|
||||
hip_result = syncStream(config.stream, config.world_rank);
|
||||
ASSERT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Stream sync failed before transfer";
|
||||
|
||||
// Perform the actual data transfer using NCCL
|
||||
const size_t count = buffer_size / sizeof(float);
|
||||
result = shm_config.is_sender ? ncclSend(send_buffer,
|
||||
count,
|
||||
ncclFloat,
|
||||
config.peer_rank,
|
||||
config.nccl_comm,
|
||||
config.stream)
|
||||
: ncclRecv(recv_buffer,
|
||||
count,
|
||||
ncclFloat,
|
||||
config.peer_rank,
|
||||
config.nccl_comm,
|
||||
config.stream);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result) << "Rank " << config.world_rank << ": Large buffer "
|
||||
<< (shm_config.is_sender ? "Send" : "Recv")
|
||||
<< " with CE memcpy failed: " << ncclGetErrorString(result);
|
||||
|
||||
// Synchronize to ensure transfer completes
|
||||
hip_result = syncStream(config.stream, config.world_rank);
|
||||
ASSERT_EQ(hipSuccess, hip_result)
|
||||
<< "Rank " << config.world_rank << ": Stream sync failed after transfer";
|
||||
|
||||
// Step 5: Validate received data (on receiver only)
|
||||
if(!shm_config.is_sender)
|
||||
{
|
||||
// Verify with custom pattern check (matching initialization pattern)
|
||||
size_t error_idx;
|
||||
float expected_val, actual_val;
|
||||
bool data_correct = verifyBufferData<float>(
|
||||
recv_buffer,
|
||||
num_elements,
|
||||
[peer_rank = config.peer_rank](size_t i) {
|
||||
return static_cast<float>(peer_rank * kLargePatternMultiplier
|
||||
+ (i % kPatternModulo));
|
||||
},
|
||||
0, // verify all elements
|
||||
1e-5,
|
||||
&error_idx,
|
||||
&expected_val,
|
||||
&actual_val);
|
||||
|
||||
EXPECT_TRUE(data_correct) << "Rank " << config.world_rank
|
||||
<< ": Data validation failed at index " << error_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Test SHM buffer allocation and sharing
|
||||
void testShmBufferAllocation()
|
||||
{
|
||||
// Test buffer allocation with various sizes
|
||||
const std::vector<size_t> test_sizes
|
||||
= {kSmallBufferSize, kMediumBufferSize, kLargeBufferSize};
|
||||
|
||||
for(const auto size : test_sizes)
|
||||
{
|
||||
void* send_buff = nullptr;
|
||||
void* recv_buff = nullptr;
|
||||
|
||||
// Allocate with local guards (store_in_base=false)
|
||||
// Guards will cleanup at end of loop iteration
|
||||
auto [sendGuard, recvGuard]
|
||||
= allocateAndInitBuffersGuarded(&send_buff, &recv_buff, size, size, false);
|
||||
|
||||
// Verify buffers are accessible
|
||||
EXPECT_NE(send_buff, nullptr) << "Rank " << config.world_rank << ": send_buff is null";
|
||||
EXPECT_NE(recv_buff, nullptr) << "Rank " << config.world_rank << ": recv_buff is null";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ShmMPITest, ShmWorkflow)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
// Create test-specific communicator for isolation
|
||||
// Use ASSERT_MPI_SUCCESS to prevent deadlock if creation fails on some ranks
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Starting comprehensive SHM transport workflow test with %d processes", config.world_size);
|
||||
TEST_INFO("This test exercises the low-level SHM transport API");
|
||||
}
|
||||
|
||||
// Test 1: SHM Capability Detection
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Step 1: Testing SHM canConnect capability");
|
||||
}
|
||||
testShmCanConnect();
|
||||
|
||||
// Test 2: SHM Setup
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Step 2: Setting up SHM transport connectors");
|
||||
}
|
||||
testShmSetup();
|
||||
|
||||
// Test 3: SHM Connection
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Step 3: Connecting SHM transport");
|
||||
}
|
||||
testShmConnect();
|
||||
|
||||
// Test 4: Data Transfer through SHM
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Step 4: Performing SHM data transfer");
|
||||
}
|
||||
testShmDataTransfer();
|
||||
|
||||
// Test 5: Resource Cleanup
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Step 5: Validating resource cleanup");
|
||||
}
|
||||
testShmCleanup();
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("SHM transport workflow test completed successfully");
|
||||
TEST_INFO("NOTE: Base class TearDown() handles connector cleanup automatically");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmWithMemcpyTest)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
// Create test-specific communicator for isolation
|
||||
// Use ASSERT_MPI_SUCCESS to prevent deadlock if creation fails on some ranks
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
testShmWithMemcpy();
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmBufferAllocationTest)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
// Use ASSERT_MPI_SUCCESS to prevent deadlock if creation fails on some ranks
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
testShmBufferAllocation();
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmTransfer_ZeroSizeBuffer)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
// Use ASSERT_MPI_SUCCESS to prevent deadlock if creation fails on some ranks
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
// Allocate minimal buffer
|
||||
void* buffer = nullptr;
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipMalloc(&buffer, 1)); // Allocate 1 byte
|
||||
auto bufferGuard = makeDeviceBufferAutoGuard(buffer); // Device memory
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
const int peer = is_sender ? 1 : 0;
|
||||
|
||||
// Try to send/recv 0 elements
|
||||
const auto result = is_sender
|
||||
? ncclSend(buffer, 0, ncclFloat, peer, config.nccl_comm, config.stream)
|
||||
: ncclRecv(buffer, 0, ncclFloat, peer, config.nccl_comm, config.stream);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank << ": Zero-size transfer should succeed";
|
||||
|
||||
HIP_TEST_CHECK_GTEST_FAIL(syncStream(config.stream, config.world_rank));
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmTransfer_VeryLargeBuffer)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
// Use ASSERT_MPI_SUCCESS to prevent deadlock if creation fails on some ranks
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
// Try to allocate a very large buffer
|
||||
const size_t large_size = kCEMemcpyBufferSize;
|
||||
void* send_buffer = nullptr;
|
||||
void* recv_buffer = nullptr;
|
||||
|
||||
hipError_t hip_result = hipMalloc(&send_buffer, large_size);
|
||||
auto sendBufferGuard = makeDeviceBufferAutoGuard(send_buffer);
|
||||
|
||||
hip_result = hipMalloc(&recv_buffer, large_size);
|
||||
auto recvBufferGuard = makeDeviceBufferAutoGuard(recv_buffer);
|
||||
|
||||
// Initialize buffer
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipMemset(send_buffer, 0x42, large_size));
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
const int peer = is_sender ? 1 : 0;
|
||||
const size_t count = large_size / sizeof(float);
|
||||
|
||||
// Perform send/recv with large buffer
|
||||
const auto result
|
||||
= is_sender
|
||||
? ncclSend(send_buffer, count, ncclFloat, peer, config.nccl_comm, config.stream)
|
||||
: ncclRecv(recv_buffer, count, ncclFloat, peer, config.nccl_comm, config.stream);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank << ": Large buffer transfer failed";
|
||||
|
||||
HIP_TEST_CHECK_GTEST_FAIL(syncStream(config.stream, config.world_rank));
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmTransfer_UnalignedBufferAddress)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
// Allocate aligned buffer
|
||||
const size_t buffer_size = 4096;
|
||||
void* aligned_buffer = nullptr;
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipMalloc(&aligned_buffer, buffer_size));
|
||||
auto bufferGuard = makeDeviceBufferAutoGuard(aligned_buffer); // Device memory
|
||||
|
||||
// Create unaligned pointer (offset by 1 byte)
|
||||
void* unaligned_buffer = static_cast<char*>(aligned_buffer) + 1;
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
const int peer = is_sender ? 1 : 0;
|
||||
|
||||
const auto result
|
||||
= is_sender
|
||||
? ncclSend(unaligned_buffer, 1024, ncclChar, peer, config.nccl_comm, config.stream)
|
||||
: ncclRecv(unaligned_buffer, 1024, ncclChar, peer, config.nccl_comm, config.stream);
|
||||
|
||||
// Don't fail the test - just report the result
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipStreamSynchronize(config.stream));
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmMultipleConsecutiveTransfers)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
const size_t buffer_size = kMediumBufferSize;
|
||||
void* send_buffer = nullptr;
|
||||
void* recv_buffer = nullptr;
|
||||
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipMalloc(&send_buffer, buffer_size));
|
||||
auto sendBufferGuard = makeDeviceBufferAutoGuard(send_buffer);
|
||||
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipMalloc(&recv_buffer, buffer_size));
|
||||
auto recvBufferGuard = makeDeviceBufferAutoGuard(recv_buffer);
|
||||
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipMemset(send_buffer, 0xAB, buffer_size));
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
const int peer = is_sender ? 1 : 0;
|
||||
const size_t count = buffer_size / sizeof(float);
|
||||
|
||||
for(int i = 0; i < kMultipleTransferCount; i++)
|
||||
{
|
||||
const auto result
|
||||
= is_sender
|
||||
? ncclSend(send_buffer, count, ncclFloat, peer, config.nccl_comm, config.stream)
|
||||
: ncclRecv(recv_buffer, count, ncclFloat, peer, config.nccl_comm, config.stream);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank << ": Transfer " << i << " failed";
|
||||
|
||||
// Ensure both ranks have posted their NCCL operations before synchronizing
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
HIP_TEST_CHECK_GTEST_FAIL(hipStreamSynchronize(config.stream));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmCleanup_DoubleCleanup)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
auto* connector = is_sender ? &send_connector : &recv_connector;
|
||||
|
||||
// Setup connector
|
||||
ncclConnect connect_info{};
|
||||
const auto setup_result = is_sender ? shmTransport.send.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&connect_info,
|
||||
connector,
|
||||
0,
|
||||
0)
|
||||
: shmTransport.recv.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&connect_info,
|
||||
connector,
|
||||
0,
|
||||
0);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, setup_result) << "Rank " << config.world_rank << ": Setup failed";
|
||||
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
// First cleanup
|
||||
if(connector->transportResources)
|
||||
{
|
||||
const auto result1
|
||||
= is_sender ? shmTransport.send.free(connector) : shmTransport.recv.free(connector);
|
||||
EXPECT_EQ(ncclSuccess, result1) << "Rank " << config.world_rank << ": First cleanup failed";
|
||||
}
|
||||
|
||||
// Second cleanup (should handle gracefully since resources are already freed)
|
||||
[[maybe_unused]] const auto result2
|
||||
= is_sender ? shmTransport.send.free(connector) : shmTransport.recv.free(connector);
|
||||
|
||||
// Mark as cleaned up
|
||||
connector->transportResources = nullptr;
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmConnect_WithoutSetup)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Testing SHM connect without prior setup (%d processes)", config.world_size);
|
||||
}
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
auto* connector = is_sender ? &send_connector : &recv_connector;
|
||||
|
||||
// Create empty/uninitialized connect info (simulates invalid state)
|
||||
ncclConnect invalid_connect_info{};
|
||||
memset(&invalid_connect_info, 0, sizeof(ncclConnect));
|
||||
|
||||
// Try to connect without calling setup first - this should fail or handle gracefully
|
||||
const auto result = is_sender ? shmTransport.send.connect(comm_handle,
|
||||
&invalid_connect_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
connector)
|
||||
: shmTransport.recv.connect(comm_handle,
|
||||
&invalid_connect_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
connector);
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Connect without setup result: %s", ncclGetErrorString(result));
|
||||
TEST_INFO("Note: This tests invalid state handling");
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ShmMPITest, ShmConnect_CorruptedConnectInfo)
|
||||
{
|
||||
ASSERT_TRUE(validateTestPrerequisites(kMinProcessesForMPI,
|
||||
kNoProcessLimit,
|
||||
kRequirePowerOfTwo,
|
||||
1,
|
||||
kRequireSingleNode))
|
||||
<< "Test requirements not met - all ranks must meet requirements";
|
||||
|
||||
ASSERT_MPI_SUCCESS(createTestCommunicator());
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Testing SHM connect with corrupted connect info (%d processes)",
|
||||
config.world_size);
|
||||
}
|
||||
|
||||
const bool is_sender = (config.world_rank == 0);
|
||||
auto* connector = is_sender ? &send_connector : &recv_connector;
|
||||
|
||||
// First, do valid setup
|
||||
ncclConnect valid_connect_info{};
|
||||
const auto setup_result = is_sender ? shmTransport.send.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&valid_connect_info,
|
||||
connector,
|
||||
0,
|
||||
0)
|
||||
: shmTransport.recv.setup(comm_handle,
|
||||
topology_graph,
|
||||
local_peer_info,
|
||||
remote_peer_info,
|
||||
&valid_connect_info,
|
||||
connector,
|
||||
0,
|
||||
0);
|
||||
|
||||
ASSERT_EQ(ncclSuccess, setup_result) << "Rank " << config.world_rank << ": Setup failed";
|
||||
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
// Create corrupted connect info (fill with invalid data)
|
||||
ncclConnect corrupted_info{};
|
||||
memset(&corrupted_info, 0xFF, sizeof(ncclConnect)); // Fill with 0xFF
|
||||
|
||||
// Try to connect with corrupted info
|
||||
// This tests internal validation of connect info structures
|
||||
const auto result = is_sender ? shmTransport.send.connect(comm_handle,
|
||||
&corrupted_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
connector)
|
||||
: shmTransport.recv.connect(comm_handle,
|
||||
&corrupted_info,
|
||||
config.world_size,
|
||||
config.world_rank,
|
||||
connector);
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Connect with corrupted info result: %s", ncclGetErrorString(result));
|
||||
TEST_INFO("Note: Tests connect info validation similar to proxy function validation");
|
||||
}
|
||||
|
||||
// Cleanup properly allocated resources
|
||||
if(connector->transportResources)
|
||||
{
|
||||
const auto cleanup_result
|
||||
= is_sender ? shmTransport.send.free(connector) : shmTransport.recv.free(connector);
|
||||
(void)cleanup_result; // Ignore result as we're in error path
|
||||
connector->transportResources = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
@@ -0,0 +1,313 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "TransportMPIBase.hpp"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
|
||||
namespace
|
||||
{
|
||||
// Test pattern generation constants for TransportTestBase
|
||||
inline constexpr int kDefaultPatternMultiplier = 100; // For transport base patterns
|
||||
inline constexpr int kByteValueModulo = 256; // For uint8_t wraparound
|
||||
} // namespace
|
||||
|
||||
// Override createTestCommunicator to also update config and transport components
|
||||
ncclResult_t TransportTestBase::createTestCommunicator()
|
||||
{
|
||||
// Call base class implementation
|
||||
ncclResult_t result = MPITestBase::createTestCommunicator();
|
||||
|
||||
if(result == ncclSuccess)
|
||||
{
|
||||
// Update config with the new communicator and stream
|
||||
config.nccl_comm = getActiveCommunicator();
|
||||
config.stream = getActiveStream();
|
||||
|
||||
// Initialize transport components now that we have a valid communicator
|
||||
comm_handle = config.nccl_comm;
|
||||
local_peer_info = &comm_handle->peerInfo[config.world_rank];
|
||||
remote_peer_info = &comm_handle->peerInfo[config.peer_rank];
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("TransportTestBase config and transport components updated with per-test "
|
||||
"communicator");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Set transport type and initialize connectors accordingly
|
||||
void TransportTestBase::setTransportType(TransportType type)
|
||||
{
|
||||
initialized_transport = type;
|
||||
|
||||
switch(type)
|
||||
{
|
||||
case TransportType::P2P:
|
||||
send_connector.transportComm = &p2pTransport.send;
|
||||
recv_connector.transportComm = &p2pTransport.recv;
|
||||
break;
|
||||
case TransportType::Network:
|
||||
send_connector.transportComm = &netTransport.send;
|
||||
recv_connector.transportComm = &netTransport.recv;
|
||||
break;
|
||||
case TransportType::SHM:
|
||||
send_connector.transportComm = &shmTransport.send;
|
||||
recv_connector.transportComm = &shmTransport.recv;
|
||||
break;
|
||||
case TransportType::None:
|
||||
send_connector.transportComm = nullptr;
|
||||
recv_connector.transportComm = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// SetUp: Initialize common transport test components
|
||||
void TransportTestBase::SetUp()
|
||||
{
|
||||
// Call GTest's SetUp (which will call MPITestCore::initializeTest())
|
||||
MPITestBase::SetUp();
|
||||
|
||||
// Initialize test configuration using aggregate initialization
|
||||
// Note: rccl_comm and stream are set to nullptr initially; tests must call createTestCommunicator()
|
||||
config = {.world_rank = MPIEnvironment::world_rank,
|
||||
.world_size = MPIEnvironment::world_size,
|
||||
.peer_rank = (MPIEnvironment::world_rank == 0) ? 1 : 0,
|
||||
.nccl_comm = nullptr,
|
||||
.stream = nullptr};
|
||||
|
||||
// Require at least 2 MPI processes for testing
|
||||
if(config.world_size < 2)
|
||||
{
|
||||
GTEST_SKIP() << "Transport testing requires at least 2 MPI processes";
|
||||
}
|
||||
|
||||
// Check if MPIEnvironment was properly initialized
|
||||
if(MPIEnvironment::retCode != 0)
|
||||
{
|
||||
GTEST_FAIL() << "MPIEnvironment initialization failed";
|
||||
}
|
||||
|
||||
// Initialize transport component pointers to nullptr
|
||||
// They will be set in createTestCommunicator() after the communicator is created
|
||||
comm_handle = nullptr;
|
||||
local_peer_info = nullptr;
|
||||
remote_peer_info = nullptr;
|
||||
|
||||
// Create and initialize topology graph
|
||||
topology_graph = static_cast<ncclTopoGraph*>(malloc(sizeof(ncclTopoGraph)));
|
||||
if(topology_graph)
|
||||
{
|
||||
*topology_graph = {.id = 0,
|
||||
.pattern = NCCL_TOPO_PATTERN_RING,
|
||||
.nChannels = 1,
|
||||
.bwIntra = 0.0f,
|
||||
.bwInter = 0.0f,
|
||||
.typeIntra = PATH_SYS,
|
||||
.typeInter = PATH_NET};
|
||||
}
|
||||
|
||||
// Initialize with P2P transport by default
|
||||
// Tests can call setTransportType() to switch to SHM or Network
|
||||
setTransportType(TransportType::P2P);
|
||||
}
|
||||
|
||||
// TearDown: Cleanup common transport test components
|
||||
void TransportTestBase::TearDown()
|
||||
{
|
||||
// CRITICAL: Synchronize device before freeing connectors
|
||||
// The transport proxy may have its own internal stream for CE memcpy operations
|
||||
// that must be idle before we can destroy it
|
||||
// Note: We ignore errors here as we're in cleanup path
|
||||
(void)hipDeviceSynchronize();
|
||||
|
||||
// Cleanup topology graph
|
||||
if(topology_graph)
|
||||
{
|
||||
free(topology_graph);
|
||||
topology_graph = nullptr;
|
||||
}
|
||||
|
||||
// Cleanup transport resources based on initialized transport type
|
||||
if(send_connector.transportResources)
|
||||
{
|
||||
if(initialized_transport == TransportType::P2P)
|
||||
{
|
||||
p2pTransport.send.free(&send_connector);
|
||||
}
|
||||
else if(initialized_transport == TransportType::SHM)
|
||||
{
|
||||
shmTransport.send.free(&send_connector);
|
||||
}
|
||||
else if(initialized_transport == TransportType::Network)
|
||||
{
|
||||
netTransport.send.free(&send_connector);
|
||||
}
|
||||
send_connector.transportResources = nullptr;
|
||||
}
|
||||
if(recv_connector.transportResources)
|
||||
{
|
||||
if(initialized_transport == TransportType::P2P)
|
||||
{
|
||||
p2pTransport.recv.free(&recv_connector);
|
||||
}
|
||||
else if(initialized_transport == TransportType::SHM)
|
||||
{
|
||||
shmTransport.recv.free(&recv_connector);
|
||||
}
|
||||
else if(initialized_transport == TransportType::Network)
|
||||
{
|
||||
netTransport.recv.free(&recv_connector);
|
||||
}
|
||||
recv_connector.transportResources = nullptr;
|
||||
}
|
||||
|
||||
// Reset transport type
|
||||
initialized_transport = TransportType::None;
|
||||
|
||||
// Nullify peer info pointers
|
||||
local_peer_info = nullptr;
|
||||
remote_peer_info = nullptr;
|
||||
comm_handle = nullptr;
|
||||
|
||||
// Note: Clear RAII guard vectors BEFORE destroying communicator
|
||||
// The guards (especially NcclRegHandleGuard) need the communicator to be valid
|
||||
// when they call ncclCommDeregister() in their destructors
|
||||
reg_handle_guards_.clear();
|
||||
buffer_guards_.clear();
|
||||
|
||||
// Call base class TearDown to cleanup test communicator
|
||||
// This calls MPITestBase::TearDown() -> MPITestCore::cleanupTest() -> cleanupTestCommunicator()
|
||||
MPITestBase::TearDown();
|
||||
}
|
||||
|
||||
// Allocate and initialize test buffers
|
||||
void TransportTestBase::allocateAndInitBuffers(void** send_buffer,
|
||||
void** recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes)
|
||||
{
|
||||
// Allocate send buffer
|
||||
ASSERT_EQ(hipSuccess, hipMalloc(send_buffer, send_bytes))
|
||||
<< "Rank " << config.world_rank << ": Failed to allocate send buffer";
|
||||
|
||||
// Allocate recv buffer
|
||||
ASSERT_EQ(hipSuccess, hipMalloc(recv_buffer, recv_bytes))
|
||||
<< "Rank " << config.world_rank << ": Failed to allocate recv buffer";
|
||||
|
||||
std::vector<uint8_t> host_data(send_bytes);
|
||||
for(size_t i = 0; i < host_data.size(); i++)
|
||||
{
|
||||
host_data[i] = static_cast<uint8_t>((config.world_rank * kDefaultPatternMultiplier + i)
|
||||
% kByteValueModulo);
|
||||
}
|
||||
|
||||
ASSERT_EQ(hipSuccess,
|
||||
hipMemcpy(*send_buffer, host_data.data(), send_bytes, hipMemcpyHostToDevice))
|
||||
<< "Rank " << config.world_rank << ": Failed to initialize send buffer";
|
||||
|
||||
if(config.world_rank == 0)
|
||||
{
|
||||
TEST_INFO("Allocated and initialized buffers (%zu bytes each)", send_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-register buffers with ncclCommRegister
|
||||
void TransportTestBase::preRegisterBuffers(void* send_buffer,
|
||||
void* recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes,
|
||||
void** send_reg_handle,
|
||||
void** recv_reg_handle)
|
||||
{
|
||||
ncclComm_t comm = getActiveCommunicator();
|
||||
|
||||
// Register send buffer
|
||||
ncclResult_t result = ncclCommRegister(comm, send_buffer, send_bytes, send_reg_handle);
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": Failed to pre-register send buffer: " << ncclGetErrorString(result);
|
||||
|
||||
// Register recv buffer
|
||||
result = ncclCommRegister(comm, recv_buffer, recv_bytes, recv_reg_handle);
|
||||
ASSERT_EQ(ncclSuccess, result)
|
||||
<< "Rank " << config.world_rank
|
||||
<< ": Failed to pre-register recv buffer: " << ncclGetErrorString(result);
|
||||
}
|
||||
|
||||
// Buffer allocation with automatic RAII guards
|
||||
std::pair<DeviceBufferAutoGuard, DeviceBufferAutoGuard>
|
||||
TransportTestBase::allocateAndInitBuffersGuarded(void** send_buffer,
|
||||
void** recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes,
|
||||
bool store_in_base)
|
||||
{
|
||||
// Allocate buffers using existing method
|
||||
allocateAndInitBuffers(send_buffer, recv_buffer, send_bytes, recv_bytes);
|
||||
|
||||
// Create guards
|
||||
auto sendGuard = makeDeviceBufferAutoGuard(*send_buffer); // Device memory
|
||||
auto recvGuard = makeDeviceBufferAutoGuard(*recv_buffer); // Device memory
|
||||
|
||||
if(store_in_base)
|
||||
{
|
||||
// Store guards in base class for cleanup at test end
|
||||
buffer_guards_.push_back(std::move(sendGuard));
|
||||
buffer_guards_.push_back(std::move(recvGuard));
|
||||
|
||||
// Return empty guards (resources now managed by base class)
|
||||
return {makeDeviceBufferAutoGuard(nullptr), makeDeviceBufferAutoGuard(nullptr)};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return guards for caller to manage (cleanup at caller's scope exit)
|
||||
return {std::move(sendGuard), std::move(recvGuard)};
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer registration with automatic RAII guards
|
||||
std::pair<NcclRegHandleGuard, NcclRegHandleGuard>
|
||||
TransportTestBase::preRegisterBuffersGuarded(void* send_buffer,
|
||||
void* recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes,
|
||||
void** send_reg_handle,
|
||||
void** recv_reg_handle,
|
||||
bool store_in_base)
|
||||
{
|
||||
// Register buffers using existing method
|
||||
preRegisterBuffers(send_buffer,
|
||||
recv_buffer,
|
||||
send_bytes,
|
||||
recv_bytes,
|
||||
send_reg_handle,
|
||||
recv_reg_handle);
|
||||
|
||||
// Create guards (handles may be nullptr if registration is not needed)
|
||||
NcclRegHandleGuard sendGuard(*send_reg_handle, NcclRegHandleDeleter(getActiveCommunicator()));
|
||||
NcclRegHandleGuard recvGuard(*recv_reg_handle, NcclRegHandleDeleter(getActiveCommunicator()));
|
||||
|
||||
if(store_in_base)
|
||||
{
|
||||
// Store guards in base class for cleanup at test end
|
||||
reg_handle_guards_.push_back(std::move(sendGuard));
|
||||
reg_handle_guards_.push_back(std::move(recvGuard));
|
||||
|
||||
// Return empty guards (resources now managed by base class)
|
||||
return {makeRegHandleGuard(nullptr, nullptr), makeRegHandleGuard(nullptr, nullptr)};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return guards for caller to manage (cleanup at caller's scope exit)
|
||||
return {std::move(sendGuard), std::move(recvGuard)};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
@@ -0,0 +1,295 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#ifndef TRANSPORT_MPI_BASE_HPP
|
||||
#define TRANSPORT_MPI_BASE_HPP
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "rccl/rccl.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#ifdef MPI_TESTS_ENABLED
|
||||
#include "MPITestBase.hpp"
|
||||
#include "MPIEnvironment.hpp"
|
||||
#include "TestChecks.hpp"
|
||||
#include "ResourceGuards.hpp"
|
||||
#include "comm.h"
|
||||
#include "core.h"
|
||||
#include "device.h"
|
||||
#include "graph.h"
|
||||
#include "graph/topo.h"
|
||||
#include "nccl_common.h"
|
||||
#include "transport.h"
|
||||
|
||||
using namespace RCCLTestGuards;
|
||||
|
||||
// Transport-specific RAII deleters
|
||||
namespace RCCLTestGuards
|
||||
{
|
||||
|
||||
struct TransportSendResourceDeleter
|
||||
{
|
||||
ncclTransport* transport;
|
||||
explicit TransportSendResourceDeleter(ncclTransport* t = nullptr) : transport(t) {}
|
||||
void operator()(ncclConnector* connector) const
|
||||
{
|
||||
if(connector && transport)
|
||||
{
|
||||
transport->send.free(connector);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TransportRecvResourceDeleter
|
||||
{
|
||||
ncclTransport* transport;
|
||||
explicit TransportRecvResourceDeleter(ncclTransport* t = nullptr) : transport(t) {}
|
||||
void operator()(ncclConnector* connector) const
|
||||
{
|
||||
if(connector && transport)
|
||||
{
|
||||
transport->recv.free(connector);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using TransportSendResourceGuard = ResourceGuard<ncclConnector*, TransportSendResourceDeleter>;
|
||||
using TransportRecvResourceGuard = ResourceGuard<ncclConnector*, TransportRecvResourceDeleter>;
|
||||
|
||||
class TransportResourceGuard
|
||||
{
|
||||
private:
|
||||
ncclConnector* send_connector_;
|
||||
ncclConnector* recv_connector_;
|
||||
ncclTransport* transport_;
|
||||
|
||||
public:
|
||||
TransportResourceGuard(ncclConnector* send, ncclConnector* recv, ncclTransport* transport)
|
||||
: send_connector_(send), recv_connector_(recv), transport_(transport)
|
||||
{}
|
||||
|
||||
~TransportResourceGuard()
|
||||
{
|
||||
if(recv_connector_ && transport_)
|
||||
{
|
||||
transport_->recv.free(recv_connector_);
|
||||
}
|
||||
if(send_connector_ && transport_)
|
||||
{
|
||||
transport_->send.free(send_connector_);
|
||||
}
|
||||
}
|
||||
|
||||
TransportResourceGuard(const TransportResourceGuard&) = delete;
|
||||
TransportResourceGuard& operator=(const TransportResourceGuard&) = delete;
|
||||
TransportResourceGuard(TransportResourceGuard&&) = delete;
|
||||
TransportResourceGuard& operator=(TransportResourceGuard&&) = delete;
|
||||
};
|
||||
|
||||
inline TransportSendResourceGuard makeTransportSendGuard(ncclConnector* connector,
|
||||
ncclTransport* transport)
|
||||
{
|
||||
return TransportSendResourceGuard(connector, TransportSendResourceDeleter(transport));
|
||||
}
|
||||
|
||||
inline TransportRecvResourceGuard makeTransportRecvGuard(ncclConnector* connector,
|
||||
ncclTransport* transport)
|
||||
{
|
||||
return TransportRecvResourceGuard(connector, TransportRecvResourceDeleter(transport));
|
||||
}
|
||||
|
||||
} // namespace RCCLTestGuards
|
||||
|
||||
extern struct ncclTransport p2pTransport;
|
||||
extern struct ncclTransport netTransport;
|
||||
extern struct ncclTransport shmTransport;
|
||||
|
||||
// ============================================================================
|
||||
// Transport Test Constants
|
||||
// ============================================================================
|
||||
|
||||
namespace TransportTestConstants
|
||||
{
|
||||
|
||||
// Buffer size constants (common across P2P, SHM, NET tests)
|
||||
inline constexpr size_t kDefaultBufferSize = 1024 * sizeof(float); // 4096 bytes
|
||||
inline constexpr size_t kSmallBufferSize = 256;
|
||||
inline constexpr size_t kMediumBufferSize = 16384; // 16 KB
|
||||
inline constexpr size_t kLargeBufferSize = 135168; // ~132 KB
|
||||
inline constexpr size_t kVeryLargeBufferSize = 256 * 1024 * 1024; // 256 MB
|
||||
inline constexpr size_t kCEMemcpyBufferSize = 256 * 1024 * 1024; // 256 MB (for CE tests)
|
||||
|
||||
// Pattern generation constants
|
||||
inline constexpr int kDefaultPatternMultiplier = 1000; // Standard rank-based patterns
|
||||
inline constexpr int kSmallPatternMultiplier = 100; // Smaller patterns (memcpy tests)
|
||||
inline constexpr int kLargePatternMultiplier = 1000000; // Large buffer patterns
|
||||
inline constexpr int kPatternModulo = 10000; // Wraparound patterns
|
||||
inline constexpr int kBytePatternModulo = 256; // uint8_t wraparound
|
||||
|
||||
// Validation constants
|
||||
inline constexpr size_t kMaxValidationElements = 100; // Number of elements to validate
|
||||
inline constexpr size_t kMinValidationSamples = 100; // Minimum samples for validation
|
||||
inline constexpr size_t kValidationStride = 1000; // Stride for sampling validation
|
||||
inline constexpr int kMaxErrorsToReport = 10; // Max errors to display
|
||||
|
||||
// Test iteration constants
|
||||
inline constexpr int kMultipleTransferCount = 5; // Number of sequential transfers
|
||||
|
||||
} // namespace TransportTestConstants
|
||||
|
||||
// Common test configuration
|
||||
struct TransportTestConfig
|
||||
{
|
||||
int world_rank{0};
|
||||
int world_size{0};
|
||||
int peer_rank{0};
|
||||
ncclComm_t nccl_comm{nullptr};
|
||||
hipStream_t stream{nullptr};
|
||||
};
|
||||
|
||||
// Base class for transport tests with common functionality
|
||||
// Inherits from MPITestBase to get validation capabilities
|
||||
class TransportTestBase : public MPITestBase
|
||||
{
|
||||
protected:
|
||||
TransportTestConfig config;
|
||||
|
||||
// Transport connectors (can be used for P2P or NET)
|
||||
ncclConnector send_connector = {};
|
||||
ncclConnector recv_connector = {};
|
||||
|
||||
// Track which transport type is initialized
|
||||
enum class TransportType
|
||||
{
|
||||
None,
|
||||
P2P,
|
||||
SHM,
|
||||
Network
|
||||
};
|
||||
TransportType initialized_transport = TransportType::None;
|
||||
|
||||
// Core NCCL components
|
||||
struct ncclComm* comm_handle = nullptr;
|
||||
ncclPeerInfo* local_peer_info = nullptr;
|
||||
ncclPeerInfo* remote_peer_info = nullptr;
|
||||
ncclTopoGraph* topology_graph = nullptr;
|
||||
|
||||
// RAII guards for automatic resource cleanup
|
||||
// These are managed by helper methods and cleaned up automatically
|
||||
std::vector<DeviceBufferAutoGuard> buffer_guards_;
|
||||
std::vector<NcclRegHandleGuard> reg_handle_guards_;
|
||||
|
||||
// Setup and teardown
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
|
||||
// Override createTestCommunicator to also update config
|
||||
ncclResult_t createTestCommunicator() override;
|
||||
|
||||
// Set transport type and initialize connectors
|
||||
void setTransportType(TransportType type);
|
||||
|
||||
// Buffer allocation (unguarded - for manual management)
|
||||
void allocateAndInitBuffers(void** send_buffer,
|
||||
void** recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes);
|
||||
|
||||
// Buffer allocation with automatic RAII guards
|
||||
// store_in_base=true: Guards stored in base class, cleanup at test end
|
||||
// store_in_base=false: Guards returned, caller controls cleanup scope
|
||||
std::pair<DeviceBufferAutoGuard, DeviceBufferAutoGuard> allocateAndInitBuffersGuarded(void** send_buffer,
|
||||
void** recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes,
|
||||
bool store_in_base = true);
|
||||
|
||||
// Buffer registration (unguarded - for manual management)
|
||||
void preRegisterBuffers(void* send_buffer,
|
||||
void* recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes,
|
||||
void** send_reg_handle,
|
||||
void** recv_reg_handle);
|
||||
|
||||
// Buffer registration with automatic RAII guards
|
||||
// store_in_base=true: Guards stored in base class, cleanup at test end
|
||||
// store_in_base=false: Guards returned, caller controls cleanup scope
|
||||
std::pair<NcclRegHandleGuard, NcclRegHandleGuard>
|
||||
preRegisterBuffersGuarded(void* send_buffer,
|
||||
void* recv_buffer,
|
||||
size_t send_bytes,
|
||||
size_t recv_bytes,
|
||||
void** send_reg_handle,
|
||||
void** recv_reg_handle,
|
||||
bool store_in_base = true);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Generic Stream Synchronization Helpers
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief Generic stream synchronization helper
|
||||
*
|
||||
* Synchronizes a HIP stream and returns the error code. This function is
|
||||
* marked [[nodiscard]] to ensure callers check the return value.
|
||||
*
|
||||
* @param stream HIP stream to synchronize
|
||||
* @param rank MPI rank (for error reporting, currently unused but allows
|
||||
* future enhancement with rank-specific error messages)
|
||||
* @return hipError_t Result of hipStreamSynchronize
|
||||
*
|
||||
* Usage examples:
|
||||
* - Manual error checking: hipError_t err = syncStream(stream, rank);
|
||||
* - With HIPCHECK macro: HIPCHECK(syncStream(stream, rank));
|
||||
* - With assertion macro: ASSERT_STREAM_SYNC(stream, rank);
|
||||
*/
|
||||
[[nodiscard]] inline hipError_t syncStream(hipStream_t stream, int rank = 0)
|
||||
{
|
||||
return hipStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
/**
|
||||
* @def ASSERT_STREAM_SYNC
|
||||
* @brief Macro to assert stream synchronization succeeds
|
||||
*
|
||||
* Convenience macro that combines syncStream() with ASSERT_EQ to provide
|
||||
* clean, consistent stream synchronization checks in tests.
|
||||
*
|
||||
* @param stream HIP stream to synchronize
|
||||
* @param rank MPI rank for error reporting
|
||||
*
|
||||
* Example: ASSERT_STREAM_SYNC(config.stream, config.world_rank);
|
||||
*/
|
||||
#define ASSERT_STREAM_SYNC(stream, rank) \
|
||||
ASSERT_EQ(hipSuccess, syncStream(stream, rank)) \
|
||||
<< "Rank " << rank << ": Stream synchronization failed"
|
||||
|
||||
/**
|
||||
* @def ASSERT_STREAM_SYNC_MPI
|
||||
* @brief MPI-aware stream synchronization assertion
|
||||
*
|
||||
* Uses ASSERT_MPI_EQ to ensure all ranks synchronize before failing.
|
||||
* This prevents deadlocks when one rank fails while others are waiting
|
||||
* in collective operations.
|
||||
*
|
||||
* @param stream HIP stream to synchronize
|
||||
* @param rank MPI rank for error reporting
|
||||
*
|
||||
* Example: ASSERT_STREAM_SYNC_MPI(config.stream, config.world_rank);
|
||||
*
|
||||
* @note Prefer this version in multi-rank tests to avoid hangs
|
||||
*/
|
||||
#define ASSERT_STREAM_SYNC_MPI(stream, rank) ASSERT_MPI_EQ(hipSuccess, syncStream(stream, rank))
|
||||
|
||||
#endif // MPI_TESTS_ENABLED
|
||||
|
||||
#endif // TRANSPORT_MPI_BASE_HPP
|
||||
Reference in New Issue
Block a user