SWDEV-571222 - Fix bf16 headers on gcc (#2260)

GCC does not support anonymous structs with members that have non-trivial constructors. This commit changes the header to remove the union when compiling with gcc. This should be a non-breaking change for other compilers.
This commit is contained in:
Fábio Mestre
2026-01-16 15:02:48 +00:00
zatwierdzone przez GitHub
rodzic 7794ac9ac6
commit e6236417f7
@@ -111,9 +111,8 @@
#if !defined(__HIPCC_RTC__)
#include <hip/amd_detail/amd_hip_common.h>
#include "amd_hip_vector_types.h" // float2 etc
#include "device_library_decls.h" // ocml conversion functions
#include "math_fwd.h" // ocml device functions
#if defined(__clang__) && defined(__HIP__)
#include "math_fwd.h" // ocml device functions
#include <hip/amd_detail/amd_warp_functions.h> // define warpSize
#include <hip/amd_detail/amd_warp_sync_functions.h> // Sync functions
#endif
@@ -338,7 +337,11 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
};
/**@}*/
#if defined(__clang__)
typedef __bf16 __bf16_2 __attribute__((ext_vector_type(2)));
#else
typedef __bf16 __bf16_2 __attribute__((vector_size(sizeof(__bf16) * 2)));
#endif
/**
* \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT
@@ -350,6 +353,7 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
static_assert(sizeof(__hip_bfloat16[2]) == sizeof(__bf16_2));
public:
#if defined(__clang__)
union {
struct {
__hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
@@ -357,7 +361,12 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
};
__bf16_2 __xy_bf162;
};
#else
/* GCC does not support anonymous structs with members that have non-trivial constructors (Clang
allows this as an extension). Expose x and y directly instead. */
__hip_bfloat16 x;
__hip_bfloat16 y;
#endif
public:
/*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */
@@ -373,7 +382,11 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
: x(a), y(b) {}
/*! \brief create __hip_bfloat162 from vector of __bf16_2 */
#if defined(__clang__)
__BF16_HOST_DEVICE__ __hip_bfloat162(const __bf16_2 in) : __xy_bf162(in) {}
#else
__BF16_HOST_DEVICE__ __hip_bfloat162(const __bf16_2 in) : x{in[0]}, y{in[1]} {}
#endif
/*! \brief default constructor of __hip_bfloat162 */
__BF16_HOST_DEVICE__ __hip_bfloat162() = default;
@@ -392,11 +405,22 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
}
/*! \brief return a vector of bf16 */
__BF16_HOST_DEVICE__ operator __bf16_2() const { return __xy_bf162; }
__BF16_HOST_DEVICE__ operator __bf16_2() const {
#if defined(__clang__)
return __xy_bf162;
#else
return __bf16_2{x, y};
#endif
}
/*! \brief return a vector of bf16 */
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __bf16_2 in) {
#if defined(__clang__)
__xy_bf162 = in;
#else
x = __hip_bfloat16{in[0]};
y = __hip_bfloat16{in[1]};
#endif
return *this;
}
@@ -835,6 +859,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const
return (__bf16)a / (__bf16)b;
}
#if defined(__clang__) && defined(__HIP__)
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Performs FMA of given bfloat16 values
@@ -844,6 +869,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip
return __hip_bfloat16(__builtin_elementwise_fma(__bf16(a), __bf16(b), __bf16(c)));
;
}
#endif
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
@@ -919,6 +945,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2_rn(const __hip_bfloat162 a,
return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)};
}
#if defined(__clang__) && defined(__HIP__)
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
* \brief Performs FMA of given bfloat162 values
@@ -927,6 +955,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __
const __hip_bfloat162 c) {
return __hip_bfloat162{__builtin_elementwise_fma(__bf16_2(a), __bf16_2(b), __bf16_2(c))};
}
#endif
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
@@ -1639,6 +1668,7 @@ __BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat162& l, const __hi
return fl.x >= fr.x && fl.x >= fr.y;
}
#if defined(__clang__) && defined(__HIP__)
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_MATH
* \brief Calculate ceil of bfloat16
@@ -1883,7 +1913,6 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) {
return __hip_bfloat162(htrunc(h.x), htrunc(h.y));
}
#if defined(__clang__) && defined(__HIP__)
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_MATH
* \brief Atomic add bfloat162