[UT] Start supporting multiple group calls and graphs (#1151)
* Start supporting multiple group calls UT
This commit is contained in:
+89
-2
@@ -195,14 +195,14 @@ namespace RcclUnitTesting
|
||||
isMultiProcess ? "MP" : "SP", totalRanks, numCollPerGroup, numStreamsPerGroup);
|
||||
|
||||
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks),
|
||||
numCollPerGroup, true, numStreamsPerGroup);
|
||||
numCollPerGroup, numStreamsPerGroup);
|
||||
|
||||
// Set up each collective in group in different stream (modulo numStreamsPerGroup)
|
||||
options.redOp = ncclSum;
|
||||
for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx)
|
||||
{
|
||||
testBed.SetCollectiveArgs(ncclCollAllReduce, ncclFloat, numElements, numElements,
|
||||
options, collIdx, -1, collIdx % numStreamsPerGroup);
|
||||
options, collIdx, 0, -1, collIdx % numStreamsPerGroup);
|
||||
}
|
||||
|
||||
testBed.AllocateMem(inPlace, useManagedMem);
|
||||
@@ -216,4 +216,91 @@ namespace RcclUnitTesting
|
||||
}
|
||||
testBed.Finalize();
|
||||
}
|
||||
|
||||
TEST(GroupCall, MultiGroupCall)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
std::vector<std::vector<ncclFunc_t>> const groupCalls = {{ncclCollAllReduce, ncclCollAllGather},
|
||||
{ncclCollAllToAll, ncclCollGather},
|
||||
{ncclCollBroadcast, ncclCollReduceScatter}};
|
||||
std::vector<std::vector<int>> const numElements = {{1250, 1048576}, {384, 384 * 1024}, {1048576, 127}};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32, ncclBfloat16};
|
||||
std::vector<ncclRedOp_t> const redops = {ncclSum, ncclProd, ncclMax};
|
||||
std::vector<int> const numCollsPerGroup = {2, 2, 2};
|
||||
std::vector<int> const numStreamsPerGroup = {1, 1, 1};
|
||||
std::vector<bool> const useHipGraphList = {true, false, true};
|
||||
bool const inPlace = false;
|
||||
bool const useManagedMem = false;
|
||||
bool const useBlocking = true;
|
||||
int const numGroupCalls = groupCalls.size();
|
||||
int const numIterations = 10;
|
||||
|
||||
bool isCorrect = true;
|
||||
for (int totalRanks : testBed.ev.GetNumGpusList())
|
||||
for (int isMultiProcess : testBed.ev.GetIsMultiProcessList())
|
||||
{
|
||||
int const numProcesses = isMultiProcess ? totalRanks : 1;
|
||||
|
||||
// Initialize comms by specifying the # of group calls
|
||||
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), numCollsPerGroup, numStreamsPerGroup, numGroupCalls, useBlocking);
|
||||
|
||||
if (testBed.ev.showNames)
|
||||
INFO("%s %d-ranks GroupCall MultiGroupCall\n", isMultiProcess ? "MP" : "SP", totalRanks);
|
||||
|
||||
for (int groupCallIdx = 0; groupCallIdx < groupCalls.size(); ++groupCallIdx)
|
||||
{
|
||||
std::vector<ncclFunc_t> funcTypes = groupCalls[groupCallIdx];
|
||||
OptionalColArgs options;
|
||||
options.redOp = redops[groupCallIdx];
|
||||
options.root = 0;
|
||||
|
||||
for (int collIdx = 0; collIdx < numCollsPerGroup[groupCallIdx]; ++collIdx)
|
||||
{
|
||||
int numInputElements;
|
||||
int numOutputElements;
|
||||
CollectiveArgs::GetNumElementsForFuncType(funcTypes[collIdx],
|
||||
numElements[groupCallIdx][collIdx],
|
||||
totalRanks,
|
||||
&numInputElements,
|
||||
&numOutputElements);
|
||||
|
||||
testBed.SetCollectiveArgs(funcTypes[collIdx],
|
||||
dataTypes[groupCallIdx],
|
||||
numInputElements,
|
||||
numOutputElements,
|
||||
options,
|
||||
collIdx,
|
||||
groupCallIdx);
|
||||
}
|
||||
|
||||
testBed.AllocateMem(inPlace, useManagedMem, groupCallIdx);
|
||||
testBed.PrepareData(groupCallIdx);
|
||||
|
||||
// Stream capture in advance for HIP graph enabled collective groups
|
||||
if (useHipGraphList[groupCallIdx])
|
||||
{
|
||||
testBed.ExecuteCollectives({}, groupCallIdx, useHipGraphList[groupCallIdx]);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute collectives based on groupIdx
|
||||
for (int i = 0; i < numIterations; ++i)
|
||||
{
|
||||
// Select a random group call
|
||||
int groupCallIdx = i % groupCalls.size();
|
||||
|
||||
// Use graphs if enabled otherwise execute the collective
|
||||
if (useHipGraphList[groupCallIdx]) testBed.LaunchGraphs(groupCallIdx);
|
||||
else testBed.ExecuteCollectives({}, groupCallIdx);
|
||||
testBed.ValidateResults(isCorrect, groupCallIdx);
|
||||
}
|
||||
|
||||
testBed.DeallocateMem();
|
||||
testBed.DestroyGraphs();
|
||||
testBed.DestroyComms();
|
||||
}
|
||||
testBed.Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user