From 08fcce5ec9048d102601ad53d589caba4d767698 Mon Sep 17 00:00:00 2001 From: Gilbert Lee Date: Thu, 16 May 2019 23:13:49 +0000 Subject: [PATCH] Fixing GoogleTest to 1.8.1 and making changes to tests to support older API --- test/CMakeLists.txt.in | 2 +- test/CorrectnessTest.hpp | 13 ++++++++++ test/test_AllGather.cpp | 45 ++++++++++++++++----------------- test/test_AllReduce.cpp | 45 ++++++++++++++++----------------- test/test_Broadcast.cpp | 46 +++++++++++++++++----------------- test/test_GroupCalls.cpp | 50 ++++++++++++++++--------------------- test/test_Reduce.cpp | 45 ++++++++++++++++----------------- test/test_ReduceScatter.cpp | 46 +++++++++++++++++----------------- 8 files changed, 148 insertions(+), 144 deletions(-) diff --git a/test/CMakeLists.txt.in b/test/CMakeLists.txt.in index 128d29e7c2..af8783aad6 100644 --- a/test/CMakeLists.txt.in +++ b/test/CMakeLists.txt.in @@ -5,7 +5,7 @@ project(googletest-download NONE) include(ExternalProject) ExternalProject_Add(googletest GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG master + GIT_TAG release-1.8.1 SOURCE_DIR "${CMAKE_BINARY_DIR}/googletest-src" BINARY_DIR "${CMAKE_BINARY_DIR}/googletest-build" CONFIGURE_COMMAND "" diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index 638c028788..a4dbd567ac 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -154,7 +154,10 @@ namespace CorrectnessTests // Create streams streams.resize(numDevices); for (int i = 0; i < numDevices; i++) + { + HIP_CALL(hipSetDevice(i)); HIP_CALL(hipStreamCreate(&streams[i])); + } } // Clean up per TestTuple @@ -219,6 +222,16 @@ namespace CorrectnessTests free(arrayI1); } + void Synchronize() const + { + // Wait for reduction to complete + for (int i = 0; i < numDevices; i++) + { + HIP_CALL(hipSetDevice(i)); + HIP_CALL(hipStreamSynchronize(streams[i])); + } + } + void ValidateResults(Dataset const& dataset) const { int8_t* outputI1 = (int8_t *)malloc(dataset.NumBytes()); diff --git a/test/test_AllGather.cpp b/test/test_AllGather.cpp index d65a45a28e..c6e79fe9c8 100644 --- a/test/test_AllGather.cpp +++ b/test/test_AllGather.cpp @@ -32,32 +32,31 @@ namespace CorrectnessTests } // Wait for reduction to complete - for (int i = 0; i < numDevices; i++) - hipStreamSynchronize(streams[i]); + Synchronize(); // Check results ValidateResults(dataset); } - INSTANTIATE_TEST_SUITE_P(AllGatherCorrectnessSweep, - AllGatherCorrectnessTest, - testing::Combine( - // Reduction operator (not used) - testing::Values(ncclSum), - // Data types - testing::Values(ncclInt8, - ncclUint8, - ncclInt32, - ncclUint32, - ncclInt64, - ncclUint64, - //ncclFloat16, - ncclFloat32, - ncclFloat64), - // Number of elements - testing::Values(3072, 3145728), - // Number of devices - testing::Values(2,3,4), - // In-place or not - testing::Values(false, true))); + INSTANTIATE_TEST_CASE_P(AllGatherCorrectnessSweep, + AllGatherCorrectnessTest, + testing::Combine( + // Reduction operator (not used) + testing::Values(ncclSum), + // Data types + testing::Values(ncclInt8, + ncclUint8, + ncclInt32, + ncclUint32, + ncclInt64, + ncclUint64, + //ncclFloat16, + ncclFloat32, + ncclFloat64), + // Number of elements + testing::Values(3072, 3145728), + // Number of devices + testing::Values(2,3,4), + // In-place or not + testing::Values(false, true))); } // namespace diff --git a/test/test_AllReduce.cpp b/test/test_AllReduce.cpp index d4b35b6890..f77651c84e 100644 --- a/test/test_AllReduce.cpp +++ b/test/test_AllReduce.cpp @@ -28,32 +28,31 @@ namespace CorrectnessTests } // Wait for reduction to complete - for (int i = 0; i < numDevices; i++) - hipStreamSynchronize(streams[i]); + Synchronize(); // Check results ValidateResults(dataset); } - INSTANTIATE_TEST_SUITE_P(AllReduceCorrectnessSweep, - AllReduceCorrectnessTest, - testing::Combine( - // Reduction operator - testing::Values(ncclSum, ncclProd, ncclMax, ncclMin), - // Data types - testing::Values(ncclInt8, - ncclUint8, - ncclInt32, - ncclUint32, - ncclInt64, - ncclUint64, - //ncclFloat16, - ncclFloat32, - ncclFloat64), - // Number of elements - testing::Values(1024, 1048576), - // Number of devices - testing::Values(2,3,4), - // In-place or not - testing::Values(false, true))); + INSTANTIATE_TEST_CASE_P(AllReduceCorrectnessSweep, + AllReduceCorrectnessTest, + testing::Combine( + // Reduction operator + testing::Values(ncclSum, ncclProd, ncclMax, ncclMin), + // Data types + testing::Values(ncclInt8, + ncclUint8, + ncclInt32, + ncclUint32, + ncclInt64, + ncclUint64, + //ncclFloat16, + ncclFloat32, + ncclFloat64), + // Number of elements + testing::Values(1024, 1048576), + // Number of devices + testing::Values(2,3,4), + // In-place or not + testing::Values(false, true))); } // namespace diff --git a/test/test_Broadcast.cpp b/test/test_Broadcast.cpp index c2f47b30ad..2f2a091a6d 100644 --- a/test/test_Broadcast.cpp +++ b/test/test_Broadcast.cpp @@ -34,34 +34,34 @@ namespace CorrectnessTests root, comms[i], streams[i]); } + // Wait for reduction to complete - for (int i = 0; i < numDevices; i++) - hipStreamSynchronize(streams[i]); + Synchronize(); // Check results ValidateResults(dataset); } } - INSTANTIATE_TEST_SUITE_P(BroadcastCorrectnessSweep, - BroadcastCorrectnessTest, - testing::Combine( - // Reduction operator is not used - testing::Values(ncclSum), - // Data types - testing::Values(ncclInt8, - ncclUint8, - ncclInt32, - ncclUint32, - ncclInt64, - ncclUint64, - //ncclFloat16, - ncclFloat32, - ncclFloat64), - // Number of elements - testing::Values(1024, 1048576), - // Number of devices - testing::Values(2,3,4), - // In-place or not - testing::Values(false, true))); + INSTANTIATE_TEST_CASE_P(BroadcastCorrectnessSweep, + BroadcastCorrectnessTest, + testing::Combine( + // Reduction operator is not used + testing::Values(ncclSum), + // Data types + testing::Values(ncclInt8, + ncclUint8, + ncclInt32, + ncclUint32, + ncclInt64, + ncclUint64, + //ncclFloat16, + ncclFloat32, + ncclFloat64), + // Number of elements + testing::Values(1024, 1048576), + // Number of devices + testing::Values(2,3,4), + // In-place or not + testing::Values(false, true))); } // namespace diff --git a/test/test_GroupCalls.cpp b/test/test_GroupCalls.cpp index 9bf0dd5497..d713588c13 100644 --- a/test/test_GroupCalls.cpp +++ b/test/test_GroupCalls.cpp @@ -43,7 +43,6 @@ namespace CorrectnessTests size_t const elemCount = numElements / numDevices; for (int i = 0; i < numDevices; i++) { - HIP_CALL(hipSetDevice(i)); ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount), datasets[0].outputs[i], elemCount, dataType, comms[i], streams[i]); @@ -52,7 +51,6 @@ namespace CorrectnessTests // AllReduce for (int i = 0; i < numDevices; i++) { - HIP_CALL(hipSetDevice(i)); ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i], numElements, dataType, op, comms[i], streams[i]); } @@ -60,7 +58,6 @@ namespace CorrectnessTests // Broadcast for (int i = 0; i < numDevices; i++) { - HIP_CALL(hipSetDevice(i)); ncclBroadcast(datasets[2].inputs[i], datasets[2].outputs[i], numElements, dataType, @@ -70,7 +67,6 @@ namespace CorrectnessTests // Reduce for (int i = 0; i < numDevices; i++) { - HIP_CALL(hipSetDevice(i)); ncclReduce(datasets[3].inputs[i], datasets[3].outputs[i], numElements, dataType, op, @@ -84,15 +80,13 @@ namespace CorrectnessTests (int8_t *)datasets[4].outputs[i] + (i * byteCount), elemCount, dataType, op, comms[i], streams[i]); - HIP_CALL(hipSetDevice(i)); } // Signal end of group call ncclGroupEnd(); // Wait for reduction to complete - for (int i = 0; i < numDevices; i++) - hipStreamSynchronize(streams[i]); + Synchronize(); // Check results for each collective in the group for (int i = 0; i < 5; i++) @@ -101,25 +95,25 @@ namespace CorrectnessTests } } - INSTANTIATE_TEST_SUITE_P(GroupCallsCorrectnessSweep, - GroupCallsCorrectnessTest, - testing::Combine( - // Reduction operator (not used) - testing::Values(ncclSum), - // Data types - testing::Values(ncclInt8, - ncclUint8, - ncclInt32, - ncclUint32, - ncclInt64, - ncclUint64, - //ncclFloat16, - ncclFloat32, - ncclFloat64), - // Number of elements - testing::Values(3072, 3145728), - // Number of devices - testing::Values(2,3,4), - // In-place or not - testing::Values(false, true))); + INSTANTIATE_TEST_CASE_P(GroupCallsCorrectnessSweep, + GroupCallsCorrectnessTest, + testing::Combine( + // Reduction operator (not used) + testing::Values(ncclSum), + // Data types + testing::Values(ncclInt8, + ncclUint8, + ncclInt32, + ncclUint32, + ncclInt64, + ncclUint64, + //ncclFloat16, + ncclFloat32, + ncclFloat64), + // Number of elements + testing::Values(3072, 3145728), + // Number of devices + testing::Values(2,3,4), + // In-place or not + testing::Values(false, true))); } // namespace diff --git a/test/test_Reduce.cpp b/test/test_Reduce.cpp index 089cc97593..bf6c2164e5 100644 --- a/test/test_Reduce.cpp +++ b/test/test_Reduce.cpp @@ -35,33 +35,32 @@ namespace CorrectnessTests } // Wait for reduction to complete - for (int i = 0; i < numDevices; i++) - hipStreamSynchronize(streams[i]); + Synchronize(); // Check results ValidateResults(dataset); } } - INSTANTIATE_TEST_SUITE_P(ReduceCorrectnessSweep, - ReduceCorrectnessTest, - testing::Combine( - // Reduction operator - testing::Values(ncclSum, ncclProd, ncclMax, ncclMin), - // Data types - testing::Values(ncclInt8, - ncclUint8, - ncclInt32, - ncclUint32, - ncclInt64, - ncclUint64, - //ncclFloat16, - ncclFloat32, - ncclFloat64), - // Number of elements - testing::Values(1024, 1048576), - // Number of devices - testing::Values(2,3,4), - // In-place or not - testing::Values(false, true))); + INSTANTIATE_TEST_CASE_P(ReduceCorrectnessSweep, + ReduceCorrectnessTest, + testing::Combine( + // Reduction operator + testing::Values(ncclSum, ncclProd, ncclMax, ncclMin), + // Data types + testing::Values(ncclInt8, + ncclUint8, + ncclInt32, + ncclUint32, + ncclInt64, + ncclUint64, + //ncclFloat16, + ncclFloat32, + ncclFloat64), + // Number of elements + testing::Values(1024, 1048576), + // Number of devices + testing::Values(2,3,4), + // In-place or not + testing::Values(false, true))); } // namespace diff --git a/test/test_ReduceScatter.cpp b/test/test_ReduceScatter.cpp index 10ae2affc8..567ce1fef1 100644 --- a/test/test_ReduceScatter.cpp +++ b/test/test_ReduceScatter.cpp @@ -33,33 +33,33 @@ namespace CorrectnessTests comms[i], streams[i]); } + // Wait for reduction to complete - for (int i = 0; i < numDevices; i++) - hipStreamSynchronize(streams[i]); + Synchronize(); // Check results ValidateResults(dataset); } - INSTANTIATE_TEST_SUITE_P(ReduceScatterCorrectnessSweep, - ReduceScatterCorrectnessTest, - testing::Combine( - // Reduction operator - testing::Values(ncclSum, ncclProd, ncclMax, ncclMin), - // Data types - testing::Values(ncclInt8, - ncclUint8, - ncclInt32, - ncclUint32, - ncclInt64, - ncclUint64, - //ncclFloat16, - ncclFloat32, - ncclFloat64), - // Number of elements - testing::Values(3072, 3145728), - // Number of devices - testing::Values(2,3,4), - // In-place or not - testing::Values(false, true))); + INSTANTIATE_TEST_CASE_P(ReduceScatterCorrectnessSweep, + ReduceScatterCorrectnessTest, + testing::Combine( + // Reduction operator + testing::Values(ncclSum, ncclProd, ncclMax, ncclMin), + // Data types + testing::Values(ncclInt8, + ncclUint8, + ncclInt32, + ncclUint32, + ncclInt64, + ncclUint64, + //ncclFloat16, + ncclFloat32, + ncclFloat64), + // Number of elements + testing::Values(3072, 3145728), + // Number of devices + testing::Values(2,3,4), + // In-place or not + testing::Values(false, true))); } // namespace