2021-01-15 16:34:36 -07:00
|
|
|
/*************************************************************************
|
|
|
|
|
* Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
|
*
|
|
|
|
|
* See LICENSE.txt for license information
|
|
|
|
|
************************************************************************/
|
|
|
|
|
#ifndef TEST_ALLTOALL_MULTI_PROCESS_HPP
|
|
|
|
|
#define TEST_ALLTOALL_MULTI_PROCESS_HPP
|
|
|
|
|
|
|
|
|
|
#include "CorrectnessTest.hpp"
|
|
|
|
|
|
|
|
|
|
namespace CorrectnessTests
|
|
|
|
|
{
|
|
|
|
|
class AllToAllMultiProcessCorrectnessTest : public MultiProcessCorrectnessTest
|
|
|
|
|
{
|
|
|
|
|
public:
|
2021-02-05 17:49:25 -07:00
|
|
|
static void ComputeExpectedResults(Dataset& dataset, std::vector<int> const& ranks)
|
2021-01-15 16:34:36 -07:00
|
|
|
{
|
2021-02-05 17:49:25 -07:00
|
|
|
for (int i = 0; i < ranks.size(); i++)
|
2021-01-15 16:34:36 -07:00
|
|
|
{
|
2021-02-05 17:49:25 -07:00
|
|
|
int rank = ranks[i];
|
|
|
|
|
for (int j = 0; j < dataset.numDevices; j++)
|
|
|
|
|
{
|
|
|
|
|
HIP_CALL(hipMemcpy((int8_t *)dataset.expected[j]+dataset.NumBytes()*rank, (int8_t *)dataset.inputs[rank]+dataset.NumBytes()*j,
|
|
|
|
|
dataset.NumBytes(), hipMemcpyDeviceToHost));
|
|
|
|
|
}
|
2021-01-15 16:34:36 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-02-05 17:49:25 -07:00
|
|
|
void TestAllToAll(int rank, Dataset& dataset, bool& pass)
|
2021-01-15 16:34:36 -07:00
|
|
|
{
|
|
|
|
|
SetUpPerProcess(rank, ncclCollAllToAll, comms[rank], streams[rank], dataset);
|
|
|
|
|
|
2021-02-05 17:49:25 -07:00
|
|
|
if (numDevices > numDevicesAvailable)
|
|
|
|
|
{
|
|
|
|
|
pass = true;
|
|
|
|
|
return;
|
|
|
|
|
}
|
2021-01-15 16:34:36 -07:00
|
|
|
|
|
|
|
|
// Prepare input / output / expected results
|
|
|
|
|
FillDatasetWithPattern(dataset, rank);
|
2021-02-05 17:49:25 -07:00
|
|
|
ComputeExpectedResults(dataset, std::vector<int>(1, rank));
|
2021-01-15 16:34:36 -07:00
|
|
|
|
|
|
|
|
// Launch the reduction
|
|
|
|
|
ncclAllToAll(dataset.inputs[rank],
|
|
|
|
|
dataset.outputs[rank],
|
|
|
|
|
numElements, dataType,
|
|
|
|
|
comms[rank], streams[rank]);
|
|
|
|
|
|
|
|
|
|
// Wait for reduction to complete
|
|
|
|
|
HIP_CALL(hipStreamSynchronize(streams[rank]));
|
|
|
|
|
|
|
|
|
|
// Check results
|
2021-02-05 17:49:25 -07:00
|
|
|
pass = ValidateResults(dataset, rank);
|
2021-01-15 16:34:36 -07:00
|
|
|
|
|
|
|
|
TearDownPerProcess(comms[rank], streams[rank]);
|
|
|
|
|
dataset.Release(rank);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|