/************************************************************************* * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ // Note: InPlace is not supported for All-To-All #include "TestBed.hpp" namespace RcclUnitTesting { TEST(AllToAll, OutOfPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllToAll}; std::vector const dataTypes = {ncclFloat16, ncclFloat32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1048576, 1024}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllToAll, OutOfPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllToAll}; std::vector const dataTypes = {ncclFloat64, ncclBfloat16, ncclFp8E4M3, ncclFp8E5M2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {5685}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllToAll, ManagedMem) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllToAll}; std::vector const dataTypes = {ncclUint8}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {384 * 1024, 1024}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllToAll, ManagedMemGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllToAll}; std::vector const dataTypes = {ncclUint32, ncclUint64}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1048576}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllToAll, Channels) { TestBed testBed; if(testBed.ev.maxGpus >= 8) { if(testBed.ev.isGfx94) { // Configuration std::vector const funcTypes = {ncclCollAllToAll}; std::vector const dataTypes = {ncclBfloat16}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {64 * 1024 * 1024, 1024}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false, true}; std::vector const channelList = {"112"}; bool const enableSweep = false; for (auto channel : channelList) { setenv("NCCL_MIN_NCHANNELS", channel, 1); testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList, enableSweep); testBed.Finalize(); unsetenv("NCCL_MIN_NCHANNELS"); } } } } }