[UT] Start supporting multiple group calls and graphs (#1151)

* Start supporting multiple group calls UT
This commit is contained in:
Bertan Dogancay
2024-04-25 11:11:16 -06:00
committed by GitHub
parent efe99057b0
commit 0ec41f1386
8 changed files with 566 additions and 240 deletions
+89 -2
View File
@@ -195,14 +195,14 @@ namespace RcclUnitTesting
isMultiProcess ? "MP" : "SP", totalRanks, numCollPerGroup, numStreamsPerGroup);
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks),
numCollPerGroup, true, numStreamsPerGroup);
numCollPerGroup, numStreamsPerGroup);
// Set up each collective in group in different stream (modulo numStreamsPerGroup)
options.redOp = ncclSum;
for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx)
{
testBed.SetCollectiveArgs(ncclCollAllReduce, ncclFloat, numElements, numElements,
options, collIdx, -1, collIdx % numStreamsPerGroup);
options, collIdx, 0, -1, collIdx % numStreamsPerGroup);
}
testBed.AllocateMem(inPlace, useManagedMem);
@@ -216,4 +216,91 @@ namespace RcclUnitTesting
}
testBed.Finalize();
}
TEST(GroupCall, MultiGroupCall)
{
TestBed testBed;
// Configuration
std::vector<std::vector<ncclFunc_t>> const groupCalls = {{ncclCollAllReduce, ncclCollAllGather},
{ncclCollAllToAll, ncclCollGather},
{ncclCollBroadcast, ncclCollReduceScatter}};
std::vector<std::vector<int>> const numElements = {{1250, 1048576}, {384, 384 * 1024}, {1048576, 127}};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32, ncclBfloat16};
std::vector<ncclRedOp_t> const redops = {ncclSum, ncclProd, ncclMax};
std::vector<int> const numCollsPerGroup = {2, 2, 2};
std::vector<int> const numStreamsPerGroup = {1, 1, 1};
std::vector<bool> const useHipGraphList = {true, false, true};
bool const inPlace = false;
bool const useManagedMem = false;
bool const useBlocking = true;
int const numGroupCalls = groupCalls.size();
int const numIterations = 10;
bool isCorrect = true;
for (int totalRanks : testBed.ev.GetNumGpusList())
for (int isMultiProcess : testBed.ev.GetIsMultiProcessList())
{
int const numProcesses = isMultiProcess ? totalRanks : 1;
// Initialize comms by specifying the # of group calls
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), numCollsPerGroup, numStreamsPerGroup, numGroupCalls, useBlocking);
if (testBed.ev.showNames)
INFO("%s %d-ranks GroupCall MultiGroupCall\n", isMultiProcess ? "MP" : "SP", totalRanks);
for (int groupCallIdx = 0; groupCallIdx < groupCalls.size(); ++groupCallIdx)
{
std::vector<ncclFunc_t> funcTypes = groupCalls[groupCallIdx];
OptionalColArgs options;
options.redOp = redops[groupCallIdx];
options.root = 0;
for (int collIdx = 0; collIdx < numCollsPerGroup[groupCallIdx]; ++collIdx)
{
int numInputElements;
int numOutputElements;
CollectiveArgs::GetNumElementsForFuncType(funcTypes[collIdx],
numElements[groupCallIdx][collIdx],
totalRanks,
&numInputElements,
&numOutputElements);
testBed.SetCollectiveArgs(funcTypes[collIdx],
dataTypes[groupCallIdx],
numInputElements,
numOutputElements,
options,
collIdx,
groupCallIdx);
}
testBed.AllocateMem(inPlace, useManagedMem, groupCallIdx);
testBed.PrepareData(groupCallIdx);
// Stream capture in advance for HIP graph enabled collective groups
if (useHipGraphList[groupCallIdx])
{
testBed.ExecuteCollectives({}, groupCallIdx, useHipGraphList[groupCallIdx]);
}
}
// Execute collectives based on groupIdx
for (int i = 0; i < numIterations; ++i)
{
// Select a random group call
int groupCallIdx = i % groupCalls.size();
// Use graphs if enabled otherwise execute the collective
if (useHipGraphList[groupCallIdx]) testBed.LaunchGraphs(groupCallIdx);
else testBed.ExecuteCollectives({}, groupCallIdx);
testBed.ValidateResults(isCorrect, groupCallIdx);
}
testBed.DeallocateMem();
testBed.DestroyGraphs();
testBed.DestroyComms();
}
testBed.Finalize();
}
}