122 lines
4.5 KiB
C++
122 lines
4.5 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
#include "test_GroupCalls.hpp"
|
|
|
|
#include "test_AllGather.hpp"
|
|
#include "test_AllReduce.hpp"
|
|
#include "test_Broadcast.hpp"
|
|
#include "test_Reduce.hpp"
|
|
#include "test_ReduceScatter.hpp"
|
|
|
|
#include <omp.h>
|
|
|
|
namespace CorrectnessTests
|
|
{
|
|
TEST_P(GroupCallsCorrectnessTest, Correctness)
|
|
{
|
|
if (numDevices > numDevicesAvailable) return;
|
|
|
|
// Create multiple datasets for group operation
|
|
std::vector<Dataset> datasets(5);
|
|
for (int i = 0; i < datasets.size(); i++)
|
|
{
|
|
datasets[i].Initialize(numDevices, numElements, dataType, inPlace);
|
|
FillDatasetWithPattern(datasets[i]);
|
|
}
|
|
|
|
// Compute expected results for each dataset in group
|
|
int const root = 0;
|
|
AllGatherCorrectnessTest::ComputeExpectedResults(datasets[0]);
|
|
AllReduceCorrectnessTest::ComputeExpectedResults(datasets[1], op);
|
|
BroadcastCorrectnessTest::ComputeExpectedResults(datasets[2], root);
|
|
ReduceCorrectnessTest::ComputeExpectedResults(datasets[3], op, root);
|
|
ReduceScatterCorrectnessTest::ComputeExpectedResults(datasets[4], op);
|
|
|
|
// Start a group call
|
|
ncclGroupStart();
|
|
|
|
// AllGather
|
|
size_t const byteCount = datasets[0].NumBytes() / numDevices;
|
|
size_t const elemCount = numElements / numDevices;
|
|
for (int i = 0; i < numDevices; i++)
|
|
{
|
|
ncclAllGather((int8_t *)datasets[0].inputs[i] + (i * byteCount),
|
|
datasets[0].outputs[i], elemCount,
|
|
dataType, comms[i], streams[i]);
|
|
}
|
|
|
|
// AllReduce
|
|
for (int i = 0; i < numDevices; i++)
|
|
{
|
|
ncclAllReduce(datasets[1].inputs[i], datasets[1].outputs[i],
|
|
numElements, dataType, op, comms[i], streams[i]);
|
|
}
|
|
|
|
// Broadcast
|
|
for (int i = 0; i < numDevices; i++)
|
|
{
|
|
ncclBroadcast(datasets[2].inputs[i],
|
|
datasets[2].outputs[i],
|
|
numElements, dataType,
|
|
root, comms[i], streams[i]);
|
|
}
|
|
|
|
// Reduce
|
|
for (int i = 0; i < numDevices; i++)
|
|
{
|
|
ncclReduce(datasets[3].inputs[i],
|
|
datasets[3].outputs[i],
|
|
numElements, dataType, op,
|
|
root, comms[i], streams[i]);
|
|
}
|
|
|
|
// ReduceScatter
|
|
for (int i = 0; i < numDevices; i++)
|
|
{
|
|
ncclReduceScatter(datasets[4].inputs[i],
|
|
(int8_t *)datasets[4].outputs[i] + (i * byteCount),
|
|
elemCount, dataType, op,
|
|
comms[i], streams[i]);
|
|
}
|
|
|
|
// Signal end of group call
|
|
ncclGroupEnd();
|
|
|
|
// Wait for reduction to complete
|
|
Synchronize();
|
|
|
|
// Check results for each collective in the group
|
|
for (int i = 0; i < 5; i++)
|
|
{
|
|
ValidateResults(datasets[i]);
|
|
datasets[i].Release();
|
|
}
|
|
}
|
|
|
|
INSTANTIATE_TEST_CASE_P(GroupCallsCorrectnessSweep,
|
|
GroupCallsCorrectnessTest,
|
|
testing::Combine(
|
|
// Reduction operator (not used)
|
|
testing::Values(ncclSum),
|
|
// Data types
|
|
testing::Values(ncclInt8,
|
|
ncclUint8,
|
|
ncclInt32,
|
|
ncclUint32,
|
|
ncclInt64,
|
|
ncclUint64,
|
|
//ncclFloat16,
|
|
ncclFloat32,
|
|
ncclFloat64,
|
|
ncclBfloat16),
|
|
// Number of elements
|
|
testing::Values(3072, 3145728),
|
|
// Number of devices
|
|
testing::Values(2,3,4),
|
|
// In-place or not
|
|
testing::Values(false, true)));
|
|
} // namespace
|