d7293281f3
[ROCm/rccl commit: 858b4e76eb]
1601 satır
66 KiB
C++
1601 satır
66 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
#include "net.h"
|
|
#include "common/ProcessIsolatedTestRunner.hpp"
|
|
#include "gtest/gtest.h"
|
|
#include <atomic>
|
|
#include <cstring>
|
|
#include <thread>
|
|
|
|
extern ncclNet_t ncclNetSocket;
|
|
|
|
namespace RcclUnitTesting {
|
|
|
|
/**
|
|
* @brief Establishes a reliable connection pair (send and receive communicators) using the provided handle and listen communicator.
|
|
*
|
|
* This function attempts to create a pair of connected communicators (sendComm and recvComm) using the ncclNetSocket API.
|
|
* It uses internal RAII guards to ensure proper cleanup in case of partial failures. The function coordinates accept and connect
|
|
* operations in parallel threads, with extended timeouts and retries for robustness.
|
|
*
|
|
* @param handle Pointer to the device handle used for connection.
|
|
* @param listenComm Pointer to the listen communicator, used for accepting connections.
|
|
* @param[out] sendComm Reference to a pointer that will receive the newly created send communicator on success.
|
|
* WARNING: The ownership of the communicator is transferred to the caller. The caller is responsible for
|
|
* closing and cleaning up sendComm. If sendComm was previously pointing to a resource (e.g., a unique_ptr or
|
|
* other managed pointer), it will be overwritten and the previous resource may be leaked. Ensure sendComm is
|
|
* either nullptr or properly released before calling this function.
|
|
* @param[out] recvComm Reference to a pointer that will receive the newly created receive communicator on success.
|
|
* WARNING: The ownership of the communicator is transferred to the caller. The caller is responsible for
|
|
* closing and cleaning up recvComm. If recvComm was previously pointing to a resource (e.g., a unique_ptr or
|
|
* other managed pointer), it will be overwritten and the previous resource may be leaked. Ensure recvComm is
|
|
* either nullptr or properly released before calling this function.
|
|
*
|
|
* @return true if both send and receive communicators were successfully established and ownership transferred to the caller;
|
|
* false otherwise (in which case all resources are cleaned up internally).
|
|
*/
|
|
class NetSocketTests : public ::testing::Test {
|
|
|
|
private:
|
|
// RAII wrapper for send communicator
|
|
class SendCommGuard {
|
|
private:
|
|
void *comm_ = nullptr;
|
|
|
|
public:
|
|
explicit SendCommGuard(void *comm = nullptr)
|
|
: comm_(comm) {} // default constructor
|
|
|
|
// Move constructor
|
|
SendCommGuard(SendCommGuard &&other) noexcept : comm_(other.comm_) {
|
|
other.comm_ = nullptr;
|
|
}
|
|
|
|
// Move assignment
|
|
SendCommGuard &operator=(SendCommGuard &&other) noexcept {
|
|
if (this != &other) {
|
|
reset();
|
|
comm_ = other.comm_;
|
|
other.comm_ = nullptr;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
// Disable copy constructor and assignment
|
|
SendCommGuard(const SendCommGuard &other) = delete;
|
|
SendCommGuard &operator=(const SendCommGuard &other) = delete;
|
|
|
|
~SendCommGuard() { reset(); }
|
|
|
|
void reset(void *comm = nullptr) {
|
|
if (comm_ && comm_ != comm) {
|
|
ncclResult_t result = ncclNetSocket.closeSend(comm_);
|
|
ASSERT_EQ(result, ncclSuccess) << "SendCommGuard failed to close send communicator (comm_="
|
|
<< comm_ << "). ncclNetSocket.closeSend() returned error code: "
|
|
<< result << ". This indicates a potential resource leak or "
|
|
<< "invalid communicator state during RAII cleanup.";
|
|
}
|
|
comm_ = comm;
|
|
}
|
|
|
|
void *get() const { return comm_; }
|
|
void *release() {
|
|
void *temp = comm_;
|
|
comm_ = nullptr;
|
|
return temp;
|
|
}
|
|
|
|
explicit operator bool() const { return comm_ != nullptr; }
|
|
};
|
|
|
|
// RAII wrapper for receive communicator
|
|
class RecvCommGuard {
|
|
private:
|
|
void *comm_;
|
|
|
|
public:
|
|
explicit RecvCommGuard(void *comm = nullptr) : comm_(comm) {}
|
|
|
|
// Move constructor
|
|
RecvCommGuard(RecvCommGuard &&other) noexcept : comm_(other.comm_) {
|
|
other.comm_ = nullptr;
|
|
}
|
|
|
|
// Move assignment
|
|
RecvCommGuard &operator=(RecvCommGuard &&other) noexcept {
|
|
if (this != &other) {
|
|
reset();
|
|
comm_ = other.comm_;
|
|
other.comm_ = nullptr;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
// Disable copy
|
|
RecvCommGuard(const RecvCommGuard &) = delete;
|
|
RecvCommGuard &operator=(const RecvCommGuard &) = delete;
|
|
|
|
~RecvCommGuard() { reset(); }
|
|
|
|
void reset(void *comm = nullptr) {
|
|
if (comm_ && comm_ != comm) {
|
|
ncclResult_t result = ncclNetSocket.closeRecv(comm_);
|
|
ASSERT_EQ(result, ncclSuccess) << "RecvCommGuard failed to close receive communicator (comm_="
|
|
<< comm_ << "). ncclNetSocket.closeRecv() returned error code: "
|
|
<< result << ". This indicates a potential resource leak or "
|
|
<< "invalid communicator state during RAII cleanup.";
|
|
}
|
|
comm_ = comm;
|
|
}
|
|
|
|
void *get() const { return comm_; }
|
|
void *release() {
|
|
void *temp = comm_;
|
|
comm_ = nullptr;
|
|
return temp;
|
|
}
|
|
|
|
explicit operator bool() const { return comm_ != nullptr; }
|
|
};
|
|
|
|
protected:
|
|
void SetUp() override {
|
|
void* ctx = nullptr;
|
|
uint64_t commId = 0;
|
|
ncclNetCommConfig_t config = {};
|
|
ncclDebugLogger_t logFunction = nullptr;
|
|
ncclProfilerCallback_t profFunction = nullptr;
|
|
ncclResult_t result = ncclNetSocket.init(&ctx, commId, &config, logFunction, profFunction);
|
|
ASSERT_EQ(result, ncclSuccess) << "Failed to initialize ncclNetSocket. "
|
|
<< "Error code: " << result
|
|
<< ". Ensure RCCL networking is properly configured.";
|
|
|
|
result = ncclNetSocket.devices(&ndev);
|
|
ASSERT_EQ(result, ncclSuccess) << "Failed to query network devices. "
|
|
<< "Error code: " << result
|
|
<< ". Check if network devices are available and accessible.";
|
|
|
|
if (ndev == 0) {
|
|
GTEST_SKIP() << "No network devices available for testing. "
|
|
<< "Ensure network hardware is present and properly configured.";
|
|
}
|
|
}
|
|
|
|
int ndev = 0;
|
|
|
|
// Common function to test socket properties
|
|
void TestSocketProperties() {
|
|
INFO(NCCL_LOG_INFO, "\n=== Testing socket properties ===");
|
|
|
|
// Test ncclNetSocketGetProperties for each device
|
|
for (int dev = 0; dev < ndev; dev++) {
|
|
ncclNetProperties_t props = {};
|
|
ncclResult_t propsResult = ncclNetSocket.getProperties(dev, &props);
|
|
INFO(NCCL_LOG_INFO, "Device %d - getProperties result: %d", dev,
|
|
propsResult);
|
|
if (propsResult == ncclSuccess) {
|
|
INFO(NCCL_LOG_INFO,
|
|
" Device %d properties: name='%s', pciPath='%s', guid=%llu, "
|
|
"speed=%d, port=%d, maxComms=%d",
|
|
dev, props.name, props.pciPath, (unsigned long long)props.guid,
|
|
props.speed, props.port, props.maxComms);
|
|
}
|
|
EXPECT_EQ(propsResult, ncclSuccess)
|
|
<< "getProperties failed for device " << dev
|
|
<< ". ncclNetSocket.getProperties() returned error code: " << propsResult
|
|
<< ". Verify device " << dev << " is available and properly configured.";
|
|
}
|
|
}
|
|
|
|
// Common function to establish a connection pair with improved reliability
|
|
bool EstablishConnectionPair(void *handle, void *listenComm, void *&sendComm,
|
|
void *&recvComm) {
|
|
// Allow overriding max attempts via environment variable for flexibility
|
|
int maxAttempts = 100;
|
|
const char* maxAttemptsEnv = getenv("RCCL_TEST_NETSOCKET_MAX_ATTEMPTS");
|
|
if (maxAttemptsEnv) {
|
|
maxAttempts = ParseEnvVar(maxAttemptsEnv, "RCCL_TEST_NETSOCKET_MAX_ATTEMPTS", 100, 1);
|
|
}
|
|
|
|
// Allow overriding sleep duration via environment variable for flexibility
|
|
int sleepMs = 100;
|
|
const char* sleepMsEnv = getenv("RCCL_TEST_NETSOCKET_SLEEP_MS");
|
|
if (sleepMsEnv) {
|
|
sleepMs = ParseEnvVar(sleepMsEnv, "RCCL_TEST_NETSOCKET_SLEEP_MS", 100, 1);
|
|
}
|
|
|
|
// Initialize output parameters
|
|
sendComm = nullptr;
|
|
recvComm = nullptr;
|
|
|
|
// RAII guards for automatic cleanup
|
|
SendCommGuard sendGuard;
|
|
RecvCommGuard recvGuard;
|
|
|
|
std::atomic<bool> connectionEstablished{false};
|
|
std::atomic<bool> acceptCompleted{false};
|
|
std::atomic<bool> connectCompleted{false};
|
|
std::atomic<bool> shouldStop{false};
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"Establishing connection pair with enhanced reliability");
|
|
|
|
std::thread connectAcceptThread([&]() {
|
|
// Accept thread with longer timeout and better coordination
|
|
std::thread acceptThread([&]() {
|
|
ncclNetDeviceHandle_t *recvDevComm = nullptr;
|
|
void *tempRecvComm = nullptr;
|
|
|
|
// Increased attempts and longer total timeout for reliability
|
|
for (int attempt = 0; attempt < maxAttempts && !shouldStop.load(); attempt++) {
|
|
ncclResult_t acceptResult =
|
|
ncclNetSocket.accept(listenComm, &tempRecvComm, &recvDevComm);
|
|
if (acceptResult == ncclSuccess && tempRecvComm != nullptr) {
|
|
recvGuard.reset(tempRecvComm);
|
|
acceptCompleted.store(true);
|
|
INFO(NCCL_LOG_INFO, "Accept completed successfully on attempt %d",
|
|
attempt + 1);
|
|
break;
|
|
}
|
|
|
|
// Longer sleep for network stability
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs));
|
|
}
|
|
|
|
if (!acceptCompleted.load()) {
|
|
INFO(NCCL_LOG_INFO, "Accept thread timed out after %d attempts", maxAttempts);
|
|
}
|
|
});
|
|
|
|
// Connect thread with longer timeout and better coordination
|
|
std::thread connectThread([&]() {
|
|
ncclNetCommConfig_t config = {};
|
|
ncclNetDeviceHandle_t *sendDevComm = nullptr;
|
|
void *tempSendComm = nullptr;
|
|
|
|
// Give accept thread more time to start listening
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs));
|
|
|
|
// Increased attempts and longer total timeout for reliability
|
|
for (int attempt = 0; attempt < 100 && !shouldStop.load(); attempt++) {
|
|
void* ctx = nullptr;
|
|
int dev = 0;
|
|
ncclResult_t connectResult = ncclNetSocket.connect(ctx, dev, handle, &tempSendComm, &sendDevComm);
|
|
if (connectResult == ncclSuccess && tempSendComm != nullptr) {
|
|
sendGuard.reset(tempSendComm);
|
|
connectCompleted = true;
|
|
INFO(NCCL_LOG_INFO, "Connect completed successfully on attempt %d",
|
|
attempt + 1);
|
|
break;
|
|
}
|
|
|
|
// Longer sleep for network stability
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(sleepMs));
|
|
}
|
|
|
|
if (!connectCompleted.load()) {
|
|
INFO(NCCL_LOG_INFO, "Connect thread timed out after %d attempts", maxAttempts);
|
|
}
|
|
});
|
|
|
|
// Wait for both threads with overall timeout
|
|
auto startTime = std::chrono::steady_clock::now();
|
|
const auto maxWaitTime =
|
|
std::chrono::seconds(10); // 10 second overall timeout
|
|
|
|
while (!acceptCompleted.load() || !connectCompleted.load()) {
|
|
auto currentTime = std::chrono::steady_clock::now();
|
|
if (currentTime - startTime > maxWaitTime) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Overall connection timeout reached, stopping threads");
|
|
shouldStop.store(true);
|
|
break;
|
|
}
|
|
|
|
// Check every 100ms
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
}
|
|
|
|
acceptThread.join();
|
|
connectThread.join();
|
|
|
|
// Check if both operations completed successfully
|
|
connectionEstablished = acceptCompleted.load() &&
|
|
connectCompleted.load() && sendGuard && recvGuard;
|
|
});
|
|
|
|
connectAcceptThread.join();
|
|
|
|
if (connectionEstablished) {
|
|
// Transfer ownership to output parameters
|
|
sendComm = sendGuard.release();
|
|
recvComm = recvGuard.release();
|
|
|
|
INFO(NCCL_LOG_INFO, "Successfully established connection pair");
|
|
return true;
|
|
} else {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Failed to establish connection pair - accept: %s, connect: %s",
|
|
acceptCompleted.load() ? "success" : "failed",
|
|
connectCompleted.load() ? "success" : "failed");
|
|
// RAII guards will automatically clean up any partial connections
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Common function to setup memory and operations for a test size
|
|
bool SetupOperationsForSize(void *sendComm, void *recvComm, size_t testSize,
|
|
std::vector<std::vector<char>> &sendBuffers,
|
|
std::vector<std::vector<char>> &recvBuffers,
|
|
std::vector<void *> &sendMhandles,
|
|
std::vector<void *> &recvMhandles,
|
|
std::vector<void *> &sendRequests,
|
|
std::vector<void *> &recvRequests,
|
|
uint8_t fillPattern = 0xCD) {
|
|
|
|
// Create buffers
|
|
sendBuffers.emplace_back(testSize, fillPattern);
|
|
recvBuffers.emplace_back(testSize, 0x00);
|
|
|
|
void *sendMhandle = nullptr;
|
|
void *recvMhandle = nullptr;
|
|
|
|
// Register memory
|
|
ncclResult_t sendRegResult =
|
|
ncclNetSocket.regMr(sendComm, sendBuffers.back().data(), testSize,
|
|
NCCL_PTR_HOST, &sendMhandle);
|
|
ncclResult_t recvRegResult =
|
|
ncclNetSocket.regMr(recvComm, recvBuffers.back().data(), testSize,
|
|
NCCL_PTR_HOST, &recvMhandle);
|
|
|
|
// Always add handles to vectors (even if nullptr)
|
|
// to maintain consistency with buffer vectors for proper cleanup
|
|
sendMhandles.push_back(sendMhandle);
|
|
recvMhandles.push_back(recvMhandle);
|
|
|
|
if (sendRegResult == ncclSuccess && recvRegResult == ncclSuccess) {
|
|
INFO(NCCL_LOG_INFO, "Memory registration successful for size %zu",
|
|
testSize);
|
|
|
|
// Start send operation
|
|
void *sendRequest = nullptr;
|
|
ncclResult_t sendResult =
|
|
ncclNetSocket.isend(sendComm, sendBuffers.back().data(), testSize, 0,
|
|
sendMhandle, nullptr, &sendRequest);
|
|
|
|
// Start receive operation
|
|
void *recvRequest = nullptr;
|
|
void *recvDataPtr = recvBuffers.back().data();
|
|
size_t recvSize = testSize;
|
|
int tag = 0;
|
|
ncclResult_t recvResult =
|
|
ncclNetSocket.irecv(recvComm, 1, &recvDataPtr, &recvSize, &tag,
|
|
&recvMhandle, nullptr, &recvRequest);
|
|
|
|
if (sendResult == ncclSuccess && recvResult == ncclSuccess &&
|
|
sendRequest && recvRequest) {
|
|
sendRequests.push_back(sendRequest);
|
|
recvRequests.push_back(recvRequest);
|
|
INFO(NCCL_LOG_INFO, "Successfully started operations for size %zu",
|
|
testSize);
|
|
return true;
|
|
} else {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Failed to start operations - send result: %d, recv result: %d",
|
|
sendResult, recvResult);
|
|
sendRequests.push_back(nullptr);
|
|
recvRequests.push_back(nullptr);
|
|
// NOTE: Memory handles are already in vectors and will be cleaned up by DeregisterMemory
|
|
return false;
|
|
}
|
|
} else {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Failed to register memory - send result: %d, recv result: %d",
|
|
sendRegResult, recvRegResult);
|
|
// NOTE: Even if only one registration succeeded, the handle is in the vector
|
|
// and will be properly cleaned up by DeregisterMemory (it handles nullptr gracefully)
|
|
sendRequests.push_back(nullptr);
|
|
recvRequests.push_back(nullptr);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Common function to progress operations and test ncclNetSocketGetTask
|
|
bool ProgressOperations(void *sendRequest, void *recvRequest, size_t testSize,
|
|
const std::string &testContext = "") {
|
|
const int maxTestIterations = 10;
|
|
bool taskCreationExercised = false;
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"Starting progress testing - this exercises ncclNetSocketGetTask%s",
|
|
testContext.c_str());
|
|
|
|
for (int testIter = 0; testIter < maxTestIterations; testIter++) {
|
|
INFO(NCCL_LOG_INFO, " Progress test iteration %d/%d", testIter + 1,
|
|
maxTestIterations);
|
|
|
|
if (sendRequest && recvRequest) {
|
|
int sendDone = 0, recvDone = 0;
|
|
int sendSize = 0, recvSize_out = 0;
|
|
|
|
ncclResult_t sendTestResult =
|
|
ncclNetSocket.test(sendRequest, &sendDone, &sendSize);
|
|
ncclResult_t recvTestResult =
|
|
ncclNetSocket.test(recvRequest, &recvDone, &recvSize_out);
|
|
|
|
INFO(NCCL_LOG_INFO, " Send test: result=%d, done=%d", sendTestResult,
|
|
sendDone);
|
|
INFO(NCCL_LOG_INFO, " Recv test: result=%d, done=%d", recvTestResult,
|
|
recvDone);
|
|
|
|
// If we reach this point with successful or in-progress results,
|
|
// ncclNetSocketGetTask was exercised
|
|
if ((sendTestResult == ncclSuccess ||
|
|
sendTestResult == ncclInProgress) &&
|
|
(recvTestResult == ncclSuccess ||
|
|
recvTestResult == ncclInProgress)) {
|
|
taskCreationExercised = true;
|
|
INFO(NCCL_LOG_INFO,
|
|
" *** SUCCESS: ncclNetSocketGetTask was exercised! ***");
|
|
INFO(NCCL_LOG_INFO,
|
|
" Task exercised with sendTestResult=%d (%s), recvTestResult=%d (%s)",
|
|
sendTestResult,
|
|
(sendTestResult == ncclSuccess) ? "ncclSuccess" : "ncclInProgress",
|
|
recvTestResult,
|
|
(recvTestResult == ncclSuccess) ? "ncclSuccess" : "ncclInProgress");
|
|
}
|
|
|
|
// Count completed operations
|
|
if (sendDone && recvDone) {
|
|
INFO(NCCL_LOG_INFO, " Operations completed successfully!");
|
|
break;
|
|
}
|
|
|
|
// If operations fail, that's okay - we still exercised the code path
|
|
if (sendTestResult != ncclSuccess && sendTestResult != ncclInProgress) {
|
|
INFO(NCCL_LOG_INFO, " Send operation failed, but "
|
|
"ncclNetSocketGetTask was still exercised");
|
|
break;
|
|
}
|
|
if (recvTestResult != ncclSuccess && recvTestResult != ncclInProgress) {
|
|
INFO(NCCL_LOG_INFO, " Recv operation failed, but "
|
|
"ncclNetSocketGetTask was still exercised");
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Give time between tests
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
}
|
|
|
|
if (taskCreationExercised) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"*** VERIFICATION: ncclNetSocketGetTask was successfully exercised "
|
|
"for buffer size %zu ***",
|
|
testSize);
|
|
}
|
|
|
|
return taskCreationExercised;
|
|
}
|
|
|
|
// Common function to deregister memory and test ncclNetSocketDeregMr
|
|
void DeregisterMemory(void *sendComm, void *recvComm,
|
|
const std::vector<void *> &sendMhandles,
|
|
const std::vector<void *> &recvMhandles,
|
|
size_t testSize) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"\n=== Testing ncclNetSocketDeregMr for size %zu ===", testSize);
|
|
|
|
// Deregister send memory handles
|
|
for (size_t j = 0; j < sendMhandles.size(); j++) {
|
|
if (sendComm) {
|
|
INFO(NCCL_LOG_INFO, "Deregistering send memory handle %zu for size %zu",
|
|
j, testSize);
|
|
ncclResult_t deregResult =
|
|
ncclNetSocket.deregMr(sendComm, sendMhandles[j]);
|
|
INFO(NCCL_LOG_INFO, "Send memory deregMr result: %d", deregResult);
|
|
EXPECT_EQ(deregResult, ncclSuccess) << "Failed to deregister send memory handle " << j
|
|
<< " for buffer size " << testSize << ". "
|
|
<< "ncclNetSocket.deregMr() returned error code: " << deregResult
|
|
<< ". This may indicate memory registration/deregistration mismatch.";
|
|
}
|
|
}
|
|
|
|
// Deregister receive memory handles
|
|
for (size_t j = 0; j < recvMhandles.size(); j++) {
|
|
if (recvComm) {
|
|
INFO(NCCL_LOG_INFO, "Deregistering recv memory handle %zu for size %zu",
|
|
j, testSize);
|
|
ncclResult_t deregResult =
|
|
ncclNetSocket.deregMr(recvComm, recvMhandles[j]);
|
|
INFO(NCCL_LOG_INFO, "Recv memory deregMr result: %d", deregResult);
|
|
EXPECT_EQ(deregResult, ncclSuccess) << "Failed to deregister send memory handle " << j
|
|
<< " for buffer size " << testSize << ". "
|
|
<< "ncclNetSocket.deregMr() returned error code: " << deregResult
|
|
<< ". This may indicate memory registration/deregistration mismatch.";
|
|
}
|
|
}
|
|
}
|
|
|
|
// Common function to cleanup communicators
|
|
void CleanupCommunicators(const std::vector<void *> &sendComms,
|
|
const std::vector<void *> &recvComms,
|
|
void *listenComm) {
|
|
INFO(NCCL_LOG_INFO, "\nCleaning up communicators...");
|
|
|
|
for (size_t i = 0; i < sendComms.size(); i++) {
|
|
if (sendComms[i]) {
|
|
INFO(NCCL_LOG_INFO, "Closing send communicator %zu", i);
|
|
ncclResult_t closeResult = ncclNetSocket.closeSend(sendComms[i]);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator " << i
|
|
<< ". ncclNetSocket.closeSend() returned error code: " << closeResult
|
|
<< ". This may indicate communicator state corruption or resource cleanup issues.";
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < recvComms.size(); i++) {
|
|
if (recvComms[i]) {
|
|
INFO(NCCL_LOG_INFO, "Closing recv communicator %zu", i);
|
|
ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComms[i]);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator " << i
|
|
<< ". ncclNetSocket.closeRecv() returned error code: " << closeResult
|
|
<< ". This may indicate communicator state corruption or resource cleanup issues.";
|
|
}
|
|
}
|
|
|
|
if (listenComm) {
|
|
INFO(NCCL_LOG_INFO, "Closing listen communicator");
|
|
ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator. "
|
|
<< "ncclNetSocket.closeListen() returned error code: " << closeResult
|
|
<< ". This may indicate listen socket state corruption or resource cleanup issues.";
|
|
listenComm = nullptr;
|
|
}
|
|
}
|
|
|
|
// Common function to get test buffer sizes
|
|
std::vector<size_t> GetTestSizes() {
|
|
return {
|
|
1024, // Small - basic test
|
|
64 * 1024, // MIN_CHUNKSIZE - boundary case
|
|
128 * 1024, // 2x MIN_CHUNKSIZE - will exercise subdivision
|
|
256 * 1024, // 4x MIN_CHUNKSIZE - multiple chunks
|
|
512 * 1024, // 8x MIN_CHUNKSIZE - many chunks
|
|
1024 * 1024, // Large - stress test
|
|
2 * 1024 * 1024 // Very large - comprehensive test
|
|
};
|
|
}
|
|
|
|
// Helper function to safely parse environment variables
|
|
int ParseEnvVar(const char* envVar, const char* envName, int defaultValue = 0, int minValue = 0) {
|
|
if (!envVar || strlen(envVar) == 0) {
|
|
return defaultValue;
|
|
}
|
|
|
|
char* endPtr = nullptr;
|
|
errno = 0;
|
|
long result = std::strtol(envVar, &endPtr, 10);
|
|
|
|
// Check for various error conditions - using ADD_FAILURE instead of GTEST_FAIL
|
|
if (errno == ERANGE) {
|
|
ADD_FAILURE() << "Environment variable " << envName << "='" << envVar
|
|
<< "' is out of range for integer conversion. "
|
|
<< "Please provide a valid integer value.";
|
|
return defaultValue;
|
|
}
|
|
|
|
if (endPtr == envVar) {
|
|
ADD_FAILURE() << "Environment variable " << envName << "='" << envVar
|
|
<< "' is not a valid number. "
|
|
<< "Please provide a valid integer value (e.g., " << envName << "=8).";
|
|
return defaultValue;
|
|
}
|
|
|
|
if (*endPtr != '\0') {
|
|
ADD_FAILURE() << "Environment variable " << envName << "='" << envVar
|
|
<< "' contains non-numeric characters. "
|
|
<< "Please provide a valid integer value (e.g., " << envName << "=8).";
|
|
return defaultValue;
|
|
}
|
|
|
|
if (result < minValue) {
|
|
ADD_FAILURE() << "Environment variable " << envName << "='" << envVar
|
|
<< "' must be >= " << minValue << ". "
|
|
<< "Current value: " << result << ". Please provide a valid positive integer.";
|
|
return defaultValue;
|
|
}
|
|
|
|
if (result > INT_MAX) {
|
|
ADD_FAILURE() << "Environment variable " << envName << "='" << envVar
|
|
<< "' is too large (> " << INT_MAX << "). "
|
|
<< "Please provide a smaller integer value.";
|
|
return defaultValue;
|
|
}
|
|
|
|
return static_cast<int>(result);
|
|
}
|
|
|
|
void RunConcurrentOperationsTaskCreationWithEnvVars() {
|
|
INFO(NCCL_LOG_INFO, "Checking socket configuration environment variables");
|
|
|
|
// Check if the required environment variables are set
|
|
const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS");
|
|
const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD");
|
|
|
|
if (!nThreadsEnv || !nSocksPerThreadEnv) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Required environment variables not set. "
|
|
<< "Please set the following environment variables to run this test: "
|
|
<< "export NCCL_SOCKET_NTHREADS=1 and export NCCL_NSOCKS_PERTHREAD=2. "
|
|
<< "This ensures nSocks > 0 so that ncclNetSocketGetTask gets called. "
|
|
<< "Environment variables NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be set";
|
|
return;
|
|
}
|
|
|
|
int nThreads = ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1);
|
|
int nSocksPerThread = ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1);
|
|
|
|
// Additional validation for reasonable upper bounds
|
|
const int MAX_THREADS = 16;
|
|
const int MAX_SOCKS_PER_THREAD = 64;
|
|
const int MAX_TOTAL_SOCKETS = 64;
|
|
|
|
if (nThreads > MAX_THREADS) {
|
|
GTEST_SKIP() << "SKIPPING TEST: NCCL_SOCKET_NTHREADS=" << nThreads << " exceeds maximum " << MAX_THREADS << ". "
|
|
<< "Please provide a reasonable value (e.g., NCCL_SOCKET_NTHREADS=8). "
|
|
<< "Values too large may cause resource exhaustion.";
|
|
return;
|
|
}
|
|
|
|
if (nSocksPerThread > MAX_SOCKS_PER_THREAD) {
|
|
GTEST_SKIP() << "SKIPPING TEST: NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread << " exceeds maximum " << MAX_SOCKS_PER_THREAD << ". "
|
|
<< "Please provide a reasonable value (e.g., NCCL_NSOCKS_PERTHREAD=4). "
|
|
<< "Values too large may cause resource exhaustion.";
|
|
return;
|
|
}
|
|
|
|
// Check for potential overflow before multiplication
|
|
if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Configuration would cause integer overflow. "
|
|
<< "NCCL_SOCKET_NTHREADS=" << nThreads << " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread
|
|
<< " exceeds maximum integer value. Please use smaller values.";
|
|
return;
|
|
}
|
|
|
|
int totalSockets = nThreads * nSocksPerThread;
|
|
|
|
INFO(NCCL_LOG_INFO, "Environment configuration found:");
|
|
INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads);
|
|
INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread);
|
|
INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets);
|
|
|
|
// Validate total sockets count
|
|
if (totalSockets <= 0) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Invalid configuration - total sockets must be > 0. "
|
|
<< "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread
|
|
<< " = " << totalSockets << ". "
|
|
<< "Both NCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD must be positive integers. "
|
|
<< "Example: export NCCL_SOCKET_NTHREADS=2 && export NCCL_NSOCKS_PERTHREAD=2";
|
|
return;
|
|
}
|
|
|
|
if (totalSockets > MAX_TOTAL_SOCKETS) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Total sockets " << totalSockets << " exceeds maximum " << MAX_TOTAL_SOCKETS << ". "
|
|
<< "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread
|
|
<< " = " << totalSockets << ". "
|
|
<< "Please reduce either NCCL_SOCKET_NTHREADS or NCCL_NSOCKS_PERTHREAD. "
|
|
<< "Example: export NCCL_SOCKET_NTHREADS=8 && export NCCL_NSOCKS_PERTHREAD=4";
|
|
return;
|
|
}
|
|
|
|
if (totalSockets > NCCL_NET_MAX_REQUESTS) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Total sockets " << totalSockets << " exceeds NCCL_NET_MAX_REQUESTS=" << NCCL_NET_MAX_REQUESTS << ". "
|
|
<< "Current configuration: nThreads=" << nThreads << " * nSocksPerThread=" << nSocksPerThread
|
|
<< " = " << totalSockets << ". "
|
|
<< "NCCL network layer can handle at most " << NCCL_NET_MAX_REQUESTS << " concurrent requests. "
|
|
<< "Please reduce configuration to stay within NCCL limits.";
|
|
return;
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "Configuration valid - proceeding with test to exercise "
|
|
"ncclNetSocketGetTask");
|
|
|
|
// Test socket properties
|
|
TestSocketProperties();
|
|
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
ASSERT_EQ(result, ncclSuccess) << "Failed to establish listening socket for test execution. "
|
|
<< "ncclNetSocket.listen() returned error code: " << result
|
|
<< ". Verify network device availability and port accessibility.";
|
|
|
|
INFO(NCCL_LOG_INFO, "Testing task creation functionality - ensuring "
|
|
"ncclNetSocketGetTask is called");
|
|
|
|
std::vector<void *> sendComms;
|
|
std::vector<void *> recvComms;
|
|
|
|
// Establish connection
|
|
void *sendComm = nullptr;
|
|
void *recvComm = nullptr;
|
|
bool connectionSuccess =
|
|
EstablishConnectionPair(handle, listenComm, sendComm, recvComm);
|
|
|
|
if (connectionSuccess) {
|
|
sendComms.push_back(sendComm);
|
|
recvComms.push_back(recvComm);
|
|
|
|
// Test with buffer sizes that will trigger task subdivision
|
|
std::vector<size_t> testSizes = GetTestSizes();
|
|
|
|
for (size_t testSize : testSizes) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"\n=== Testing with buffer size: %zu bytes ===", testSize);
|
|
INFO(NCCL_LOG_INFO, "This should trigger ncclNetSocketGetTask to create "
|
|
"task subdivision");
|
|
|
|
std::vector<void *> sendMhandles;
|
|
std::vector<void *> recvMhandles;
|
|
std::vector<void *> sendRequests;
|
|
std::vector<void *> recvRequests;
|
|
std::vector<std::vector<char>> sendBuffers;
|
|
std::vector<std::vector<char>> recvBuffers;
|
|
|
|
// Setup operations for this test size
|
|
bool setupSuccess = SetupOperationsForSize(
|
|
sendComm, recvComm, testSize, sendBuffers, recvBuffers, sendMhandles,
|
|
recvMhandles, sendRequests, recvRequests, 0xAB);
|
|
|
|
if (setupSuccess) {
|
|
// Progress operations with context about environment variables
|
|
ProgressOperations(sendRequests[0], recvRequests[0], testSize,
|
|
" (with nSocks > 0 from environment variables)");
|
|
} else {
|
|
INFO(NCCL_LOG_INFO,
|
|
"No operations started - skipping progress testing for size %zu",
|
|
testSize);
|
|
}
|
|
|
|
// Deregister memory
|
|
DeregisterMemory(sendComm, recvComm, sendMhandles, recvMhandles,
|
|
testSize);
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"=== Completed testing for buffer size: %zu bytes ===", testSize);
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "\n*** TEST SUCCESS: ncclNetSocketGetTask was "
|
|
"successfully exercised! ***");
|
|
} else {
|
|
INFO(NCCL_LOG_INFO, "No connections established - test passed (network may "
|
|
"not be available)");
|
|
}
|
|
|
|
// Cleanup
|
|
CleanupCommunicators(sendComms, recvComms, listenComm);
|
|
INFO(NCCL_LOG_INFO,
|
|
"TestConcurrentOperationsTaskCreation completed successfully");
|
|
}
|
|
|
|
};
|
|
|
|
// Test concurrent operations task creation in default configuration (without
|
|
// env vars)
|
|
TEST_F(NetSocketTests, TestConcurrentOperationsTaskCreationDefault) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Testing task creation functionality in default configuration");
|
|
INFO(NCCL_LOG_INFO,
|
|
"This test exercises ncclNetSocketGetTask regardless of nSocks value");
|
|
|
|
// Test socket properties
|
|
TestSocketProperties();
|
|
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
|
|
void* ctx = nullptr;
|
|
int dev = 0;
|
|
ncclResult_t result = ncclNetSocket.listen(ctx, dev, handle, &listenComm);
|
|
ASSERT_EQ(result, ncclSuccess) << "Failed to establish listening socket for test execution. "
|
|
<< "ncclNetSocket.listen() returned error code: " << result
|
|
<< ". Verify network device availability and port accessibility.";
|
|
|
|
INFO(NCCL_LOG_INFO, "Testing task creation functionality in default mode");
|
|
|
|
std::vector<void *> sendComms;
|
|
std::vector<void *> recvComms;
|
|
|
|
// Establish connection
|
|
void *sendComm = nullptr;
|
|
void *recvComm = nullptr;
|
|
bool connectionSuccess =
|
|
EstablishConnectionPair(handle, listenComm, sendComm, recvComm);
|
|
|
|
if (connectionSuccess) {
|
|
sendComms.push_back(sendComm);
|
|
recvComms.push_back(recvComm);
|
|
|
|
// Test with various buffer sizes
|
|
std::vector<size_t> testSizes = GetTestSizes();
|
|
|
|
for (size_t testSize : testSizes) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"\n=== Testing with buffer size: %zu bytes ===", testSize);
|
|
INFO(NCCL_LOG_INFO,
|
|
"This exercises ncclNetSocketGetTask task creation logic");
|
|
|
|
std::vector<void *> sendMhandles;
|
|
std::vector<void *> recvMhandles;
|
|
std::vector<void *> sendRequests;
|
|
std::vector<void *> recvRequests;
|
|
std::vector<std::vector<char>> sendBuffers;
|
|
std::vector<std::vector<char>> recvBuffers;
|
|
|
|
// Setup operations for this test size
|
|
bool setupSuccess = SetupOperationsForSize(
|
|
sendComm, recvComm, testSize, sendBuffers, recvBuffers, sendMhandles,
|
|
recvMhandles, sendRequests, recvRequests, 0xCD);
|
|
|
|
if (setupSuccess) {
|
|
// Progress operations
|
|
bool taskExercised =
|
|
ProgressOperations(sendRequests[0], recvRequests[0], testSize);
|
|
|
|
if (!taskExercised) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"*** NOTE: Operations didn't progress as expected for size %zu, "
|
|
"but API was still exercised ***",
|
|
testSize);
|
|
}
|
|
} else {
|
|
INFO(NCCL_LOG_INFO,
|
|
"No operations started - skipping progress testing for size %zu",
|
|
testSize);
|
|
}
|
|
|
|
// Deregister memory
|
|
DeregisterMemory(sendComm, recvComm, sendMhandles, recvMhandles,
|
|
testSize);
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"=== Completed testing for buffer size: %zu bytes ===", testSize);
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "\n*** TEST SUCCESS: ncclNetSocketGetTask was "
|
|
"successfully exercised in default configuration! ***");
|
|
} else {
|
|
INFO(NCCL_LOG_INFO, "No connections established - test passed (network may "
|
|
"not be available)");
|
|
}
|
|
|
|
// Cleanup
|
|
CleanupCommunicators(sendComms, recvComms, listenComm);
|
|
INFO(NCCL_LOG_INFO,
|
|
"TestConcurrentOperationsTaskCreationDefault completed successfully");
|
|
}
|
|
|
|
// Test multiple concurrent operations to stress test task creation
|
|
TEST_F(NetSocketTests, TestConcurrentOperationsTaskCreation) {
|
|
ProcessIsolatedTestRunner::ExecutionOptions options;
|
|
options.stopOnFirstFailure = false; // Continue running all tests
|
|
options.verboseLogging = true;
|
|
|
|
RUN_ISOLATED_TESTS_WITH_OPTIONS(options,
|
|
ProcessIsolatedTestRunner::TestConfig(
|
|
"TestConcurrentOperationsTaskCreation",
|
|
[this]() { RunConcurrentOperationsTaskCreationWithEnvVars(); })
|
|
.withEnvironment({{"NCCL_SOCKET_NTHREADS", "1"},
|
|
{"NCCL_NSOCKS_PERTHREAD", "2"},
|
|
{"NCCL_DEBUG", "TRACE"},
|
|
{"NCCL_DEBUG_SUBSYS", "ALL"}})
|
|
);
|
|
}
|
|
|
|
// Test for invalid device index in listen function
|
|
TEST_F(NetSocketTests, TestInvalidDeviceIndexListen) {
|
|
INFO(NCCL_LOG_INFO, "Testing invalid device index in ncclNetSocketListen");
|
|
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
|
|
// Test with negative device index
|
|
void* ctx = nullptr;
|
|
int dev = -1;
|
|
ncclResult_t result = ncclNetSocket.listen(ctx, dev, handle, &listenComm);
|
|
INFO(NCCL_LOG_INFO, "Listen with dev=-1 returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError)
|
|
<< "Listen should fail with negative device index. "
|
|
<< "ncclNetSocket.listen() with device index -1 should return ncclInternalError "
|
|
<< "but returned: " << result << ". Verify input validation for device indices.";
|
|
|
|
// Test with device index greater than available devices
|
|
int invalidDev = ndev + 10;
|
|
result = ncclNetSocket.listen(ctx, invalidDev, handle, &listenComm);
|
|
INFO(NCCL_LOG_INFO, "Listen with dev=%d (> ndev=%d) returned: %d", invalidDev,
|
|
ndev, result);
|
|
EXPECT_EQ(result, ncclInternalError)
|
|
<< "Listen should fail with device index >= ndev. "
|
|
<< "ncclNetSocket.listen() with device index " << invalidDev << " (> ndev=" << ndev
|
|
<< ") should return ncclInternalError but returned: " << result
|
|
<< ". Verify bounds checking for device indices.";
|
|
|
|
INFO(NCCL_LOG_INFO, "TestInvalidDeviceIndexListen completed");
|
|
}
|
|
|
|
// Test for invalid device index in connect function
|
|
TEST_F(NetSocketTests, TestInvalidDeviceIndexConnect) {
|
|
INFO(NCCL_LOG_INFO, "Testing invalid device index in ncclNetSocketConnect");
|
|
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *sendComm = nullptr;
|
|
ncclNetCommConfig_t config = {};
|
|
ncclNetDeviceHandle_t *sendDevComm = nullptr;
|
|
|
|
// Test with negative device index
|
|
void* ctx = nullptr;
|
|
int dev = -1;
|
|
ncclResult_t result = ncclNetSocket.connect(ctx, dev, handle, &sendComm, &sendDevComm);
|
|
INFO(NCCL_LOG_INFO, "Connect with dev=-1 returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError)
|
|
<< "Connect should fail with negative device index. "
|
|
<< "ncclNetSocket.connect() with device index -1 should return ncclInternalError "
|
|
<< "but returned: " << result << ". Verify input validation for device indices.";
|
|
|
|
// Test with device index greater than available devices
|
|
int invalidDev = ndev + 10;
|
|
result = ncclNetSocket.connect(ctx, invalidDev, handle, &sendComm, &sendDevComm);
|
|
INFO(NCCL_LOG_INFO, "Connect with dev=%d (> ndev=%d) returned: %d",
|
|
invalidDev, ndev, result);
|
|
EXPECT_EQ(result, ncclInternalError)
|
|
<< "Connect should fail with device index >= ndev. "
|
|
<< "ncclNetSocket.connect() with device index " << invalidDev << " (> ndev=" << ndev
|
|
<< ") should return ncclInternalError but returned: " << result
|
|
<< ". Verify bounds checking for device indices.";
|
|
|
|
INFO(NCCL_LOG_INFO, "TestInvalidDeviceIndexConnect completed");
|
|
}
|
|
|
|
// Test for NULL request in test function
|
|
TEST_F(NetSocketTests, TestNullRequestInTest) {
|
|
INFO(NCCL_LOG_INFO, "Testing NULL request in ncclNetSocketTest");
|
|
|
|
int done = 0;
|
|
int size = 0;
|
|
|
|
// Test with NULL request
|
|
ncclResult_t result = ncclNetSocket.test(nullptr, &done, &size);
|
|
INFO(NCCL_LOG_INFO, "Test with NULL request returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError) << "Test should fail with NULL request. "
|
|
<< "ncclNetSocket.test() with nullptr request should return ncclInternalError "
|
|
<< "but returned: " << result << ". Verify NULL pointer validation.";
|
|
|
|
INFO(NCCL_LOG_INFO, "TestNullRequestInTest completed");
|
|
}
|
|
|
|
// Test for invalid array size in irecv function
|
|
TEST_F(NetSocketTests, TestInvalidArraySizeIrecv) {
|
|
INFO(NCCL_LOG_INFO, "Testing invalid array size in ncclNetSocketIrecv");
|
|
|
|
// Setup a dummy communicator first
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
|
|
if (result == ncclSuccess && listenComm) {
|
|
void *sendComm = nullptr;
|
|
void *recvComm = nullptr;
|
|
bool connectionSuccess =
|
|
EstablishConnectionPair(handle, listenComm, sendComm, recvComm);
|
|
|
|
if (connectionSuccess && recvComm) {
|
|
// Test with n != 1 (invalid for socket implementation)
|
|
std::vector<char> buffer1(1024, 0xAA);
|
|
std::vector<char> buffer2(1024, 0xBB);
|
|
void *data[2] = {buffer1.data(), buffer2.data()};
|
|
size_t sizes[2] = {1024, 1024};
|
|
int tags[2] = {0, 1};
|
|
void *mhandles[2] = {nullptr, nullptr};
|
|
void *phandles[2] = {nullptr, nullptr};
|
|
void *request = nullptr;
|
|
|
|
// Test with n=2 (should fail for socket implementation)
|
|
result = ncclNetSocket.irecv(recvComm, 2, data, sizes, tags, mhandles,
|
|
phandles, &request);
|
|
INFO(NCCL_LOG_INFO, "Irecv with n=2 returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError) << "Irecv should fail with n != 1. "
|
|
<< "ncclNetSocket.irecv() with n=2 should return ncclInternalError "
|
|
<< "but returned: " << result << ". Socket implementation only supports n=1.";
|
|
|
|
// Test with n=0 (should fail)
|
|
result = ncclNetSocket.irecv(recvComm, 0, data, sizes, tags, mhandles,
|
|
phandles, &request);
|
|
INFO(NCCL_LOG_INFO, "Irecv with n=0 returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError) << "Irecv should fail with n != 1. "
|
|
<< "ncclNetSocket.irecv() with n=0 should return ncclInternalError "
|
|
<< "but returned: " << result << ". Socket implementation only supports n=1.";
|
|
|
|
// Cleanup communicators
|
|
if (sendComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator";
|
|
sendComm = nullptr;
|
|
}
|
|
|
|
if (recvComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator";
|
|
recvComm = nullptr;
|
|
}
|
|
}
|
|
|
|
if (listenComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator";
|
|
listenComm = nullptr;
|
|
}
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "TestInvalidArraySizeIrecv completed");
|
|
}
|
|
|
|
// Test for non-host memory type in regMr function
|
|
TEST_F(NetSocketTests, TestNonHostMemoryRegMr) {
|
|
INFO(NCCL_LOG_INFO, "Testing non-host memory type in ncclNetSocketRegMr");
|
|
|
|
// Setup a dummy communicator first
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
|
|
if (result == ncclSuccess && listenComm) {
|
|
void *sendComm = nullptr;
|
|
void *recvComm = nullptr;
|
|
bool connectionSuccess =
|
|
EstablishConnectionPair(handle, listenComm, sendComm, recvComm);
|
|
|
|
if (connectionSuccess && sendComm) {
|
|
std::vector<char> buffer(1024, 0xAA);
|
|
void *mhandle = nullptr;
|
|
|
|
// Test with NCCL_PTR_CUDA (should fail for socket implementation)
|
|
result = ncclNetSocket.regMr(sendComm, buffer.data(), 1024, NCCL_PTR_CUDA,
|
|
&mhandle);
|
|
INFO(NCCL_LOG_INFO, "RegMr with NCCL_PTR_CUDA returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError)
|
|
<< "RegMr should fail with non-host memory type. "
|
|
<< "ncclNetSocket.regMr() with NCCL_PTR_CUDA should return ncclInternalError "
|
|
<< "but returned: " << result << ". Socket implementation only supports NCCL_PTR_HOST.";
|
|
|
|
// Test with valid NCCL_PTR_HOST (should succeed)
|
|
result = ncclNetSocket.regMr(sendComm, buffer.data(), 1024, NCCL_PTR_HOST,
|
|
&mhandle);
|
|
INFO(NCCL_LOG_INFO, "RegMr with NCCL_PTR_HOST returned: %d", result);
|
|
EXPECT_EQ(result, ncclSuccess)
|
|
<< "RegMr should succeed with host memory type. "
|
|
<< "ncclNetSocket.regMr() with NCCL_PTR_HOST should return ncclSuccess "
|
|
<< "but returned: " << result << ". Verify host memory registration support.";
|
|
|
|
// Cleanup communicators
|
|
if (sendComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator";
|
|
sendComm = nullptr;
|
|
}
|
|
|
|
if (recvComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator";
|
|
recvComm = nullptr;
|
|
}
|
|
}
|
|
|
|
if (listenComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator";
|
|
listenComm = nullptr;
|
|
}
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "TestNonHostMemoryRegMr completed");
|
|
}
|
|
|
|
// Test for excessive thread configuration warning
|
|
TEST_F(NetSocketTests, TestExcessiveThreadConfig) {
|
|
ProcessIsolatedTestRunner::ExecutionOptions options;
|
|
options.stopOnFirstFailure = false; // Continue running all tests
|
|
options.verboseLogging = true;
|
|
|
|
RUN_ISOLATED_TESTS_WITH_OPTIONS(options,
|
|
ProcessIsolatedTestRunner::TestConfig(
|
|
"TestExcessiveThreadConfig",
|
|
[this]() {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Testing excessive thread configuration warning");
|
|
|
|
// Check if the required environment variables are set
|
|
const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS");
|
|
const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD");
|
|
|
|
if (!nThreadsEnv || !nSocksPerThreadEnv) {
|
|
GTEST_SKIP()
|
|
<< "SKIPPING TEST: Required environment variables not set. "
|
|
<< "This test requires NCCL_SOCKET_NTHREADS > "
|
|
"NCCL_NET_MAX_REQUESTS ("
|
|
<< NCCL_NET_MAX_REQUESTS
|
|
<< ") and NCCL_NSOCKS_PERTHREAD = 1 to trigger warning. "
|
|
<< "Environment variables NCCL_SOCKET_NTHREADS and "
|
|
"NCCL_NSOCKS_PERTHREAD must be set";
|
|
return;
|
|
}
|
|
|
|
// Parse with validation - both must be positive
|
|
int nThreads =
|
|
ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1);
|
|
int nSocksPerThread =
|
|
ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1);
|
|
|
|
// Check for potential overflow before multiplication
|
|
if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Configuration would cause "
|
|
"integer overflow. "
|
|
<< "NCCL_SOCKET_NTHREADS=" << nThreads
|
|
<< " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread
|
|
<< " exceeds maximum integer value. Please use "
|
|
"smaller values.";
|
|
return;
|
|
}
|
|
|
|
int totalSockets = nThreads * nSocksPerThread;
|
|
|
|
INFO(NCCL_LOG_INFO, "Environment configuration found:");
|
|
INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads);
|
|
INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread);
|
|
INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets);
|
|
|
|
// Check if configuration is set to trigger the excessive threads
|
|
// warning Use NCCL_NET_MAX_REQUESTS instead of arbitrary
|
|
// MAX_THREADS
|
|
if (nThreads <= NCCL_NET_MAX_REQUESTS) {
|
|
GTEST_SKIP()
|
|
<< "SKIPPING TEST: NCCL_SOCKET_NTHREADS must be > "
|
|
<< NCCL_NET_MAX_REQUESTS
|
|
<< " to test excessive thread warning. "
|
|
<< "Current NCCL_SOCKET_NTHREADS=" << nThreads << ". "
|
|
<< "Please set: export NCCL_SOCKET_NTHREADS="
|
|
<< (NCCL_NET_MAX_REQUESTS + 1) << ". "
|
|
<< "NCCL_SOCKET_NTHREADS must be > NCCL_NET_MAX_REQUESTS ("
|
|
<< NCCL_NET_MAX_REQUESTS << ") to trigger warning";
|
|
return;
|
|
}
|
|
|
|
if (totalSockets >
|
|
NCCL_NET_MAX_REQUESTS *
|
|
10) { // Allow 10x for testing excessive config
|
|
GTEST_SKIP() << "SKIPPING TEST: Total sockets=" << totalSockets
|
|
<< " is unreasonably large (> "
|
|
<< (NCCL_NET_MAX_REQUESTS * 10) << "). "
|
|
<< "Please use more reasonable values for testing. "
|
|
"NCCL_NET_MAX_REQUESTS="
|
|
<< NCCL_NET_MAX_REQUESTS << ". "
|
|
<< "Example: export NCCL_SOCKET_NTHREADS="
|
|
<< (NCCL_NET_MAX_REQUESTS + 1)
|
|
<< " && export NCCL_NSOCKS_PERTHREAD=1";
|
|
return;
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"Configuration valid for testing excessive threads warning");
|
|
INFO(NCCL_LOG_INFO,
|
|
"NCCL_SOCKET_NTHREADS=%d > NCCL_NET_MAX_REQUESTS=%d", nThreads,
|
|
NCCL_NET_MAX_REQUESTS);
|
|
|
|
// Test socket properties
|
|
TestSocketProperties();
|
|
|
|
// Initialize to trigger the warning logic
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
|
|
if (result == ncclSuccess && listenComm) {
|
|
// The implementation should have limited the threads to
|
|
// NCCL_NET_MAX_REQUESTS internally
|
|
INFO(NCCL_LOG_INFO, "*** SUCCESS: Listen succeeded with "
|
|
"excessive NCCL_SOCKET_NTHREADS - "
|
|
"limits enforced internally ***");
|
|
ncclNetSocket.closeListen(listenComm);
|
|
} else {
|
|
INFO(NCCL_LOG_INFO, "Listen failed with result: %d", result);
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "TestExcessiveThreadConfig completed");
|
|
})
|
|
.withEnvironment({{"NCCL_SOCKET_NTHREADS", "33"},
|
|
{"NCCL_NSOCKS_PERTHREAD", "1"},
|
|
{"NCCL_DEBUG", "TRACE"},
|
|
{"NCCL_DEBUG_SUBSYS", "ALL"}})
|
|
);
|
|
}
|
|
|
|
// Test for excessive socket configuration warning
|
|
TEST_F(NetSocketTests, TestExcessiveSocketConfig) {
|
|
ProcessIsolatedTestRunner::ExecutionOptions options;
|
|
options.stopOnFirstFailure = false; // Continue running all tests
|
|
options.verboseLogging = true;
|
|
|
|
RUN_ISOLATED_TESTS_WITH_OPTIONS(options,
|
|
ProcessIsolatedTestRunner::TestConfig(
|
|
"TestExcessiveThreadConfig",
|
|
[this]() {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Testing excessive socket configuration warning");
|
|
|
|
// Check if the required environment variables are set
|
|
const char *nThreadsEnv = getenv("NCCL_SOCKET_NTHREADS");
|
|
const char *nSocksPerThreadEnv = getenv("NCCL_NSOCKS_PERTHREAD");
|
|
|
|
if (!nThreadsEnv || !nSocksPerThreadEnv) {
|
|
GTEST_SKIP()
|
|
<< "SKIPPING TEST: Required environment variables not set. "
|
|
<< "This test requires total sockets (nThreads * "
|
|
"nSocksPerThread) > MAX_SOCKETS (64). "
|
|
<< "Environment variables NCCL_SOCKET_NTHREADS and "
|
|
"NCCL_NSOCKS_PERTHREAD must be set";
|
|
return;
|
|
}
|
|
|
|
// Parse with validation - both must be positive
|
|
int nThreads =
|
|
ParseEnvVar(nThreadsEnv, "NCCL_SOCKET_NTHREADS", 0, 1);
|
|
int nSocksPerThread =
|
|
ParseEnvVar(nSocksPerThreadEnv, "NCCL_NSOCKS_PERTHREAD", 0, 1);
|
|
|
|
// Check for potential overflow before multiplication
|
|
if (nThreads > 0 && nSocksPerThread > INT_MAX / nThreads) {
|
|
GTEST_SKIP() << "SKIPPING TEST: Configuration would cause "
|
|
"integer overflow. "
|
|
<< "NCCL_SOCKET_NTHREADS=" << nThreads
|
|
<< " * NCCL_NSOCKS_PERTHREAD=" << nSocksPerThread
|
|
<< " exceeds maximum integer value. Please use "
|
|
"smaller values.";
|
|
return;
|
|
}
|
|
|
|
int totalSockets = nThreads * nSocksPerThread;
|
|
|
|
INFO(NCCL_LOG_INFO, "Environment configuration found:");
|
|
INFO(NCCL_LOG_INFO, " NCCL_SOCKET_NTHREADS=%d", nThreads);
|
|
INFO(NCCL_LOG_INFO, " NCCL_NSOCKS_PERTHREAD=%d", nSocksPerThread);
|
|
INFO(NCCL_LOG_INFO, " Total sockets=%d", totalSockets);
|
|
|
|
// Check if configuration is set to trigger the excessive sockets
|
|
// warning
|
|
const int MAX_SOCKETS = 64;
|
|
if (totalSockets <= MAX_SOCKETS) {
|
|
GTEST_SKIP()
|
|
<< "SKIPPING TEST: Total sockets must be > " << MAX_SOCKETS
|
|
<< " to test excessive socket warning. "
|
|
<< "Current total sockets=" << totalSockets
|
|
<< " (nThreads=" << nThreads
|
|
<< " * nSocksPerThread=" << nSocksPerThread << "). "
|
|
<< "Please set environment variables such that total > "
|
|
<< MAX_SOCKETS << ", e.g.: "
|
|
<< "export NCCL_SOCKET_NTHREADS=9 && export "
|
|
"NCCL_NSOCKS_PERTHREAD=8. "
|
|
<< "Total sockets must be > MAX_SOCKETS (" << MAX_SOCKETS
|
|
<< ") to trigger warning";
|
|
return;
|
|
}
|
|
|
|
// Additional validation against NCCL_NET_MAX_REQUESTS for
|
|
// reasonable upper bounds
|
|
if (totalSockets >
|
|
NCCL_NET_MAX_REQUESTS *
|
|
10) { // Allow 10x for testing excessive config
|
|
GTEST_SKIP() << "SKIPPING TEST: Total sockets=" << totalSockets
|
|
<< " is unreasonably large (> "
|
|
<< (NCCL_NET_MAX_REQUESTS * 10) << "). "
|
|
<< "Please use more reasonable values for testing. "
|
|
"NCCL_NET_MAX_REQUESTS="
|
|
<< NCCL_NET_MAX_REQUESTS << ". "
|
|
<< "Example: export NCCL_SOCKET_NTHREADS=10 && "
|
|
"export NCCL_NSOCKS_PERTHREAD=10";
|
|
return;
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"Configuration valid for testing excessive sockets warning");
|
|
INFO(NCCL_LOG_INFO, "Total sockets=%d > MAX_SOCKETS=64",
|
|
totalSockets);
|
|
|
|
// Test socket properties
|
|
TestSocketProperties();
|
|
|
|
// Initialize to trigger the warning logic
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
|
|
if (result == ncclSuccess && listenComm) {
|
|
// The implementation should have limited the sockets to
|
|
// MAX_SOCKETS internally
|
|
INFO(NCCL_LOG_INFO,
|
|
"*** SUCCESS: Listen succeeded with excessive total "
|
|
"sockets - limits enforced internally ***");
|
|
ncclNetSocket.closeListen(listenComm);
|
|
} else {
|
|
INFO(NCCL_LOG_INFO, "Listen failed with result: %d", result);
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "TestExcessiveSocketConfig completed");
|
|
})
|
|
.withEnvironment({{"NCCL_SOCKET_NTHREADS", "10"},
|
|
{"NCCL_NSOCKS_PERTHREAD", "10"},
|
|
{"NCCL_DEBUG", "TRACE"},
|
|
{"NCCL_DEBUG_SUBSYS", "ALL"}})
|
|
);
|
|
}
|
|
|
|
// Test to trigger request allocation failure scenario
|
|
TEST_F(NetSocketTests, TestRequestAllocationFailure) {
|
|
INFO(NCCL_LOG_INFO, "Testing request allocation failure scenario");
|
|
|
|
// Setup communication
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
|
|
if (result == ncclSuccess && listenComm) {
|
|
void *sendComm = nullptr;
|
|
void *recvComm = nullptr;
|
|
bool connectionSuccess =
|
|
EstablishConnectionPair(handle, listenComm, sendComm, recvComm);
|
|
|
|
if (connectionSuccess && sendComm && recvComm) {
|
|
INFO(NCCL_LOG_INFO, "Attempting to exhaust request pool (MAX_REQUESTS)");
|
|
|
|
std::vector<void *> requests;
|
|
std::vector<std::vector<char>> buffers;
|
|
std::vector<void *> mhandles;
|
|
|
|
// Try to allocate many requests to potentially exhaust the pool
|
|
// MAX_REQUESTS is defined as NCCL_NET_MAX_REQUESTS in the code
|
|
for (int i = 0; i < (NCCL_NET_MAX_REQUESTS * 10); i++) { // Try to exceed NCCL_NET_MAX_REQUESTS by a reasonable margin
|
|
buffers.emplace_back(1024, 0xAA + (i % 256));
|
|
void *mhandle = nullptr;
|
|
|
|
// Register memory first
|
|
result = ncclNetSocket.regMr(sendComm, buffers.back().data(), 1024,
|
|
NCCL_PTR_HOST, &mhandle);
|
|
EXPECT_EQ(result, ncclSuccess) << "Memory registration failed at iteration " << i
|
|
<< ". ncclNetSocket.regMr() returned error code: " << result
|
|
<< ". Verify memory registration limits and resource availability.";
|
|
if (result != ncclSuccess)
|
|
break;
|
|
mhandles.push_back(mhandle);
|
|
|
|
// Try to create send request
|
|
void *request = nullptr;
|
|
result = ncclNetSocket.isend(sendComm, buffers.back().data(), 1024, 0,
|
|
mhandle, nullptr, &request);
|
|
|
|
if (result == ncclInternalError) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Request allocation failed at iteration %d (expected behavior "
|
|
"when pool exhausted)",
|
|
i);
|
|
break;
|
|
} else if (result == ncclSuccess) {
|
|
requests.push_back(request);
|
|
} else {
|
|
INFO(NCCL_LOG_INFO, "Unexpected result at iteration %d: %d", i,
|
|
result);
|
|
break;
|
|
}
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO,
|
|
"Successfully allocated %zu requests before failure/completion",
|
|
requests.size());
|
|
|
|
// Cleanup: Test any pending requests and deregister memory
|
|
for (size_t i = 0; i < requests.size(); i++) {
|
|
if (requests[i]) {
|
|
int done = 0;
|
|
int size = 0;
|
|
ncclNetSocket.test(requests[i], &done,
|
|
&size); // Don't care about result
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < mhandles.size(); i++) {
|
|
if (mhandles[i]) {
|
|
ncclNetSocket.deregMr(sendComm, mhandles[i]);
|
|
}
|
|
}
|
|
|
|
// Cleanup communicators
|
|
if (sendComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator";
|
|
sendComm = nullptr;
|
|
}
|
|
if (recvComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator";
|
|
recvComm = nullptr;
|
|
}
|
|
}
|
|
|
|
if (listenComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator";
|
|
listenComm = nullptr;
|
|
}
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "TestRequestAllocationFailure completed");
|
|
}
|
|
|
|
// Test for message size mismatch scenario
|
|
TEST_F(NetSocketTests, TestMessageSizeMismatch) {
|
|
INFO(NCCL_LOG_INFO, "Testing message size mismatch scenario");
|
|
|
|
// This test simulates the condition where a receiver expects a smaller
|
|
// message than what the sender is trying to send, which should trigger the
|
|
// truncation warning
|
|
|
|
char handle[NCCL_NET_HANDLE_MAXSIZE];
|
|
void *listenComm = nullptr;
|
|
ncclResult_t result = ncclNetSocket.listen(nullptr, 0, handle, &listenComm);
|
|
|
|
if (result == ncclSuccess && listenComm) {
|
|
void *sendComm = nullptr;
|
|
void *recvComm = nullptr;
|
|
bool connectionSuccess =
|
|
EstablishConnectionPair(handle, listenComm, sendComm, recvComm);
|
|
|
|
if (connectionSuccess && sendComm && recvComm) {
|
|
// Large send buffer
|
|
const size_t sendSize = 2048;
|
|
std::vector<char> sendBuffer(sendSize, 0xAA);
|
|
|
|
// Small receive buffer (to simulate size mismatch)
|
|
const size_t recvSize = 1024; // Smaller than send size
|
|
std::vector<char> recvBuffer(recvSize, 0x00);
|
|
|
|
void *sendMhandle = nullptr;
|
|
void *recvMhandle = nullptr;
|
|
|
|
// Register memory
|
|
result = ncclNetSocket.regMr(sendComm, sendBuffer.data(), sendSize,
|
|
NCCL_PTR_HOST, &sendMhandle);
|
|
EXPECT_EQ(result, ncclSuccess) << "Failed to register send memory for size mismatch test. "
|
|
<< "ncclNetSocket.regMr() returned error code: " << result
|
|
<< ". Verify memory registration support for buffer size " << sendSize << ".";
|
|
|
|
result = ncclNetSocket.regMr(recvComm, recvBuffer.data(), recvSize,
|
|
NCCL_PTR_HOST, &recvMhandle);
|
|
EXPECT_EQ(result, ncclSuccess) << "Failed to register receive memory for size mismatch test. "
|
|
<< "ncclNetSocket.regMr() returned error code: " << result
|
|
<< ". Verify memory registration support for buffer size " << recvSize << ".";
|
|
|
|
// Start send operation with large size
|
|
void *sendRequest = nullptr;
|
|
result = ncclNetSocket.isend(sendComm, sendBuffer.data(), sendSize, 0,
|
|
sendMhandle, nullptr, &sendRequest);
|
|
EXPECT_EQ(result, ncclSuccess) << "Failed to start send operation for size mismatch test. "
|
|
<< "ncclNetSocket.isend() returned error code: " << result
|
|
<< ". Verify send operation support for buffer size " << sendSize << ".";
|
|
|
|
// Start receive operation with small size
|
|
void *recvRequest = nullptr;
|
|
void *recvDataPtr = recvBuffer.data();
|
|
size_t recvSizeVar = recvSize;
|
|
int tag = 0;
|
|
result = ncclNetSocket.irecv(recvComm, 1, &recvDataPtr, &recvSizeVar,
|
|
&tag, &recvMhandle, nullptr, &recvRequest);
|
|
EXPECT_EQ(result, ncclSuccess) << "Failed to start receive operation for size mismatch test. "
|
|
<< "ncclNetSocket.irecv() returned error code: " << result
|
|
<< ". Verify receive operation support for buffer size " << recvSize << ".";
|
|
|
|
// Progress operations - this should eventually trigger the size mismatch
|
|
// warning
|
|
for (int i = 0; i < 100; i++) {
|
|
if (sendRequest) {
|
|
int sendDone = 0, sendSize_out = 0;
|
|
ncclResult_t sendTestResult = ncclNetSocket.test(sendRequest, &sendDone, &sendSize_out);
|
|
if (sendTestResult != ncclSuccess || sendDone) {
|
|
INFO(NCCL_LOG_INFO, "Send operation completed: result=%d, done=%d", sendTestResult, sendDone);
|
|
sendRequest = nullptr; // Request is cleaned up by the networking layer
|
|
}
|
|
}
|
|
|
|
if (recvRequest) {
|
|
int recvDone = 0, recvSize_out = 0;
|
|
ncclResult_t recvTestResult = ncclNetSocket.test(recvRequest, &recvDone, &recvSize_out);
|
|
if (recvTestResult != ncclSuccess || recvDone) {
|
|
INFO(NCCL_LOG_INFO, "Recv operation completed: result=%d, done=%d, size=%d",
|
|
recvTestResult, recvDone, recvSize_out);
|
|
recvRequest = nullptr; // Request is cleaned up by the networking layer
|
|
}
|
|
}
|
|
|
|
if (!sendRequest && !recvRequest) {
|
|
INFO(NCCL_LOG_INFO, "Both operations completed after %d iterations", i + 1);
|
|
break;
|
|
}
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
}
|
|
|
|
// Cleanup memory handles
|
|
if (sendMhandle && sendComm) {
|
|
ncclResult_t deregResult = ncclNetSocket.deregMr(sendComm, sendMhandle);
|
|
if (deregResult != ncclSuccess) {
|
|
INFO(NCCL_LOG_INFO, "Warning: Failed to deregister send memory handle: %d", deregResult);
|
|
}
|
|
sendMhandle = nullptr;
|
|
}
|
|
|
|
if (recvMhandle && recvComm) {
|
|
ncclResult_t deregResult = ncclNetSocket.deregMr(recvComm, recvMhandle);
|
|
if (deregResult != ncclSuccess) {
|
|
INFO(NCCL_LOG_INFO, "Warning: Failed to deregister recv memory handle: %d", deregResult);
|
|
}
|
|
recvMhandle = nullptr;
|
|
}
|
|
|
|
// Cleanup communicators
|
|
if (sendComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeSend(sendComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close send communicator";
|
|
sendComm = nullptr;
|
|
}
|
|
if (recvComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeRecv(recvComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close receive communicator";
|
|
recvComm = nullptr;
|
|
}
|
|
}
|
|
|
|
if (listenComm) {
|
|
ncclResult_t closeResult = ncclNetSocket.closeListen(listenComm);
|
|
EXPECT_EQ(closeResult, ncclSuccess) << "Failed to close listen communicator";
|
|
listenComm = nullptr;
|
|
}
|
|
}
|
|
|
|
INFO(NCCL_LOG_INFO, "TestMessageSizeMismatch completed");
|
|
}
|
|
|
|
// Test to cover the iflush function that always returns ncclInternalError
|
|
TEST_F(NetSocketTests, TestIflushAlwaysFails) {
|
|
INFO(NCCL_LOG_INFO,
|
|
"Testing ncclNetSocketIflush always returns ncclInternalError");
|
|
|
|
// This function should always return ncclInternalError for socket
|
|
// implementation as it doesn't support CUDA pointers and flush operations
|
|
|
|
std::vector<char> buffer(1024, 0xAA);
|
|
void *data = buffer.data();
|
|
int size = 1024;
|
|
void *mhandle = nullptr;
|
|
void *request = nullptr;
|
|
|
|
// Test with dummy parameters - should always fail
|
|
ncclResult_t result =
|
|
ncclNetSocket.iflush(nullptr, 1, &data, &size, &mhandle, &request);
|
|
INFO(NCCL_LOG_INFO, "ncclNetSocketIflush returned: %d", result);
|
|
EXPECT_EQ(result, ncclInternalError)
|
|
<< "iflush should always return ncclInternalError";
|
|
|
|
INFO(NCCL_LOG_INFO, "TestIflushAlwaysFails completed");
|
|
}
|
|
|
|
} // namespace RcclUnitTesting
|