e373bd44bf
* initial checkin
* resolve cr comments
* resolve the build issue
* fix the data correctless issue
* update fp8 header file and update the unit test for fp8 support
* remove fp16 from fp8 headers
* fix ut issue and catch up the latest code from develop
* udate according to cr comments
* update ut according to cr comments
* update num floats for each SumPostDiv from 4 to 6
* update fp8 header file name
* fix the typo
[ROCm/rccl commit: 6777e65c1d]
124 строки
5.1 KiB
C++
124 строки
5.1 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
#include "TestBed.hpp"
|
|
|
|
namespace RcclUnitTesting
|
|
{
|
|
TEST(ReduceScatter, OutOfPlace)
|
|
{
|
|
TestBed testBed;
|
|
|
|
// Configuration
|
|
std::vector<ncclFunc_t> const funcTypes = {ncclCollReduceScatter};
|
|
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
|
|
std::vector<ncclRedOp_t> const redOps = {ncclMax};
|
|
std::vector<int> const roots = {0};
|
|
std::vector<int> const numElements = {393216, 384};
|
|
std::vector<bool> const inPlaceList = {false};
|
|
std::vector<bool> const managedMemList = {false};
|
|
std::vector<bool> const useHipGraphList = {false};
|
|
|
|
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
|
|
inPlaceList, managedMemList, useHipGraphList);
|
|
testBed.Finalize();
|
|
}
|
|
|
|
TEST(ReduceScatter, OutOfPlaceGraph)
|
|
{
|
|
TestBed testBed;
|
|
|
|
// Configuration
|
|
std::vector<ncclFunc_t> const funcTypes = {ncclCollReduceScatter};
|
|
std::vector<ncclDataType_t> const dataTypes = {ncclFloat64, ncclBfloat16, ncclFp8E4M3, ncclFp8E5M2};
|
|
std::vector<ncclRedOp_t> const redOps = {ncclMax};
|
|
std::vector<int> const roots = {0};
|
|
std::vector<int> const numElements = {1048576};
|
|
std::vector<bool> const inPlaceList = {false};
|
|
std::vector<bool> const managedMemList = {false};
|
|
std::vector<bool> const useHipGraphList = {true};
|
|
|
|
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
|
|
inPlaceList, managedMemList, useHipGraphList);
|
|
testBed.Finalize();
|
|
}
|
|
|
|
TEST(ReduceScatter, InPlace)
|
|
{
|
|
TestBed testBed;
|
|
|
|
// Configuration
|
|
std::vector<ncclFunc_t> const funcTypes = {ncclCollReduceScatter};
|
|
std::vector<ncclDataType_t> const dataTypes = {ncclInt32};
|
|
std::vector<ncclRedOp_t> const redOps = {ncclProd};
|
|
std::vector<int> const roots = {0, 1};
|
|
std::vector<int> const numElements = {542357};
|
|
std::vector<bool> const inPlaceList = {true};
|
|
std::vector<bool> const managedMemList = {false};
|
|
std::vector<bool> const useHipGraphList = {false};
|
|
|
|
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
|
|
inPlaceList, managedMemList, useHipGraphList);
|
|
testBed.Finalize();
|
|
}
|
|
|
|
TEST(ReduceScatter, InPlaceGraph)
|
|
{
|
|
TestBed testBed;
|
|
|
|
// Configuration
|
|
std::vector<ncclFunc_t> const funcTypes = {ncclCollReduceScatter};
|
|
std::vector<ncclDataType_t> const dataTypes = {ncclUint8, ncclFloat16};
|
|
std::vector<ncclRedOp_t> const redOps = {ncclMin};
|
|
std::vector<int> const roots = {0};
|
|
std::vector<int> const numElements = {246};;
|
|
std::vector<bool> const inPlaceList = {true};
|
|
std::vector<bool> const managedMemList = {false};
|
|
std::vector<bool> const useHipGraphList = {true};
|
|
|
|
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
|
|
inPlaceList, managedMemList, useHipGraphList);
|
|
testBed.Finalize();
|
|
}
|
|
|
|
TEST(ReduceScatter, ManagedMem)
|
|
{
|
|
TestBed testBed;
|
|
|
|
// Configuration
|
|
std::vector<ncclFunc_t> const funcTypes = {ncclCollReduceScatter};
|
|
std::vector<ncclDataType_t> const dataTypes = {ncclInt64, ncclUint8};
|
|
std::vector<ncclRedOp_t> const redOps = {ncclAvg};
|
|
std::vector<int> const roots = {0};
|
|
std::vector<int> const numElements = {1024};
|
|
std::vector<bool> const inPlaceList = {false};
|
|
std::vector<bool> const managedMemList = {true};
|
|
std::vector<bool> const useHipGraphList = {false};
|
|
|
|
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
|
|
inPlaceList, managedMemList, useHipGraphList);
|
|
testBed.Finalize();
|
|
}
|
|
|
|
TEST(ReduceScatter, ManagedMemGraph)
|
|
{
|
|
TestBed testBed;
|
|
|
|
// Configuration
|
|
std::vector<ncclFunc_t> const funcTypes = {ncclCollReduceScatter};
|
|
std::vector<ncclDataType_t> const dataTypes = {ncclUint32, ncclUint64};
|
|
std::vector<ncclRedOp_t> const redOps = {ncclAvg};
|
|
std::vector<int> const roots = {0};
|
|
std::vector<int> const numElements = {6485423};
|
|
std::vector<bool> const inPlaceList = {false};
|
|
std::vector<bool> const managedMemList = {true};
|
|
std::vector<bool> const useHipGraphList = {true};
|
|
|
|
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
|
|
inPlaceList, managedMemList, useHipGraphList);
|
|
testBed.Finalize();
|
|
}
|
|
}
|