Fix rccl test suite to use hip_bf16.h instead of hip_bfloat16.h for the __bf16 intrinsic (#2082)

Este commit está contenido en:
Atul Kulkarni
2025-12-04 10:02:06 -06:00
cometido por GitHub
padre 7c12b0b76b
commit cc6e259a02
Se han modificado 3 ficheros con 13 adiciones y 7 borrados
+1 -1
Ver fichero
@@ -225,7 +225,7 @@ namespace RcclUnitTesting
case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break;
case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break;
case ncclFloat8e5m2: ss << (float)scalarsPerRank.B1[this->globalRank]; break;
case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break;
case ncclBfloat16: ss << (float)scalarsPerRank.B2[this->globalRank]; break;
default: ss << "(UNKNOWN)";
}
ss << " ";
+5 -5
Ver fichero
@@ -202,11 +202,11 @@ namespace RcclUnitTesting
case ncclUint32: valueI = U4[idx]; break;
case ncclInt64: valueI = I8[idx]; break;
case ncclUint64: valueI = U8[idx]; break;
case ncclFloat8e4m3: valueF = float(F1[idx]); break;
case ncclFloat8e4m3: valueF = float(F1[idx]); break;
case ncclFloat16: valueF = __half2float(F2[idx]); break;
case ncclFloat32: valueF = F4[idx]; break;
case ncclFloat64: valueF = F8[idx]; break;
case ncclFloat8e5m2: valueF = float(B1[idx]); break;
case ncclFloat8e5m2: valueF = float(B1[idx]); break;
case ncclBfloat16: valueF = B2[idx]; break;
default:
ERROR("Unsupported datatype\n");
@@ -274,7 +274,7 @@ namespace RcclUnitTesting
case ncclFloat32: F4[idx] = ReduceOp(op, F4[idx], inputCpu.F4[idx]); break;
case ncclFloat64: F8[idx] = ReduceOp(op, F8[idx], inputCpu.F8[idx]); break;
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(ReduceOp(op, float(B1[idx]), float(inputCpu.B1[idx]))); break;
case ncclBfloat16: B2[idx] = ReduceOp(op, B2[idx], inputCpu.B2[idx]); break;
case ncclBfloat16: B2[idx] = hip_bfloat16(ReduceOp(op, float(B2[idx]), float(inputCpu.B2[idx]))); break;
default:
ERROR("Unsupported datatype\n");
return TEST_FAIL;
@@ -360,7 +360,7 @@ namespace RcclUnitTesting
case ncclUint64:
ERROR("Expected output: %lu. Actual output: %lu at index %lu\n", expected.U8[idx], U8[idx], idx); break;
case ncclFloat8e4m3:
ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx);
ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.F1[idx], (float)F1[idx], idx); break;
case ncclFloat16:
ERROR("Expected output: %f. Actual output: %f at index %lu\n", __half2float(expected.F2[idx]), __half2float(F2[idx]), idx); break;
case ncclFloat32:
@@ -368,7 +368,7 @@ namespace RcclUnitTesting
case ncclFloat64:
ERROR("Expected output: %lf. Actual output: %lf at index %lu\n", expected.F8[idx], F8[idx], idx); break;
case ncclFloat8e5m2:
ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx);
ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B1[idx], (float)B1[idx], idx); break;
case ncclBfloat16:
ERROR("Expected output: %f. Actual output: %f at index %lu\n", (float)expected.B2[idx], (float)B2[idx], idx); break;
default:
+7 -1
Ver fichero
@@ -8,7 +8,13 @@
#include "ErrCode.hpp"
#include "rccl/rccl.h"
#include "rccl_float8.h"
#include <hip/hip_bfloat16.h>
#if ROCM_VERSION >= 60000
// hip_bf16.h should be used from ROCm 6.0
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 hip_bfloat16;
#else
#include <hip/hip_bfloat16.h>
#endif
#include "hip/hip_fp16.h"
namespace RcclUnitTesting