/************************************************************************* * Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "net.h" #include "common/ProcessIsolatedTestRunner.hpp" #include "gtest/gtest.h" #include #include #include extern ncclNet_t ncclNetSocket; namespace RcclUnitTesting { /** * @brief Establishes a reliable connection pair (send and receive communicators) using the provided handle and listen communicator. * * This function attempts to create a pair of connected communicators (sendComm and recvComm) using the ncclNetSocket API. * It uses internal RAII guards to ensure proper cleanup in case of partial failures. The function coordinates accept and connect * operations in parallel threads, with extended timeouts and retries for robustness. * * @param handle Pointer to the device handle used for connection. * @param listenComm Pointer to the listen communicator, used for accepting connections. * @param[out] sendComm Reference to a pointer that will receive the newly created send communicator on success. * WARNING: The ownership of the communicator is transferred to the caller. The caller is responsible for * closing and cleaning up sendComm. If sendComm was previously pointing to a resource (e.g., a unique_ptr or * other managed pointer), it will be overwritten and the previous resource may be leaked. Ensure sendComm is * either nullptr or properly released before calling this function. * @param[out] recvComm Reference to a pointer that will receive the newly created receive communicator on success. * WARNING: The ownership of the communicator is transferred to the caller. The caller is responsible for * closing and cleaning up recvComm. If recvComm was previously pointing to a resource (e.g., a unique_ptr or * other managed pointer), it will be overwritten and the previous resource may be leaked. Ensure recvComm is * either nullptr or properly released before calling this function. * * @return true if both send and receive communicators were successfully established and ownership transferred to the caller; * false otherwise (in which case all resources are cleaned up internally). */ class NetSocketTests : public ::testing::Test { private: // RAII wrapper for send communicator class SendCommGuard { private: void *comm_ = nullptr; public: explicit SendCommGuard(void *comm = nullptr) : comm_(comm) {} // default constructor // Move constructor SendCommGuard(SendCommGuard &&other) noexcept : comm_(other.comm_) { other.comm_ = nullptr; } // Move assignment SendCommGuard &operator=(SendCommGuard &&other) noexcept { if (this != &other) { reset(); comm_ = other.comm_; other.comm_ = nullptr; } return *this; } // Disable copy constructor and assignment SendCommGuard(const SendCommGuard &other) = delete; SendCommGuard &operator=(const SendCommGuard &other) = delete; ~SendCommGuard() { reset(); } void reset(void *comm = nullptr) { if (comm_ && comm_ != comm) { ncclResult_t result = ncclNetSocket.closeSend(comm_); ASSERT_EQ(result, ncclSuccess) << "SendCommGuard failed to close send communicator (comm_=" << comm_ << "). ncclNetSocket.closeSend() returned error code: " << result << ". This indicates a potential resource leak or " << "invalid communicator state during RAII cleanup."; } comm_ = comm; } void *get() const { return comm_; } void *release() { void *temp = comm_; comm_ = nullptr; return temp; } explicit operator bool() const { return comm_ != nullptr; } }; // RAII wrapper for receive communicator class RecvCommGuard { private: void *comm_; public: explicit RecvCommGuard(void *comm = nullptr) : comm_(comm) {} // Move constructor RecvCommGuard(RecvCommGuard &&other) noexcept : comm_(other.comm_) { other.comm_ = nullptr; } // Move assignment RecvCommGuard &operator=(RecvCommGuard &&other) noexcept { if (this != &other) { reset(); comm_ = other.comm_; other.comm_ = nullptr; } return *this; } // Disable copy RecvCommGuard(const RecvCommGuard &) = delete; RecvCommGuard &operator=(const RecvCommGuard &) = delete; ~RecvCommGuard() { reset(); } void reset(void *comm = nullptr) { if (comm_ && comm_ != comm) { ncclResult_t result = ncclNetSocket.closeRecv(comm_); ASSERT_EQ(result, ncclSuccess) << "RecvCommGuard failed to close receive communicator (comm_=" << comm_ << "). ncclNetSocket.closeRecv() returned error code: " << result << ". This indicates a potential resource leak or " << "invalid communicator state during RAII cleanup."; } comm_ = comm; } void *get() const { return comm_; } void *release() { void *temp = comm_; comm_ = nullptr; return temp; } explicit operator bool() const { return comm_ != nullptr; } }; protected: void SetUp() override { ncclResult_t result = ncclNetSocket.init(nullptr, nullptr); ASSERT_EQ(result, ncclSuccess) << "Failed to initialize ncclNetSocket. " << "Error code: " << result << ". Ensure RCCL networking is properly configured."; result = ncclNetSocket.devices(&ndev); ASSERT_EQ(result, ncclSuccess) << "Failed to query network devices. " << "Error code: " << result << ". Check if network devices are available and accessible."; if (ndev == 0) { GTEST_SKIP() << "No network devices available for testing. " << "Ensure network hardware is present and properly configured."; } } int ndev = 0; // Common function to test socket properties void TestSocketProperties() { INFO(NCCL_LOG_INFO, "\n=== Testing socket properties ==="); // Test ncclNetSocketGetProperties for each device for (int dev = 0; dev < ndev; dev++) { ncclNetProperties_t props = {}; ncclResult_t propsResult = ncclNetSocket.getProperties(dev, &props); INFO(NCCL_LOG_INFO, "Device %d - getProperties result: %d", dev, propsResult); if (propsResult == ncclSuccess) { INFO(NCCL_LOG_INFO, " Device %d properties: name='%s', pciPath='%s', guid=%llu, " "speed=%d, port=%d, maxComms=%d", dev, props.name, props.pciPath, (unsigned long long)props.guid, props.speed, props.port, props.maxComms); } EXPECT_EQ(propsResult, ncclSuccess) << "getProperties failed for device " << dev << ". ncclNetSocket.getProperties() returned error code: " << propsResult << ". Verify device " << dev << " is available and properly configured."; } } // Common function to establish a connection pair with improved reliability bool EstablishConnectionPair(void *handle, void *listenComm, void *&sendComm, void *&recvComm) { // Allow overriding max attempts via environment variable for flexibility int maxAttempts = 100; const char* maxAttemptsEnv = getenv("RCCL_TEST_NETSOCKET_MAX_ATTEMPTS"); if (maxAttemptsEnv) { maxAttempts = ParseEnvVar(maxAttemptsEnv, "RCCL_TEST_NETSOCKET_MAX_ATTEMPTS", 100, 1); } // Allow overriding sleep duration via environment variable for flexibility int sleepMs = 100; const char* sleepMsEnv = getenv("RCCL_TEST_NETSOCKET_SLEEP_MS"); if (sleepMsEnv) { sleepMs = ParseEnvVar(sleepMsEnv, "RCCL_TEST_NETSOCKET_SLEEP_MS", 100, 1); } // Initialize output parameters sendComm = nullptr; recvComm = nullptr; // RAII guards for automatic cleanup SendCommGuard sendGuard; RecvCommGuard recvGuard; std::atomic connectionEstablished{false}; std::atomic acceptCompleted{false}; std::atomic connectCompleted{false}; std::atomic shouldStop{false}; INFO(NCCL_LOG_INFO, "Establishing connection pair with enhanced reliability"); std::thread connectAcceptThread([&]() { // Accept thread with longer timeout and better coordination std::thread acceptThread([&]() { ncclNetDeviceHandle_t *recvDevComm = nullptr; void *tempRecvComm = nullptr; // Increased attempts and longer total timeout for reliability for (int attempt = 0; attempt < maxAttempts && !shouldStop.load(); attempt++) { ncclResult_t acceptResult = ncclNetSocket.accept(listenComm, &tempRecvComm, &recvDevComm); if (acceptResult == ncclSuccess && tempRecvComm != nullptr) { recvGuard.reset(tempRecvComm); acceptCompleted.store(true); INFO(NCCL_LOG_INFO, "Accept completed successfully on attempt %d", attempt + 1); break; } // Longer sleep for network stability std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); } if (!acceptCompleted.load()) { INFO(NCCL_LOG_INFO, "Accept thread timed out after %d attempts", maxAttempts); } }); // Connect thread with longer timeout and better coordination std::thread connectThread([&]() { ncclNetCommConfig_t config = {}; ncclNetDeviceHandle_t *sendDevComm = nullptr; void *tempSendComm = nullptr; // Give accept thread more time to start listening std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); // Increased attempts and longer total timeout for reliability for (int attempt = 0; attempt < 100 && !shouldStop.load(); attempt++) { ncclResult_t connectResult = ncclNetSocket.connect( 0, &config, handle, &tempSendComm, &sendDevComm); if (connectResult == ncclSuccess && tempSendComm != nullptr) { sendGuard.reset(tempSendComm); connectCompleted = true; INFO(NCCL_LOG_INFO, "Connect completed successfully on attempt %d", attempt + 1); break; } // Longer sleep for network stability std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); } if (!connectCompleted.load()) { INFO(NCCL_LOG_INFO, "Connect thread timed out after %d attempts", maxAttempts); } }); // Wait for both threads with overall timeout auto startTime = std::chrono::steady_clock::now(); const auto maxWaitTime = std::chrono::seconds(10); // 10 second overall timeout while (!acceptCompleted.load() || !connectCompleted.load()) { auto currentTime = std::chrono::steady_clock::now(); if (currentTime - startTime > maxWaitTime) { INFO(NCCL_LOG_INFO, "Overall connection timeout reached, stopping threads"); shouldStop.store(true); break; } // Check every 100ms std::this_thread::sleep_for(std::chrono::milliseconds(100)); } acceptThread.join(); connectThread.join(); // Check if both operations completed successfully connectionEstablished = acceptCompleted.load() && connectCompleted.load() && sendGuard && recvGuard; }); connectAcceptThread.join(); if (connectionEstablished) { // Transfer ownership to output parameters sendComm = sendGuard.release(); recvComm = recvGuard.release(); INFO(NCCL_LOG_INFO, "Successfully established connection pair"); return true; } else { INFO(NCCL_LOG_INFO, "Failed to establish connection pair - accept: %s, connect: %s", acceptCompleted.load() ? "success" : "failed", connectCompleted.load() ? "success" : "failed"); // RAII guards will automatically clean up any partial connections return false; } } // Common function to setup memory and operations for a test size bool SetupOperationsForSize(void *sendComm, void *recvComm, size_t testSize, std::vector> &sendBuffers, std::vector> &recvBuffers, std::vector &sendMhandles, std::vector &recvMhandles, std::vector &sendRequests, std::vector &recvRequests, uint8_t fillPattern = 0xCD) { // Create buffers sendBuffers.emplace_back(testSize, fillPattern); recvBuffers.emplace_back(testSize, 0x00); void *sendMhandle = nullptr; void *recvMhandle = nullptr; // Register memory ncclResult_t sendRegResult = ncclNetSocket.regMr(sendComm, sendBuffers.back().data(), testSize, NCCL_PTR_HOST, &sendMhandle); ncclResult_t recvRegResult = ncclNetSocket.regMr(recvComm, recvBuffers.back().data(), testSize, NCCL_PTR_HOST, &recvMhandle); // Always add handles to vectors (even if nullptr) // to maintain consistency with buffer vectors for proper cleanup sendMhandles.push_back(sendMhandle); recvMhandles.push_back(recvMhandle); if (sendRegResult == ncclSuccess && recvRegResult == ncclSuccess) { INFO(NCCL_LOG_INFO, "Memory registration successful for size %zu", testSize); // Start send operation void *sendRequest = nullptr; ncclResult_t sendResult = ncclNetSocket.isend(sendComm, sendBuffers.back().data(), testSize, 0, sendMhandle, nullptr, &sendRequest); // Start receive operation void *recvRequest = nullptr; void *recvDataPtr = recvBuffers.back().data(); size_t recvSize = testSize; int tag = 0; ncclResult_t recvResult = ncclNetSocket.irecv(recvComm, 1, &recvDataPtr, &recvSize, &tag, &recvMhandle, nullptr, &recvRequest); if (sendResult == ncclSuccess && recvResult == ncclSuccess && sendRequest && recvRequest) { sendRequests.push_back(sendRequest); recvRequests.push_back(recvRequest); INFO(NCCL_LOG_INFO, "Successfully started operations for size %zu", testSize); return true; } else { INFO(NCCL_LOG_INFO, "Failed to start operations - send result: %d, recv result: %d", sendResult, recvResult); sendRequests.push_back(nullptr); recvRequests.push_back(nullptr); // NOTE: Memory handles are already in vectors and will be cleaned up by DeregisterMemory return false; } } else { INFO(NCCL_LOG_INFO, "Failed to register memory - send result: %d, recv result: %d", sendRegResult, recvRegResult); // NOTE: Even if only one registration succeeded, the handle is in the vector // and will be properly cleaned up by DeregisterMemory (it handles nullptr gracefully) sendRequests.push_back(nullptr); recvRequests.push_back(nullptr); return false; } } // Common function to progress operations and test ncclNetSocketGetTask bool ProgressOperations(void *sendRequest, void *recvRequest, size_t testSize, const std::string &testContext = "") { const int maxTestIterations = 10; bool taskCreationExercised = false; INFO(NCCL_LOG_INFO, "Starting progress testing - this exercises ncclNetSocketGetTask%s", testContext.c_str()); for (int testIter = 0; testIter < maxTestIterations; testIter++) { INFO(NCCL_LOG_INFO, " Progress test iteration %d/%d", testIter + 1, maxTestIterations); if (sendRequest && recvRequest) { int sendDone = 0, recvDone = 0; int sendSize = 0, recvSize_out = 0; ncclResult_t sendTestResult = ncclNetSocket.test(sendRequest, &sendDone, &sendSize); ncclResult_t recvTestResult = ncclNetSocket.test(recvRequest, &recvDone, &recvSize_out); INFO(NCCL_LOG_INFO, " Send test: result=%d, done=%d", sendTestResult, sendDone); INFO(NCCL_LOG_INFO, " Recv test: result=%d, done=%d", recvTestResult, recvDone); // If we reach this point with successful or in-progress results, // ncclNetSocketGetTask was exercised if ((sendTestResult == ncclSuccess || sendTestResult == ncclInProgress) && (recvTestResult == ncclSuccess || recvTestResult == ncclInProgress)) { taskCreationExercised = true; INFO(NCCL_LOG_INFO, " *** SUCCESS: ncclNetSocketGetTask was exercised! ***"); INFO(NCCL_LOG_INFO, " Task exercised with sendTestResult=%d (%s), recvTestResult=%d (%s)", sendTestResult, (sendTestResult == ncclSuccess) ? "ncclSuccess" : "ncclInProgress", recvTestResult, (recvTestResult == ncclSuccess) ? "ncclSuccess" : "ncclInProgress"); } // Count completed operations if (sendDone && recvDone) { INFO(NCCL_LOG_INFO, " Operations completed successfully!"); break; } // If operations fail, that's okay - we still exercised the code path if (sendTestResult != ncclSuccess && sendTestResult != ncclInProgress) { INFO(NCCL_LOG_INFO, " Send operation failed, but " "ncclNetSocketGetTask was still exercised"); break; } if (recvTestResult != ncclSuccess && recvTestResult != ncclInProgress) { INFO(NCCL_LOG_INFO, " Recv operation failed, but " "ncclNetSocketGetTask was still exercised"); break; } } // Give time between tests std::this_thread::sleep_for(std::chrono::milliseconds(50)); } if (taskCreationExercised) { INFO(NCCL_LOG_INFO, "*** VERIFICATION: ncclNetSocketGetTask was successfully exercised " "for buffer size %zu ***", testSize); } return taskCreationExercised; } // Common function to deregister memory and test ncclNetSocketDeregMr void DeregisterMemory(void *sendComm, void *recvComm, const std::vector &sendMhandles, const std::vector &recvMhandles, size_t testSize) { INFO(NCCL_LOG_INFO, "\n=== Testing ncclNetSocketDeregMr for size %zu ===", testSize); // Deregister send memory handles for (size_t j = 0; j < sendMhandles.size(); j++) { if (sendComm) { INFO(NCCL_LOG_INFO, "Deregistering send memory handle %zu for size %zu", j, testSize); ncclResult_t deregResult = ncclNetSocket.deregMr(sendComm, sendMhandles[j]); INFO(NCCL_LOG_INFO, "Send memory deregMr result: %d", deregResult); EXPECT_EQ(deregResult, ncclSuccess) << "Failed to deregister send memory handle " << j << " for buffer size " << testSize << ". " << "ncclNetSocket.deregMr() returned error code: " << deregResult << ". This may indicate memory registration/deregistration mismatch."; } } // Deregister receive memory handles for (size_t j = 0; j < recvMhandles.size(); j++) { if (recvComm) { INFO(NCCL_LOG_INFO, "Deregistering recv memory handle %zu for size %zu", j, testSize); ncclResult_t deregResult = ncclNetSocket.deregMr(recvComm, recvMhandles[j]); INFO(NCCL_LOG_INFO, "Recv memory deregMr result: %d", deregResult); EXPECT_EQ(deregResult, ncclSuccess) << "Failed to deregister send memory handle " << j << " for buffer size " << testSize << ". " << "ncclNetSocket.deregMr() returned error code: " << deregResult << ". This may indicate memory registration/deregistration mismatch."; } } } // Common function to cleanup communicators void CleanupCommunicators(const std::vector &sendComms, const std::vector &recvComms, void *listenComm) { INFO(NCCL_LOG_INFO, "\nCleaning up communicators..."); for (size_t i = 0; i < sendComms.size(); i++) { if (sendComms[i]) { INFO(NCCL_LOG_INFO, "Closing send communicator %zu", i); ncclResult_t closeResult = ncclNetSocket.closeSend(sendComms[i]); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator " << i << ". ncclNetSocket.closeSend() returned error code: " << closeResult << ". This may indicate communicator state corruption or resource cleanup issues."; } } for (size_t i = 0; i < recvComms.size(); i++) { if (recvComms[i]) { INFO(NCCL_LOG_INFO, "Closing recv communicator %zu", i); ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComms[i]); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator " << i << ". ncclNetSocket.closeRecv() returned error code: " << closeResult << ". This may indicate communicator state corruption or resource cleanup issues."; } } if (listenComm) { INFO(NCCL_LOG_INFO, "Closing listen communicator"); ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator. " << "ncclNetSocket.closeListen() returned error code: " << closeResult << ". This may indicate listen socket state corruption or resource cleanup issues."; listenComm = nullptr; } } // Common function to get test buffer sizes std::vector GetTestSizes() { return { 1024, // Small - basic test 64 * 1024, // MIN_CHUNKSIZE - boundary case 128 * 1024, // 2x MIN_CHUNKSIZE - will exercise subdivision 256 * 1024, // 4x MIN_CHUNKSIZE - multiple chunks 512 * 1024, // 8x MIN_CHUNKSIZE - many chunks 1024 * 1024, // Large - stress test 2 * 1024 * 1024 // Very large - comprehensive test }; } // Helper function to safely parse environment variables int ParseEnvVar(const char* envVar, const char* envName, int defaultValue = 0, int minValue = 0) { if (!envVar || strlen(envVar) == 0) { return defaultValue; } char* endPtr = nullptr; errno = 0; long result = std::strtol(envVar, &endPtr, 10); // Check for various error conditions - using ADD_FAILURE instead of GTEST_FAIL if (errno == ERANGE) { ADD_FAILURE() << "Environment variable " << envName << "='" << envVar << "' is out of range for integer conversion. " << "Please provide a valid integer value."; return defaultValue; } if (endPtr == envVar) { ADD_FAILURE() << "Environment variable " << envName << "='" << envVar << "' is not a valid number. " << "Please provide a valid integer value (e.g., " << envName << "=8)."; return defaultValue; } if (*endPtr != '\0') { ADD_FAILURE() << "Environment variable " << envName << "='" << envVar << "' contains non-numeric characters. " << "Please provide a valid integer value (e.g., " << envName << "=8)."; return defaultValue; } if (result < minValue) { ADD_FAILURE() << "Environment variable " << envName << "='" << envVar << "' must be >= " << minValue << ". " << "Current value: " << result << ". Please provide a valid positive integer."; return defaultValue; } if (result > INT_MAX) { ADD_FAILURE() << "Environment variable " << envName << "='" << envVar << "' is too large (> " << INT_MAX << "). " << "Please provide a smaller integer value."; return defaultValue; } return static_cast(result); } void RunConcurrentOperationsTaskCreationWithEnvVars() { INFO(NCCL_LOG_INFO, "Checking socket configuration environment variables"); // Check if the required environment variables are set const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS"); const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD"); if (!nThreadsEnv || !nSocksPerThreadEnv) { GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. " << "Please set the following environment variables to run this test: " << "export NCCL_SOCKET_NTHREADS=1 and export NCCL_NSOCKS_PERTHREAD=2. " << "This ensures nSocks > 0 so that ncclNetSocketGetTask gets called. " << "Environment variables NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be set"; return; } int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1); int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1); // Additional validation for reasonable upper bounds const int MAX_THREADS = 16; const int MAX_SOCKS_PER_THREAD = 64; const int MAX_TOTAL_SOCKETS = 64; if (nThreads > MAX_THREADS) { GTEST_SKIP() << "SKIPPING TEST: NCCL_SOCKET_NTHREADS=" << nThreads << " exceeds maximum " << MAX_THREADS << ". " << "Please provide a reasonable value (e.g., NCCL_SOCKET_NTHREADS=8). " << "Values too large may cause resource exhaustion."; return; } if (nSocksPerThread > MAX_SOCKS_PER_THREAD) { GTEST_SKIP() << "SKIPPING TEST: NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread << " exceeds maximum " << MAX_SOCKS_PER_THREAD << ". " << "Please provide a reasonable value (e.g., NCCL_NSOCKS_PERTHREAD=4). " << "Values too large may cause resource exhaustion."; return; } // Check for potential overflow before multiplication if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) { GTEST_SKIP() << "SKIPPING TEST: Configuration would cause integer overflow. " << "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread << " exceeds maximum integer value. Please use smaller values."; return; } int totalSockets = nThreads * nSocksPerThread; INFO(NCCL_LOG_INFO, "Environment configuration found:"); INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads); INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread); INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets); // Validate total sockets count if (totalSockets <= 0) { GTEST_SKIP() << "SKIPPING TEST: Invalid configuration - total sockets must be > 0. " << "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread << " = " << totalSockets << ". " << "Both NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be positive integers. " << "Example: export NCCL_SOCKET_NTHREADS=2 && export NCCL_NSOCKS_PERTHREAD=2"; return; } if (totalSockets > MAX_TOTAL_SOCKETS) { GTEST_SKIP() << "SKIPPING TEST: Total sockets " << totalSockets << " exceeds maximum " << MAX_TOTAL_SOCKETS << ". " << "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread << " = " << totalSockets << ". " << "Please reduce either NCCL_SOCKET_NTHREADS or NCCL_NSOCKS_PERTHREAD. " << "Example: export NCCL_SOCKET_NTHREADS=8 && export NCCL_NSOCKS_PERTHREAD=4"; return; } if (totalSockets > NCCL_NET_MAX_REQUESTS) { GTEST_SKIP() << "SKIPPING TEST: Total sockets " << totalSockets << " exceeds NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". " << "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread << " = " << totalSockets << ". " << "NCCL network layer can handle at most " << NCCL_NET_MAX_REQUESTS << " concurrent requests. " << "Please reduce configuration to stay within NCCL limits."; return; } INFO(NCCL_LOG_INFO, "Configuration valid - proceeding with test to exercise " "ncclNetSocketGetTask"); // Test socket properties TestSocketProperties(); char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); ASSERT_EQ(result, ncclSuccess) << "Failed to establish listening socket for test execution. " << "ncclNetSocket.listen() returned error code: " << result << ". Verify network device availability and port accessibility."; INFO(NCCL_LOG_INFO, "Testing task creation functionality - ensuring " "ncclNetSocketGetTask is called"); std::vector sendComms; std::vector recvComms; // Establish connection void *sendComm = nullptr; void *recvComm = nullptr; bool connectionSuccess = EstablishConnectionPair(handle, listenComm, sendComm, recvComm); if (connectionSuccess) { sendComms.push_back(sendComm); recvComms.push_back(recvComm); // Test with buffer sizes that will trigger task subdivision std::vector testSizes = GetTestSizes(); for (size_t testSize : testSizes) { INFO(NCCL_LOG_INFO, "\n=== Testing with buffer size: %zu bytes ===", testSize); INFO(NCCL_LOG_INFO, "This should trigger ncclNetSocketGetTask to create " "task subdivision"); std::vector sendMhandles; std::vector recvMhandles; std::vector sendRequests; std::vector recvRequests; std::vector> sendBuffers; std::vector> recvBuffers; // Setup operations for this test size bool setupSuccess = SetupOperationsForSize( sendComm, recvComm, testSize, sendBuffers, recvBuffers, sendMhandles, recvMhandles, sendRequests, recvRequests, 0xAB); if (setupSuccess) { // Progress operations with context about environment variables ProgressOperations(sendRequests[0], recvRequests[0], testSize, " (with nSocks > 0 from environment variables)"); } else { INFO(NCCL_LOG_INFO, "No operations started - skipping progress testing for size %zu", testSize); } // Deregister memory DeregisterMemory(sendComm, recvComm, sendMhandles, recvMhandles, testSize); INFO(NCCL_LOG_INFO, "=== Completed testing for buffer size: %zu bytes ===", testSize); } INFO(NCCL_LOG_INFO, "\n*** TEST SUCCESS: ncclNetSocketGetTask was " "successfully exercised! ***"); } else { INFO(NCCL_LOG_INFO, "No connections established - test passed (network may " "not be available)"); } // Cleanup CleanupCommunicators(sendComms, recvComms, listenComm); INFO(NCCL_LOG_INFO, "TestConcurrentOperationsTaskCreation completed successfully"); } }; // Test concurrent operations task creation in default configuration (without // env vars) TEST_F(NetSocketTests, TestConcurrentOperationsTaskCreationDefault) { INFO(NCCL_LOG_INFO, "Testing task creation functionality in default configuration"); INFO(NCCL_LOG_INFO, "This test exercises ncclNetSocketGetTask regardless of nSocks value"); // Test socket properties TestSocketProperties(); char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); ASSERT_EQ(result, ncclSuccess) << "Failed to establish listening socket for test execution. " << "ncclNetSocket.listen() returned error code: " << result << ". Verify network device availability and port accessibility."; INFO(NCCL_LOG_INFO, "Testing task creation functionality in default mode"); std::vector sendComms; std::vector recvComms; // Establish connection void *sendComm = nullptr; void *recvComm = nullptr; bool connectionSuccess = EstablishConnectionPair(handle, listenComm, sendComm, recvComm); if (connectionSuccess) { sendComms.push_back(sendComm); recvComms.push_back(recvComm); // Test with various buffer sizes std::vector testSizes = GetTestSizes(); for (size_t testSize : testSizes) { INFO(NCCL_LOG_INFO, "\n=== Testing with buffer size: %zu bytes ===", testSize); INFO(NCCL_LOG_INFO, "This exercises ncclNetSocketGetTask task creation logic"); std::vector sendMhandles; std::vector recvMhandles; std::vector sendRequests; std::vector recvRequests; std::vector> sendBuffers; std::vector> recvBuffers; // Setup operations for this test size bool setupSuccess = SetupOperationsForSize( sendComm, recvComm, testSize, sendBuffers, recvBuffers, sendMhandles, recvMhandles, sendRequests, recvRequests, 0xCD); if (setupSuccess) { // Progress operations bool taskExercised = ProgressOperations(sendRequests[0], recvRequests[0], testSize); if (!taskExercised) { INFO(NCCL_LOG_INFO, "*** NOTE: Operations didn't progress as expected for size %zu, " "but API was still exercised ***", testSize); } } else { INFO(NCCL_LOG_INFO, "No operations started - skipping progress testing for size %zu", testSize); } // Deregister memory DeregisterMemory(sendComm, recvComm, sendMhandles, recvMhandles, testSize); INFO(NCCL_LOG_INFO, "=== Completed testing for buffer size: %zu bytes ===", testSize); } INFO(NCCL_LOG_INFO, "\n*** TEST SUCCESS: ncclNetSocketGetTask was " "successfully exercised in default configuration! ***"); } else { INFO(NCCL_LOG_INFO, "No connections established - test passed (network may " "not be available)"); } // Cleanup CleanupCommunicators(sendComms, recvComms, listenComm); INFO(NCCL_LOG_INFO, "TestConcurrentOperationsTaskCreationDefault completed successfully"); } // Test multiple concurrent operations to stress test task creation TEST_F(NetSocketTests, TestConcurrentOperationsTaskCreation) { ProcessIsolatedTestRunner::ExecutionOptions options; options.stopOnFirstFailure = false; // Continue running all tests options.verboseLogging = true; RUN_ISOLATED_TESTS_WITH_OPTIONS(options, ProcessIsolatedTestRunner::TestConfig( "TestConcurrentOperationsTaskCreation", [this]() { RunConcurrentOperationsTaskCreationWithEnvVars(); }) .withEnvironment({{"NCCL_SOCKET_NTHREADS", "1"}, {"NCCL_NSOCKS_PERTHREAD", "2"}, {"NCCL_DEBUG", "TRACE"}, {"NCCL_DEBUG_SUBSYS", "ALL"}}) ); } // Test for invalid device index in listen function TEST_F(NetSocketTests, TestInvalidDeviceIndexListen) { INFO(NCCL_LOG_INFO, "Testing invalid device index in ncclNetSocketListen"); char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; // Test with negative device index ncclResult_t result = ncclNetSocket.listen(-1, handle, &listenComm); INFO(NCCL_LOG_INFO, "Listen with dev=-1 returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "Listen should fail with negative device index. " << "ncclNetSocket.listen() with device index -1 should return ncclInternalError " << "but returned: " << result << ". Verify input validation for device indices."; // Test with device index greater than available devices int invalidDev = ndev + 10; result = ncclNetSocket.listen(invalidDev, handle, &listenComm); INFO(NCCL_LOG_INFO, "Listen with dev=%d (> ndev=%d) returned: %d", invalidDev, ndev, result); EXPECT_EQ(result, ncclInternalError) << "Listen should fail with device index >= ndev. " << "ncclNetSocket.listen() with device index " << invalidDev << " (> ndev=" << ndev << ") should return ncclInternalError but returned: " << result << ". Verify bounds checking for device indices."; INFO(NCCL_LOG_INFO, "TestInvalidDeviceIndexListen completed"); } // Test for invalid device index in connect function TEST_F(NetSocketTests, TestInvalidDeviceIndexConnect) { INFO(NCCL_LOG_INFO, "Testing invalid device index in ncclNetSocketConnect"); char handle[NCCL_NET_HANDLE_MAXSIZE]; void *sendComm = nullptr; ncclNetCommConfig_t config = {}; ncclNetDeviceHandle_t *sendDevComm = nullptr; // Test with negative device index ncclResult_t result = ncclNetSocket.connect(-1, &config, handle, &sendComm, &sendDevComm); INFO(NCCL_LOG_INFO, "Connect with dev=-1 returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "Connect should fail with negative device index. " << "ncclNetSocket.connect() with device index -1 should return ncclInternalError " << "but returned: " << result << ". Verify input validation for device indices."; // Test with device index greater than available devices int invalidDev = ndev + 10; result = ncclNetSocket.connect(invalidDev, &config, handle, &sendComm, &sendDevComm); INFO(NCCL_LOG_INFO, "Connect with dev=%d (> ndev=%d) returned: %d", invalidDev, ndev, result); EXPECT_EQ(result, ncclInternalError) << "Connect should fail with device index >= ndev. " << "ncclNetSocket.connect() with device index " << invalidDev << " (> ndev=" << ndev << ") should return ncclInternalError but returned: " << result << ". Verify bounds checking for device indices."; INFO(NCCL_LOG_INFO, "TestInvalidDeviceIndexConnect completed"); } // Test for NULL request in test function TEST_F(NetSocketTests, TestNullRequestInTest) { INFO(NCCL_LOG_INFO, "Testing NULL request in ncclNetSocketTest"); int done = 0; int size = 0; // Test with NULL request ncclResult_t result = ncclNetSocket.test(nullptr, &done, &size); INFO(NCCL_LOG_INFO, "Test with NULL request returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "Test should fail with NULL request. " << "ncclNetSocket.test() with nullptr request should return ncclInternalError " << "but returned: " << result << ". Verify NULL pointer validation."; INFO(NCCL_LOG_INFO, "TestNullRequestInTest completed"); } // Test for invalid array size in irecv function TEST_F(NetSocketTests, TestInvalidArraySizeIrecv) { INFO(NCCL_LOG_INFO, "Testing invalid array size in ncclNetSocketIrecv"); // Setup a dummy communicator first char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); if (result == ncclSuccess && listenComm) { void *sendComm = nullptr; void *recvComm = nullptr; bool connectionSuccess = EstablishConnectionPair(handle, listenComm, sendComm, recvComm); if (connectionSuccess && recvComm) { // Test with n != 1 (invalid for socket implementation) std::vector buffer1(1024, 0xAA); std::vector buffer2(1024, 0xBB); void *data[2] = {buffer1.data(), buffer2.data()}; size_t sizes[2] = {1024, 1024}; int tags[2] = {0, 1}; void *mhandles[2] = {nullptr, nullptr}; void *phandles[2] = {nullptr, nullptr}; void *request = nullptr; // Test with n=2 (should fail for socket implementation) result = ncclNetSocket.irecv(recvComm, 2, data, sizes, tags, mhandles, phandles, &request); INFO(NCCL_LOG_INFO, "Irecv with n=2 returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "Irecv should fail with n != 1. " << "ncclNetSocket.irecv() with n=2 should return ncclInternalError " << "but returned: " << result << ". Socket implementation only supports n=1."; // Test with n=0 (should fail) result = ncclNetSocket.irecv(recvComm, 0, data, sizes, tags, mhandles, phandles, &request); INFO(NCCL_LOG_INFO, "Irecv with n=0 returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "Irecv should fail with n != 1. " << "ncclNetSocket.irecv() with n=0 should return ncclInternalError " << "but returned: " << result << ". Socket implementation only supports n=1."; // Cleanup communicators if (sendComm) { ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; sendComm = nullptr; } if (recvComm) { ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; recvComm = nullptr; } } if (listenComm) { ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; listenComm = nullptr; } } INFO(NCCL_LOG_INFO, "TestInvalidArraySizeIrecv completed"); } // Test for non-host memory type in regMr function TEST_F(NetSocketTests, TestNonHostMemoryRegMr) { INFO(NCCL_LOG_INFO, "Testing non-host memory type in ncclNetSocketRegMr"); // Setup a dummy communicator first char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); if (result == ncclSuccess && listenComm) { void *sendComm = nullptr; void *recvComm = nullptr; bool connectionSuccess = EstablishConnectionPair(handle, listenComm, sendComm, recvComm); if (connectionSuccess && sendComm) { std::vector buffer(1024, 0xAA); void *mhandle = nullptr; // Test with NCCL_PTR_CUDA (should fail for socket implementation) result = ncclNetSocket.regMr(sendComm, buffer.data(), 1024, NCCL_PTR_CUDA, &mhandle); INFO(NCCL_LOG_INFO, "RegMr with NCCL_PTR_CUDA returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "RegMr should fail with non-host memory type. " << "ncclNetSocket.regMr() with NCCL_PTR_CUDA should return ncclInternalError " << "but returned: " << result << ". Socket implementation only supports NCCL_PTR_HOST."; // Test with valid NCCL_PTR_HOST (should succeed) result = ncclNetSocket.regMr(sendComm, buffer.data(), 1024, NCCL_PTR_HOST, &mhandle); INFO(NCCL_LOG_INFO, "RegMr with NCCL_PTR_HOST returned: %d", result); EXPECT_EQ(result, ncclSuccess) << "RegMr should succeed with host memory type. " << "ncclNetSocket.regMr() with NCCL_PTR_HOST should return ncclSuccess " << "but returned: " << result << ". Verify host memory registration support."; // Cleanup communicators if (sendComm) { ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; sendComm = nullptr; } if (recvComm) { ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; recvComm = nullptr; } } if (listenComm) { ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; listenComm = nullptr; } } INFO(NCCL_LOG_INFO, "TestNonHostMemoryRegMr completed"); } // Test for excessive thread configuration warning TEST_F(NetSocketTests, TestExcessiveThreadConfig) { ProcessIsolatedTestRunner::ExecutionOptions options; options.stopOnFirstFailure = false; // Continue running all tests options.verboseLogging = true; RUN_ISOLATED_TESTS_WITH_OPTIONS(options, ProcessIsolatedTestRunner::TestConfig( "TestExcessiveThreadConfig", [this]() { INFO(NCCL_LOG_INFO, "Testing excessive thread configuration warning"); // Check if the required environment variables are set const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS"); const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD"); if (!nThreadsEnv || !nSocksPerThreadEnv) { GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. " << "This test requires NCCL_SOCKET_NTHREADS > " "NCCL_NET_MAX_REQUESTS (" << NCCL_NET_MAX_REQUESTS << ") and NCCL_NSOCKS_PERTHREAD = 1 to trigger warning. " << "Environment variables NCCL_SOCKET_NTHREADS and " "NCCL_NSOCKS_PERTHREAD must be set"; return; } // Parse with validation - both must be positive int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1); int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1); // Check for potential overflow before multiplication if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) { GTEST_SKIP() << "SKIPPING TEST: Configuration would cause " "integer overflow. " << "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread << " exceeds maximum integer value. Please use " "smaller values."; return; } int totalSockets = nThreads * nSocksPerThread; INFO(NCCL_LOG_INFO, "Environment configuration found:"); INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads); INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread); INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets); // Check if configuration is set to trigger the excessive threads // warning Use NCCL_NET_MAX_REQUESTS instead of arbitrary // MAX_THREADS if (nThreads <= NCCL_NET_MAX_REQUESTS) { GTEST_SKIP() << "SKIPPING TEST: NCCL_SOCKET_NTHREADS must be > " << NCCL_NET_MAX_REQUESTS << " to test excessive thread warning. " << "Current NCCL_SOCKET_NTHREADS=" << nThreads << ". " << "Please set: export NCCL_SOCKET_NTHREADS=" << (NCCL_NET_MAX_REQUESTS + 1) << ". " << "NCCL_SOCKET_NTHREADS must be > NCCL_NET_MAX_REQUESTS (" << NCCL_NET_MAX_REQUESTS << ") to trigger warning"; return; } if (totalSockets > NCCL_NET_MAX_REQUESTS * 10) { // Allow 10x for testing excessive config GTEST_SKIP() << "SKIPPING TEST: Total sockets=" << totalSockets << " is unreasonably large (> " << (NCCL_NET_MAX_REQUESTS * 10) << "). " << "Please use more reasonable values for testing. " "NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". " << "Example: export NCCL_SOCKET_NTHREADS=" << (NCCL_NET_MAX_REQUESTS + 1) << " && export NCCL_NSOCKS_PERTHREAD=1"; return; } INFO(NCCL_LOG_INFO, "Configuration valid for testing excessive threads warning"); INFO(NCCL_LOG_INFO, "NCCL_SOCKET_NTHREADS=%d > NCCL_NET_MAX_REQUESTS=%d", nThreads, NCCL_NET_MAX_REQUESTS); // Test socket properties TestSocketProperties(); // Initialize to trigger the warning logic char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); if (result == ncclSuccess && listenComm) { // The implementation should have limited the threads to // NCCL_NET_MAX_REQUESTS internally INFO(NCCL_LOG_INFO, "*** SUCCESS: Listen succeeded with " "excessive NCCL_SOCKET_NTHREADS - " "limits enforced internally ***"); ncclNetSocket.closeListen(listenComm); } else { INFO(NCCL_LOG_INFO, "Listen failed with result: %d", result); } INFO(NCCL_LOG_INFO, "TestExcessiveThreadConfig completed"); }) .withEnvironment({{"NCCL_SOCKET_NTHREADS", "33"}, {"NCCL_NSOCKS_PERTHREAD", "1"}, {"NCCL_DEBUG", "TRACE"}, {"NCCL_DEBUG_SUBSYS", "ALL"}}) ); } // Test for excessive socket configuration warning TEST_F(NetSocketTests, TestExcessiveSocketConfig) { ProcessIsolatedTestRunner::ExecutionOptions options; options.stopOnFirstFailure = false; // Continue running all tests options.verboseLogging = true; RUN_ISOLATED_TESTS_WITH_OPTIONS(options, ProcessIsolatedTestRunner::TestConfig( "TestExcessiveThreadConfig", [this]() { INFO(NCCL_LOG_INFO, "Testing excessive socket configuration warning"); // Check if the required environment variables are set const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS"); const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD"); if (!nThreadsEnv || !nSocksPerThreadEnv) { GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. " << "This test requires total sockets (nThreads * " "nSocksPerThread) > MAX_SOCKETS (64). " << "Environment variables NCCL_SOCKET_NTHREADS and " "NCCL_NSOCKS_PERTHREAD must be set"; return; } // Parse with validation - both must be positive int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1); int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1); // Check for potential overflow before multiplication if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) { GTEST_SKIP() << "SKIPPING TEST: Configuration would cause " "integer overflow. " << "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread << " exceeds maximum integer value. Please use " "smaller values."; return; } int totalSockets = nThreads * nSocksPerThread; INFO(NCCL_LOG_INFO, "Environment configuration found:"); INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads); INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread); INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets); // Check if configuration is set to trigger the excessive sockets // warning const int MAX_SOCKETS = 64; if (totalSockets <= MAX_SOCKETS) { GTEST_SKIP() << "SKIPPING TEST: Total sockets must be > " << MAX_SOCKETS << " to test excessive socket warning. " << "Current total sockets=" << totalSockets << " (nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread << "). " << "Please set environment variables such that total > " << MAX_SOCKETS << ", e.g.: " << "export NCCL_SOCKET_NTHREADS=9 && export " "NCCL_NSOCKS_PERTHREAD=8. " << "Total sockets must be > MAX_SOCKETS (" << MAX_SOCKETS << ") to trigger warning"; return; } // Additional validation against NCCL_NET_MAX_REQUESTS for // reasonable upper bounds if (totalSockets > NCCL_NET_MAX_REQUESTS * 10) { // Allow 10x for testing excessive config GTEST_SKIP() << "SKIPPING TEST: Total sockets=" << totalSockets << " is unreasonably large (> " << (NCCL_NET_MAX_REQUESTS * 10) << "). " << "Please use more reasonable values for testing. " "NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". " << "Example: export NCCL_SOCKET_NTHREADS=10 && " "export NCCL_NSOCKS_PERTHREAD=10"; return; } INFO(NCCL_LOG_INFO, "Configuration valid for testing excessive sockets warning"); INFO(NCCL_LOG_INFO, "Total sockets=%d > MAX_SOCKETS=64", totalSockets); // Test socket properties TestSocketProperties(); // Initialize to trigger the warning logic char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); if (result == ncclSuccess && listenComm) { // The implementation should have limited the sockets to // MAX_SOCKETS internally INFO(NCCL_LOG_INFO, "*** SUCCESS: Listen succeeded with excessive total " "sockets - limits enforced internally ***"); ncclNetSocket.closeListen(listenComm); } else { INFO(NCCL_LOG_INFO, "Listen failed with result: %d", result); } INFO(NCCL_LOG_INFO, "TestExcessiveSocketConfig completed"); }) .withEnvironment({{"NCCL_SOCKET_NTHREADS", "10"}, {"NCCL_NSOCKS_PERTHREAD", "10"}, {"NCCL_DEBUG", "TRACE"}, {"NCCL_DEBUG_SUBSYS", "ALL"}}) ); } // Test to trigger request allocation failure scenario TEST_F(NetSocketTests, TestRequestAllocationFailure) { INFO(NCCL_LOG_INFO, "Testing request allocation failure scenario"); // Setup communication char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); if (result == ncclSuccess && listenComm) { void *sendComm = nullptr; void *recvComm = nullptr; bool connectionSuccess = EstablishConnectionPair(handle, listenComm, sendComm, recvComm); if (connectionSuccess && sendComm && recvComm) { INFO(NCCL_LOG_INFO, "Attempting to exhaust request pool (MAX_REQUESTS)"); std::vector requests; std::vector> buffers; std::vector mhandles; // Try to allocate many requests to potentially exhaust the pool // MAX_REQUESTS is defined as NCCL_NET_MAX_REQUESTS in the code for (int i = 0; i < (NCCL_NET_MAX_REQUESTS * 10); i++) { // Try to exceed NCCL_NET_MAX_REQUESTS by a reasonable margin buffers.emplace_back(1024, 0xAA + (i % 256)); void *mhandle = nullptr; // Register memory first result = ncclNetSocket.regMr(sendComm, buffers.back().data(), 1024, NCCL_PTR_HOST, &mhandle); EXPECT_EQ(result, ncclSuccess) << "Memory registration failed at iteration " << i << ". ncclNetSocket.regMr() returned error code: " << result << ". Verify memory registration limits and resource availability."; if (result != ncclSuccess) break; mhandles.push_back(mhandle); // Try to create send request void *request = nullptr; result = ncclNetSocket.isend(sendComm, buffers.back().data(), 1024, 0, mhandle, nullptr, &request); if (result == ncclInternalError) { INFO(NCCL_LOG_INFO, "Request allocation failed at iteration %d (expected behavior " "when pool exhausted)", i); break; } else if (result == ncclSuccess) { requests.push_back(request); } else { INFO(NCCL_LOG_INFO, "Unexpected result at iteration %d: %d", i, result); break; } } INFO(NCCL_LOG_INFO, "Successfully allocated %zu requests before failure/completion", requests.size()); // Cleanup: Test any pending requests and deregister memory for (size_t i = 0; i < requests.size(); i++) { if (requests[i]) { int done = 0; int size = 0; ncclNetSocket.test(requests[i], &done, &size); // Don't care about result } } for (size_t i = 0; i < mhandles.size(); i++) { if (mhandles[i]) { ncclNetSocket.deregMr(sendComm, mhandles[i]); } } // Cleanup communicators if (sendComm) { ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; sendComm = nullptr; } if (recvComm) { ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; recvComm = nullptr; } } if (listenComm) { ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; listenComm = nullptr; } } INFO(NCCL_LOG_INFO, "TestRequestAllocationFailure completed"); } // Test for message size mismatch scenario TEST_F(NetSocketTests, TestMessageSizeMismatch) { INFO(NCCL_LOG_INFO, "Testing message size mismatch scenario"); // This test simulates the condition where a receiver expects a smaller // message than what the sender is trying to send, which should trigger the // truncation warning char handle[NCCL_NET_HANDLE_MAXSIZE]; void *listenComm = nullptr; ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); if (result == ncclSuccess && listenComm) { void *sendComm = nullptr; void *recvComm = nullptr; bool connectionSuccess = EstablishConnectionPair(handle, listenComm, sendComm, recvComm); if (connectionSuccess && sendComm && recvComm) { // Large send buffer const size_t sendSize = 2048; std::vector sendBuffer(sendSize, 0xAA); // Small receive buffer (to simulate size mismatch) const size_t recvSize = 1024; // Smaller than send size std::vector recvBuffer(recvSize, 0x00); void *sendMhandle = nullptr; void *recvMhandle = nullptr; // Register memory result = ncclNetSocket.regMr(sendComm, sendBuffer.data(), sendSize, NCCL_PTR_HOST, &sendMhandle); EXPECT_EQ(result, ncclSuccess) << "Failed to register send memory for size mismatch test. " << "ncclNetSocket.regMr() returned error code: " << result << ". Verify memory registration support for buffer size " << sendSize << "."; result = ncclNetSocket.regMr(recvComm, recvBuffer.data(), recvSize, NCCL_PTR_HOST, &recvMhandle); EXPECT_EQ(result, ncclSuccess) << "Failed to register receive memory for size mismatch test. " << "ncclNetSocket.regMr() returned error code: " << result << ". Verify memory registration support for buffer size " << recvSize << "."; // Start send operation with large size void *sendRequest = nullptr; result = ncclNetSocket.isend(sendComm, sendBuffer.data(), sendSize, 0, sendMhandle, nullptr, &sendRequest); EXPECT_EQ(result, ncclSuccess) << "Failed to start send operation for size mismatch test. " << "ncclNetSocket.isend() returned error code: " << result << ". Verify send operation support for buffer size " << sendSize << "."; // Start receive operation with small size void *recvRequest = nullptr; void *recvDataPtr = recvBuffer.data(); size_t recvSizeVar = recvSize; int tag = 0; result = ncclNetSocket.irecv(recvComm, 1, &recvDataPtr, &recvSizeVar, &tag, &recvMhandle, nullptr, &recvRequest); EXPECT_EQ(result, ncclSuccess) << "Failed to start receive operation for size mismatch test. " << "ncclNetSocket.irecv() returned error code: " << result << ". Verify receive operation support for buffer size " << recvSize << "."; // Progress operations - this should eventually trigger the size mismatch // warning for (int i = 0; i < 100; i++) { if (sendRequest) { int sendDone = 0, sendSize_out = 0; ncclResult_t sendTestResult = ncclNetSocket.test(sendRequest, &sendDone, &sendSize_out); if (sendTestResult != ncclSuccess || sendDone) { INFO(NCCL_LOG_INFO, "Send operation completed: result=%d, done=%d", sendTestResult, sendDone); sendRequest = nullptr; // Request is cleaned up by the networking layer } } if (recvRequest) { int recvDone = 0, recvSize_out = 0; ncclResult_t recvTestResult = ncclNetSocket.test(recvRequest, &recvDone, &recvSize_out); if (recvTestResult != ncclSuccess || recvDone) { INFO(NCCL_LOG_INFO, "Recv operation completed: result=%d, done=%d, size=%d", recvTestResult, recvDone, recvSize_out); recvRequest = nullptr; // Request is cleaned up by the networking layer } } if (!sendRequest && !recvRequest) { INFO(NCCL_LOG_INFO, "Both operations completed after %d iterations", i + 1); break; } std::this_thread::sleep_for(std::chrono::milliseconds(10)); } // Cleanup memory handles if (sendMhandle && sendComm) { ncclResult_t deregResult = ncclNetSocket.deregMr(sendComm, sendMhandle); if (deregResult != ncclSuccess) { INFO(NCCL_LOG_INFO, "Warning: Failed to deregister send memory handle: %d", deregResult); } sendMhandle = nullptr; } if (recvMhandle && recvComm) { ncclResult_t deregResult = ncclNetSocket.deregMr(recvComm, recvMhandle); if (deregResult != ncclSuccess) { INFO(NCCL_LOG_INFO, "Warning: Failed to deregister recv memory handle: %d", deregResult); } recvMhandle = nullptr; } // Cleanup communicators if (sendComm) { ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; sendComm = nullptr; } if (recvComm) { ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; recvComm = nullptr; } } if (listenComm) { ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; listenComm = nullptr; } } INFO(NCCL_LOG_INFO, "TestMessageSizeMismatch completed"); } // Test to cover the iflush function that always returns ncclInternalError TEST_F(NetSocketTests, TestIflushAlwaysFails) { INFO(NCCL_LOG_INFO, "Testing ncclNetSocketIflush always returns ncclInternalError"); // This function should always return ncclInternalError for socket // implementation as it doesn't support CUDA pointers and flush operations std::vector buffer(1024, 0xAA); void *data = buffer.data(); int size = 1024; void *mhandle = nullptr; void *request = nullptr; // Test with dummy parameters - should always fail ncclResult_t result = ncclNetSocket.iflush(nullptr, 1, &data, &size, &mhandle, &request); INFO(NCCL_LOG_INFO, "ncclNetSocketIflush returned: %d", result); EXPECT_EQ(result, ncclInternalError) << "iflush should always return ncclInternalError"; INFO(NCCL_LOG_INFO, "TestIflushAlwaysFails completed"); } } // namespace RcclUnitTesting