diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h index 132abf713d..8d07119f27 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h @@ -111,9 +111,8 @@ #if !defined(__HIPCC_RTC__) #include #include "amd_hip_vector_types.h" // float2 etc -#include "device_library_decls.h" // ocml conversion functions -#include "math_fwd.h" // ocml device functions #if defined(__clang__) && defined(__HIP__) +#include "math_fwd.h" // ocml device functions #include // define warpSize #include // Sync functions #endif @@ -338,7 +337,11 @@ struct __attribute__((aligned(2))) __hip_bfloat16 { }; /**@}*/ +#if defined(__clang__) typedef __bf16 __bf16_2 __attribute__((ext_vector_type(2))); +#else +typedef __bf16 __bf16_2 __attribute__((vector_size(sizeof(__bf16) * 2))); +#endif /** * \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT @@ -350,6 +353,7 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { static_assert(sizeof(__hip_bfloat16[2]) == sizeof(__bf16_2)); public: +#if defined(__clang__) union { struct { __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */ @@ -357,7 +361,12 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { }; __bf16_2 __xy_bf162; }; - +#else + /* GCC does not support anonymous structs with members that have non-trivial constructors (Clang + allows this as an extension). Expose x and y directly instead. */ + __hip_bfloat16 x; + __hip_bfloat16 y; +#endif public: /*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */ @@ -373,7 +382,11 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { : x(a), y(b) {} /*! \brief create __hip_bfloat162 from vector of __bf16_2 */ +#if defined(__clang__) __BF16_HOST_DEVICE__ __hip_bfloat162(const __bf16_2 in) : __xy_bf162(in) {} +#else + __BF16_HOST_DEVICE__ __hip_bfloat162(const __bf16_2 in) : x{in[0]}, y{in[1]} {} +#endif /*! \brief default constructor of __hip_bfloat162 */ __BF16_HOST_DEVICE__ __hip_bfloat162() = default; @@ -392,11 +405,22 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { } /*! \brief return a vector of bf16 */ - __BF16_HOST_DEVICE__ operator __bf16_2() const { return __xy_bf162; } + __BF16_HOST_DEVICE__ operator __bf16_2() const { +#if defined(__clang__) + return __xy_bf162; +#else + return __bf16_2{x, y}; +#endif + } /*! \brief return a vector of bf16 */ __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __bf16_2 in) { +#if defined(__clang__) __xy_bf162 = in; +#else + x = __hip_bfloat16{in[0]}; + y = __hip_bfloat16{in[1]}; +#endif return *this; } @@ -835,6 +859,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const return (__bf16)a / (__bf16)b; } +#if defined(__clang__) && defined(__HIP__) /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Performs FMA of given bfloat16 values @@ -844,6 +869,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip return __hip_bfloat16(__builtin_elementwise_fma(__bf16(a), __bf16(b), __bf16(c))); ; } +#endif /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH @@ -919,6 +945,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2_rn(const __hip_bfloat162 a, return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)}; } + +#if defined(__clang__) && defined(__HIP__) /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH * \brief Performs FMA of given bfloat162 values @@ -927,6 +955,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __ const __hip_bfloat162 c) { return __hip_bfloat162{__builtin_elementwise_fma(__bf16_2(a), __bf16_2(b), __bf16_2(c))}; } +#endif /** * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH @@ -1639,6 +1668,7 @@ __BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat162& l, const __hi return fl.x >= fr.x && fl.x >= fr.y; } +#if defined(__clang__) && defined(__HIP__) /** * \ingroup HIP_INTRINSIC_BFLOAT16_MATH * \brief Calculate ceil of bfloat16 @@ -1883,7 +1913,6 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { return __hip_bfloat162(htrunc(h.x), htrunc(h.y)); } -#if defined(__clang__) && defined(__HIP__) /** * \ingroup HIP_INTRINSIC_BFLOAT162_MATH * \brief Atomic add bfloat162