gtest: add scatter to combined calls and use loops (#303)

* gtest: add scatter to combined calls and use loops

* gtest: run validation inside loop

* gtest: revert small element count to 2520

* gtest: fix memory leak in validation
This commit is contained in:
Wenkai Du
2020-11-13 17:57:44 -08:00
committed by GitHub
parent 2958f7eace
commit b0853ccd51
2 changed files with 45 additions and 29 deletions
+1
View File
@@ -444,6 +444,7 @@ namespace CorrectnessTests
}
ASSERT_EQ(isMatch, true);
}
free(outputI1);
}
// Passed in parameters from TestTuple
+44 -29
View File
@@ -10,6 +10,7 @@
#include "test_Broadcast.hpp"
#include "test_Reduce.hpp"
#include "test_ReduceScatter.hpp"
#include "test_Scatter.hpp"
namespace CorrectnessTests
{
@@ -25,6 +26,10 @@ namespace CorrectnessTests
FillDatasetWithPattern(datasets[i]);
}
Dataset scatter_dataset;
scatter_dataset.Initialize(numDevices, numElements, dataType, inPlace, ncclCollScatter);
FillDatasetWithPattern(scatter_dataset);
// Compute expected results for each dataset in combined
int const root = 0;
AllGatherCorrectnessTest::ComputeExpectedResults(datasets[0]);
@@ -32,46 +37,56 @@ namespace CorrectnessTests
BroadcastCorrectnessTest::ComputeExpectedResults(datasets[2], root);
ReduceCorrectnessTest::ComputeExpectedResults(datasets[3], op, root);
ReduceScatterCorrectnessTest::ComputeExpectedResults(datasets[4], op);
ScatterCorrectnessTest::ComputeExpectedResults(scatter_dataset, root);
size_t const byteCount = datasets[0].NumBytes() / numDevices;
size_t const elemCount = numElements / numDevices;
ncclGroupStart();
for (int i = 0; i < numDevices; i++)
for (int j = 0; j < 10; j++)
{
ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount),
datasets[0].outputs[i], elemCount,
dataType, comms[i], streams[i]);
ncclGroupStart();
for (int i = 0; i < numDevices; i++)
{
ncclScatter(scatter_dataset.inputs[i],
scatter_dataset.outputs[i],
numElements, dataType,
root, comms[i], streams[i]);
ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i],
numElements, dataType, op, comms[i], streams[i]);
ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount),
datasets[0].outputs[i], elemCount,
dataType, comms[i], streams[i]);
ncclBroadcast(datasets[2].inputs[i],
datasets[2].outputs[i],
numElements, dataType,
root, comms[i], streams[i]);
ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i],
numElements, dataType, op, comms[i], streams[i]);
ncclReduce(datasets[3].inputs[i],
datasets[3].outputs[i],
numElements, dataType, op,
root, comms[i], streams[i]);
ncclBroadcast(datasets[2].inputs[i],
datasets[2].outputs[i],
numElements, dataType,
root, comms[i], streams[i]);
ncclReduceScatter(datasets[4].inputs[i],
(int8_t *)datasets[4].outputs[i] + (i * byteCount),
elemCount, dataType, op,
comms[i], streams[i]);
ncclReduce(datasets[3].inputs[i],
datasets[3].outputs[i],
numElements, dataType, op,
root, comms[i], streams[i]);
ncclReduceScatter(datasets[4].inputs[i],
(int8_t *)datasets[4].outputs[i] + (i * byteCount),
elemCount, dataType, op,
comms[i], streams[i]);
}
ncclGroupEnd();
// Wait for reduction to complete
Synchronize();
// Check results for each collective in the combined
for (int i = 0; i < 5; i++)
ValidateResults(datasets[i]);
ValidateResults(scatter_dataset);
}
ncclGroupEnd();
// Wait for reduction to complete
Synchronize();
// Check results for each collective in the combined
for (int i = 0; i < 5; i++)
{
ValidateResults(datasets[i]);
datasets[i].Release();
}
scatter_dataset.Release();
}
INSTANTIATE_TEST_SUITE_P(CombinedCallsCorrectnessSweep,
@@ -95,7 +110,7 @@ namespace CorrectnessTests
// Number of devices
testing::Values(2,3,4,5,6,7,8),
// In-place or not
testing::Values(false, true),
testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")),
testing::Values(false),
testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1", "RCCL_ALLTOALL_KERNEL_DISABLE=0", "RCCL_ALLTOALL_KERNEL_DISABLE=1")),
CorrectnessTest::PrintToStringParamName());
} // namespace