From f500628ef2ed243e425b66fe0a468af5373bf239 Mon Sep 17 00:00:00 2001 From: ishkool Date: Wed, 27 Aug 2025 19:54:21 +0530 Subject: [PATCH] Code coverage tests for net_socket.cc (#1840) * Code coverage UTs for net_socket.cc * Addressed review comments --------- Co-authored-by: Atul Kulkarni [ROCm/rccl commit: c288fbf1b2c037a20d968d000675b7bb0f361270] --- projects/rccl/test/CMakeLists.txt | 3 +- projects/rccl/test/NetSocketTests.cpp | 1493 +++++++++++++++++++++++++ 2 files changed, 1495 insertions(+), 1 deletion(-) create mode 100644 projects/rccl/test/NetSocketTests.cpp diff --git a/projects/rccl/test/CMakeLists.txt b/projects/rccl/test/CMakeLists.txt index 95bf4da95a..37fea0ca88 100644 --- a/projects/rccl/test/CMakeLists.txt +++ b/projects/rccl/test/CMakeLists.txt @@ -142,10 +142,11 @@ if(BUILD_TESTS) AllocTests.cpp ArgCheckTests.cpp BitOpsTests.cpp - IpcsocketTests.cpp CollRegTests.cpp CommTests.cpp EnqueueTests.cpp + IpcsocketTests.cpp + NetSocketTests.cpp P2pTests.cpp ProxyTests.cpp RcclWrapTests.cpp diff --git a/projects/rccl/test/NetSocketTests.cpp b/projects/rccl/test/NetSocketTests.cpp new file mode 100644 index 0000000000..f02bd234ae --- /dev/null +++ b/projects/rccl/test/NetSocketTests.cpp @@ -0,0 +1,1493 @@ +/************************************************************************* + * Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "net.h" +#include "gtest/gtest.h" +#include +#include +#include + +extern ncclNet_t ncclNetSocket; + +namespace RcclUnitTesting { + +/** + * @brief Establishes a reliable connection pair (send and receive communicators) using the provided handle and listen communicator. + * + * This function attempts to create a pair of connected communicators (sendComm and recvComm) using the ncclNetSocket API. + * It uses internal RAII guards to ensure proper cleanup in case of partial failures. The function coordinates accept and connect + * operations in parallel threads, with extended timeouts and retries for robustness. + * + * @param handle Pointer to the device handle used for connection. + * @param listenComm Pointer to the listen communicator, used for accepting connections. + * @param[out] sendComm Reference to a pointer that will receive the newly created send communicator on success. + * WARNING: The ownership of the communicator is transferred to the caller. The caller is responsible for + * closing and cleaning up sendComm. If sendComm was previously pointing to a resource (e.g., a unique_ptr or + * other managed pointer), it will be overwritten and the previous resource may be leaked. Ensure sendComm is + * either nullptr or properly released before calling this function. + * @param[out] recvComm Reference to a pointer that will receive the newly created receive communicator on success. + * WARNING: The ownership of the communicator is transferred to the caller. The caller is responsible for + * closing and cleaning up recvComm. If recvComm was previously pointing to a resource (e.g., a unique_ptr or + * other managed pointer), it will be overwritten and the previous resource may be leaked. Ensure recvComm is + * either nullptr or properly released before calling this function. + * + * @return true if both send and receive communicators were successfully established and ownership transferred to the caller; + * false otherwise (in which case all resources are cleaned up internally). + */ +class NetSocketTests : public ::testing::Test { + +private: + // RAII wrapper for send communicator + class SendCommGuard { + private: + void *comm_ = nullptr; + + public: + explicit SendCommGuard(void *comm = nullptr) + : comm_(comm) {} // default constructor + + // Move constructor + SendCommGuard(SendCommGuard &&other) noexcept : comm_(other.comm_) { + other.comm_ = nullptr; + } + + // Move assignment + SendCommGuard &operator=(SendCommGuard &&other) noexcept { + if (this != &other) { + reset(); + comm_ = other.comm_; + other.comm_ = nullptr; + } + return *this; + } + + // Disable copy constructor and assignment + SendCommGuard(const SendCommGuard &other) = delete; + SendCommGuard &operator=(const SendCommGuard &other) = delete; + + ~SendCommGuard() { reset(); } + + void reset(void *comm = nullptr) { + if (comm_ && comm_ != comm) { + ncclResult_t result = ncclNetSocket.closeSend(comm_); + ASSERT_EQ(result, ncclSuccess) << "SendCommGuard failed to close send communicator (comm_=" + << comm_ << "). ncclNetSocket.closeSend() returned error code: " + << result << ". This indicates a potential resource leak or " + << "invalid communicator state during RAII cleanup."; + } + comm_ = comm; + } + + void *get() const { return comm_; } + void *release() { + void *temp = comm_; + comm_ = nullptr; + return temp; + } + + explicit operator bool() const { return comm_ != nullptr; } + }; + + // RAII wrapper for receive communicator + class RecvCommGuard { + private: + void *comm_; + + public: + explicit RecvCommGuard(void *comm = nullptr) : comm_(comm) {} + + // Move constructor + RecvCommGuard(RecvCommGuard &&other) noexcept : comm_(other.comm_) { + other.comm_ = nullptr; + } + + // Move assignment + RecvCommGuard &operator=(RecvCommGuard &&other) noexcept { + if (this != &other) { + reset(); + comm_ = other.comm_; + other.comm_ = nullptr; + } + return *this; + } + + // Disable copy + RecvCommGuard(const RecvCommGuard &) = delete; + RecvCommGuard &operator=(const RecvCommGuard &) = delete; + + ~RecvCommGuard() { reset(); } + + void reset(void *comm = nullptr) { + if (comm_ && comm_ != comm) { + ncclResult_t result = ncclNetSocket.closeRecv(comm_); + ASSERT_EQ(result, ncclSuccess) << "RecvCommGuard failed to close receive communicator (comm_=" + << comm_ << "). ncclNetSocket.closeRecv() returned error code: " + << result << ". This indicates a potential resource leak or " + << "invalid communicator state during RAII cleanup."; + } + comm_ = comm; + } + + void *get() const { return comm_; } + void *release() { + void *temp = comm_; + comm_ = nullptr; + return temp; + } + + explicit operator bool() const { return comm_ != nullptr; } + }; + +protected: + void SetUp() override { + ncclResult_t result = ncclNetSocket.init(nullptr, nullptr); + ASSERT_EQ(result, ncclSuccess) << "Failed to initialize ncclNetSocket. " + << "Error code: " << result + << ". Ensure RCCL networking is properly configured."; + + result = ncclNetSocket.devices(&ndev); + ASSERT_EQ(result, ncclSuccess) << "Failed to query network devices. " + << "Error code: " << result + << ". Check if network devices are available and accessible."; + + if (ndev == 0) { + GTEST_SKIP() << "No network devices available for testing. " + << "Ensure network hardware is present and properly configured."; + } + } + + int ndev = 0; + + // Common function to test socket properties + void TestSocketProperties() { + INFO(NCCL_LOG_INFO, "\n=== Testing socket properties ==="); + + // Test ncclNetSocketGetProperties for each device + for (int dev = 0; dev < ndev; dev++) { + ncclNetProperties_t props = {}; + ncclResult_t propsResult = ncclNetSocket.getProperties(dev, &props); + INFO(NCCL_LOG_INFO, "Device %d - getProperties result: %d", dev, + propsResult); + if (propsResult == ncclSuccess) { + INFO(NCCL_LOG_INFO, + " Device %d properties: name='%s', pciPath='%s', guid=%llu, " + "speed=%d, port=%d, maxComms=%d", + dev, props.name, props.pciPath, (unsigned long long)props.guid, + props.speed, props.port, props.maxComms); + } + EXPECT_EQ(propsResult, ncclSuccess) + << "getProperties failed for device " << dev + << ". ncclNetSocket.getProperties() returned error code: " << propsResult + << ". Verify device " << dev << " is available and properly configured."; + } + } + + // Common function to establish a connection pair with improved reliability + bool EstablishConnectionPair(void *handle, void *listenComm, void *&sendComm, + void *&recvComm) { + // Allow overriding max attempts via environment variable for flexibility + int maxAttempts = 100; + const char* maxAttemptsEnv = getenv("RCCL_TEST_NETSOCKET_MAX_ATTEMPTS"); + if (maxAttemptsEnv) { + maxAttempts = ParseEnvVar(maxAttemptsEnv, "RCCL_TEST_NETSOCKET_MAX_ATTEMPTS", 100, 1); + } + + // Allow overriding sleep duration via environment variable for flexibility + int sleepMs = 100; + const char* sleepMsEnv = getenv("RCCL_TEST_NETSOCKET_SLEEP_MS"); + if (sleepMsEnv) { + sleepMs = ParseEnvVar(sleepMsEnv, "RCCL_TEST_NETSOCKET_SLEEP_MS", 100, 1); + } + + // Initialize output parameters + sendComm = nullptr; + recvComm = nullptr; + + // RAII guards for automatic cleanup + SendCommGuard sendGuard; + RecvCommGuard recvGuard; + + std::atomic connectionEstablished{false}; + std::atomic acceptCompleted{false}; + std::atomic connectCompleted{false}; + std::atomic shouldStop{false}; + + INFO(NCCL_LOG_INFO, + "Establishing connection pair with enhanced reliability"); + + std::thread connectAcceptThread([&]() { + // Accept thread with longer timeout and better coordination + std::thread acceptThread([&]() { + ncclNetDeviceHandle_t *recvDevComm = nullptr; + void *tempRecvComm = nullptr; + + // Increased attempts and longer total timeout for reliability + for (int attempt = 0; attempt < maxAttempts && !shouldStop.load(); attempt++) { + ncclResult_t acceptResult = + ncclNetSocket.accept(listenComm, &tempRecvComm, &recvDevComm); + if (acceptResult == ncclSuccess && tempRecvComm != nullptr) { + recvGuard.reset(tempRecvComm); + acceptCompleted.store(true); + INFO(NCCL_LOG_INFO, "Accept completed successfully on attempt %d", + attempt + 1); + break; + } + + // Longer sleep for network stability + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + } + + if (!acceptCompleted.load()) { + INFO(NCCL_LOG_INFO, "Accept thread timed out after %d attempts", maxAttempts); + } + }); + + // Connect thread with longer timeout and better coordination + std::thread connectThread([&]() { + ncclNetCommConfig_t config = {}; + ncclNetDeviceHandle_t *sendDevComm = nullptr; + void *tempSendComm = nullptr; + + // Give accept thread more time to start listening + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + + // Increased attempts and longer total timeout for reliability + for (int attempt = 0; attempt < 100 && !shouldStop.load(); attempt++) { + ncclResult_t connectResult = ncclNetSocket.connect( + 0, &config, handle, &tempSendComm, &sendDevComm); + if (connectResult == ncclSuccess && tempSendComm != nullptr) { + sendGuard.reset(tempSendComm); + connectCompleted = true; + INFO(NCCL_LOG_INFO, "Connect completed successfully on attempt %d", + attempt + 1); + break; + } + + // Longer sleep for network stability + std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs)); + } + + if (!connectCompleted.load()) { + INFO(NCCL_LOG_INFO, "Connect thread timed out after %d attempts", maxAttempts); + } + }); + + // Wait for both threads with overall timeout + auto startTime = std::chrono::steady_clock::now(); + const auto maxWaitTime = + std::chrono::seconds(10); // 10 second overall timeout + + while (!acceptCompleted.load() || !connectCompleted.load()) { + auto currentTime = std::chrono::steady_clock::now(); + if (currentTime - startTime > maxWaitTime) { + INFO(NCCL_LOG_INFO, + "Overall connection timeout reached, stopping threads"); + shouldStop.store(true); + break; + } + + // Check every 100ms + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + acceptThread.join(); + connectThread.join(); + + // Check if both operations completed successfully + connectionEstablished = acceptCompleted.load() && + connectCompleted.load() && sendGuard && recvGuard; + }); + + connectAcceptThread.join(); + + if (connectionEstablished) { + // Transfer ownership to output parameters + sendComm = sendGuard.release(); + recvComm = recvGuard.release(); + + INFO(NCCL_LOG_INFO, "Successfully established connection pair"); + return true; + } else { + INFO(NCCL_LOG_INFO, + "Failed to establish connection pair - accept: %s, connect: %s", + acceptCompleted.load() ? "success" : "failed", + connectCompleted.load() ? "success" : "failed"); + // RAII guards will automatically clean up any partial connections + return false; + } + } + + // Common function to setup memory and operations for a test size + bool SetupOperationsForSize(void *sendComm, void *recvComm, size_t testSize, + std::vector> &sendBuffers, + std::vector> &recvBuffers, + std::vector &sendMhandles, + std::vector &recvMhandles, + std::vector &sendRequests, + std::vector &recvRequests, + uint8_t fillPattern = 0xCD) { + + // Create buffers + sendBuffers.emplace_back(testSize, fillPattern); + recvBuffers.emplace_back(testSize, 0x00); + + void *sendMhandle = nullptr; + void *recvMhandle = nullptr; + + // Register memory + ncclResult_t sendRegResult = + ncclNetSocket.regMr(sendComm, sendBuffers.back().data(), testSize, + NCCL_PTR_HOST, &sendMhandle); + ncclResult_t recvRegResult = + ncclNetSocket.regMr(recvComm, recvBuffers.back().data(), testSize, + NCCL_PTR_HOST, &recvMhandle); + + // Always add handles to vectors (even if nullptr) + // to maintain consistency with buffer vectors for proper cleanup + sendMhandles.push_back(sendMhandle); + recvMhandles.push_back(recvMhandle); + + if (sendRegResult == ncclSuccess && recvRegResult == ncclSuccess) { + INFO(NCCL_LOG_INFO, "Memory registration successful for size %zu", + testSize); + + // Start send operation + void *sendRequest = nullptr; + ncclResult_t sendResult = + ncclNetSocket.isend(sendComm, sendBuffers.back().data(), testSize, 0, + sendMhandle, nullptr, &sendRequest); + + // Start receive operation + void *recvRequest = nullptr; + void *recvDataPtr = recvBuffers.back().data(); + size_t recvSize = testSize; + int tag = 0; + ncclResult_t recvResult = + ncclNetSocket.irecv(recvComm, 1, &recvDataPtr, &recvSize, &tag, + &recvMhandle, nullptr, &recvRequest); + + if (sendResult == ncclSuccess && recvResult == ncclSuccess && + sendRequest && recvRequest) { + sendRequests.push_back(sendRequest); + recvRequests.push_back(recvRequest); + INFO(NCCL_LOG_INFO, "Successfully started operations for size %zu", + testSize); + return true; + } else { + INFO(NCCL_LOG_INFO, + "Failed to start operations - send result: %d, recv result: %d", + sendResult, recvResult); + sendRequests.push_back(nullptr); + recvRequests.push_back(nullptr); + // NOTE: Memory handles are already in vectors and will be cleaned up by DeregisterMemory + return false; + } + } else { + INFO(NCCL_LOG_INFO, + "Failed to register memory - send result: %d, recv result: %d", + sendRegResult, recvRegResult); + // NOTE: Even if only one registration succeeded, the handle is in the vector + // and will be properly cleaned up by DeregisterMemory (it handles nullptr gracefully) + sendRequests.push_back(nullptr); + recvRequests.push_back(nullptr); + return false; + } + } + + // Common function to progress operations and test ncclNetSocketGetTask + bool ProgressOperations(void *sendRequest, void *recvRequest, size_t testSize, + const std::string &testContext = "") { + const int maxTestIterations = 10; + bool taskCreationExercised = false; + + INFO(NCCL_LOG_INFO, + "Starting progress testing - this exercises ncclNetSocketGetTask%s", + testContext.c_str()); + + for (int testIter = 0; testIter < maxTestIterations; testIter++) { + INFO(NCCL_LOG_INFO, " Progress test iteration %d/%d", testIter + 1, + maxTestIterations); + + if (sendRequest && recvRequest) { + int sendDone = 0, recvDone = 0; + int sendSize = 0, recvSize_out = 0; + + ncclResult_t sendTestResult = + ncclNetSocket.test(sendRequest, &sendDone, &sendSize); + ncclResult_t recvTestResult = + ncclNetSocket.test(recvRequest, &recvDone, &recvSize_out); + + INFO(NCCL_LOG_INFO, " Send test: result=%d, done=%d", sendTestResult, + sendDone); + INFO(NCCL_LOG_INFO, " Recv test: result=%d, done=%d", recvTestResult, + recvDone); + + // If we reach this point with successful or in-progress results, + // ncclNetSocketGetTask was exercised + if ((sendTestResult == ncclSuccess || + sendTestResult == ncclInProgress) && + (recvTestResult == ncclSuccess || + recvTestResult == ncclInProgress)) { + taskCreationExercised = true; + INFO(NCCL_LOG_INFO, + " *** SUCCESS: ncclNetSocketGetTask was exercised! ***"); + INFO(NCCL_LOG_INFO, + " Task exercised with sendTestResult=%d (%s), recvTestResult=%d (%s)", + sendTestResult, + (sendTestResult == ncclSuccess) ? "ncclSuccess" : "ncclInProgress", + recvTestResult, + (recvTestResult == ncclSuccess) ? "ncclSuccess" : "ncclInProgress"); + } + + // Count completed operations + if (sendDone && recvDone) { + INFO(NCCL_LOG_INFO, " Operations completed successfully!"); + break; + } + + // If operations fail, that's okay - we still exercised the code path + if (sendTestResult != ncclSuccess && sendTestResult != ncclInProgress) { + INFO(NCCL_LOG_INFO, " Send operation failed, but " + "ncclNetSocketGetTask was still exercised"); + break; + } + if (recvTestResult != ncclSuccess && recvTestResult != ncclInProgress) { + INFO(NCCL_LOG_INFO, " Recv operation failed, but " + "ncclNetSocketGetTask was still exercised"); + break; + } + } + + // Give time between tests + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + + if (taskCreationExercised) { + INFO(NCCL_LOG_INFO, + "*** VERIFICATION: ncclNetSocketGetTask was successfully exercised " + "for buffer size %zu ***", + testSize); + } + + return taskCreationExercised; + } + + // Common function to deregister memory and test ncclNetSocketDeregMr + void DeregisterMemory(void *sendComm, void *recvComm, + const std::vector &sendMhandles, + const std::vector &recvMhandles, + size_t testSize) { + INFO(NCCL_LOG_INFO, + "\n=== Testing ncclNetSocketDeregMr for size %zu ===", testSize); + + // Deregister send memory handles + for (size_t j = 0; j < sendMhandles.size(); j++) { + if (sendComm) { + INFO(NCCL_LOG_INFO, "Deregistering send memory handle %zu for size %zu", + j, testSize); + ncclResult_t deregResult = + ncclNetSocket.deregMr(sendComm, sendMhandles[j]); + INFO(NCCL_LOG_INFO, "Send memory deregMr result: %d", deregResult); + EXPECT_EQ(deregResult, ncclSuccess) << "Failed to deregister send memory handle " << j + << " for buffer size " << testSize << ". " + << "ncclNetSocket.deregMr() returned error code: " << deregResult + << ". This may indicate memory registration/deregistration mismatch."; + } + } + + // Deregister receive memory handles + for (size_t j = 0; j < recvMhandles.size(); j++) { + if (recvComm) { + INFO(NCCL_LOG_INFO, "Deregistering recv memory handle %zu for size %zu", + j, testSize); + ncclResult_t deregResult = + ncclNetSocket.deregMr(recvComm, recvMhandles[j]); + INFO(NCCL_LOG_INFO, "Recv memory deregMr result: %d", deregResult); + EXPECT_EQ(deregResult, ncclSuccess) << "Failed to deregister send memory handle " << j + << " for buffer size " << testSize << ". " + << "ncclNetSocket.deregMr() returned error code: " << deregResult + << ". This may indicate memory registration/deregistration mismatch."; + } + } + } + + // Common function to cleanup communicators + void CleanupCommunicators(const std::vector &sendComms, + const std::vector &recvComms, + void *listenComm) { + INFO(NCCL_LOG_INFO, "\nCleaning up communicators..."); + + for (size_t i = 0; i < sendComms.size(); i++) { + if (sendComms[i]) { + INFO(NCCL_LOG_INFO, "Closing send communicator %zu", i); + ncclResult_t closeResult = ncclNetSocket.closeSend(sendComms[i]); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator " << i + << ". ncclNetSocket.closeSend() returned error code: " << closeResult + << ". This may indicate communicator state corruption or resource cleanup issues."; + } + } + + for (size_t i = 0; i < recvComms.size(); i++) { + if (recvComms[i]) { + INFO(NCCL_LOG_INFO, "Closing recv communicator %zu", i); + ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComms[i]); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator " << i + << ". ncclNetSocket.closeRecv() returned error code: " << closeResult + << ". This may indicate communicator state corruption or resource cleanup issues."; + } + } + + if (listenComm) { + INFO(NCCL_LOG_INFO, "Closing listen communicator"); + ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator. " + << "ncclNetSocket.closeListen() returned error code: " << closeResult + << ". This may indicate listen socket state corruption or resource cleanup issues."; + listenComm = nullptr; + } + } + + // Common function to get test buffer sizes + std::vector GetTestSizes() { + return { + 1024, // Small - basic test + 64 * 1024, // MIN_CHUNKSIZE - boundary case + 128 * 1024, // 2x MIN_CHUNKSIZE - will exercise subdivision + 256 * 1024, // 4x MIN_CHUNKSIZE - multiple chunks + 512 * 1024, // 8x MIN_CHUNKSIZE - many chunks + 1024 * 1024, // Large - stress test + 2 * 1024 * 1024 // Very large - comprehensive test + }; + } + + // Helper function to safely parse environment variables + int ParseEnvVar(const char* envVar, const char* envName, int defaultValue = 0, int minValue = 0) { + if (!envVar || strlen(envVar) == 0) { + return defaultValue; + } + + char* endPtr = nullptr; + errno = 0; + long result = std::strtol(envVar, &endPtr, 10); + + // Check for various error conditions - using ADD_FAILURE instead of GTEST_FAIL + if (errno == ERANGE) { + ADD_FAILURE() << "Environment variable " << envName << "='" << envVar + << "' is out of range for integer conversion. " + << "Please provide a valid integer value."; + return defaultValue; + } + + if (endPtr == envVar) { + ADD_FAILURE() << "Environment variable " << envName << "='" << envVar + << "' is not a valid number. " + << "Please provide a valid integer value (e.g., " << envName << "=8)."; + return defaultValue; + } + + if (*endPtr != '\0') { + ADD_FAILURE() << "Environment variable " << envName << "='" << envVar + << "' contains non-numeric characters. " + << "Please provide a valid integer value (e.g., " << envName << "=8)."; + return defaultValue; + } + + if (result < minValue) { + ADD_FAILURE() << "Environment variable " << envName << "='" << envVar + << "' must be >= " << minValue << ". " + << "Current value: " << result << ". Please provide a valid positive integer."; + return defaultValue; + } + + if (result > INT_MAX) { + ADD_FAILURE() << "Environment variable " << envName << "='" << envVar + << "' is too large (> " << INT_MAX << "). " + << "Please provide a smaller integer value."; + return defaultValue; + } + + return static_cast(result); + } + +}; + +// Test concurrent operations task creation in default configuration (without +// env vars) +TEST_F(NetSocketTests, TestConcurrentOperationsTaskCreationDefault) { + INFO(NCCL_LOG_INFO, + "Testing task creation functionality in default configuration"); + INFO(NCCL_LOG_INFO, + "This test exercises ncclNetSocketGetTask regardless of nSocks value"); + + // Test socket properties + TestSocketProperties(); + + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + ASSERT_EQ(result, ncclSuccess) << "Failed to establish listening socket for test execution. " + << "ncclNetSocket.listen() returned error code: " << result + << ". Verify network device availability and port accessibility."; + + INFO(NCCL_LOG_INFO, "Testing task creation functionality in default mode"); + + std::vector sendComms; + std::vector recvComms; + + // Establish connection + void *sendComm = nullptr; + void *recvComm = nullptr; + bool connectionSuccess = + EstablishConnectionPair(handle, listenComm, sendComm, recvComm); + + if (connectionSuccess) { + sendComms.push_back(sendComm); + recvComms.push_back(recvComm); + + // Test with various buffer sizes + std::vector testSizes = GetTestSizes(); + + for (size_t testSize : testSizes) { + INFO(NCCL_LOG_INFO, + "\n=== Testing with buffer size: %zu bytes ===", testSize); + INFO(NCCL_LOG_INFO, + "This exercises ncclNetSocketGetTask task creation logic"); + + std::vector sendMhandles; + std::vector recvMhandles; + std::vector sendRequests; + std::vector recvRequests; + std::vector> sendBuffers; + std::vector> recvBuffers; + + // Setup operations for this test size + bool setupSuccess = SetupOperationsForSize( + sendComm, recvComm, testSize, sendBuffers, recvBuffers, sendMhandles, + recvMhandles, sendRequests, recvRequests, 0xCD); + + if (setupSuccess) { + // Progress operations + bool taskExercised = + ProgressOperations(sendRequests[0], recvRequests[0], testSize); + + if (!taskExercised) { + INFO(NCCL_LOG_INFO, + "*** NOTE: Operations didn't progress as expected for size %zu, " + "but API was still exercised ***", + testSize); + } + } else { + INFO(NCCL_LOG_INFO, + "No operations started - skipping progress testing for size %zu", + testSize); + } + + // Deregister memory + DeregisterMemory(sendComm, recvComm, sendMhandles, recvMhandles, + testSize); + + INFO(NCCL_LOG_INFO, + "=== Completed testing for buffer size: %zu bytes ===", testSize); + } + + INFO(NCCL_LOG_INFO, "\n*** TEST SUCCESS: ncclNetSocketGetTask was " + "successfully exercised in default configuration! ***"); + } else { + INFO(NCCL_LOG_INFO, "No connections established - test passed (network may " + "not be available)"); + } + + // Cleanup + CleanupCommunicators(sendComms, recvComms, listenComm); + INFO(NCCL_LOG_INFO, + "TestConcurrentOperationsTaskCreationDefault completed successfully"); +} + +// Test multiple concurrent operations to stress test task creation +TEST_F(NetSocketTests, TestConcurrentOperationsTaskCreation) { + INFO(NCCL_LOG_INFO, "Checking socket configuration environment variables"); + + // Check if the required environment variables are set + const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS"); + const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD"); + + if (!nThreadsEnv || !nSocksPerThreadEnv) { + GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. " + << "Please set the following environment variables to run this test: " + << "export NCCL_SOCKET_NTHREADS=1 and export NCCL_NSOCKS_PERTHREAD=2. " + << "This ensures nSocks > 0 so that ncclNetSocketGetTask gets called. " + << "Environment variables NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be set"; + return; + } + + int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1); + int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1); + + // Additional validation for reasonable upper bounds + const int MAX_THREADS = 16; + const int MAX_SOCKS_PER_THREAD = 64; + const int MAX_TOTAL_SOCKETS = 64; + + if (nThreads > MAX_THREADS) { + GTEST_SKIP() << "SKIPPING TEST: NCCL_SOCKET_NTHREADS=" << nThreads << " exceeds maximum " << MAX_THREADS << ". " + << "Please provide a reasonable value (e.g., NCCL_SOCKET_NTHREADS=8). " + << "Values too large may cause resource exhaustion."; + return; + } + + if (nSocksPerThread > MAX_SOCKS_PER_THREAD) { + GTEST_SKIP() << "SKIPPING TEST: NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread << " exceeds maximum " << MAX_SOCKS_PER_THREAD << ". " + << "Please provide a reasonable value (e.g., NCCL_NSOCKS_PERTHREAD=4). " + << "Values too large may cause resource exhaustion."; + return; + } + + // Check for potential overflow before multiplication + if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) { + GTEST_SKIP() << "SKIPPING TEST: Configuration would cause integer overflow. " + << "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread + << " exceeds maximum integer value. Please use smaller values."; + return; + } + + int totalSockets = nThreads * nSocksPerThread; + + INFO(NCCL_LOG_INFO, "Environment configuration found:"); + INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads); + INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread); + INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets); + + // Validate total sockets count + if (totalSockets <= 0) { + GTEST_SKIP() << "SKIPPING TEST: Invalid configuration - total sockets must be > 0. " + << "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread + << " = " << totalSockets << ". " + << "Both NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be positive integers. " + << "Example: export NCCL_SOCKET_NTHREADS=2 && export NCCL_NSOCKS_PERTHREAD=2"; + return; + } + + if (totalSockets > MAX_TOTAL_SOCKETS) { + GTEST_SKIP() << "SKIPPING TEST: Total sockets " << totalSockets << " exceeds maximum " << MAX_TOTAL_SOCKETS << ". " + << "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread + << " = " << totalSockets << ". " + << "Please reduce either NCCL_SOCKET_NTHREADS or NCCL_NSOCKS_PERTHREAD. " + << "Example: export NCCL_SOCKET_NTHREADS=8 && export NCCL_NSOCKS_PERTHREAD=4"; + return; + } + + if (totalSockets > NCCL_NET_MAX_REQUESTS) { + GTEST_SKIP() << "SKIPPING TEST: Total sockets " << totalSockets << " exceeds NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". " + << "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread + << " = " << totalSockets << ". " + << "NCCL network layer can handle at most " << NCCL_NET_MAX_REQUESTS << " concurrent requests. " + << "Please reduce configuration to stay within NCCL limits."; + return; + } + + INFO(NCCL_LOG_INFO, "Configuration valid - proceeding with test to exercise " + "ncclNetSocketGetTask"); + + // Test socket properties + TestSocketProperties(); + + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + ASSERT_EQ(result, ncclSuccess) << "Failed to establish listening socket for test execution. " + << "ncclNetSocket.listen() returned error code: " << result + << ". Verify network device availability and port accessibility."; + + INFO(NCCL_LOG_INFO, "Testing task creation functionality - ensuring " + "ncclNetSocketGetTask is called"); + + std::vector sendComms; + std::vector recvComms; + + // Establish connection + void *sendComm = nullptr; + void *recvComm = nullptr; + bool connectionSuccess = + EstablishConnectionPair(handle, listenComm, sendComm, recvComm); + + if (connectionSuccess) { + sendComms.push_back(sendComm); + recvComms.push_back(recvComm); + + // Test with buffer sizes that will trigger task subdivision + std::vector testSizes = GetTestSizes(); + + for (size_t testSize : testSizes) { + INFO(NCCL_LOG_INFO, + "\n=== Testing with buffer size: %zu bytes ===", testSize); + INFO(NCCL_LOG_INFO, "This should trigger ncclNetSocketGetTask to create " + "task subdivision"); + + std::vector sendMhandles; + std::vector recvMhandles; + std::vector sendRequests; + std::vector recvRequests; + std::vector> sendBuffers; + std::vector> recvBuffers; + + // Setup operations for this test size + bool setupSuccess = SetupOperationsForSize( + sendComm, recvComm, testSize, sendBuffers, recvBuffers, sendMhandles, + recvMhandles, sendRequests, recvRequests, 0xAB); + + if (setupSuccess) { + // Progress operations with context about environment variables + ProgressOperations(sendRequests[0], recvRequests[0], testSize, + " (with nSocks > 0 from environment variables)"); + } else { + INFO(NCCL_LOG_INFO, + "No operations started - skipping progress testing for size %zu", + testSize); + } + + // Deregister memory + DeregisterMemory(sendComm, recvComm, sendMhandles, recvMhandles, + testSize); + + INFO(NCCL_LOG_INFO, + "=== Completed testing for buffer size: %zu bytes ===", testSize); + } + + INFO(NCCL_LOG_INFO, "\n*** TEST SUCCESS: ncclNetSocketGetTask was " + "successfully exercised! ***"); + } else { + INFO(NCCL_LOG_INFO, "No connections established - test passed (network may " + "not be available)"); + } + + // Cleanup + CleanupCommunicators(sendComms, recvComms, listenComm); + INFO(NCCL_LOG_INFO, + "TestConcurrentOperationsTaskCreation completed successfully"); +} + +// Test for invalid device index in listen function +TEST_F(NetSocketTests, TestInvalidDeviceIndexListen) { + INFO(NCCL_LOG_INFO, "Testing invalid device index in ncclNetSocketListen"); + + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + + // Test with negative device index + ncclResult_t result = ncclNetSocket.listen(-1, handle, &listenComm); + INFO(NCCL_LOG_INFO, "Listen with dev=-1 returned: %d", result); + EXPECT_EQ(result, ncclInternalError) + << "Listen should fail with negative device index. " + << "ncclNetSocket.listen() with device index -1 should return ncclInternalError " + << "but returned: " << result << ". Verify input validation for device indices."; + + // Test with device index greater than available devices + int invalidDev = ndev + 10; + result = ncclNetSocket.listen(invalidDev, handle, &listenComm); + INFO(NCCL_LOG_INFO, "Listen with dev=%d (> ndev=%d) returned: %d", invalidDev, + ndev, result); + EXPECT_EQ(result, ncclInternalError) + << "Listen should fail with device index >= ndev. " + << "ncclNetSocket.listen() with device index " << invalidDev << " (> ndev=" << ndev + << ") should return ncclInternalError but returned: " << result + << ". Verify bounds checking for device indices."; + + INFO(NCCL_LOG_INFO, "TestInvalidDeviceIndexListen completed"); +} + +// Test for invalid device index in connect function +TEST_F(NetSocketTests, TestInvalidDeviceIndexConnect) { + INFO(NCCL_LOG_INFO, "Testing invalid device index in ncclNetSocketConnect"); + + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *sendComm = nullptr; + ncclNetCommConfig_t config = {}; + ncclNetDeviceHandle_t *sendDevComm = nullptr; + + // Test with negative device index + ncclResult_t result = + ncclNetSocket.connect(-1, &config, handle, &sendComm, &sendDevComm); + INFO(NCCL_LOG_INFO, "Connect with dev=-1 returned: %d", result); + EXPECT_EQ(result, ncclInternalError) + << "Connect should fail with negative device index. " + << "ncclNetSocket.connect() with device index -1 should return ncclInternalError " + << "but returned: " << result << ". Verify input validation for device indices."; + + // Test with device index greater than available devices + int invalidDev = ndev + 10; + result = ncclNetSocket.connect(invalidDev, &config, handle, &sendComm, + &sendDevComm); + INFO(NCCL_LOG_INFO, "Connect with dev=%d (> ndev=%d) returned: %d", + invalidDev, ndev, result); + EXPECT_EQ(result, ncclInternalError) + << "Connect should fail with device index >= ndev. " + << "ncclNetSocket.connect() with device index " << invalidDev << " (> ndev=" << ndev + << ") should return ncclInternalError but returned: " << result + << ". Verify bounds checking for device indices."; + + INFO(NCCL_LOG_INFO, "TestInvalidDeviceIndexConnect completed"); +} + +// Test for NULL request in test function +TEST_F(NetSocketTests, TestNullRequestInTest) { + INFO(NCCL_LOG_INFO, "Testing NULL request in ncclNetSocketTest"); + + int done = 0; + int size = 0; + + // Test with NULL request + ncclResult_t result = ncclNetSocket.test(nullptr, &done, &size); + INFO(NCCL_LOG_INFO, "Test with NULL request returned: %d", result); + EXPECT_EQ(result, ncclInternalError) << "Test should fail with NULL request. " + << "ncclNetSocket.test() with nullptr request should return ncclInternalError " + << "but returned: " << result << ". Verify NULL pointer validation."; + + INFO(NCCL_LOG_INFO, "TestNullRequestInTest completed"); +} + +// Test for invalid array size in irecv function +TEST_F(NetSocketTests, TestInvalidArraySizeIrecv) { + INFO(NCCL_LOG_INFO, "Testing invalid array size in ncclNetSocketIrecv"); + + // Setup a dummy communicator first + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + + if (result == ncclSuccess && listenComm) { + void *sendComm = nullptr; + void *recvComm = nullptr; + bool connectionSuccess = + EstablishConnectionPair(handle, listenComm, sendComm, recvComm); + + if (connectionSuccess && recvComm) { + // Test with n != 1 (invalid for socket implementation) + std::vector buffer1(1024, 0xAA); + std::vector buffer2(1024, 0xBB); + void *data[2] = {buffer1.data(), buffer2.data()}; + size_t sizes[2] = {1024, 1024}; + int tags[2] = {0, 1}; + void *mhandles[2] = {nullptr, nullptr}; + void *phandles[2] = {nullptr, nullptr}; + void *request = nullptr; + + // Test with n=2 (should fail for socket implementation) + result = ncclNetSocket.irecv(recvComm, 2, data, sizes, tags, mhandles, + phandles, &request); + INFO(NCCL_LOG_INFO, "Irecv with n=2 returned: %d", result); + EXPECT_EQ(result, ncclInternalError) << "Irecv should fail with n != 1. " + << "ncclNetSocket.irecv() with n=2 should return ncclInternalError " + << "but returned: " << result << ". Socket implementation only supports n=1."; + + // Test with n=0 (should fail) + result = ncclNetSocket.irecv(recvComm, 0, data, sizes, tags, mhandles, + phandles, &request); + INFO(NCCL_LOG_INFO, "Irecv with n=0 returned: %d", result); + EXPECT_EQ(result, ncclInternalError) << "Irecv should fail with n != 1. " + << "ncclNetSocket.irecv() with n=0 should return ncclInternalError " + << "but returned: " << result << ". Socket implementation only supports n=1."; + + // Cleanup communicators + if (sendComm) { + ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; + sendComm = nullptr; + } + + if (recvComm) { + ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; + recvComm = nullptr; + } + } + + if (listenComm) { + ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; + listenComm = nullptr; + } + } + + INFO(NCCL_LOG_INFO, "TestInvalidArraySizeIrecv completed"); +} + +// Test for non-host memory type in regMr function +TEST_F(NetSocketTests, TestNonHostMemoryRegMr) { + INFO(NCCL_LOG_INFO, "Testing non-host memory type in ncclNetSocketRegMr"); + + // Setup a dummy communicator first + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + + if (result == ncclSuccess && listenComm) { + void *sendComm = nullptr; + void *recvComm = nullptr; + bool connectionSuccess = + EstablishConnectionPair(handle, listenComm, sendComm, recvComm); + + if (connectionSuccess && sendComm) { + std::vector buffer(1024, 0xAA); + void *mhandle = nullptr; + + // Test with NCCL_PTR_CUDA (should fail for socket implementation) + result = ncclNetSocket.regMr(sendComm, buffer.data(), 1024, NCCL_PTR_CUDA, + &mhandle); + INFO(NCCL_LOG_INFO, "RegMr with NCCL_PTR_CUDA returned: %d", result); + EXPECT_EQ(result, ncclInternalError) + << "RegMr should fail with non-host memory type. " + << "ncclNetSocket.regMr() with NCCL_PTR_CUDA should return ncclInternalError " + << "but returned: " << result << ". Socket implementation only supports NCCL_PTR_HOST."; + + // Test with valid NCCL_PTR_HOST (should succeed) + result = ncclNetSocket.regMr(sendComm, buffer.data(), 1024, NCCL_PTR_HOST, + &mhandle); + INFO(NCCL_LOG_INFO, "RegMr with NCCL_PTR_HOST returned: %d", result); + EXPECT_EQ(result, ncclSuccess) + << "RegMr should succeed with host memory type. " + << "ncclNetSocket.regMr() with NCCL_PTR_HOST should return ncclSuccess " + << "but returned: " << result << ". Verify host memory registration support."; + + // Cleanup communicators + if (sendComm) { + ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; + sendComm = nullptr; + } + + if (recvComm) { + ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; + recvComm = nullptr; + } + } + + if (listenComm) { + ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; + listenComm = nullptr; + } + } + + INFO(NCCL_LOG_INFO, "TestNonHostMemoryRegMr completed"); +} + +// Test for excessive thread configuration warning +TEST_F(NetSocketTests, TestExcessiveThreadConfig) { + INFO(NCCL_LOG_INFO, "Testing excessive thread configuration warning"); + + // Check if the required environment variables are set + const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS"); + const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD"); + + if (!nThreadsEnv || !nSocksPerThreadEnv) { + GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. " + << "This test requires NCCL_SOCKET_NTHREADS > NCCL_NET_MAX_REQUESTS (" << NCCL_NET_MAX_REQUESTS << ") and NCCL_NSOCKS_PERTHREAD = 1 to trigger warning. " + << "Environment variables NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be set"; + return; + } + + // Parse with validation - both must be positive + int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1); + int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1); + + // Check for potential overflow before multiplication + if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) { + GTEST_SKIP() << "SKIPPING TEST: Configuration would cause integer overflow. " + << "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread + << " exceeds maximum integer value. Please use smaller values."; + return; + } + + int totalSockets = nThreads * nSocksPerThread; + + INFO(NCCL_LOG_INFO, "Environment configuration found:"); + INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads); + INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread); + INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets); + + // Check if configuration is set to trigger the excessive threads warning + // Use NCCL_NET_MAX_REQUESTS instead of arbitrary MAX_THREADS + if (nThreads <= NCCL_NET_MAX_REQUESTS) { + GTEST_SKIP() << "SKIPPING TEST: NCCL_SOCKET_NTHREADS must be > " << NCCL_NET_MAX_REQUESTS << " to test excessive thread warning. " + << "Current NCCL_SOCKET_NTHREADS=" << nThreads << ". " + << "Please set: export NCCL_SOCKET_NTHREADS=" << (NCCL_NET_MAX_REQUESTS + 1) << ". " + << "NCCL_SOCKET_NTHREADS must be > NCCL_NET_MAX_REQUESTS (" << NCCL_NET_MAX_REQUESTS << ") to trigger warning"; + return; + } + + if (totalSockets > NCCL_NET_MAX_REQUESTS * 10) { // Allow 10x for testing excessive config + GTEST_SKIP() << "SKIPPING TEST: Total sockets=" << totalSockets << " is unreasonably large (> " << (NCCL_NET_MAX_REQUESTS * 10) << "). " + << "Please use more reasonable values for testing. NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". " + << "Example: export NCCL_SOCKET_NTHREADS=" << (NCCL_NET_MAX_REQUESTS + 1) << " && export NCCL_NSOCKS_PERTHREAD=1"; + return; + } + + INFO(NCCL_LOG_INFO, + "Configuration valid for testing excessive threads warning"); + INFO(NCCL_LOG_INFO, "NCCL_SOCKET_NTHREADS=%d > NCCL_NET_MAX_REQUESTS=%d", nThreads, NCCL_NET_MAX_REQUESTS); + + // Test socket properties + TestSocketProperties(); + + // Initialize to trigger the warning logic + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + + if (result == ncclSuccess && listenComm) { + // The implementation should have limited the threads to NCCL_NET_MAX_REQUESTS + // internally + INFO(NCCL_LOG_INFO, + "*** SUCCESS: Listen succeeded with excessive NCCL_SOCKET_NTHREADS - " + "limits enforced internally ***"); + ncclNetSocket.closeListen(listenComm); + } else { + INFO(NCCL_LOG_INFO, "Listen failed with result: %d", result); + } + + INFO(NCCL_LOG_INFO, "TestExcessiveThreadConfig completed"); +} + +// Test for excessive socket configuration warning +TEST_F(NetSocketTests, TestExcessiveSocketConfig) { + INFO(NCCL_LOG_INFO, "Testing excessive socket configuration warning"); + + // Check if the required environment variables are set + const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS"); + const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD"); + + if (!nThreadsEnv || !nSocksPerThreadEnv) { + GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. " + << "This test requires total sockets (nThreads * nSocksPerThread) > MAX_SOCKETS (64). " + << "Environment variables NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be set"; + return; + } + + // Parse with validation - both must be positive + int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1); + int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1); + + // Check for potential overflow before multiplication + if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) { + GTEST_SKIP() << "SKIPPING TEST: Configuration would cause integer overflow. " + << "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread + << " exceeds maximum integer value. Please use smaller values."; + return; + } + + int totalSockets = nThreads * nSocksPerThread; + + INFO(NCCL_LOG_INFO, "Environment configuration found:"); + INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads); + INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread); + INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets); + + // Check if configuration is set to trigger the excessive sockets warning + const int MAX_SOCKETS = 64; + if (totalSockets <= MAX_SOCKETS) { + GTEST_SKIP() << "SKIPPING TEST: Total sockets must be > " << MAX_SOCKETS << " to test excessive socket warning. " + << "Current total sockets=" << totalSockets + << " (nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread << "). " + << "Please set environment variables such that total > " << MAX_SOCKETS << ", e.g.: " + << "export NCCL_SOCKET_NTHREADS=9 && export NCCL_NSOCKS_PERTHREAD=8. " + << "Total sockets must be > MAX_SOCKETS (" << MAX_SOCKETS << ") to trigger warning"; + return; + } + + // Additional validation against NCCL_NET_MAX_REQUESTS for reasonable upper bounds + if (totalSockets > NCCL_NET_MAX_REQUESTS * 10) { // Allow 10x for testing excessive config + GTEST_SKIP() << "SKIPPING TEST: Total sockets=" << totalSockets << " is unreasonably large (> " << (NCCL_NET_MAX_REQUESTS * 10) << "). " + << "Please use more reasonable values for testing. NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". " + << "Example: export NCCL_SOCKET_NTHREADS=10 && export NCCL_NSOCKS_PERTHREAD=10"; + return; + } + + INFO(NCCL_LOG_INFO, + "Configuration valid for testing excessive sockets warning"); + INFO(NCCL_LOG_INFO, "Total sockets=%d > MAX_SOCKETS=64", totalSockets); + + // Test socket properties + TestSocketProperties(); + + // Initialize to trigger the warning logic + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + + if (result == ncclSuccess && listenComm) { + // The implementation should have limited the sockets to MAX_SOCKETS + // internally + INFO(NCCL_LOG_INFO, "*** SUCCESS: Listen succeeded with excessive total " + "sockets - limits enforced internally ***"); + ncclNetSocket.closeListen(listenComm); + } else { + INFO(NCCL_LOG_INFO, "Listen failed with result: %d", result); + } + + INFO(NCCL_LOG_INFO, "TestExcessiveSocketConfig completed"); +} + +// Test to trigger request allocation failure scenario +TEST_F(NetSocketTests, TestRequestAllocationFailure) { + INFO(NCCL_LOG_INFO, "Testing request allocation failure scenario"); + + // Setup communication + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + + if (result == ncclSuccess && listenComm) { + void *sendComm = nullptr; + void *recvComm = nullptr; + bool connectionSuccess = + EstablishConnectionPair(handle, listenComm, sendComm, recvComm); + + if (connectionSuccess && sendComm && recvComm) { + INFO(NCCL_LOG_INFO, "Attempting to exhaust request pool (MAX_REQUESTS)"); + + std::vector requests; + std::vector> buffers; + std::vector mhandles; + + // Try to allocate many requests to potentially exhaust the pool + // MAX_REQUESTS is defined as NCCL_NET_MAX_REQUESTS in the code + for (int i = 0; i < (NCCL_NET_MAX_REQUESTS * 10); i++) { // Try to exceed NCCL_NET_MAX_REQUESTS by a reasonable margin + buffers.emplace_back(1024, 0xAA + (i % 256)); + void *mhandle = nullptr; + + // Register memory first + result = ncclNetSocket.regMr(sendComm, buffers.back().data(), 1024, + NCCL_PTR_HOST, &mhandle); + EXPECT_EQ(result, ncclSuccess) << "Memory registration failed at iteration " << i + << ". ncclNetSocket.regMr() returned error code: " << result + << ". Verify memory registration limits and resource availability."; + if (result != ncclSuccess) + break; + mhandles.push_back(mhandle); + + // Try to create send request + void *request = nullptr; + result = ncclNetSocket.isend(sendComm, buffers.back().data(), 1024, 0, + mhandle, nullptr, &request); + + if (result == ncclInternalError) { + INFO(NCCL_LOG_INFO, + "Request allocation failed at iteration %d (expected behavior " + "when pool exhausted)", + i); + break; + } else if (result == ncclSuccess) { + requests.push_back(request); + } else { + INFO(NCCL_LOG_INFO, "Unexpected result at iteration %d: %d", i, + result); + break; + } + } + + INFO(NCCL_LOG_INFO, + "Successfully allocated %zu requests before failure/completion", + requests.size()); + + // Cleanup: Test any pending requests and deregister memory + for (size_t i = 0; i < requests.size(); i++) { + if (requests[i]) { + int done = 0; + int size = 0; + ncclNetSocket.test(requests[i], &done, + &size); // Don't care about result + } + } + + for (size_t i = 0; i < mhandles.size(); i++) { + if (mhandles[i]) { + ncclNetSocket.deregMr(sendComm, mhandles[i]); + } + } + + // Cleanup communicators + if (sendComm) { + ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; + sendComm = nullptr; + } + if (recvComm) { + ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; + recvComm = nullptr; + } + } + + if (listenComm) { + ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; + listenComm = nullptr; + } + } + + INFO(NCCL_LOG_INFO, "TestRequestAllocationFailure completed"); +} + +// Test for message size mismatch scenario +TEST_F(NetSocketTests, TestMessageSizeMismatch) { + INFO(NCCL_LOG_INFO, "Testing message size mismatch scenario"); + + // This test simulates the condition where a receiver expects a smaller + // message than what the sender is trying to send, which should trigger the + // truncation warning + + char handle[NCCL_NET_HANDLE_MAXSIZE]; + void *listenComm = nullptr; + ncclResult_t result = ncclNetSocket.listen(0, handle, &listenComm); + + if (result == ncclSuccess && listenComm) { + void *sendComm = nullptr; + void *recvComm = nullptr; + bool connectionSuccess = + EstablishConnectionPair(handle, listenComm, sendComm, recvComm); + + if (connectionSuccess && sendComm && recvComm) { + // Large send buffer + const size_t sendSize = 2048; + std::vector sendBuffer(sendSize, 0xAA); + + // Small receive buffer (to simulate size mismatch) + const size_t recvSize = 1024; // Smaller than send size + std::vector recvBuffer(recvSize, 0x00); + + void *sendMhandle = nullptr; + void *recvMhandle = nullptr; + + // Register memory + result = ncclNetSocket.regMr(sendComm, sendBuffer.data(), sendSize, + NCCL_PTR_HOST, &sendMhandle); + EXPECT_EQ(result, ncclSuccess) << "Failed to register send memory for size mismatch test. " + << "ncclNetSocket.regMr() returned error code: " << result + << ". Verify memory registration support for buffer size " << sendSize << "."; + + result = ncclNetSocket.regMr(recvComm, recvBuffer.data(), recvSize, + NCCL_PTR_HOST, &recvMhandle); + EXPECT_EQ(result, ncclSuccess) << "Failed to register receive memory for size mismatch test. " + << "ncclNetSocket.regMr() returned error code: " << result + << ". Verify memory registration support for buffer size " << recvSize << "."; + + // Start send operation with large size + void *sendRequest = nullptr; + result = ncclNetSocket.isend(sendComm, sendBuffer.data(), sendSize, 0, + sendMhandle, nullptr, &sendRequest); + EXPECT_EQ(result, ncclSuccess) << "Failed to start send operation for size mismatch test. " + << "ncclNetSocket.isend() returned error code: " << result + << ". Verify send operation support for buffer size " << sendSize << "."; + + // Start receive operation with small size + void *recvRequest = nullptr; + void *recvDataPtr = recvBuffer.data(); + size_t recvSizeVar = recvSize; + int tag = 0; + result = ncclNetSocket.irecv(recvComm, 1, &recvDataPtr, &recvSizeVar, + &tag, &recvMhandle, nullptr, &recvRequest); + EXPECT_EQ(result, ncclSuccess) << "Failed to start receive operation for size mismatch test. " + << "ncclNetSocket.irecv() returned error code: " << result + << ". Verify receive operation support for buffer size " << recvSize << "."; + + // Progress operations - this should eventually trigger the size mismatch + // warning + for (int i = 0; i < 100; i++) { + if (sendRequest) { + int sendDone = 0, sendSize_out = 0; + ncclResult_t sendTestResult = ncclNetSocket.test(sendRequest, &sendDone, &sendSize_out); + if (sendTestResult != ncclSuccess || sendDone) { + INFO(NCCL_LOG_INFO, "Send operation completed: result=%d, done=%d", sendTestResult, sendDone); + sendRequest = nullptr; // Request is cleaned up by the networking layer + } + } + + if (recvRequest) { + int recvDone = 0, recvSize_out = 0; + ncclResult_t recvTestResult = ncclNetSocket.test(recvRequest, &recvDone, &recvSize_out); + if (recvTestResult != ncclSuccess || recvDone) { + INFO(NCCL_LOG_INFO, "Recv operation completed: result=%d, done=%d, size=%d", + recvTestResult, recvDone, recvSize_out); + recvRequest = nullptr; // Request is cleaned up by the networking layer + } + } + + if (!sendRequest && !recvRequest) { + INFO(NCCL_LOG_INFO, "Both operations completed after %d iterations", i + 1); + break; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Cleanup memory handles + if (sendMhandle && sendComm) { + ncclResult_t deregResult = ncclNetSocket.deregMr(sendComm, sendMhandle); + if (deregResult != ncclSuccess) { + INFO(NCCL_LOG_INFO, "Warning: Failed to deregister send memory handle: %d", deregResult); + } + sendMhandle = nullptr; + } + + if (recvMhandle && recvComm) { + ncclResult_t deregResult = ncclNetSocket.deregMr(recvComm, recvMhandle); + if (deregResult != ncclSuccess) { + INFO(NCCL_LOG_INFO, "Warning: Failed to deregister recv memory handle: %d", deregResult); + } + recvMhandle = nullptr; + } + + // Cleanup communicators + if (sendComm) { + ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator"; + sendComm = nullptr; + } + if (recvComm) { + ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator"; + recvComm = nullptr; + } + } + + if (listenComm) { + ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm); + EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator"; + listenComm = nullptr; + } + } + + INFO(NCCL_LOG_INFO, "TestMessageSizeMismatch completed"); +} + +// Test to cover the iflush function that always returns ncclInternalError +TEST_F(NetSocketTests, TestIflushAlwaysFails) { + INFO(NCCL_LOG_INFO, + "Testing ncclNetSocketIflush always returns ncclInternalError"); + + // This function should always return ncclInternalError for socket + // implementation as it doesn't support CUDA pointers and flush operations + + std::vector buffer(1024, 0xAA); + void *data = buffer.data(); + int size = 1024; + void *mhandle = nullptr; + void *request = nullptr; + + // Test with dummy parameters - should always fail + ncclResult_t result = + ncclNetSocket.iflush(nullptr, 1, &data, &size, &mhandle, &request); + INFO(NCCL_LOG_INFO, "ncclNetSocketIflush returned: %d", result); + EXPECT_EQ(result, ncclInternalError) + << "iflush should always return ncclInternalError"; + + INFO(NCCL_LOG_INFO, "TestIflushAlwaysFails completed"); +} + +} // namespace RcclUnitTesting