Comhaid
rocm-systems/test/test_ScatterMultiProcess.hpp
T
Stanley Tsang d00b7d17bd Update MP UT to support arbitrary # of GPUs; multiple bugfixes (#16)
* Fixing temp file creation/deletion for Clique kernel mode.

* Refactoring of MP unit tests; include bugfixes and general support for any number of GPUs

* GroupCall MP UT properly quits when too many devices specified

* MP UT will programmatically set NCCL_COMM_ID if not specified; updated install script
2021-02-05 16:49:25 -08:00

69 línte
2.3 KiB
C++

/*************************************************************************
* Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef TEST_SCATTER_MULTI_PROCESS_HPP
#define TEST_SCATTER_MULTI_PROCESS_HPP
#include "CorrectnessTest.hpp"
namespace CorrectnessTests
{
class ScatterMultiProcessCorrectnessTest : public MultiProcessCorrectnessTest
{
public:
static void ComputeExpectedResults(Dataset& dataset, int const root, int const rank)
{
if (rank == root)
{
for (int i = 0; i < dataset.numDevices; i++)
HIP_CALL(hipMemcpy(dataset.expected[i], (int8_t *)dataset.inputs[root]+dataset.NumBytes()*i,
dataset.NumBytes(), hipMemcpyDeviceToHost));
}
}
void TestScatter(int rank, Dataset& dataset, bool& pass)
{
// Prepare input / output / expected results
SetUpPerProcess(rank, ncclCollScatter, comms[rank], streams[rank], dataset);
if (numDevices > numDevicesAvailable)
{
pass = true;
return;
}
Barrier barrier(rank, numDevices, std::atoi(getenv("NCCL_COMM_ID")));
// Test each possible root
for (int root = 0; root < numDevices; root++)
{
// Prepare input / output / expected results
FillDatasetWithPattern(dataset, rank);
ComputeExpectedResults(dataset, root, rank);
// Launch the reduction (1 process per GPU)
ncclScatter(dataset.inputs[rank],
dataset.outputs[rank],
numElements, dataType,
root, comms[rank], streams[rank]);
// Wait for reduction to complete
HIP_CALL(hipStreamSynchronize(streams[rank]));
// Check results
pass = ValidateResults(dataset, rank);
barrier.Wait();
}
TearDownPerProcess(comms[rank], streams[rank]);
dataset.Release(rank);
}
};
}
#endif