diff --git a/projects/rccl/src/include/device.h b/projects/rccl/src/include/device.h index 6944f3921b..f722f1fe21 100644 --- a/projects/rccl/src/include/device.h +++ b/projects/rccl/src/include/device.h @@ -12,9 +12,16 @@ #include "nccl.h" #include "rccl_float8.h" #if ROCM_VERSION >= 60000 - // hip_bf16.h should be used from ROCm 6.0 - #include - typedef __hip_bfloat16 hip_bfloat16; + // This is a workaround for the fact that the old hip_bfloat16.h header file may still be used by some rocm files. + // The _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_ and _HIP_BFLOAT16_H_ macros are defined in the old hip_bfloat16.h header + #if !defined(_HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_) && !defined(_HIP_BFLOAT16_H_) + #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_ + #define _HIP_BFLOAT16_H_ + #include + typedef __hip_bfloat16 hip_bfloat16; + #else + #error "RCCL is not using the correct hip_bf16.h file. Please make sure that the correct header is included!" + #endif #else #include #endif