diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index fd7187a810..15cf6bcf20 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -456,6 +456,16 @@ namespace CorrectnessTests // Make the test tuple parameters accessible std::tie(op, dataType, numElements, numDevices, inPlace, envVals) = GetParam(); + // Collect the number of available GPUs + HIP_CALL(hipGetDeviceCount(&numDevicesAvailable)); + + // Only proceed with testing if there are enough GPUs + if (numDevices > numDevicesAvailable) + { + GTEST_SKIP(); + return; + } + envString = 0; numTokens = 0; if (strcmp(envVals, "")) { @@ -477,22 +487,6 @@ namespace CorrectnessTests } } - // Collect the number of available GPUs - HIP_CALL(hipGetDeviceCount(&numDevicesAvailable)); - - // Only proceed with testing if there are enough GPUs - if (numDevices > numDevicesAvailable) - { - fprintf(stdout, "[ SKIPPED ] Test requires %d devices (only %d available)\n", - numDevices, numDevicesAvailable); - - // Modify the number of devices so that tear-down doesn't occur - // This is temporary until GTEST_SKIP() becomes available - numDevices = 0; - numDevicesAvailable = -1; - return; - } - // Initialize communicators comms.resize(numDevices); NCCL_CALL(ncclCommInitAll(comms.data(), numDevices, NULL)); @@ -509,6 +503,8 @@ namespace CorrectnessTests // Clean up per TestTuple void TearDown() override { + if (IsSkipped()) return; + // Release communicators and streams for (int i = 0; i < numDevices; i++) {