gtest: add scatter to combined calls and use loops (#303)
* gtest: add scatter to combined calls and use loops * gtest: run validation inside loop * gtest: revert small element count to 2520 * gtest: fix memory leak in validation
This commit is contained in:
@@ -444,6 +444,7 @@ namespace CorrectnessTests
|
||||
}
|
||||
ASSERT_EQ(isMatch, true);
|
||||
}
|
||||
free(outputI1);
|
||||
}
|
||||
|
||||
// Passed in parameters from TestTuple
|
||||
|
||||
+44
-29
@@ -10,6 +10,7 @@
|
||||
#include "test_Broadcast.hpp"
|
||||
#include "test_Reduce.hpp"
|
||||
#include "test_ReduceScatter.hpp"
|
||||
#include "test_Scatter.hpp"
|
||||
|
||||
namespace CorrectnessTests
|
||||
{
|
||||
@@ -25,6 +26,10 @@ namespace CorrectnessTests
|
||||
FillDatasetWithPattern(datasets[i]);
|
||||
}
|
||||
|
||||
Dataset scatter_dataset;
|
||||
scatter_dataset.Initialize(numDevices, numElements, dataType, inPlace, ncclCollScatter);
|
||||
FillDatasetWithPattern(scatter_dataset);
|
||||
|
||||
// Compute expected results for each dataset in combined
|
||||
int const root = 0;
|
||||
AllGatherCorrectnessTest::ComputeExpectedResults(datasets[0]);
|
||||
@@ -32,46 +37,56 @@ namespace CorrectnessTests
|
||||
BroadcastCorrectnessTest::ComputeExpectedResults(datasets[2], root);
|
||||
ReduceCorrectnessTest::ComputeExpectedResults(datasets[3], op, root);
|
||||
ReduceScatterCorrectnessTest::ComputeExpectedResults(datasets[4], op);
|
||||
ScatterCorrectnessTest::ComputeExpectedResults(scatter_dataset, root);
|
||||
|
||||
size_t const byteCount = datasets[0].NumBytes() / numDevices;
|
||||
size_t const elemCount = numElements / numDevices;
|
||||
|
||||
ncclGroupStart();
|
||||
for (int i = 0; i < numDevices; i++)
|
||||
for (int j = 0; j < 10; j++)
|
||||
{
|
||||
ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount),
|
||||
datasets[0].outputs[i], elemCount,
|
||||
dataType, comms[i], streams[i]);
|
||||
ncclGroupStart();
|
||||
for (int i = 0; i < numDevices; i++)
|
||||
{
|
||||
ncclScatter(scatter_dataset.inputs[i],
|
||||
scatter_dataset.outputs[i],
|
||||
numElements, dataType,
|
||||
root, comms[i], streams[i]);
|
||||
|
||||
ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i],
|
||||
numElements, dataType, op, comms[i], streams[i]);
|
||||
ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount),
|
||||
datasets[0].outputs[i], elemCount,
|
||||
dataType, comms[i], streams[i]);
|
||||
|
||||
ncclBroadcast(datasets[2].inputs[i],
|
||||
datasets[2].outputs[i],
|
||||
numElements, dataType,
|
||||
root, comms[i], streams[i]);
|
||||
ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i],
|
||||
numElements, dataType, op, comms[i], streams[i]);
|
||||
|
||||
ncclReduce(datasets[3].inputs[i],
|
||||
datasets[3].outputs[i],
|
||||
numElements, dataType, op,
|
||||
root, comms[i], streams[i]);
|
||||
ncclBroadcast(datasets[2].inputs[i],
|
||||
datasets[2].outputs[i],
|
||||
numElements, dataType,
|
||||
root, comms[i], streams[i]);
|
||||
|
||||
ncclReduceScatter(datasets[4].inputs[i],
|
||||
(int8_t *)datasets[4].outputs[i] + (i * byteCount),
|
||||
elemCount, dataType, op,
|
||||
comms[i], streams[i]);
|
||||
ncclReduce(datasets[3].inputs[i],
|
||||
datasets[3].outputs[i],
|
||||
numElements, dataType, op,
|
||||
root, comms[i], streams[i]);
|
||||
|
||||
ncclReduceScatter(datasets[4].inputs[i],
|
||||
(int8_t *)datasets[4].outputs[i] + (i * byteCount),
|
||||
elemCount, dataType, op,
|
||||
comms[i], streams[i]);
|
||||
}
|
||||
ncclGroupEnd();
|
||||
// Wait for reduction to complete
|
||||
Synchronize();
|
||||
// Check results for each collective in the combined
|
||||
for (int i = 0; i < 5; i++)
|
||||
ValidateResults(datasets[i]);
|
||||
|
||||
ValidateResults(scatter_dataset);
|
||||
}
|
||||
ncclGroupEnd();
|
||||
|
||||
// Wait for reduction to complete
|
||||
Synchronize();
|
||||
|
||||
// Check results for each collective in the combined
|
||||
for (int i = 0; i < 5; i++)
|
||||
{
|
||||
ValidateResults(datasets[i]);
|
||||
datasets[i].Release();
|
||||
}
|
||||
scatter_dataset.Release();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(CombinedCallsCorrectnessSweep,
|
||||
@@ -95,7 +110,7 @@ namespace CorrectnessTests
|
||||
// Number of devices
|
||||
testing::Values(2,3,4,5,6,7,8),
|
||||
// In-place or not
|
||||
testing::Values(false, true),
|
||||
testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")),
|
||||
testing::Values(false),
|
||||
testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1", "RCCL_ALLTOALL_KERNEL_DISABLE=0", "RCCL_ALLTOALL_KERNEL_DISABLE=1")),
|
||||
CorrectnessTest::PrintToStringParamName());
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user