Files
rocm-systems/test/BroadcastTests.cpp
T
Andy li 6777e65c1d Enable fp8 support (#1101)
* initial checkin

* resolve cr comments

* resolve the build issue

* fix the data correctless issue

* update fp8 header file and update the unit test for fp8 support

* remove fp16 from fp8 headers

* fix ut issue and catch up the latest code from develop

* udate according to cr comments

* update ut according to cr comments

* update num floats for each SumPostDiv from 4 to 6

* update fp8 header file name

* fix the typo
2024-03-08 15:17:53 -08:00

124 строки
5.0 KiB
C++

/*************************************************************************
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "TestBed.hpp"
namespace RcclUnitTesting
{
TEST(Broadcast, OutOfPlace)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576, 500};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Broadcast, OutOfPlaceGraph)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclBfloat16, ncclFloat64, ncclFp8E4M3, ncclFp8E5M2};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {586};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Broadcast, InPlace)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclInt32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {1};
std::vector<int> const numElements = {104857, 264};
std::vector<bool> const inPlaceList = {true};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Broadcast, InPlaceGraph)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclInt8, ncclInt64};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {1};
std::vector<int> const numElements = {958};
std::vector<bool> const inPlaceList = {true};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Broadcast, ManagedMem)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclUint8};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1039203, 2500};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {true};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(Broadcast, ManagedMemGraph)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclUint32, ncclUint64};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {896};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {true};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
}