Files
rocm-systems/test/test_AllToAllMultiProcess.hpp
T

62 líneas
2.0 KiB
C++
Original Vista normal Histórico

2021-01-15 16:34:36 -07:00
/*************************************************************************
* Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef TEST_ALLTOALL_MULTI_PROCESS_HPP
#define TEST_ALLTOALL_MULTI_PROCESS_HPP
#include "CorrectnessTest.hpp"
namespace CorrectnessTests
{
class AllToAllMultiProcessCorrectnessTest : public MultiProcessCorrectnessTest
{
public:
static void ComputeExpectedResults(Dataset& dataset, std::vector<int> const& ranks)
2021-01-15 16:34:36 -07:00
{
for (int i = 0; i < ranks.size(); i++)
2021-01-15 16:34:36 -07:00
{
int rank = ranks[i];
for (int j = 0; j < dataset.numDevices; j++)
{
HIP_CALL(hipMemcpy((int8_t *)dataset.expected[j]+dataset.NumBytes()*rank, (int8_t *)dataset.inputs[rank]+dataset.NumBytes()*j,
dataset.NumBytes(), hipMemcpyDeviceToHost));
}
2021-01-15 16:34:36 -07:00
}
}
void TestAllToAll(int rank, Dataset& dataset, bool& pass)
2021-01-15 16:34:36 -07:00
{
SetUpPerProcess(rank, ncclCollAllToAll, comms[rank], streams[rank], dataset);
if (numDevices > numDevicesAvailable)
{
pass = true;
return;
}
2021-01-15 16:34:36 -07:00
// Prepare input / output / expected results
FillDatasetWithPattern(dataset, rank);
ComputeExpectedResults(dataset, std::vector<int>(1, rank));
2021-01-15 16:34:36 -07:00
// Launch the reduction
ncclAllToAll(dataset.inputs[rank],
dataset.outputs[rank],
numElements, dataType,
comms[rank], streams[rank]);
// Wait for reduction to complete
HIP_CALL(hipStreamSynchronize(streams[rank]));
// Check results
pass = ValidateResults(dataset, rank);
2021-01-15 16:34:36 -07:00
TearDownPerProcess(comms[rank], streams[rank]);
dataset.Release(rank);
}
};
}
#endif