2
0

Adding FP16 cases to unit tests(#1093)

Signed-off-by: Tim Hu <timhu102@amd.com>
Este cometimento está contido em:
Tim
2024-02-26 12:08:04 -05:00
cometido por GitHub
ascendente 74f9e5db64
cometimento 0d06b0f1de
6 ficheiros modificados com 63 adições e 14 eliminações
+1 -1
Ver ficheiro
@@ -13,7 +13,7 @@ namespace RcclUnitTesting
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllGather};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576, 500};
+1 -1
Ver ficheiro
@@ -16,7 +16,7 @@ namespace RcclUnitTesting
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllToAll};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576, 1024};
+1 -1
Ver ficheiro
@@ -13,7 +13,7 @@ namespace RcclUnitTesting
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollBroadcast};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat32};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {1048576, 500};
+49
Ver ficheiro
@@ -118,6 +118,55 @@ namespace RcclUnitTesting
testBed.Finalize();
}
// Test identical collectives with different data type
TEST(GroupCall, MixedDataType)
{
TestBed testBed;
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllReduce, ncclCollAllReduce, ncclCollAllReduce};
std::vector<ncclRedOp_t> const redOps = {ncclSum, ncclSum, ncclSum};
std::vector<ncclDataType_t> const dataTypes = {ncclFloat16, ncclFloat32, ncclFloat64};
std::vector<int> const numElements = {1048576, 384 * 1024, 384};
int const numCollPerGroup = numElements.size();
bool const inPlace = false;
bool const useManagedMem = false;
bool isCorrect = true;
for (int totalRanks : testBed.ev.GetNumGpusList())
for (int isMultiProcess : testBed.ev.GetIsMultiProcessList())
{
// Test either single process all GPUs, or 1 process per GPU
int const numProcesses = isMultiProcess ? totalRanks : 1;
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), numCollPerGroup);
if (testBed.ev.showNames)
INFO("%s %d-ranks GroupCall MixedDayaType\n", isMultiProcess ? "MP" : "SP", totalRanks);
// Set up the different collectives within the group
for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx)
{
OptionalColArgs options;
options.redOp = redOps[collIdx];
testBed.SetCollectiveArgs(funcTypes[collIdx],
dataTypes[collIdx],
numElements[collIdx],
numElements[collIdx],
options,
collIdx);
}
testBed.AllocateMem(inPlace, useManagedMem);
testBed.PrepareData();
testBed.ExecuteCollectives();
testBed.ValidateResults(isCorrect);
testBed.DeallocateMem();
testBed.DestroyComms();
}
testBed.Finalize();
}
TEST(GroupCall, Multistream)
{
TestBed testBed;
+1 -1
Ver ficheiro
@@ -12,7 +12,7 @@ namespace RcclUnitTesting
TestBed testBed;
// Configuration
std::vector<ncclDataType_t> const& dataTypes = {ncclInt32, ncclFloat64};
std::vector<ncclDataType_t> const& dataTypes = {ncclInt32, ncclFloat16, ncclFloat64};
std::vector<int> const numElements = {1048576, 53327, 1024, 0};
bool const inPlace = false;
bool const useManagedMem = false;
+10 -10
Ver ficheiro
@@ -274,7 +274,7 @@ namespace RcclUnitTesting
case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx])/divisor); break;
case ncclFloat32: F4[idx] /= divisor; break;
case ncclFloat64: F8[idx] /= divisor; break;
case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break;
case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break;
default:
ERROR("Unsupported datatype\n");
return TEST_FAIL;
@@ -295,15 +295,15 @@ namespace RcclUnitTesting
{
switch (dataType)
{
case ncclInt8: isMatch = (I1[idx] == expected.I1[idx]); break;
case ncclUint8: isMatch = (U1[idx] == expected.U1[idx]); break;
case ncclInt32: isMatch = (I4[idx] == expected.I4[idx]); break;
case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break;
case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break;
case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break;
case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break;
case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break;
case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break;
case ncclInt8: isMatch = (I1[idx] == expected.I1[idx]); break;
case ncclUint8: isMatch = (U1[idx] == expected.U1[idx]); break;
case ncclInt32: isMatch = (I4[idx] == expected.I4[idx]); break;
case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break;
case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break;
case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break;
case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break;
case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break;
case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break;
case ncclBfloat16: isMatch = (fabs((float)B2[idx] - (float)expected.B2[idx]) < 9e-2); break;
default:
ERROR("Unsupported datatype\n");