diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 93efccc01f..3ad45217b8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -55,6 +55,7 @@ if(BUILD_TESTS) ReduceTests.cpp ScatterTests.cpp SendRecvTests.cpp + StandaloneTests.cpp ) endif() diff --git a/test/StandaloneTests.cpp b/test/StandaloneTests.cpp new file mode 100644 index 0000000000..e4d7b506d6 --- /dev/null +++ b/test/StandaloneTests.cpp @@ -0,0 +1,143 @@ +/************************************************************************* + * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include +#include + +#include "StandaloneUtils.hpp" + +namespace RcclUnitTesting { + TEST(Standalone, SplitComms_RankCheck) + { + // Check for multi-gpu + int numDevices; + HIPCALL(hipGetDeviceCount(&numDevices)); + if (numDevices < 2) { + GTEST_SKIP() << "This test requires at least 2 devices."; + } + + // Initialize the original comms + std::vector comms(numDevices); + NCCLCHECK(ncclCommInitAll(comms.data(), numDevices, nullptr)); + + // Split into new comms (round-robin) + std::vector subComms(numDevices); + int numSubComms = 2; + + std::map mapCounter; + NCCLCHECK(ncclGroupStart()); + for (int localRank = 0; localRank < numDevices; localRank++) { + NCCLCHECK(ncclCommSplit(comms[localRank], localRank % numSubComms, localRank, &subComms[localRank], NULL)); + mapCounter[localRank % numSubComms]++; + } + NCCLCHECK(ncclGroupEnd()); + + // Check that new comms have correct subranks / ranks + for (int i = 0; i < numDevices; i++) { + int subCommRank, subCommNRank; + NCCLCHECK(ncclCommUserRank(subComms[i], &subCommRank)); + NCCLCHECK(ncclCommCount(subComms[i], &subCommNRank)); + + ASSERT_EQ(subCommRank, i / numSubComms); + ASSERT_EQ(subCommNRank, mapCounter[i % numSubComms]); + } + + // Clean up comms + for (auto& subComm : subComms) + NCCLCHECK(ncclCommDestroy(subComm)); + for (auto& comm : comms) + NCCLCHECK(ncclCommDestroy(comm)); + } + + TEST(Standalone, SplitComms_OneColor) + { + // Check for multi-gpu + int numDevices; + HIPCALL(hipGetDeviceCount(&numDevices)); + if (numDevices < 2) { + GTEST_SKIP() << "This test requires at least 2 devices."; + } + + // Initialize the original comms + std::vector comms(numDevices); + NCCLCHECK(ncclCommInitAll(comms.data(), numDevices, nullptr)); + + // Split into new comms (all of the same color) + std::vector subComms(numDevices); + NCCLCHECK(ncclGroupStart()); + for (int localRank = 0; localRank < numDevices; localRank++) + NCCLCHECK(ncclCommSplit(comms[localRank], 0, localRank, &subComms[localRank], NULL)); + NCCLCHECK(ncclGroupEnd()); + + // Validate results + for (int i = 0; i < numDevices; i++) { + int originalRank, originalNRank; + NCCLCHECK(ncclCommUserRank(comms[i], &originalRank)); + NCCLCHECK(ncclCommCount(comms[i], &originalNRank)); + + int subCommRank, subCommNRank; + NCCLCHECK(ncclCommUserRank(subComms[i], &subCommRank)); + NCCLCHECK(ncclCommCount(subComms[i], &subCommNRank)); + + ASSERT_EQ(originalRank, subCommRank); + ASSERT_EQ(originalNRank, subCommNRank); + } + + // Clean up comms + for (auto& subComm : subComms) + NCCLCHECK(ncclCommDestroy(subComm)); + for (auto& comm : comms) + NCCLCHECK(ncclCommDestroy(comm)); + } + + TEST(Standalone, SplitComms_Reduce) + { + // Check for multi-gpu + int numDevices; + HIPCALL(hipGetDeviceCount(&numDevices)); + if (numDevices < 2) { + GTEST_SKIP() << "This test requires at least 2 devices."; + } + + // Initialize the original comms + std::vector comms(numDevices); + NCCLCHECK(ncclCommInitAll(comms.data(), numDevices, nullptr)); + + // Split into new comms + int numReducedRanks = numDevices / 2; + std::vector subComms(numDevices); + NCCLCHECK(ncclGroupStart()); + for (int localRank = 0; localRank < numDevices; localRank++) + NCCLCHECK(ncclCommSplit(comms[localRank], + localRank < numReducedRanks ? 0 : NCCL_SPLIT_NOCOLOR, + localRank, &subComms[localRank], NULL)); + NCCLCHECK(ncclGroupEnd()); + + // Validate results + for (int i = 0; i < numDevices; i++) { + int originalRank, originalNRank; + NCCLCHECK(ncclCommUserRank(comms[i], &originalRank)); + NCCLCHECK(ncclCommCount(comms[i], &originalNRank)); + + if (i < numReducedRanks) { + int subCommRank, subCommNRank; + NCCLCHECK(ncclCommUserRank(subComms[i], &subCommRank)); + NCCLCHECK(ncclCommCount(subComms[i], &subCommNRank)); + + ASSERT_EQ(originalRank, subCommRank); + ASSERT_EQ(subCommNRank, numReducedRanks); + } else { + ASSERT_EQ(subComms[i], nullptr); + } + } + + // Cleanup comms + for (auto& subComm : subComms) + NCCLCHECK(ncclCommDestroy(subComm)); + for (auto& comm : comms) + NCCLCHECK(ncclCommDestroy(comm)); + } +} \ No newline at end of file diff --git a/test/common/StandaloneUtils.hpp b/test/common/StandaloneUtils.hpp new file mode 100644 index 0000000000..5be85c0c33 --- /dev/null +++ b/test/common/StandaloneUtils.hpp @@ -0,0 +1,23 @@ +#ifndef STANDALONE_UTILS_H +#define STANDALONE_UTILS_H + +#define HIPCALL(cmd) \ + do { \ + hipError_t error = (cmd); \ + if (error != hipSuccess) \ + { \ + printf("Encountered HIP error (%s) at line %d in file %s\n", \ + hipGetErrorString(error), __LINE__, __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) do { \ + ncclResult_t res = cmd; \ + if (res != ncclSuccess) { \ + printf("NCCL failure %s:%d '%s'\n", \ + __FILE__,__LINE__,ncclGetErrorString(res)); \ + } \ +} while(0) + +#endif \ No newline at end of file