From cc6e259a0208c380d21cf2e7636714720f1263f4 Mon Sep 17 00:00:00 2001 From: Atul Kulkarni Date: Thu, 4 Dec 2025 10:02:06 -0600 Subject: [PATCH] Fix rccl test suite to use hip_bf16.h instead of hip_bfloat16.h for the __bf16 intrinsic (#2082) --- test/common/CollectiveArgs.cpp | 2 +- test/common/PtrUnion.cpp | 10 +++++----- test/common/PtrUnion.hpp | 8 +++++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/test/common/CollectiveArgs.cpp b/test/common/CollectiveArgs.cpp index 93fe0d588f..e9a84255a9 100644 --- a/test/common/CollectiveArgs.cpp +++ b/test/common/CollectiveArgs.cpp @@ -225,7 +225,7 @@ namespace RcclUnitTesting case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break; case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break; case ncclFloat8e5m2: ss << (float)scalarsPerRank.B1[this->globalRank]; break; - case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break; + case ncclBfloat16: ss << (float)scalarsPerRank.B2[this->globalRank]; break; default: ss << "(UNKNOWN)"; } ss << " "; diff --git a/test/common/PtrUnion.cpp b/test/common/PtrUnion.cpp index 9664a41bac..c089ac3221 100644 --- a/test/common/PtrUnion.cpp +++ b/test/common/PtrUnion.cpp @@ -202,11 +202,11 @@ namespace RcclUnitTesting case ncclUint32: valueI = U4[idx]; break; case ncclInt64: valueI = I8[idx]; break; case ncclUint64: valueI = U8[idx]; break; - case ncclFloat8e4m3: valueF = float(F1[idx]); break; + case ncclFloat8e4m3: valueF = float(F1[idx]); break; case ncclFloat16: valueF = __half2float(F2[idx]); break; case ncclFloat32: valueF = F4[idx]; break; case ncclFloat64: valueF = F8[idx]; break; - case ncclFloat8e5m2: valueF = float(B1[idx]); break; + case ncclFloat8e5m2: valueF = float(B1[idx]); break; case ncclBfloat16: valueF = B2[idx]; break; default: ERROR("Unsupported datatype\n"); @@ -274,7 +274,7 @@ namespace RcclUnitTesting case ncclFloat32: F4[idx] = ReduceOp(op, F4[idx], inputCpu.F4[idx]); break; case ncclFloat64: F8[idx] = ReduceOp(op, F8[idx], inputCpu.F8[idx]); break; case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(ReduceOp(op, float(B1[idx]), float(inputCpu.B1[idx]))); break; - case ncclBfloat16: B2[idx] = ReduceOp(op, B2[idx], inputCpu.B2[idx]); break; + case ncclBfloat16: B2[idx] = hip_bfloat16(ReduceOp(op, float(B2[idx]), float(inputCpu.B2[idx]))); break; default: ERROR("Unsupported datatype\n"); return TEST_FAIL; @@ -360,7 +360,7 @@ namespace RcclUnitTesting case ncclUint64: ERROR("Expected output: %lu. Actual output: %lu at index %lu\n", expected.U8[idx], U8[idx], idx); break; case ncclFloat8e4m3: - ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); + ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); break; case ncclFloat16: ERROR("Expected output: %f. Actual output: %f at index %lu\n", __half2float(expected.F2[idx]), __half2float(F2[idx]), idx); break; case ncclFloat32: @@ -368,7 +368,7 @@ namespace RcclUnitTesting case ncclFloat64: ERROR("Expected output: %lf. Actual output: %lf at index %lu\n", expected.F8[idx], F8[idx], idx); break; case ncclFloat8e5m2: - ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); + ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); break; case ncclBfloat16: ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B2[idx], (float)B2[idx], idx); break; default: diff --git a/test/common/PtrUnion.hpp b/test/common/PtrUnion.hpp index 75c1255d2b..83a9c296bf 100644 --- a/test/common/PtrUnion.hpp +++ b/test/common/PtrUnion.hpp @@ -8,7 +8,13 @@ #include "ErrCode.hpp" #include "rccl/rccl.h" #include "rccl_float8.h" -#include +#if ROCM_VERSION >= 60000 + // hip_bf16.h should be used from ROCm 6.0 + #include + typedef __hip_bfloat16 hip_bfloat16; +#else + #include +#endif #include "hip/hip_fp16.h" namespace RcclUnitTesting