Files
rocm-systems/test/GatherTests.cpp
T
akolliasAMD cf8cfa88a8 Re-enabled graph tests (#736)
* enabled graph tests
* joined multi and single process CI testing
2023-06-29 08:08:17 -06:00

125 lines
5.0 KiB
C++

/*************************************************************************
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "TestBed.hpp"
namespace RcclUnitTesting
{
TEST(Gather, OutOfPlace)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollGather};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {1};
std::vector<int> const numElements = {1048576, 127};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Gather, OutOfPlaceGraph)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollGather};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat64};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {1};
std::vector<int> const numElements = {924873};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Gather, InPlace)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollGather};
std::vector<ncclDataType_t> const dataTypes = {ncclInt8, ncclInt32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {13576};
std::vector<bool> const inPlaceList = {true};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Gather, InPlaceGraph)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollGather};
std::vector<ncclDataType_t> const dataTypes = {ncclInt64, ncclFloat16};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576};
std::vector<bool> const inPlaceList = {true};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Gather, ManagedMem)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollGather};
std::vector<ncclDataType_t> const dataTypes = {ncclUint8, ncclUint32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {1};
std::vector<int> const numElements = {1051234};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {true};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Gather, ManagedMemGraph)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollGather};
std::vector<ncclDataType_t> const dataTypes = {ncclUint64};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {1};
std::vector<int> const numElements = {5231};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {true};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
}