From 39483c55f874c8ab2aaa26dae11158ffda7552fd Mon Sep 17 00:00:00 2001 From: mberenjk <146776561+mberenjk@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:39:02 -0600 Subject: [PATCH] =?UTF-8?q?Initializing=20all=20ranks=20to=20the=20same=20?= =?UTF-8?q?value=20to=20avoid=20failure=20of=20=20UT=20AllR=E2=80=A6=20(#1?= =?UTF-8?q?459)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initializing all ranks to the same value to avoid failure of UT AllReduce for FP8 type Co-authored-by: Marzieh Berenjkoub --- test/AllReduceTests.cpp | 2 +- test/common/PtrUnion.cpp | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/AllReduceTests.cpp b/test/AllReduceTests.cpp index ac393dfe8e..fe9579ec55 100644 --- a/test/AllReduceTests.cpp +++ b/test/AllReduceTests.cpp @@ -13,7 +13,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllReduce}; - std::vector const dataTypes = {ncclFloat32}; + std::vector const dataTypes = {ncclFloat32, ncclFp8E4M3, ncclFp8E5M2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {393216, 384}; diff --git a/test/common/PtrUnion.cpp b/test/common/PtrUnion.cpp index 7ed1558f1e..facf60b342 100644 --- a/test/common/PtrUnion.cpp +++ b/test/common/PtrUnion.cpp @@ -148,7 +148,9 @@ namespace RcclUnitTesting for (int i = 0; i < numElements; i++) { - int valueI = (globalRank + i) % 256; + // Due to floating-point math not being commutative, the ordering in which ranks are added will matter. + // For lower-precision data types, we initialize all ranks to the same value to avoid this + int valueI = (dataType == ncclFp8E4M3 || dataType == ncclFp8E5M2)? (i % 16) :(globalRank + i) % 256; double valueF = 1.0L/((double)valueI+1.0L); temp.Set(dataType, i, valueI, valueF); }