diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 55de9c9649..31755c9f46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -67,6 +67,13 @@ if(BUILD_TESTS) target_link_libraries(UnitTestsMultiProcess PRIVATE ${GTEST_BOTH_LIBRARIES}) target_link_libraries(UnitTestsMultiProcess PRIVATE hip::host hip::device) + find_program( rocminfo_executable rocminfo ) + execute_process(COMMAND bash "-c" "${rocminfo_executable} | grep 'Device Type' | grep GPU | wc -l | tr -d '\n'" OUTPUT_VARIABLE gtest_num_gpus) + if(${gtest_num_gpus} MATCHES "0" OR ${gtest_num_gpus} MATCHES "1") + set(gtest_num_gpus,"2") + endif() + target_compile_options(UnitTests PRIVATE -DGTESTS_NUM_GPUS=${gtest_num_gpus}) + # UnitTests using static library of rccl requires passing rccl # through -l and -L instead of command line input. if(BUILD_STATIC) diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index 9df068c162..7bd6dbe0e2 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -812,7 +812,7 @@ dropback: case ncclUint64: isMatch &= (outputU8[j] == expectedU8[j]); break; case ncclFloat32: isMatch &= (fabs(outputF4[j] - expectedF4[j]) < 1e-5); break; case ncclFloat64: isMatch &= (fabs(outputF8[j] - expectedF8[j]) < 1e-12); break; - case ncclBfloat16: isMatch &= (fabs((float)outputB2[j] - (float)expectedB2[j]) < 5e-2); break; + case ncclBfloat16: isMatch &= (fabs((float)outputB2[j] - (float)expectedB2[j]) < 9e-2); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0); diff --git a/test/test_AllGather.cpp b/test/test_AllGather.cpp index 295892bbd8..280c1cea48 100644 --- a/test/test_AllGather.cpp +++ b/test/test_AllGather.cpp @@ -9,6 +9,8 @@ namespace CorrectnessTests { TEST_P(AllGatherCorrectnessTest, Correctness) { + // Adjust numElements to be multiple of numDevices + numElements = (numElements/numDevices)*numDevices; if (numDevices > numDevicesAvailable) return; if (numElements % numDevices != 0) return; @@ -107,7 +109,7 @@ namespace CorrectnessTests // Number of elements testing::Values(2520, 3026520), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("")), diff --git a/test/test_AllReduce.cpp b/test/test_AllReduce.cpp index 0dab003c33..3a8697bcc6 100644 --- a/test/test_AllReduce.cpp +++ b/test/test_AllReduce.cpp @@ -46,7 +46,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), @@ -71,7 +71,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), diff --git a/test/test_AllReduceGroup.cpp b/test/test_AllReduceGroup.cpp index fe265a9c91..fb24ab7d94 100644 --- a/test/test_AllReduceGroup.cpp +++ b/test/test_AllReduceGroup.cpp @@ -58,7 +58,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), @@ -74,7 +74,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), diff --git a/test/test_AllToAll.cpp b/test/test_AllToAll.cpp index 17ab932978..8997a319c5 100644 --- a/test/test_AllToAll.cpp +++ b/test/test_AllToAll.cpp @@ -59,7 +59,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false), testing::Values("")), diff --git a/test/test_AllToAllv.cpp b/test/test_AllToAllv.cpp index 1d204fb30e..cb303c3679 100644 --- a/test/test_AllToAllv.cpp +++ b/test/test_AllToAllv.cpp @@ -67,7 +67,7 @@ namespace CorrectnessTests // Number of elements testing::Values(2520, 3026520), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false), testing::Values("")), diff --git a/test/test_Broadcast.cpp b/test/test_Broadcast.cpp index d273a67780..173f5f2c52 100644 --- a/test/test_Broadcast.cpp +++ b/test/test_Broadcast.cpp @@ -63,7 +63,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("")), diff --git a/test/test_CombinedCalls.cpp b/test/test_CombinedCalls.cpp index 278c1bf067..b951470606 100644 --- a/test/test_CombinedCalls.cpp +++ b/test/test_CombinedCalls.cpp @@ -27,6 +27,8 @@ namespace CorrectnessTests ncclFuncs.push_back(ncclCollReduce); ncclFuncs.push_back(ncclCollReduceScatter); + // Adjust numElements to be multiple of numDevices + numElements = (numElements/numDevices)*numDevices; for (int i = 0; i < datasets.size(); i++) { datasets[i].Initialize(numDevices, numElements, dataType, inPlace, ncclFuncs[i]); @@ -119,7 +121,7 @@ namespace CorrectnessTests // Number of elements testing::Values(2520, 3026520), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false), testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1", "RCCL_P2P_NET_DISABLE=0", "RCCL_P2P_NET_DISABLE=1")), diff --git a/test/test_Gather.cpp b/test/test_Gather.cpp index 837ec30ea7..8bf4edd6d5 100644 --- a/test/test_Gather.cpp +++ b/test/test_Gather.cpp @@ -63,7 +63,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false), testing::Values("")), diff --git a/test/test_GroupCalls.cpp b/test/test_GroupCalls.cpp index 1199846477..cb05ab6e5e 100644 --- a/test/test_GroupCalls.cpp +++ b/test/test_GroupCalls.cpp @@ -26,6 +26,8 @@ namespace CorrectnessTests ncclFuncs.push_back(ncclCollReduce); ncclFuncs.push_back(ncclCollReduceScatter); + // Adjust numElements to be multiple of numDevices + numElements = (numElements/numDevices)*numDevices; for (int i = 0; i < datasets.size(); i++) { datasets[i].Initialize(numDevices, numElements, dataType, inPlace, ncclFuncs[i]); @@ -120,7 +122,7 @@ namespace CorrectnessTests // Number of elements testing::Values(2520, 3026520), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), diff --git a/test/test_Reduce.cpp b/test/test_Reduce.cpp index 8927ba59c2..35b4576e9e 100644 --- a/test/test_Reduce.cpp +++ b/test/test_Reduce.cpp @@ -63,7 +63,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("")), diff --git a/test/test_ReduceScatter.cpp b/test/test_ReduceScatter.cpp index bed3a5f5f3..7462a3ab6b 100644 --- a/test/test_ReduceScatter.cpp +++ b/test/test_ReduceScatter.cpp @@ -10,6 +10,8 @@ namespace CorrectnessTests { TEST_P(ReduceScatterCorrectnessTest, Correctness) { + // Adjust numElements to be multiple of numDevices + numElements = (numElements/numDevices)*numDevices; if (numDevices > numDevicesAvailable) return; if (numElements % numDevices != 0) return; @@ -61,7 +63,7 @@ namespace CorrectnessTests // Number of elements testing::Values(2520, 3026520), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false, true), testing::Values("")), diff --git a/test/test_Scatter.cpp b/test/test_Scatter.cpp index 82f4f9088b..8ceec37589 100644 --- a/test/test_Scatter.cpp +++ b/test/test_Scatter.cpp @@ -63,7 +63,7 @@ namespace CorrectnessTests // Number of elements testing::Values(1024, 1048576), // Number of devices - testing::Values(2,3,4,5,6,7,8), + testing::Range(2,(GTESTS_NUM_GPUS+1)), // In-place or not testing::Values(false), testing::Values("")),