From b0853ccd516cee04879e890a897967bbaa368375 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Fri, 13 Nov 2020 17:57:44 -0800 Subject: [PATCH] gtest: add scatter to combined calls and use loops (#303) * gtest: add scatter to combined calls and use loops * gtest: run validation inside loop * gtest: revert small element count to 2520 * gtest: fix memory leak in validation --- test/CorrectnessTest.hpp | 1 + test/test_CombinedCalls.cpp | 73 ++++++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index fb11027b64..12c5fbe3f1 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -444,6 +444,7 @@ namespace CorrectnessTests } ASSERT_EQ(isMatch, true); } + free(outputI1); } // Passed in parameters from TestTuple diff --git a/test/test_CombinedCalls.cpp b/test/test_CombinedCalls.cpp index 7d1b23c98c..3916b4b8cc 100644 --- a/test/test_CombinedCalls.cpp +++ b/test/test_CombinedCalls.cpp @@ -10,6 +10,7 @@ #include "test_Broadcast.hpp" #include "test_Reduce.hpp" #include "test_ReduceScatter.hpp" +#include "test_Scatter.hpp" namespace CorrectnessTests { @@ -25,6 +26,10 @@ namespace CorrectnessTests FillDatasetWithPattern(datasets[i]); } + Dataset scatter_dataset; + scatter_dataset.Initialize(numDevices, numElements, dataType, inPlace, ncclCollScatter); + FillDatasetWithPattern(scatter_dataset); + // Compute expected results for each dataset in combined int const root = 0; AllGatherCorrectnessTest::ComputeExpectedResults(datasets[0]); @@ -32,46 +37,56 @@ namespace CorrectnessTests BroadcastCorrectnessTest::ComputeExpectedResults(datasets[2], root); ReduceCorrectnessTest::ComputeExpectedResults(datasets[3], op, root); ReduceScatterCorrectnessTest::ComputeExpectedResults(datasets[4], op); + ScatterCorrectnessTest::ComputeExpectedResults(scatter_dataset, root); size_t const byteCount = datasets[0].NumBytes() / numDevices; size_t const elemCount = numElements / numDevices; - ncclGroupStart(); - for (int i = 0; i < numDevices; i++) + for (int j = 0; j < 10; j++) { - ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount), - datasets[0].outputs[i], elemCount, - dataType, comms[i], streams[i]); + ncclGroupStart(); + for (int i = 0; i < numDevices; i++) + { + ncclScatter(scatter_dataset.inputs[i], + scatter_dataset.outputs[i], + numElements, dataType, + root, comms[i], streams[i]); - ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i], - numElements, dataType, op, comms[i], streams[i]); + ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount), + datasets[0].outputs[i], elemCount, + dataType, comms[i], streams[i]); - ncclBroadcast(datasets[2].inputs[i], - datasets[2].outputs[i], - numElements, dataType, - root, comms[i], streams[i]); + ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i], + numElements, dataType, op, comms[i], streams[i]); - ncclReduce(datasets[3].inputs[i], - datasets[3].outputs[i], - numElements, dataType, op, - root, comms[i], streams[i]); + ncclBroadcast(datasets[2].inputs[i], + datasets[2].outputs[i], + numElements, dataType, + root, comms[i], streams[i]); - ncclReduceScatter(datasets[4].inputs[i], - (int8_t *)datasets[4].outputs[i] + (i * byteCount), - elemCount, dataType, op, - comms[i], streams[i]); + ncclReduce(datasets[3].inputs[i], + datasets[3].outputs[i], + numElements, dataType, op, + root, comms[i], streams[i]); + + ncclReduceScatter(datasets[4].inputs[i], + (int8_t *)datasets[4].outputs[i] + (i * byteCount), + elemCount, dataType, op, + comms[i], streams[i]); + } + ncclGroupEnd(); + // Wait for reduction to complete + Synchronize(); + // Check results for each collective in the combined + for (int i = 0; i < 5; i++) + ValidateResults(datasets[i]); + + ValidateResults(scatter_dataset); } - ncclGroupEnd(); - // Wait for reduction to complete - Synchronize(); - - // Check results for each collective in the combined for (int i = 0; i < 5; i++) - { - ValidateResults(datasets[i]); datasets[i].Release(); - } + scatter_dataset.Release(); } INSTANTIATE_TEST_SUITE_P(CombinedCallsCorrectnessSweep, @@ -95,7 +110,7 @@ namespace CorrectnessTests // Number of devices testing::Values(2,3,4,5,6,7,8), // In-place or not - testing::Values(false, true), - testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1")), + testing::Values(false), + testing::Values("RCCL_ENABLE_CLIQUE=0", "RCCL_ENABLE_CLIQUE=1", "RCCL_ALLTOALL_KERNEL_DISABLE=0", "RCCL_ALLTOALL_KERNEL_DISABLE=1")), CorrectnessTest::PrintToStringParamName()); } // namespace