diff --git a/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/hipamd/include/hip/amd_detail/amd_hip_bf16.h index 1042d459ba..6426cc1610 100644 --- a/hipamd/include/hip/amd_detail/amd_hip_bf16.h +++ b/hipamd/include/hip/amd_detail/amd_hip_bf16.h @@ -128,20 +128,6 @@ #define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline #define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline -#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__) -// Enable with -mavx512vl -mavx512bf16 -#if defined(__MINGW64__) -#include -#else -#include -#endif -#define HIP_BF16_AVX512_OP 1 -static_assert(sizeof(__bf16) == sizeof(unsigned short), - "sizeof __bf16 should match sizeof unsigned short"); -#else -#define HIP_BF16_AVX512_OP 0 -#endif - #define HIPRT_ONE_BF16 __ushort_as_bfloat16((unsigned short)0x3F80U) #define HIPRT_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x0000U) #define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) @@ -157,6 +143,7 @@ static_assert(sizeof(__bf16) == sizeof(unsigned short), static_assert(CHAR_BIT == 8, "byte size should be of 8 bits"); #endif static_assert(sizeof(unsigned short) == 2, "size of unsigned short should be 2 bytes"); +static_assert(sizeof(__bf16) == sizeof(unsigned short)); /** * \ingroup HIP_INTRINSIC_BFLOAT16_RAW @@ -182,88 +169,13 @@ typedef struct __attribute__((aligned(4))) { * @{ */ struct __attribute__((aligned(2))) __hip_bfloat16 { - private: - __BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) { -#if HIP_BF16_AVX512_OP - union { - unsigned short us; - __bf16 bf16; - } u = {val}; - return _mm_cvtsbh_ss(u.bf16); -#else - unsigned int uval = val << 16; - union { - unsigned int u32; - float fp32; - } u = {uval}; - return u.fp32; -#endif - } - - __BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) { -#if HIP_BF16_AVX512_OP - union { - __bf16 bf16; - unsigned short us; - } u = {_mm_cvtness_sbh(f)}; - return u.us; -#else - union { - float fp32; - unsigned int u32; - } u = {f}; - if (~u.u32 & 0x7f800000) { - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even - } else if (u.u32 & 0xffff) { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - u.u32 |= 0x10000; // Preserve signaling NaN - } - return static_cast(u.u32 >> 16); -#endif - } - - __BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) { - union { - float fp32; - unsigned int u32; - } u = {static_cast(d_in)}; - double d = u.fp32; - - // Round to odd - if ((d_in > 0.0 && d > d_in) || (d_in < 0.0 && d < d_in)) { - u.u32--; - u.u32 |= 1; - } - - return float_2_bfloatraw(u.fp32); - } - protected: - /*! \brief raw representation of bfloat16 */ - unsigned short __x; + union { + /*! \brief raw representation of bfloat16 */ + unsigned short __x; + /*! \brief bf16 represenation */ + __bf16 __x_bf16; + }; public: // TODO: SWDEV-452411 @@ -275,30 +187,29 @@ struct __attribute__((aligned(2))) __hip_bfloat16 { // Casting directly to double might lead to double rounding. /*! \brief create __hip_bfloat16 from an unsigned int */ - __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val) - : __x(double_2_bfloatraw(static_cast(val))) {} + __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val) : __x_bf16(static_cast<__bf16>(val)) {} /*! \brief create __hip_bfloat16 from a int */ - __BF16_HOST_DEVICE__ __hip_bfloat16(int val) - : __x(double_2_bfloatraw(static_cast(val))) {} + __BF16_HOST_DEVICE__ __hip_bfloat16(int val) : __x_bf16(static_cast<__bf16>(val)) {} /*! \brief create __hip_bfloat16 from an unsigned short */ - __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val) - : __x(float_2_bfloatraw(static_cast(val))) {} + __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val) : __x_bf16(static_cast<__bf16>(val)) {} /*! \brief create __hip_bfloat16 from a short */ - __BF16_HOST_DEVICE__ __hip_bfloat16(short val) - : __x(float_2_bfloatraw(static_cast(val))) {} + __BF16_HOST_DEVICE__ __hip_bfloat16(short val) : __x_bf16(static_cast<__bf16>(val)) {} /*! \brief create __hip_bfloat16 from a double */ - __BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x(double_2_bfloatraw(val)) {} + __BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x_bf16(static_cast<__bf16>(val)) {} /*! \brief create __hip_bfloat16 from a float */ - __BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x(float_2_bfloatraw(val)) {} + __BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x_bf16(static_cast<__bf16>(val)) {} /*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */ __BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {} + /*! \brief create __hip_bfloat16 from __bf16 */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const __bf16 val) : __x_bf16(val) {} + /*! \brief default constructor */ __BF16_HOST_DEVICE__ __hip_bfloat16() = default; @@ -311,96 +222,89 @@ struct __attribute__((aligned(2))) __hip_bfloat16 { } /*! \brief return false if bfloat value is +0.0 or -0.0, returns true otherwise */ - __BF16_HOST_DEVICE__ operator bool() const { - auto val = bfloatraw_2_float(__x); - return val != 0.0f && val != -0.0f; - } + __BF16_HOST_DEVICE__ operator bool() const { return __x_bf16 != 0.0f; } /*! \brief return a casted char from underlying float val */ - __BF16_HOST_DEVICE__ operator char() const { return static_cast(bfloatraw_2_float(__x)); } + __BF16_HOST_DEVICE__ operator char() const { return static_cast(__x_bf16); } /*! \brief return a float */ - __BF16_HOST_DEVICE__ operator float() const { return bfloatraw_2_float(__x); } + __BF16_HOST_DEVICE__ operator float() const { return static_cast(__x_bf16); } /*! \brief return a casted int casted from float of underlying bfloat16 value */ - __BF16_HOST_DEVICE__ operator int() const { return static_cast(bfloatraw_2_float(__x)); } + __BF16_HOST_DEVICE__ operator int() const { return static_cast(__x_bf16); } /*! \brief return a casted long casted from float of underlying bfloat16 value */ - __BF16_HOST_DEVICE__ operator long() const { return static_cast(bfloatraw_2_float(__x)); } + __BF16_HOST_DEVICE__ operator long() const { return static_cast(__x_bf16); } /*! \brief return a casted long long casted from float of underlying bfloat16 value */ - __BF16_HOST_DEVICE__ operator long long() const { - return static_cast(bfloatraw_2_float(__x)); - } + __BF16_HOST_DEVICE__ operator long long() const { return static_cast(__x_bf16); } /*! \brief return a casted short casted from float of underlying bfloat16 value */ - __BF16_HOST_DEVICE__ operator short() const { return static_cast(bfloatraw_2_float(__x)); } + __BF16_HOST_DEVICE__ operator short() const { return static_cast(__x_bf16); } /*! \brief return a casted signed char from float of underlying bfloat16 value */ - __BF16_HOST_DEVICE__ operator signed char() const { - return static_cast(bfloatraw_2_float(__x)); - } + __BF16_HOST_DEVICE__ operator signed char() const { return static_cast(__x_bf16); } /*! \brief return a casted unsigned char casted from float of underlying bfloat16 value */ __BF16_HOST_DEVICE__ operator unsigned char() const { - return static_cast(bfloatraw_2_float(__x)); + return static_cast(__x_bf16); } /*! \brief return a casted unsigned int casted from float of underlying bfloat16 value */ - __BF16_HOST_DEVICE__ operator unsigned int() const { - return static_cast(bfloatraw_2_float(__x)); - } + __BF16_HOST_DEVICE__ operator unsigned int() const { return static_cast(__x_bf16); } /*! \brief return a casted unsigned from float of underlying bfloat16 value */ __BF16_HOST_DEVICE__ operator unsigned long() const { - return static_cast(bfloatraw_2_float(__x)); + return static_cast(__x_bf16); } /*! \brief return a casted unsigned long long from float of underlying bfloat16 value */ __BF16_HOST_DEVICE__ operator unsigned long long() const { - return static_cast(bfloatraw_2_float(__x)); + return static_cast(__x_bf16); } /*! \brief return a casted unsigned short from float of underlying bfloat16 value */ __BF16_HOST_DEVICE__ operator unsigned short() const { - return static_cast(bfloatraw_2_float(__x)); + return static_cast(__x_bf16); } + __BF16_HOST_DEVICE__ operator __bf16() const { return __x_bf16; } + // TODO: SWDEV-452411 add operator which converts unsigned long long and long long to bfloat /*! \brief assign value from an unsigned int */ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned int val) { - __x = float_2_bfloatraw(static_cast(val)); + __x_bf16 = static_cast<__bf16>(val); return *this; } /*! \brief assign value from a int */ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(int val) { - __x = float_2_bfloatraw(static_cast(val)); + __x_bf16 = static_cast<__bf16>(val); return *this; } /*! \brief assign value from an unsigned short */ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned short val) { - __x = float_2_bfloatraw(static_cast(val)); + __x_bf16 = static_cast<__bf16>(val); return *this; } /*! \brief assign value from a short int */ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(short val) { - __x = float_2_bfloatraw(static_cast(val)); + __x_bf16 = static_cast<__bf16>(val); return *this; } /*! \brief assign value from a double */ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const double f) { - __x = float_2_bfloatraw(static_cast(f)); + __x_bf16 = static_cast<__bf16>(f); return *this; } /*! \brief assign value from a float */ __BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const float f) { - __x = float_2_bfloatraw(f); + __x_bf16 = static_cast<__bf16>(f); return *this; } @@ -425,6 +329,8 @@ struct __attribute__((aligned(2))) __hip_bfloat16 { }; /**@}*/ +typedef __bf16 __bf16_2 __attribute__((ext_vector_type(2))); + /** * \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT * \ingroup HIP_INTRINSIC_BFLOAT16 @@ -432,9 +338,16 @@ struct __attribute__((aligned(2))) __hip_bfloat16 { * @{ */ struct __attribute__((aligned(4))) __hip_bfloat162 { + static_assert(sizeof(__hip_bfloat16[2]) == sizeof(__bf16_2)); + public: - __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */ - __hip_bfloat16 y; /*! \brief raw representation of bfloat16 */ + union { + struct { + __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */ + __hip_bfloat16 y; /*! \brief raw representation of bfloat16 */ + }; + __bf16_2 __xy_bf162; + }; public: @@ -450,6 +363,9 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b) : x(a), y(b) {} + /*! \brief create __hip_bfloat162 from vector of __bf16_2 */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __bf16_2 in) : __xy_bf162(in) {} + /*! \brief default constructor of __hip_bfloat162 */ __BF16_HOST_DEVICE__ __hip_bfloat162() = default; @@ -462,22 +378,19 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { /*! \brief return a float2 */ __BF16_HOST_DEVICE__ operator float2() const { -#if HIP_BF16_AVX512_OP - union { - __hip_bfloat162_raw raw2; - __bf16 bf162[2]; - static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw)); - } u; - u.raw2 = *this; - __m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0}; - __m128 pf32 = _mm_cvtpbh_ps(pbf16); - float2 ret(pf32[0], pf32[1]); -#else float2 ret(x, y); -#endif return ret; } + /*! \brief return a vector of bf16 */ + __BF16_HOST_DEVICE__ operator __bf16_2() const { return __xy_bf162; } + +/*! \brief return a vector of bf16 */ + __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __bf16_2 in) { + __xy_bf162 = in; + return *this; + } + /*! \brief assign value from __hip_bfloat162_raw */ __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) { x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x}); @@ -785,14 +698,14 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_xor_sync(const unsigned long long u.ui = __shfl_xor_sync(mask, u.ui, delta, width); return u.bf162; } -#endif +#endif // HIP_DISABLE_WARP_SYNC_BUILTINS /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Adds two bfloat16 values */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); + return (__bf16)a + (__bf16)b; } /** @@ -800,7 +713,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const * \brief Subtracts two bfloat16 values */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); + return (__bf16)a - (__bf16)b; } /** @@ -808,7 +721,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const * \brief Divides two bfloat16 values */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b)); + return (__bf16)a / (__bf16)b; } /** @@ -817,8 +730,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const */ __BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b, const __hip_bfloat16 c) { - return __float2bfloat16( - __ocml_fma_f32(__bfloat162float(a), __bfloat162float(b), __bfloat162float(c))); + return __hip_bfloat16(__builtin_elementwise_fma(__bf16(a), __bf16(b), __bf16(c))); + ; } /** @@ -826,7 +739,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip * \brief Multiplies two bfloat16 values */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); + return (__bf16)a * (__bf16)b; } /** @@ -855,8 +768,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162(__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)), - __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))); + return __hip_bfloat162{__bf16_2(a) / __bf16_2(b)}; } /** @@ -864,7 +776,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, * \brief Returns absolute of a bfloat162 */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { - return __hip_bfloat162(__habs(a.x), __habs(a.y)); + return __hip_bfloat162{__habs(a.x), __habs(a.y)}; } /** @@ -873,7 +785,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162(__hadd(a.x, b.x), __hadd(a.y, b.y)); + return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)}; } /** @@ -882,7 +794,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, */ __BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b, const __hip_bfloat162 c) { - return __hip_bfloat162(__hfma(a.x, b.x, c.x), __hfma(a.y, b.y, c.y)); + return __hip_bfloat162{__builtin_elementwise_fma(__bf16_2(a), __bf16_2(b), __bf16_2(c))}; } /** @@ -891,7 +803,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __ */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162(__hmul(a.x, b.x), __hmul(a.y, b.y)); + return __hip_bfloat162{__bf16_2(a) * __bf16_2(b)}; } /** @@ -899,7 +811,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, * \brief Converts a bfloat162 into negative */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { - return __hip_bfloat162(__hneg(a.x), __hneg(a.y)); + return __hip_bfloat162{__hneg(a.x), __hneg(a.y)}; } /** @@ -908,7 +820,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { - return __hip_bfloat162(__hsub(a.x, b.x), __hsub(a.y, b.y)); + return __hip_bfloat162{__bf16_2(a) - __bf16_2(b)}; } /** @@ -1166,7 +1078,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator/=(__hip_bfloat162& l, * \brief Compare two bfloat162 values */ __BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __bfloat162float(a) == __bfloat162float(b); + return (__bf16)a == (__bf16)b; } /** @@ -1174,8 +1086,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloa * \brief Compare two bfloat162 values - unordered equal */ __BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return !(__bfloat162float(a) < __bfloat162float(b)) && - !(__bfloat162float(a) > __bfloat162float(b)); + return !((__bf16)a < (__bf16)b) && !((__bf16)a > (__bf16)b); } /** @@ -1183,7 +1094,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bflo * \brief Compare two bfloat162 values - greater than */ __BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __bfloat162float(a) > __bfloat162float(b); + return (__bf16)a > (__bf16)b; } /** @@ -1191,7 +1102,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloa * \brief Compare two bfloat162 values - unordered greater than */ __BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return !(__bfloat162float(a) <= __bfloat162float(b)); + return !((__bf16)a <= (__bf16)b); } /** @@ -1199,7 +1110,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bflo * \brief Compare two bfloat162 values - greater than equal */ __BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __bfloat162float(a) >= __bfloat162float(b); + return (__bf16)a >= (__bf16)b; } /** @@ -1207,7 +1118,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloa * \brief Compare two bfloat162 values - unordered greater than equal */ __BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return !(__bfloat162float(a) < __bfloat162float(b)); + return !((__bf16)a < (__bf16)b); } /** @@ -1215,7 +1126,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bflo * \brief Compare two bfloat162 values - not equal */ __BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __bfloat162float(a) != __bfloat162float(b); + return (__bf16)a != (__bf16)b; } /** @@ -1223,7 +1134,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloa * \brief Compare two bfloat162 values - unordered not equal */ __BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return !(__bfloat162float(a) == __bfloat162float(b)); + return !((__bf16)a == (__bf16)b); } /** @@ -1231,11 +1142,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bflo * \brief Compare two bfloat162 values - return max */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { -#if __HIP_DEVICE_COMPILE__ - return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b))); -#else - return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b))); -#endif + return (__bf16)a > (__bf16)b ? a : b; } /** @@ -1243,11 +1150,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const * \brief Compare two bfloat162 values - return min */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { -#if __HIP_DEVICE_COMPILE__ - return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b))); -#else - return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b))); -#endif + return (__bf16)a < (__bf16)b ? a : b; } /** @@ -1255,7 +1158,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const * \brief Compare two bfloat162 values - less than operator */ __BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __bfloat162float(a) < __bfloat162float(b); + return (__bf16)a < (__bf16)b; } /** @@ -1263,7 +1166,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloa * \brief Compare two bfloat162 values - unordered less than */ __BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return !(__bfloat162float(a) >= __bfloat162float(b)); + return !((__bf16)a >= (__bf16)b); } /** @@ -1271,7 +1174,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bflo * \brief Compare two bfloat162 values - less than equal */ __BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return __bfloat162float(a) <= __bfloat162float(b); + return (__bf16)a <= (__bf16)b; } /** @@ -1279,7 +1182,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloa * \brief Compare two bfloat162 values - unordered less than equal */ __BF16_HOST_DEVICE_STATIC__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { - return !(__bfloat162float(a) > __bfloat162float(b)); + return !((__bf16)a > (__bf16)b); } /**