diff --git a/src/clique/ShmObject.h b/src/clique/ShmObject.h index 9759292c88..e62adec65f 100644 --- a/src/clique/ShmObject.h +++ b/src/clique/ShmObject.h @@ -72,17 +72,7 @@ ShmObject(size_t size, std::string fileName, int rank, int numRanks, int projid) { if (m_alloc) { - if (m_rank == 0) - { - std::string tmpFileName = "/tmp/" + m_shmName; - remove(tmpFileName.c_str()); - } - int retVal = shm_unlink(m_shmName.c_str()); - if (retVal == -1 && errno != ENOENT) - { - WARN("Call to shm_unlink in ShmObject failed : %s", strerror(errno)); - return ncclSystemError; - } + SYSCHECK(munmap(m_shmPtr, m_shmSize), "munmap"); } return ncclSuccess; } @@ -158,22 +148,42 @@ ncclResult_t ShmObject::Open() { NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, false)); WARN("Call to ShmObject::Open in root rank failed : %s", strerror(errno)); + if (resultSetup == ncclSuccess) + { + Close(); + } return ncclSystemError; } + ncclResult_t result; + + // Broadcast two sets of messages: one set is consumed by the other ranks to acknowledge root rank + // has successfully opened shared memory; second set is consumed by the other ranks to indicate + // that they have successfully opened shared memory and root rank can now unlink shared memory + NCCLCHECK(BroadcastMessage(mq_desc, true)); NCCLCHECK(BroadcastAndCloseMessageQueue(mq_desc, true)); + + int retVal = shm_unlink(m_shmName.c_str()); + if (retVal == -1 && errno != ENOENT) + { + WARN("Call to shm_unlink in ShmObject failed : %s", strerror(errno)); + return ncclSystemError; + } } else { char msg_text[1]; ncclResult_t res; NCCLCHECKGOTO(MsgQueueRecv(mq_desc, &msg_text[0], sizeof(msg_text)), res, dropback); - NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false)); + if (msg_text[0] == 'P') { NCCLCHECK(shmSetup(m_shmName.c_str(), m_shmSize, &shmFd, (void**)&m_shmPtr, 0)); + NCCLCHECKGOTO(MsgQueueRecv(mq_desc, &msg_text[0], sizeof(msg_text)), res, dropback); + NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false)); } else { + NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false)); WARN("Call to shm_open from non-root rank in ShmObject failed : %s", strerror(errno)); return ncclSystemError; } @@ -188,8 +198,10 @@ ncclResult_t ShmObject::Open() return ncclSuccess; dropback: - WARN("Rank %d unable to receive message from root. Closing message queue.", m_rank); + WARN("Rank %d failed ShmObject::Open(). Closing message queue.", m_rank); NCCLCHECK(MsgQueueClose(m_shmName, mq_desc, false)); + SYSCHECK(shm_unlink(m_shmName.c_str()), "shm_unlink"); + NCCLCHECK(Close()); return ncclSystemError; } diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index 15cf6bcf20..33f5b56058 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -26,8 +26,7 @@ #include "rccl.h" #include "../include/rccl_bfloat16.h" -#define HIP_CALL(x) ASSERT_EQ(x, hipSuccess) -#define NCCL_CALL(x) ASSERT_EQ(x, ncclSuccess) +#include "TestChecks.hpp" #define MAX_ENV_TOKENS 16 @@ -114,20 +113,21 @@ namespace CorrectnessTests inPlace = inPlace_; function = func_; + inputs.resize(numDevices); + outputs.resize(numDevices); + expected.resize(numDevices); + for (int i = 0; i < numDevices_; i++) { - void* ptr = (void*)mmap(NULL, sizeof(void*), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); - inputs.push_back(ptr); + inputs[i] = (void*)mmap(NULL, sizeof(void*), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); } for (int i = 0; i < numDevices_; i++) { - void* ptr = (void*)mmap(NULL, sizeof(void*), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); - outputs.push_back(ptr); + outputs[i] = (void*)mmap(NULL, sizeof(void*), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); } for (int i = 0; i < numDevices_; i++) { - void* ptr = (void*)mmap(NULL, NumBytes(ncclOutputBuffer), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); - expected.push_back(ptr); + expected[i] = (void*)mmap(NULL, NumBytes(ncclOutputBuffer), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); } } @@ -198,6 +198,19 @@ namespace CorrectnessTests hipFree(inputs[rank]); } + void ReleaseRootProcess() + { + for (int i = 0; i < numDevices; i++) + { + munmap(inputs[i], sizeof(void*)); + munmap(outputs[i], sizeof(void*)); + munmap(expected[i], NumBytes(ncclOutputBuffer)); + } + inputs.clear(); + outputs.clear(); + expected.clear(); + } + // Creates a dataset by pointing to an existing dataset // Primarily to allow for testing with different starting byte-alignments void ExtractSubDataset(size_t const startElement, @@ -255,35 +268,52 @@ namespace CorrectnessTests if (rank == 0) { - InitSemaphore(smSize, mutexName, 1, mutex); - InitSemaphore(smSize, turnstile1Name, 0, turnstile1); - InitSemaphore(smSize, turnstile2Name, 0, turnstile2); - OpenSharedMemoryVariable(sizeof(int), counterName, true, counter); - OpenSharedMemoryVariable(smSize, tinyBarrierName, true, tinyBarrier); + NCCLCHECK_BARRIER_TEST(InitSemaphore(smSize, mutexName, 1, mutex), "InitSemaphore", rank); + NCCLCHECK_BARRIER_TEST(InitSemaphore(smSize, turnstile1Name, 0, turnstile1), "InitSemaphore", rank); + NCCLCHECK_BARRIER_TEST(InitSemaphore(smSize, turnstile2Name, 0, turnstile2), "InitSemaphore", rank); + NCCLCHECK_BARRIER_TEST(OpenSharedMemoryVariable(sizeof(int), counterName, true, counter), "OpenSharedMemoryVariable", rank); + NCCLCHECK_BARRIER_TEST(OpenSharedMemoryVariable(smSize, tinyBarrierName, true, tinyBarrier), "OpenSharedMemoryVariable", rank); } else { - OpenSharedMemoryVariable(smSize, tinyBarrierName, false, tinyBarrier); - OpenSemaphore(smSize, mutexName, mutex); - OpenSemaphore(smSize, turnstile1Name, turnstile1); - OpenSemaphore(smSize, turnstile2Name, turnstile2); - OpenSharedMemoryVariable(sizeof(int), counterName, false, counter); + NCCLCHECK_BARRIER_TEST(OpenSharedMemoryVariable(smSize, tinyBarrierName, false, tinyBarrier), "OpenSharedMemoryVariable", rank); + NCCLCHECK_BARRIER_TEST(OpenSemaphore(smSize, mutexName, mutex), "OpenSemaphore", rank); + NCCLCHECK_BARRIER_TEST(OpenSemaphore(smSize, turnstile1Name, turnstile1), "OpenSemaphore", rank); + NCCLCHECK_BARRIER_TEST(OpenSemaphore(smSize, turnstile2Name, turnstile2), "OpenSemaphore", rank); + NCCLCHECK_BARRIER_TEST(OpenSharedMemoryVariable(sizeof(int), counterName, false, counter), "OpenSharedMemoryVariable", rank); } + ncclResult_t res = Wait(20); + if (res != ncclSuccess) + { + printf("Rank %d timed out during Barrier initialization.\n", rank); + } + ClearShmFiles(uniqueId); } + // Wait with no timeout void Wait() { Part1(); Part2(); } + // Wait with timeout option + ncclResult_t Wait(int timeoutSecs) + { + NCCLCHECK_TEST(Part1(timeoutSecs), "Part 1 of Barrier Wait"); + NCCLCHECK_TEST(Part2(timeoutSecs), "Part 2 of Barrier Wait"); + + return ncclSuccess; + } + ~Barrier() { - shm_unlink(mutexName.c_str()); - shm_unlink(turnstile1Name.c_str()); - shm_unlink(turnstile2Name.c_str()); - shm_unlink(counterName.c_str()); - shm_unlink(tinyBarrierName.c_str()); + size_t smSize = sizeof(sem_t); + munmap(mutex, smSize); + munmap(turnstile1, smSize); + munmap(turnstile2, smSize); + munmap(tinyBarrier, smSize); + munmap(counter, sizeof(int)); } static void ClearShmFiles(int uniqueId) @@ -311,38 +341,59 @@ namespace CorrectnessTests } private: template - void OpenSharedMemoryVariable(size_t size, std::string name, bool create, T& val) + ncclResult_t OpenSharedMemoryVariable(size_t size, std::string name, bool create, T& val) { int protection = PROT_READ | PROT_WRITE; int visibility = MAP_SHARED; int fd; + std::string msg_open("shm_open "); + msg_open.append(name); if (create) { - fd = shm_open(name.c_str(), O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); - ftruncate(fd, size); + SYSCHECKVAL_TEST(shm_open(name.c_str(), O_CREAT | O_RDWR, S_IRUSR | S_IWUSR), msg_open.c_str(), fd); + SYSCHECK_GOTO_TEST(ftruncate(fd, size), "ftruncate", dropback); } else { do { - // TODO: Error checking so we don't just infinite loop fd = shm_open(name.c_str(), O_RDWR, S_IRUSR | S_IWUSR); } while (fd == -1 && errno == ENOENT); + if (fd == -1 && errno != ENOENT) + { + printf("Call to %s failed: %s\n", msg_open.c_str(), strerror(errno)); + return ncclSystemError; + } } val = (T)mmap(NULL, size, protection, visibility, fd, 0); close(fd); + if (val == MAP_FAILED) + { + goto dropback; + } + + return ncclSuccess; +dropback: + std::string msg_unlink("shm_unlink "); + msg_unlink.append(name); + SYSCHECK_TEST(shm_unlink(name.c_str()), "shm_unlink"); + return ncclSystemError; } - void InitSemaphore(size_t size, std::string name, int semValue, sem_t*& semaphore) + ncclResult_t InitSemaphore(size_t size, std::string name, int semValue, sem_t*& semaphore) { - OpenSharedMemoryVariable(size, name, true, semaphore); - sem_init(semaphore, 1, semValue); + ncclResult_t res = OpenSharedMemoryVariable(size, name, true, semaphore); + std::string msg_init("sem_init "); + msg_init.append(name); + SYSCHECK_TEST(sem_init(semaphore, 1, semValue), "sem_init"); + + return res; } - void OpenSemaphore(size_t size, std::string name, sem_t*& semaphore) + ncclResult_t OpenSemaphore(size_t size, std::string name, sem_t*& semaphore) { - OpenSharedMemoryVariable(size, name, false, semaphore); + return OpenSharedMemoryVariable(size, name, false, semaphore); } void Part1() @@ -367,6 +418,40 @@ namespace CorrectnessTests sem_wait(turnstile2); } + ncclResult_t Part1(int timeoutSecs) + { + struct timespec ts; + SYSCHECK_TEST(clock_gettime(CLOCK_REALTIME, &ts), "clock_gettime 1"); + ts.tv_sec += timeoutSecs; + + SYSCHECK_TEST(sem_timedwait(mutex, &ts), "sem_timedwait 1-1"); + if (++(*counter) == numRanks) + { + SYSCHECK_TEST(sem_post_batch(turnstile1, numRanks), "sem_post_batch 1"); + } + SYSCHECK_TEST(sem_post(mutex), "sem_post 1"); + SYSCHECK_TEST(sem_timedwait(turnstile1, &ts), "sem_timedwait 1-2"); + + return ncclSuccess; + } + + ncclResult_t Part2(int timeoutSecs) + { + struct timespec ts; + SYSCHECK_TEST(clock_gettime(CLOCK_REALTIME, &ts), "clock_gettime 2"); + ts.tv_sec += timeoutSecs; + + SYSCHECK_TEST(sem_timedwait(mutex, &ts), "sem_timedwait 2"); + if (--(*counter) == 0) + { + SYSCHECK_TEST(sem_post_batch(turnstile2, numRanks), "sem_post_batch 2"); + } + SYSCHECK_TEST(sem_post(mutex), "sem_post 2"); + SYSCHECK_TEST(sem_timedwait(turnstile2, &ts), "sem_timedwait 2-2"); + + return ncclSuccess; + } + int sem_post_batch(sem_t*& sem, int n) { int ret = 0; @@ -739,7 +824,7 @@ namespace CorrectnessTests comms.resize(numDevices); streams.resize(numDevices); dataset = (Dataset*)mmap(NULL, sizeof(Dataset), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0); - Barrier::ClearShmFiles(std::atoi(getenv("NCCL_COMM_ID"))); + Barrier::ClearShmFiles(StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); } void TearDown() override @@ -975,12 +1060,13 @@ namespace CorrectnessTests { int numProcesses = pids.size(); int status[numProcesses]; + for (int i = 0; i < numProcesses; i++) { waitpid(pids[i], &status[i], 0); - ASSERT_NE(WIFEXITED(status[i]), 0) << "[ERROR] Child process " << i << " did not exit cleanly."; - ASSERT_EQ(WEXITSTATUS(status[i]), EXIT_SUCCESS) << "[ERROR] Child process " << i << " had a test failure."; + EXPECT_NE(WIFEXITED(status[i]), 0) << "[ERROR] Child process " << i << " did not exit cleanly."; + EXPECT_EQ(WEXITSTATUS(status[i]), EXIT_SUCCESS) << "[ERROR] Child process " << i << " had a test failure."; } } @@ -996,6 +1082,13 @@ namespace CorrectnessTests } } + int StripPortNumberFromCommId(std::string commId) + { + size_t pos = commId.find(":"); + std::string portNumString = commId.substr(pos + 1); + return std::atoi(portNumString.c_str()); + } + Dataset* dataset; }; diff --git a/test/TestChecks.hpp b/test/TestChecks.hpp new file mode 100644 index 0000000000..c2e3331d02 --- /dev/null +++ b/test/TestChecks.hpp @@ -0,0 +1,63 @@ +#ifndef TESTCHECKS_HPP +#define TESTCHECKS_HPP + +#define HIP_CALL(x) ASSERT_EQ(x, hipSuccess) +#define NCCL_CALL(x) ASSERT_EQ(x, ncclSuccess) + +#define SYSCHECK_TEST(call, name) do { \ + int retval; \ + SYSCHECKVAL_TEST(call, name, retval); \ +} while (false) + +#define SYSCHECKVAL_TEST(call, name, retval) do { \ + SYSCHECKSYNC_TEST(call, name, retval); \ + if (retval == -1) { \ + printf("Call to %s failed : %s\n", name, strerror(errno)); \ + fflush(stdout); \ + return ncclSystemError; \ + } \ +} while (false) + +#define SYSCHECK_GOTO_TEST(call, name, label) do { \ + int retval; \ + SYSCHECKVAL_GOTO_TEST(call, name, retval, label); \ +} while (false) + +#define SYSCHECKVAL_GOTO_TEST(call, name, retval, label) do { \ + SYSCHECKSYNC_TEST(call, name, retval); \ + if (retval == -1) { \ + printf("Call to %s failed : %s\n", name, strerror(errno)); \ + fflush(stdout); \ + goto label; \ + } \ +} while (false) + +#define SYSCHECKSYNC_TEST(call, name, retval) do { \ + retval = call; \ + if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ + } else { \ + break; \ + } \ +} while(true) + +#define NCCLCHECK_BARRIER_TEST(call, name, rank) do { \ + ncclResult_t retval; \ + retval = call; \ + if (retval != ncclSuccess) { \ + printf("Rank %d call to %s failed : %s\n", rank, name, strerror(errno)); \ + fflush(stdout); \ + return; \ + } \ +} while (false) + +#define NCCLCHECK_TEST(call, name) do { \ + ncclResult_t retval; \ + retval = call; \ + if (retval != ncclSuccess) { \ + printf("Call to %s failed : %s\n", name, strerror(errno)); \ + fflush(stdout); \ + return retval; \ + } \ +} while (false) + +#endif diff --git a/test/test_AllGatherMultiProcess.cpp b/test/test_AllGatherMultiProcess.cpp index 1a41298485..c6e84ae132 100644 --- a/test/test_AllGatherMultiProcess.cpp +++ b/test/test_AllGatherMultiProcess.cpp @@ -30,6 +30,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(AllGatherMultiProcessCorrectnessSweep, diff --git a/test/test_AllGatherMultiProcess.hpp b/test/test_AllGatherMultiProcess.hpp index 46e4504d5d..90a72624e2 100644 --- a/test/test_AllGatherMultiProcess.hpp +++ b/test/test_AllGatherMultiProcess.hpp @@ -51,7 +51,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Prepare input / output / expected results FillDatasetWithPattern(dataset, rank); diff --git a/test/test_AllReduceGroupMultiProcess.cpp b/test/test_AllReduceGroupMultiProcess.cpp index 1c43a9d3e2..270f81a03d 100644 --- a/test/test_AllReduceGroupMultiProcess.cpp +++ b/test/test_AllReduceGroupMultiProcess.cpp @@ -58,6 +58,7 @@ namespace CorrectnessTests for (int i = 0; i < datasets.size(); i++) { + datasets[i]->ReleaseRootProcess(); munmap(datasets[i], sizeof(Dataset)); } } diff --git a/test/test_AllReduceGroupMultiProcess.hpp b/test/test_AllReduceGroupMultiProcess.hpp index 9eb141ba58..167d8a525a 100644 --- a/test/test_AllReduceGroupMultiProcess.hpp +++ b/test/test_AllReduceGroupMultiProcess.hpp @@ -36,7 +36,7 @@ namespace CorrectnessTests } int numProcesses = numDevices / ranks.size(); - Barrier barrier(process, numProcesses, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(process, numProcesses, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); for (int i = 0; i < ranks.size(); i++) { diff --git a/test/test_AllReduceMultiProcess.cpp b/test/test_AllReduceMultiProcess.cpp index dbea0fec75..5a73b630f1 100644 --- a/test/test_AllReduceMultiProcess.cpp +++ b/test/test_AllReduceMultiProcess.cpp @@ -31,6 +31,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(AllReduceMultiProcessCorrectnessSweep, diff --git a/test/test_AllReduceMultiProcess.hpp b/test/test_AllReduceMultiProcess.hpp index f7c00de221..41ef265cd1 100644 --- a/test/test_AllReduceMultiProcess.hpp +++ b/test/test_AllReduceMultiProcess.hpp @@ -92,7 +92,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Prepare input / output / expected results FillDatasetWithPattern(dataset, rank); diff --git a/test/test_AllToAllMultiProcess.cpp b/test/test_AllToAllMultiProcess.cpp index 8b4ae6d417..33cc5c98b4 100644 --- a/test/test_AllToAllMultiProcess.cpp +++ b/test/test_AllToAllMultiProcess.cpp @@ -31,6 +31,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(AllToAllMultiProcessCorrectnessSweep, diff --git a/test/test_BroadcastMultiProcess.cpp b/test/test_BroadcastMultiProcess.cpp index 793f118eda..9871c56868 100644 --- a/test/test_BroadcastMultiProcess.cpp +++ b/test/test_BroadcastMultiProcess.cpp @@ -38,6 +38,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(BroadcastMultiProcessCorrectnessSweep, diff --git a/test/test_BroadcastMultiProcess.hpp b/test/test_BroadcastMultiProcess.hpp index 1b5a36e57c..7335c15985 100644 --- a/test/test_BroadcastMultiProcess.hpp +++ b/test/test_BroadcastMultiProcess.hpp @@ -43,7 +43,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Test each possible root for (int root = 0; root < numDevices; root++) diff --git a/test/test_CombinedCallsMultiProcess.cpp b/test/test_CombinedCallsMultiProcess.cpp index 9f030f3fcd..e877ffe267 100644 --- a/test/test_CombinedCallsMultiProcess.cpp +++ b/test/test_CombinedCallsMultiProcess.cpp @@ -49,6 +49,7 @@ namespace CorrectnessTests for (int i = 0; i < datasets.size(); i++) { + datasets[i]->ReleaseRootProcess(); munmap(datasets[i], sizeof(Dataset)); } } diff --git a/test/test_CombinedCallsMultiProcess.hpp b/test/test_CombinedCallsMultiProcess.hpp index 9b7ba6bf8b..18501b6fa1 100644 --- a/test/test_CombinedCallsMultiProcess.hpp +++ b/test/test_CombinedCallsMultiProcess.hpp @@ -30,7 +30,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Compute expected results for each dataset in combined int const root = 0; diff --git a/test/test_GatherMultiProcess.cpp b/test/test_GatherMultiProcess.cpp index c95f0ee21d..02649072fb 100644 --- a/test/test_GatherMultiProcess.cpp +++ b/test/test_GatherMultiProcess.cpp @@ -31,6 +31,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(GatherMultiProcessCorrectnessSweep, diff --git a/test/test_GatherMultiProcess.hpp b/test/test_GatherMultiProcess.hpp index 0f05e7bef8..ab022b052e 100644 --- a/test/test_GatherMultiProcess.hpp +++ b/test/test_GatherMultiProcess.hpp @@ -29,7 +29,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Test each possible root for (int root = 0; root < numDevices; root++) diff --git a/test/test_GroupCallsMultiProcess.cpp b/test/test_GroupCallsMultiProcess.cpp index b6f8c3fce3..75377cbcb4 100644 --- a/test/test_GroupCallsMultiProcess.cpp +++ b/test/test_GroupCallsMultiProcess.cpp @@ -60,6 +60,7 @@ namespace CorrectnessTests for (int i = 0; i < datasets.size(); i++) { + datasets[i]->ReleaseRootProcess(); munmap(datasets[i], sizeof(Dataset)); } } diff --git a/test/test_GroupCallsMultiProcess.hpp b/test/test_GroupCallsMultiProcess.hpp index 8942ba0b1e..7eb7a58d27 100644 --- a/test/test_GroupCallsMultiProcess.hpp +++ b/test/test_GroupCallsMultiProcess.hpp @@ -41,7 +41,7 @@ namespace CorrectnessTests } int numProcesses = numDevices / ranks.size(); - Barrier barrier(process, numProcesses, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(process, numProcesses, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); for (int i = 0; i < ranks.size(); i++) { diff --git a/test/test_ReduceMultiProcess.cpp b/test/test_ReduceMultiProcess.cpp index 7ac0462a98..c98cc5e0ef 100644 --- a/test/test_ReduceMultiProcess.cpp +++ b/test/test_ReduceMultiProcess.cpp @@ -31,6 +31,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(ReduceMultiProcessCorrectnessSweep, diff --git a/test/test_ReduceMultiProcess.hpp b/test/test_ReduceMultiProcess.hpp index 173cb241b9..d2b0ab045f 100644 --- a/test/test_ReduceMultiProcess.hpp +++ b/test/test_ReduceMultiProcess.hpp @@ -101,7 +101,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Test each possible root for (int root = 0; root < numDevices; root++) diff --git a/test/test_ReduceScatterMultiProcess.cpp b/test/test_ReduceScatterMultiProcess.cpp index 1f230c64a3..1d101712f5 100644 --- a/test/test_ReduceScatterMultiProcess.cpp +++ b/test/test_ReduceScatterMultiProcess.cpp @@ -31,6 +31,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(ReduceScatterMultiProcessCorrectnessSweep, diff --git a/test/test_ReduceScatterMultiProcess.hpp b/test/test_ReduceScatterMultiProcess.hpp index 2b57b42f63..5921c0bd94 100644 --- a/test/test_ReduceScatterMultiProcess.hpp +++ b/test/test_ReduceScatterMultiProcess.hpp @@ -116,7 +116,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Prepare input / output / expected results FillDatasetWithPattern(dataset, rank); diff --git a/test/test_ScatterMultiProcess.cpp b/test/test_ScatterMultiProcess.cpp index c732e4e91d..b23bc2f599 100644 --- a/test/test_ScatterMultiProcess.cpp +++ b/test/test_ScatterMultiProcess.cpp @@ -31,6 +31,7 @@ namespace CorrectnessTests } ValidateProcesses(pids); + dataset->ReleaseRootProcess(); } INSTANTIATE_TEST_SUITE_P(ScatterMultiProcessCorrectnessSweep, diff --git a/test/test_ScatterMultiProcess.hpp b/test/test_ScatterMultiProcess.hpp index 0a14916c92..332774f3d3 100644 --- a/test/test_ScatterMultiProcess.hpp +++ b/test/test_ScatterMultiProcess.hpp @@ -34,7 +34,7 @@ namespace CorrectnessTests return; } - Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID"))); + Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID")))); // Test each possible root for (int root = 0; root < numDevices; root++)