diff --git a/projects/rccl/src/device/reduce_kernel.h b/projects/rccl/src/device/reduce_kernel.h index 593f868eae..b0635f2f21 100755 --- a/projects/rccl/src/device/reduce_kernel.h +++ b/projects/rccl/src/device/reduce_kernel.h @@ -414,7 +414,7 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 1, __nv_bfloat16, fn.isMinNotMax ? __hmin(x, y) : __hmax(x, y)) // coverity[copy_constructor_call] SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 2, __nv_bfloat162, fn.isMinNotMax ? __hmin2(x, y) : __hmax2(x, y)) -#else +#elif ROCM_VERSION < 60000 SPECIALIZE_REDUCE(FuncSum, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) + (float)(y))) SPECIALIZE_REDUCE(FuncProd, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) * (float)(y))) SPECIALIZE_REDUCE(FuncMinMax, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y)))) diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index d42b43e3a1..0cf756974c 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -11,7 +11,13 @@ #include "nccl.h" #include "rccl_float8.h" -#include +#if ROCM_VERSION >= 60000 + // hip_bf16.h should be used from ROCm 6.0 + #include + typedef __hip_bfloat16 hip_bfloat16; +#else + #include +#endif #include "nccl_common.h" #include "bitops.h" #include "symmetric.h"