diff --git a/test/AllGatherTests.cpp b/test/AllGatherTests.cpp index edb0f31052..1be1a7f42e 100644 --- a/test/AllGatherTests.cpp +++ b/test/AllGatherTests.cpp @@ -13,7 +13,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllGather}; - std::vector const dataTypes = {ncclFloat32}; + std::vector const dataTypes = {ncclFloat16, ncclFloat32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1048576, 500}; diff --git a/test/AllToAllTests.cpp b/test/AllToAllTests.cpp index b006ef8870..5c8ea0b420 100644 --- a/test/AllToAllTests.cpp +++ b/test/AllToAllTests.cpp @@ -16,7 +16,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllToAll}; - std::vector const dataTypes = {ncclFloat32}; + std::vector const dataTypes = {ncclFloat16, ncclFloat32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1048576, 1024}; diff --git a/test/BroadcastTests.cpp b/test/BroadcastTests.cpp index a50c879c9e..d925ef4d68 100644 --- a/test/BroadcastTests.cpp +++ b/test/BroadcastTests.cpp @@ -13,7 +13,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollBroadcast}; - std::vector const dataTypes = {ncclFloat32}; + std::vector const dataTypes = {ncclFloat16, ncclFloat32}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {1048576, 500}; diff --git a/test/GroupCallTests.cpp b/test/GroupCallTests.cpp index 52ca200203..6ed0ff7f82 100644 --- a/test/GroupCallTests.cpp +++ b/test/GroupCallTests.cpp @@ -118,6 +118,55 @@ namespace RcclUnitTesting testBed.Finalize(); } + // Test identical collectives with different data type + TEST(GroupCall, MixedDataType) + { + TestBed testBed; + + // Configuration + std::vector const funcTypes = {ncclCollAllReduce, ncclCollAllReduce, ncclCollAllReduce}; + std::vector const redOps = {ncclSum, ncclSum, ncclSum}; + std::vector const dataTypes = {ncclFloat16, ncclFloat32, ncclFloat64}; + std::vector const numElements = {1048576, 384 * 1024, 384}; + + int const numCollPerGroup = numElements.size(); + bool const inPlace = false; + bool const useManagedMem = false; + + bool isCorrect = true; + for (int totalRanks : testBed.ev.GetNumGpusList()) + for (int isMultiProcess : testBed.ev.GetIsMultiProcessList()) + { + // Test either single process all GPUs, or 1 process per GPU + int const numProcesses = isMultiProcess ? totalRanks : 1; + testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks), numCollPerGroup); + + if (testBed.ev.showNames) + INFO("%s %d-ranks GroupCall MixedDayaType\n", isMultiProcess ? "MP" : "SP", totalRanks); + + // Set up the different collectives within the group + for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx) + { + OptionalColArgs options; + options.redOp = redOps[collIdx]; + testBed.SetCollectiveArgs(funcTypes[collIdx], + dataTypes[collIdx], + numElements[collIdx], + numElements[collIdx], + options, + collIdx); + } + + testBed.AllocateMem(inPlace, useManagedMem); + testBed.PrepareData(); + testBed.ExecuteCollectives(); + testBed.ValidateResults(isCorrect); + testBed.DeallocateMem(); + testBed.DestroyComms(); + } + testBed.Finalize(); + } + TEST(GroupCall, Multistream) { TestBed testBed; diff --git a/test/SendRecvTests.cpp b/test/SendRecvTests.cpp index 4b0d9eef1b..53c603e4c1 100644 --- a/test/SendRecvTests.cpp +++ b/test/SendRecvTests.cpp @@ -12,7 +12,7 @@ namespace RcclUnitTesting TestBed testBed; // Configuration - std::vector const& dataTypes = {ncclInt32, ncclFloat64}; + std::vector const& dataTypes = {ncclInt32, ncclFloat16, ncclFloat64}; std::vector const numElements = {1048576, 53327, 1024, 0}; bool const inPlace = false; bool const useManagedMem = false; diff --git a/test/common/PtrUnion.cpp b/test/common/PtrUnion.cpp index d06afa8194..7602ea2e7d 100644 --- a/test/common/PtrUnion.cpp +++ b/test/common/PtrUnion.cpp @@ -274,7 +274,7 @@ namespace RcclUnitTesting case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx])/divisor); break; case ncclFloat32: F4[idx] /= divisor; break; case ncclFloat64: F8[idx] /= divisor; break; - case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break; + case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break; default: ERROR("Unsupported datatype\n"); return TEST_FAIL; @@ -295,15 +295,15 @@ namespace RcclUnitTesting { switch (dataType) { - case ncclInt8: isMatch = (I1[idx] == expected.I1[idx]); break; - case ncclUint8: isMatch = (U1[idx] == expected.U1[idx]); break; - case ncclInt32: isMatch = (I4[idx] == expected.I4[idx]); break; - case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break; - case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break; - case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break; - case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break; - case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break; - case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break; + case ncclInt8: isMatch = (I1[idx] == expected.I1[idx]); break; + case ncclUint8: isMatch = (U1[idx] == expected.U1[idx]); break; + case ncclInt32: isMatch = (I4[idx] == expected.I4[idx]); break; + case ncclUint32: isMatch = (U4[idx] == expected.U4[idx]); break; + case ncclInt64: isMatch = (I8[idx] == expected.I8[idx]); break; + case ncclUint64: isMatch = (U8[idx] == expected.U8[idx]); break; + case ncclFloat16: isMatch = (fabs(__half2float(F2[idx]) - __half2float(expected.F2[idx])) < 9e-2); break; + case ncclFloat32: isMatch = (fabs(F4[idx] - expected.F4[idx]) < 1e-5); break; + case ncclFloat64: isMatch = (fabs(F8[idx] - expected.F8[idx]) < 1e-12); break; case ncclBfloat16: isMatch = (fabs((float)B2[idx] - (float)expected.B2[idx]) < 9e-2); break; default: ERROR("Unsupported datatype\n");