2
0
Ficheiros
rocm-systems/test/AllToAllTests.cpp
T

115 linhas
4.8 KiB
C++
Em bruto Vista normal Histórico

2023-04-06 12:28:53 -06:00
/*************************************************************************
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
2023-09-22 09:37:30 -06:00
// Note: InPlace is not supported for All-To-All
2023-04-06 12:28:53 -06:00
#include "TestBed.hpp"
namespace RcclUnitTesting
{
TEST(AlltoAll, OutOfPlace)
2023-04-06 12:28:53 -06:00
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAlltoAll};
2024-02-26 12:08:04 -05:00
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
2023-04-06 12:28:53 -06:00
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576, 1024};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(AlltoAll, OutOfPlaceGraph)
2023-04-06 12:28:53 -06:00
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAlltoAll};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat64, ncclBfloat16, ncclFloat8e4m3, ncclFloat8e5m2};
2023-04-06 12:28:53 -06:00
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {5685};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(AlltoAll, ManagedMem)
2023-04-06 12:28:53 -06:00
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAlltoAll};
2023-04-06 12:28:53 -06:00
std::vector<ncclDataType_t> const dataTypes = {ncclUint8};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {384 * 1024, 1024};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {true};
std::vector<bool> const useHipGraphList = {false};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(AlltoAll, ManagedMemGraph)
2023-04-06 12:28:53 -06:00
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAlltoAll};
2023-04-06 12:28:53 -06:00
std::vector<ncclDataType_t> const dataTypes = {ncclUint32, ncclUint64};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {true};
std::vector<bool> const useHipGraphList = {true};
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(AlltoAll, Channels)
{
TestBed testBed;
2024-06-25 10:10:10 -05:00
if(testBed.ev.maxGpus >= 8) {
if(testBed.ev.isGfx94) {
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAlltoAll};
2024-06-25 10:10:10 -05:00
std::vector<ncclDataType_t> const dataTypes = {ncclBfloat16};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {64 * 1024 * 1024, 1024};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false, true};
2024-11-04 09:46:42 -07:00
std::vector<const char *> const channelList = {"112"};
2024-06-25 10:10:10 -05:00
bool const enableSweep = false;
for (auto channel : channelList) {
setenv("NCCL_MIN_NCHANNELS", channel, 1);
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList, enableSweep);
testBed.Finalize();
unsetenv("NCCL_MIN_NCHANNELS");
}
}
}
}
2024-06-25 10:10:10 -05:00
}