Files
rocm-systems/test/test_AllGatherMultiProcess.hpp
T
Stanley Tsang 0b2bfdd6d8 Multiprocess unit test various fixes (#367)
* Re-enabling mp unit tests

* Fixing shared memory leak and other bugs related to shared mem for MP unit tests

* Revert 43bfbfc97bf9edbae1f386d461439091618ff8ed

* Further tightening up unlinks

* Moving test check macros to separate header file

* Tightening up shared memory unlinking for clique kernels, add munmap for host barrier for MP unit tests

* Updating new MP unit test

* Fixing mqueue bug

* Fixing memory leak in MP unit tests
2021-05-14 09:38:49 -06:00

82 行
2.9 KiB
C++

/*************************************************************************
* Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef TEST_ALLGATHER_MULTI_PROCESS_HPP
#define TEST_ALLGATHER_MULTI_PROCESS_HPP
#include "CorrectnessTest.hpp"
namespace CorrectnessTests
{
class AllGatherMultiProcessCorrectnessTest : public MultiProcessCorrectnessTest
{
public:
static void ComputeExpectedResults(Dataset& dataset, Barrier& barrier, int const numDevices, std::vector<int> const& ranks)
{
size_t const byteCount = dataset.NumBytes() / dataset.numDevices;
for (int i = 0; i < ranks.size(); i++)
{
int rank = ranks[i];
HIP_CALL(hipMemcpy(static_cast<char*>(dataset.expected[0]) + rank * byteCount, (int8_t *)dataset.inputs[rank] + (rank * byteCount),
byteCount, hipMemcpyDeviceToHost));
}
barrier.Wait();
// Rank 0 sends answer to other ranks
for (int i = 0; i < ranks.size(); i++)
{
int rank = ranks[i];
if (rank == 0)
{
for (int i = 0; i < dataset.numDevices; i++)
{
if (i == rank) continue;
memcpy(dataset.expected[i], dataset.expected[0], dataset.NumBytes());
}
}
}
}
void TestAllGather(int rank, Dataset& dataset, bool& pass)
{
// Prepare input / output / expected results
SetUpPerProcess(rank, ncclCollAllGather, comms[rank], streams[rank], dataset);
if (numDevices > numDevicesAvailable || numElements % numDevices != 0)
{
pass = true;
return;
}
Barrier barrier(rank, numDevices, StripPortNumberFromCommId(std::string(getenv("NCCL_COMM_ID"))));
// Prepare input / output / expected results
FillDatasetWithPattern(dataset, rank);
ComputeExpectedResults(dataset, barrier, numDevices, std::vector<int>(1, rank));
size_t const byteCount = dataset.NumBytes() / numDevices;
size_t const sendCount = dataset.numElements / numDevices;
// Launch the reduction (1 process per GPU)
ncclAllGather((int8_t *)dataset.inputs[rank] + (rank * byteCount),
dataset.outputs[rank], sendCount,
dataType, comms[rank], streams[rank]);
// Wait for reduction to complete
HIP_CALL(hipStreamSynchronize(streams[rank]));
// Check results
pass = ValidateResults(dataset, rank);
TearDownPerProcess(comms[rank], streams[rank]);
dataset.Release(rank);
}
};
}
#endif