Added alltoallv test and optional args variable on collective args (#514)

* Added alltoallv test and optional args variable on collective args
这个提交包含在:
akolliasAMD
2022-03-18 13:55:11 -04:00
提交者 GitHub
父节点 a04da71647
当前提交 65ea3d80db
修改 13 个文件,包含 284 行新增154 行删除
+56 -16
查看文件
@@ -23,6 +23,7 @@ namespace RcclUnitTesting
case ncclCollGather: return DefaultPrepData_Gather(collArgs, false);
case ncclCollScatter: return DefaultPrepData_Scatter(collArgs);
case ncclCollAllToAll: return DefaultPrepData_AllToAll(collArgs);
case ncclCollAllToAllv: return DefaultPrepData_AllToAllv(collArgs);
case ncclCollSend: return DefaultPrepData_Send(collArgs);
case ncclCollRecv: return DefaultPrepData_Recv(collArgs);
default:
@@ -64,15 +65,15 @@ namespace RcclUnitTesting
CHECK_CALL(collArgs.outputGpu.ClearGpuMem(numBytes));
// Only root needs input pattern
if (collArgs.globalRank == collArgs.root)
if (collArgs.globalRank == collArgs.options.root)
CHECK_CALL(collArgs.inputGpu.FillPattern(collArgs.dataType,
collArgs.numInputElements,
collArgs.root, true));
collArgs.options.root, true));
// Otherwise all other ranks expected output is the same as input of root
return collArgs.expected.FillPattern(collArgs.dataType,
collArgs.numInputElements,
collArgs.root,
collArgs.options.root,
false);
}
@@ -96,11 +97,11 @@ namespace RcclUnitTesting
CHECK_CALL(result.ClearCpuMem(numBytes));
// If average or custom reduction operator is used, perform a summation instead
ncclRedOp_t const tempOp = (collArgs.redOp >= ncclAvg ? ncclSum : collArgs.redOp);
ncclRedOp_t const tempOp = (collArgs.options.redOp >= ncclAvg ? ncclSum : collArgs.options.redOp);
// Loop over each rank and generate their input into a temp buffer, then reduce
PtrUnion scalarsPerRank;
scalarsPerRank.Attach(collArgs.scalarTransport.ptr);
scalarsPerRank.Attach(collArgs.options.scalarTransport.ptr);
PtrUnion tempInputCpu;
CHECK_CALL(tempInputCpu.Attach(collArgs.outputCpu));
@@ -117,14 +118,14 @@ namespace RcclUnitTesting
// Scale the temporary input by local scalar for this rank
// (Used by custom reduction ops)
if (collArgs.scalarMode >= 0)
if (collArgs.options.scalarMode >= 0)
{
CHECK_CALL(tempInputCpu.Scale(collArgs.dataType, collArgs.numInputElements,
scalarsPerRank, rank));
}
// Any rank that requires output reduces the scaled-inputs
if (isAllReduce || collArgs.root == collArgs.globalRank)
if (isAllReduce || collArgs.options.root == collArgs.globalRank)
{
if (rank == 0)
{
@@ -139,7 +140,7 @@ namespace RcclUnitTesting
}
// Perform averaging if necessary
if (collArgs.redOp == ncclAvg && (isAllReduce || collArgs.root == collArgs.globalRank))
if (collArgs.options.redOp == ncclAvg && (isAllReduce || collArgs.options.root == collArgs.globalRank))
{
CHECK_CALL(result.DivideByInt(collArgs.dataType, collArgs.numInputElements, collArgs.totalRanks));
}
@@ -176,7 +177,7 @@ namespace RcclUnitTesting
{
CHECK_HIP(hipMemcpy(collArgs.inputGpu.ptr, tempInputCpu.ptr, numInputBytes, hipMemcpyHostToDevice));
}
if (isAllGather || collArgs.root == collArgs.globalRank)
if (isAllGather || collArgs.options.root == collArgs.globalRank)
{
memcpy(result.I1 + (rank * numInputBytes), tempInputCpu.ptr, numInputBytes);
}
@@ -207,11 +208,11 @@ namespace RcclUnitTesting
CHECK_CALL(tempResultCpu.ClearCpuMem(numInputBytes));
// If average or custom reduction operator is used, perform a summation instead
ncclRedOp_t const tempOp = (collArgs.redOp >= ncclAvg ? ncclSum : collArgs.redOp);
ncclRedOp_t const tempOp = (collArgs.options.redOp >= ncclAvg ? ncclSum : collArgs.options.redOp);
// Loop over each rank and generate the input / scale / reduce
PtrUnion scalarsPerRank;
scalarsPerRank.Attach(collArgs.scalarTransport.ptr);
scalarsPerRank.Attach(collArgs.options.scalarTransport.ptr);
for (int rank = 0; rank < collArgs.totalRanks; ++rank)
{
CHECK_CALL(tempInputCpu.FillPattern(collArgs.dataType, collArgs.numInputElements, rank, false));
@@ -229,7 +230,7 @@ namespace RcclUnitTesting
// Scale the temporary input by local scalar for this rank
// (Used by custom reduction ops)
if (collArgs.scalarMode >= 0)
if (collArgs.options.scalarMode >= 0)
{
CHECK_CALL(tempInputCpu.Scale(collArgs.dataType, collArgs.numInputElements,
scalarsPerRank, rank));
@@ -247,7 +248,7 @@ namespace RcclUnitTesting
}
// Perform averaging if necessary
if (collArgs.redOp == ncclAvg)
if (collArgs.options.redOp == ncclAvg)
{
CHECK_CALL(tempResultCpu.DivideByInt(collArgs.dataType, collArgs.numInputElements, collArgs.totalRanks));
}
@@ -279,10 +280,10 @@ namespace RcclUnitTesting
// Generate input as if on root rank - each rank will receive a portion
PtrUnion tempInput;
tempInput.AllocateCpuMem(numInputBytes);
tempInput.FillPattern(collArgs.dataType, collArgs.numInputElements, collArgs.root, false);
tempInput.FillPattern(collArgs.dataType, collArgs.numInputElements, collArgs.options.root, false);
// Copy input to root rank
if (collArgs.globalRank == collArgs.root)
if (collArgs.globalRank == collArgs.options.root)
{
if (hipMemcpy(collArgs.inputGpu.ptr, tempInput.ptr, numInputBytes, hipMemcpyHostToDevice) != hipSuccess)
{
@@ -341,6 +342,45 @@ namespace RcclUnitTesting
return TEST_SUCCESS;
}
ErrCode DefaultPrepData_AllToAllv(CollectiveArgs &collArgs)
{
CHECK_CALL(CheckAllocation(collArgs));
size_t const numInputBytes = collArgs.numInputElements * DataTypeToBytes(collArgs.dataType);
size_t const numOutputBytes = collArgs.numOutputElements * DataTypeToBytes(collArgs.dataType);
// calculating maxNumElements as the maximum number of input bytes out of all the ranks
size_t maxNumElements = 0;
for (int sendRank = 0; sendRank < collArgs.totalRanks; ++sendRank)
for (int recvRank = 0; recvRank < collArgs.totalRanks; ++recvRank)
{
size_t rankSendCount = collArgs.options.sdispls[(sendRank)*collArgs.totalRanks+recvRank] + collArgs.options.sendcounts[(sendRank)*collArgs.totalRanks+recvRank];
maxNumElements = std::max(maxNumElements, rankSendCount);
}
// Clear outputs on all ranks (prior to input in case of in-place)
collArgs.outputGpu.ClearGpuMem(numOutputBytes);
// Generate input on root rank - each rank will receive a portion
PtrUnion tempInput;
tempInput.AllocateCpuMem(maxNumElements*DataTypeToBytes(collArgs.dataType));
for (int sendRank = 0; sendRank < collArgs.totalRanks; ++sendRank)
{
tempInput.FillPattern(collArgs.dataType, maxNumElements, sendRank, false);
size_t recvDspls = collArgs.options.rdispls[collArgs.globalRank*collArgs.totalRanks + sendRank] * DataTypeToBytes(collArgs.dataType);
size_t sendDspls = collArgs.options.sdispls[sendRank*collArgs.totalRanks + collArgs.globalRank] * DataTypeToBytes(collArgs.dataType);
size_t numBytes = collArgs.options.recvcounts[collArgs.globalRank*collArgs.totalRanks + sendRank] * DataTypeToBytes(collArgs.dataType);
memcpy(collArgs.expected.U1 + recvDspls, tempInput.U1 + sendDspls, numBytes);
}
tempInput.FillPattern(collArgs.dataType, collArgs.numInputElements, collArgs.globalRank, false);
CHECK_HIP(hipMemcpy(collArgs.inputGpu.ptr, tempInput.ptr, numInputBytes, hipMemcpyHostToDevice));
tempInput.FreeCpuMem();
return TEST_SUCCESS;
}
ErrCode DefaultPrepData_Send(CollectiveArgs &collArgs)
{
CHECK_CALL(CheckAllocation(collArgs));
@@ -354,7 +394,7 @@ namespace RcclUnitTesting
CHECK_CALL(CheckAllocation(collArgs));
return collArgs.expected.FillPattern(collArgs.dataType,
collArgs.numOutputElements,
collArgs.root,
collArgs.options.root,
false);
}
}