/************************************************************************* * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "TestBed.hpp" #include "CallCollectiveForked.hpp" namespace RcclUnitTesting { TEST(AllGather, OutOfPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllGather}; std::vector const dataTypes = {ncclFloat16, ncclFloat32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1048576, 500}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllGather, OutOfPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllGather}; std::vector const dataTypes = {ncclBfloat16, ncclFloat64, ncclFloat8e4m3, ncclFloat8e5m2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {586}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllGather, InPlace) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllGather}; std::vector const dataTypes = {ncclInt32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {104857, 264}; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllGather, InPlaceGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllGather}; std::vector const dataTypes = {ncclInt8, ncclInt64}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {958}; std::vector const inPlaceList = {true}; std::vector const managedMemList = {false}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllGather, ManagedMem) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllGather}; std::vector const dataTypes = {ncclUint8}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1039203, 2500}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {false}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllGather, ManagedMemGraph) { TestBed testBed; // Configuration std::vector const funcTypes = {ncclCollAllGather}; std::vector const dataTypes = {ncclUint32, ncclUint64}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {896}; std::vector const inPlaceList = {false}; std::vector const managedMemList = {true}; std::vector const useHipGraphList = {true}; testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements, inPlaceList, managedMemList, useHipGraphList); testBed.Finalize(); } TEST(AllGather, UserBufferRegistration) { const int nranks = 8; size_t count = 2048; std::vector sendBuff(count, 0); std::vector recvBuff(nranks*count, 0); std::vector expected(nranks*count, 0); for (int i = 0; i < count; ++i){ sendBuff[i] = i; } for(int r = 0; r < nranks; ++r) for (int i = 0; i < count; ++i) expected[r*count + i] = sendBuff[i]; callCollectiveForked(nranks, ncclCollAllGather, sendBuff, recvBuff, expected); } TEST(AllGather, ManagedMemUserBufferRegistration) { const int nranks = 8; size_t count = 2048; std::vector sendBuff(count, 0); std::vector recvBuff(nranks*count, 0); std::vector expected(nranks*count, 0); const bool use_managed_mem = true; for (int i = 0; i < count; ++i){ sendBuff[i] = i; } for(int r = 0; r < nranks; ++r) for (int i = 0; i < count; ++i) expected[r*count + i] = sendBuff[i]; callCollectiveForked(nranks, ncclCollAllGather, sendBuff, recvBuff, expected, use_managed_mem); } }