diff --git a/projects/rccl/test/common/CollectiveArgs.cpp b/projects/rccl/test/common/CollectiveArgs.cpp index 93fe0d588f..e9a84255a9 100644 --- a/projects/rccl/test/common/CollectiveArgs.cpp +++ b/projects/rccl/test/common/CollectiveArgs.cpp @@ -225,7 +225,7 @@ namespace RcclUnitTesting case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break; case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break; case ncclFloat8e5m2: ss << (float)scalarsPerRank.B1[this->globalRank]; break; - case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break; + case ncclBfloat16: ss << (float)scalarsPerRank.B2[this->globalRank]; break; default: ss << "(UNKNOWN)"; } ss << " "; diff --git a/projects/rccl/test/common/PtrUnion.cpp b/projects/rccl/test/common/PtrUnion.cpp index 9664a41bac..c089ac3221 100644 --- a/projects/rccl/test/common/PtrUnion.cpp +++ b/projects/rccl/test/common/PtrUnion.cpp @@ -202,11 +202,11 @@ namespace RcclUnitTesting case ncclUint32: valueI = U4[idx]; break; case ncclInt64: valueI = I8[idx]; break; case ncclUint64: valueI = U8[idx]; break; - case ncclFloat8e4m3: valueF = float(F1[idx]); break; + case ncclFloat8e4m3: valueF = float(F1[idx]); break; case ncclFloat16: valueF = __half2float(F2[idx]); break; case ncclFloat32: valueF = F4[idx]; break; case ncclFloat64: valueF = F8[idx]; break; - case ncclFloat8e5m2: valueF = float(B1[idx]); break; + case ncclFloat8e5m2: valueF = float(B1[idx]); break; case ncclBfloat16: valueF = B2[idx]; break; default: ERROR("Unsupported datatype\n"); @@ -274,7 +274,7 @@ namespace RcclUnitTesting case ncclFloat32: F4[idx] = ReduceOp(op, F4[idx], inputCpu.F4[idx]); break; case ncclFloat64: F8[idx] = ReduceOp(op, F8[idx], inputCpu.F8[idx]); break; case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(ReduceOp(op, float(B1[idx]), float(inputCpu.B1[idx]))); break; - case ncclBfloat16: B2[idx] = ReduceOp(op, B2[idx], inputCpu.B2[idx]); break; + case ncclBfloat16: B2[idx] = hip_bfloat16(ReduceOp(op, float(B2[idx]), float(inputCpu.B2[idx]))); break; default: ERROR("Unsupported datatype\n"); return TEST_FAIL; @@ -360,7 +360,7 @@ namespace RcclUnitTesting case ncclUint64: ERROR("Expected output: %lu. Actual output: %lu at index %lu\n", expected.U8[idx], U8[idx], idx); break; case ncclFloat8e4m3: - ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); + ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); break; case ncclFloat16: ERROR("Expected output: %f. Actual output: %f at index %lu\n", __half2float(expected.F2[idx]), __half2float(F2[idx]), idx); break; case ncclFloat32: @@ -368,7 +368,7 @@ namespace RcclUnitTesting case ncclFloat64: ERROR("Expected output: %lf. Actual output: %lf at index %lu\n", expected.F8[idx], F8[idx], idx); break; case ncclFloat8e5m2: - ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); + ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); break; case ncclBfloat16: ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B2[idx], (float)B2[idx], idx); break; default: diff --git a/projects/rccl/test/common/PtrUnion.hpp b/projects/rccl/test/common/PtrUnion.hpp index 75c1255d2b..83a9c296bf 100644 --- a/projects/rccl/test/common/PtrUnion.hpp +++ b/projects/rccl/test/common/PtrUnion.hpp @@ -8,7 +8,13 @@ #include "ErrCode.hpp" #include "rccl/rccl.h" #include "rccl_float8.h" -#include +#if ROCM_VERSION >= 60000 + // hip_bf16.h should be used from ROCm 6.0 + #include + typedef __hip_bfloat16 hip_bfloat16; +#else + #include +#endif #include "hip/hip_fp16.h" namespace RcclUnitTesting