Adding FP16 cases to unit tests(#1093)
Signed-off-by: Tim Hu <timhu102@amd.com>
Este cometimento está contido em:
@@ -13,7 +13,7 @@ namespace RcclUnitTesting
|
||||
|
||||
// Configuration
|
||||
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllGather};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
|
||||
std::vector<ncclRedOp_t> const redOps = {ncclSum};
|
||||
std::vector<int> const roots = {0};
|
||||
std::vector<int> const numElements = {1048576, 500};
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace RcclUnitTesting
|
||||
|
||||
// Configuration
|
||||
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllToAll};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
|
||||
std::vector<ncclRedOp_t> const redOps = {ncclSum};
|
||||
std::vector<int> const roots = {0};
|
||||
std::vector<int> const numElements = {1048576, 1024};
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace RcclUnitTesting
|
||||
|
||||
// Configuration
|
||||
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
|
||||
std::vector<ncclRedOp_t> const redOps = {ncclSum};
|
||||
std::vector<int> const roots = {0};
|
||||
std::vector<int> const numElements = {1048576, 500};
|
||||
|
||||
@@ -118,6 +118,55 @@ namespace RcclUnitTesting
|
||||
testBed.Finalize();
|
||||
}
|
||||
|
||||
// Test identical collectives with different data type
|
||||
TEST(GroupCall, MixedDataType)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllReduce, ncclCollAllReduce, ncclCollAllReduce};
|
||||
std::vector<ncclRedOp_t> const redOps = {ncclSum, ncclSum, ncclSum};
|
||||
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32, ncclFloat64};
|
||||
std::vector<int> const numElements = {1048576, 384 * 1024, 384};
|
||||
|
||||
int const numCollPerGroup = numElements.size();
|
||||
bool const inPlace = false;
|
||||
bool const useManagedMem = false;
|
||||
|
||||
bool isCorrect = true;
|
||||
for (int totalRanks : testBed.ev.GetNumGpusList())
|
||||
for (int isMultiProcess : testBed.ev.GetIsMultiProcessList())
|
||||
{
|
||||
// Test either single process all GPUs, or 1 process per GPU
|
||||
int const numProcesses = isMultiProcess ? totalRanks : 1;
|
||||
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), numCollPerGroup);
|
||||
|
||||
if (testBed.ev.showNames)
|
||||
INFO("%s %d-ranks GroupCall MixedDayaType\n", isMultiProcess ? "MP" : "SP", totalRanks);
|
||||
|
||||
// Set up the different collectives within the group
|
||||
for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx)
|
||||
{
|
||||
OptionalColArgs options;
|
||||
options.redOp = redOps[collIdx];
|
||||
testBed.SetCollectiveArgs(funcTypes[collIdx],
|
||||
dataTypes[collIdx],
|
||||
numElements[collIdx],
|
||||
numElements[collIdx],
|
||||
options,
|
||||
collIdx);
|
||||
}
|
||||
|
||||
testBed.AllocateMem(inPlace, useManagedMem);
|
||||
testBed.PrepareData();
|
||||
testBed.ExecuteCollectives();
|
||||
testBed.ValidateResults(isCorrect);
|
||||
testBed.DeallocateMem();
|
||||
testBed.DestroyComms();
|
||||
}
|
||||
testBed.Finalize();
|
||||
}
|
||||
|
||||
TEST(GroupCall, Multistream)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace RcclUnitTesting
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
std::vector<ncclDataType_t> const& dataTypes = {ncclInt32, ncclFloat64};
|
||||
std::vector<ncclDataType_t> const& dataTypes = {ncclInt32, ncclFloat16, ncclFloat64};
|
||||
std::vector<int> const numElements = {1048576, 53327, 1024, 0};
|
||||
bool const inPlace = false;
|
||||
bool const useManagedMem = false;
|
||||
|
||||
+10
-10
@@ -274,7 +274,7 @@ namespace RcclUnitTesting
|
||||
case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx])/divisor); break;
|
||||
case ncclFloat32: F4[idx] /= divisor; break;
|
||||
case ncclFloat64: F8[idx] /= divisor; break;
|
||||
case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break;
|
||||
case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break;
|
||||
default:
|
||||
ERROR("Unsupported datatype\n");
|
||||
return TEST_FAIL;
|
||||
@@ -295,15 +295,15 @@ namespace RcclUnitTesting
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case ncclInt8: isMatch = (I1[idx] == expected.I1[idx]); break;
|
||||
case ncclUint8: isMatch = (U1[idx] == expected.U1[idx]); break;
|
||||
case ncclInt32: isMatch = (I4[idx] == expected.I4[idx]); break;
|
||||
case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break;
|
||||
case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break;
|
||||
case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break;
|
||||
case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break;
|
||||
case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break;
|
||||
case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break;
|
||||
case ncclInt8: isMatch = (I1[idx] == expected.I1[idx]); break;
|
||||
case ncclUint8: isMatch = (U1[idx] == expected.U1[idx]); break;
|
||||
case ncclInt32: isMatch = (I4[idx] == expected.I4[idx]); break;
|
||||
case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break;
|
||||
case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break;
|
||||
case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break;
|
||||
case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break;
|
||||
case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break;
|
||||
case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break;
|
||||
case ncclBfloat16: isMatch = (fabs((float)B2[idx] - (float)expected.B2[idx]) < 9e-2); break;
|
||||
default:
|
||||
ERROR("Unsupported datatype\n");
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador