Fixing GoogleTest to 1.8.1 and making changes to tests to support older API

이 커밋은 다음에 포함됨:
Gilbert Lee
2019-05-16 23:13:49 +00:00
부모 11f78df04d
커밋 08fcce5ec9
8개의 변경된 파일148개의 추가작업 그리고 144개의 파일을 삭제
+1 -1
파일 보기
@@ -5,7 +5,7 @@ project(googletest-download NONE)
include(ExternalProject)
ExternalProject_Add(googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG master
GIT_TAG release-1.8.1
SOURCE_DIR "${CMAKE_BINARY_DIR}/googletest-src"
BINARY_DIR "${CMAKE_BINARY_DIR}/googletest-build"
CONFIGURE_COMMAND ""
+13
파일 보기
@@ -154,7 +154,10 @@ namespace CorrectnessTests
// Create streams
streams.resize(numDevices);
for (int i = 0; i < numDevices; i++)
{
HIP_CALL(hipSetDevice(i));
HIP_CALL(hipStreamCreate(&streams[i]));
}
}
// Clean up per TestTuple
@@ -219,6 +222,16 @@ namespace CorrectnessTests
free(arrayI1);
}
void Synchronize() const
{
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
{
HIP_CALL(hipSetDevice(i));
HIP_CALL(hipStreamSynchronize(streams[i]));
}
}
void ValidateResults(Dataset const& dataset) const
{
int8_t* outputI1 = (int8_t *)malloc(dataset.NumBytes());
+22 -23
파일 보기
@@ -32,32 +32,31 @@ namespace CorrectnessTests
}
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
hipStreamSynchronize(streams[i]);
Synchronize();
// Check results
ValidateResults(dataset);
}
INSTANTIATE_TEST_SUITE_P(AllGatherCorrectnessSweep,
AllGatherCorrectnessTest,
testing::Combine(
// Reduction operator (not used)
testing::Values(ncclSum),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(AllGatherCorrectnessSweep,
AllGatherCorrectnessTest,
testing::Combine(
// Reduction operator (not used)
testing::Values(ncclSum),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
} // namespace
+22 -23
파일 보기
@@ -28,32 +28,31 @@ namespace CorrectnessTests
}
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
hipStreamSynchronize(streams[i]);
Synchronize();
// Check results
ValidateResults(dataset);
}
INSTANTIATE_TEST_SUITE_P(AllReduceCorrectnessSweep,
AllReduceCorrectnessTest,
testing::Combine(
// Reduction operator
testing::Values(ncclSum, ncclProd, ncclMax, ncclMin),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(AllReduceCorrectnessSweep,
AllReduceCorrectnessTest,
testing::Combine(
// Reduction operator
testing::Values(ncclSum, ncclProd, ncclMax, ncclMin),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
} // namespace
+23 -23
파일 보기
@@ -34,34 +34,34 @@ namespace CorrectnessTests
root, comms[i], streams[i]);
}
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
hipStreamSynchronize(streams[i]);
Synchronize();
// Check results
ValidateResults(dataset);
}
}
INSTANTIATE_TEST_SUITE_P(BroadcastCorrectnessSweep,
BroadcastCorrectnessTest,
testing::Combine(
// Reduction operator is not used
testing::Values(ncclSum),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(BroadcastCorrectnessSweep,
BroadcastCorrectnessTest,
testing::Combine(
// Reduction operator is not used
testing::Values(ncclSum),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
} // namespace
+22 -28
파일 보기
@@ -43,7 +43,6 @@ namespace CorrectnessTests
size_t const elemCount = numElements / numDevices;
for (int i = 0; i < numDevices; i++)
{
HIP_CALL(hipSetDevice(i));
ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount),
datasets[0].outputs[i], elemCount,
dataType, comms[i], streams[i]);
@@ -52,7 +51,6 @@ namespace CorrectnessTests
// AllReduce
for (int i = 0; i < numDevices; i++)
{
HIP_CALL(hipSetDevice(i));
ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i],
numElements, dataType, op, comms[i], streams[i]);
}
@@ -60,7 +58,6 @@ namespace CorrectnessTests
// Broadcast
for (int i = 0; i < numDevices; i++)
{
HIP_CALL(hipSetDevice(i));
ncclBroadcast(datasets[2].inputs[i],
datasets[2].outputs[i],
numElements, dataType,
@@ -70,7 +67,6 @@ namespace CorrectnessTests
// Reduce
for (int i = 0; i < numDevices; i++)
{
HIP_CALL(hipSetDevice(i));
ncclReduce(datasets[3].inputs[i],
datasets[3].outputs[i],
numElements, dataType, op,
@@ -84,15 +80,13 @@ namespace CorrectnessTests
(int8_t *)datasets[4].outputs[i] + (i * byteCount),
elemCount, dataType, op,
comms[i], streams[i]);
HIP_CALL(hipSetDevice(i));
}
// Signal end of group call
ncclGroupEnd();
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
hipStreamSynchronize(streams[i]);
Synchronize();
// Check results for each collective in the group
for (int i = 0; i < 5; i++)
@@ -101,25 +95,25 @@ namespace CorrectnessTests
}
}
INSTANTIATE_TEST_SUITE_P(GroupCallsCorrectnessSweep,
GroupCallsCorrectnessTest,
testing::Combine(
// Reduction operator (not used)
testing::Values(ncclSum),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(GroupCallsCorrectnessSweep,
GroupCallsCorrectnessTest,
testing::Combine(
// Reduction operator (not used)
testing::Values(ncclSum),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
} // namespace
+22 -23
파일 보기
@@ -35,33 +35,32 @@ namespace CorrectnessTests
}
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
hipStreamSynchronize(streams[i]);
Synchronize();
// Check results
ValidateResults(dataset);
}
}
INSTANTIATE_TEST_SUITE_P(ReduceCorrectnessSweep,
ReduceCorrectnessTest,
testing::Combine(
// Reduction operator
testing::Values(ncclSum, ncclProd, ncclMax, ncclMin),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(ReduceCorrectnessSweep,
ReduceCorrectnessTest,
testing::Combine(
// Reduction operator
testing::Values(ncclSum, ncclProd, ncclMax, ncclMin),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
} // namespace
+23 -23
파일 보기
@@ -33,33 +33,33 @@ namespace CorrectnessTests
comms[i], streams[i]);
}
// Wait for reduction to complete
for (int i = 0; i < numDevices; i++)
hipStreamSynchronize(streams[i]);
Synchronize();
// Check results
ValidateResults(dataset);
}
INSTANTIATE_TEST_SUITE_P(ReduceScatterCorrectnessSweep,
ReduceScatterCorrectnessTest,
testing::Combine(
// Reduction operator
testing::Values(ncclSum, ncclProd, ncclMax, ncclMin),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(ReduceScatterCorrectnessSweep,
ReduceScatterCorrectnessTest,
testing::Combine(
// Reduction operator
testing::Values(ncclSum, ncclProd, ncclMax, ncclMin),
// Data types
testing::Values(ncclInt8,
ncclUint8,
ncclInt32,
ncclUint32,
ncclInt64,
ncclUint64,
//ncclFloat16,
ncclFloat32,
ncclFloat64),
// Number of elements
testing::Values(3072, 3145728),
// Number of devices
testing::Values(2,3,4),
// In-place or not
testing::Values(false, true)));
} // namespace