SWDEV-474146 - use __bf16 to do operations

Change-Id: I568dfa97238fd760f5362a8e560c33402f96cff3
Этот коммит содержится в:
Jatin Chaudhary
2024-12-05 00:21:36 +00:00
коммит произвёл Jatin Jaikishan Chaudhary
родитель e560d94d2c
Коммит c23913f6e7
+88 -185
Просмотреть файл
@@ -128,20 +128,6 @@
#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline
#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline
#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__)
// Enable with -mavx512vl -mavx512bf16
#if defined(__MINGW64__)
#include <intrin.h>
#else
#include <immintrin.h>
#endif
#define HIP_BF16_AVX512_OP 1
static_assert(sizeof(__bf16) == sizeof(unsigned short),
"sizeof __bf16 should match sizeof unsigned short");
#else
#define HIP_BF16_AVX512_OP 0
#endif
#define HIPRT_ONE_BF16 __ushort_as_bfloat16((unsigned short)0x3F80U)
#define HIPRT_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x0000U)
#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
@@ -157,6 +143,7 @@ static_assert(sizeof(__bf16) == sizeof(unsigned short),
static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
#endif
static_assert(sizeof(unsigned short) == 2, "size of unsigned short should be 2 bytes");
static_assert(sizeof(__bf16) == sizeof(unsigned short));
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_RAW
@@ -182,88 +169,13 @@ typedef struct __attribute__((aligned(4))) {
* @{
*/
struct __attribute__((aligned(2))) __hip_bfloat16 {
private:
__BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) {
#if HIP_BF16_AVX512_OP
union {
unsigned short us;
__bf16 bf16;
} u = {val};
return _mm_cvtsbh_ss(u.bf16);
#else
unsigned int uval = val << 16;
union {
unsigned int u32;
float fp32;
} u = {uval};
return u.fp32;
#endif
}
__BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) {
#if HIP_BF16_AVX512_OP
union {
__bf16 bf16;
unsigned short us;
} u = {_mm_cvtness_sbh(f)};
return u.us;
#else
union {
float fp32;
unsigned int u32;
} u = {f};
if (~u.u32 & 0x7f800000) {
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even
} else if (u.u32 & 0xffff) {
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.u32 |= 0x10000; // Preserve signaling NaN
}
return static_cast<unsigned short>(u.u32 >> 16);
#endif
}
__BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) {
union {
float fp32;
unsigned int u32;
} u = {static_cast<float>(d_in)};
double d = u.fp32;
// Round to odd
if ((d_in > 0.0 && d > d_in) || (d_in < 0.0 && d < d_in)) {
u.u32--;
u.u32 |= 1;
}
return float_2_bfloatraw(u.fp32);
}
protected:
/*! \brief raw representation of bfloat16 */
unsigned short __x;
union {
/*! \brief raw representation of bfloat16 */
unsigned short __x;
/*! \brief bf16 represenation */
__bf16 __x_bf16;
};
public:
// TODO: SWDEV-452411
@@ -275,30 +187,29 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
// Casting directly to double might lead to double rounding.
/*! \brief create __hip_bfloat16 from an unsigned int */
__BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val)
: __x(double_2_bfloatraw(static_cast<double>(val))) {}
__BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val) : __x_bf16(static_cast<__bf16>(val)) {}
/*! \brief create __hip_bfloat16 from a int */
__BF16_HOST_DEVICE__ __hip_bfloat16(int val)
: __x(double_2_bfloatraw(static_cast<double>(val))) {}
__BF16_HOST_DEVICE__ __hip_bfloat16(int val) : __x_bf16(static_cast<__bf16>(val)) {}
/*! \brief create __hip_bfloat16 from an unsigned short */
__BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val)
: __x(float_2_bfloatraw(static_cast<float>(val))) {}
__BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val) : __x_bf16(static_cast<__bf16>(val)) {}
/*! \brief create __hip_bfloat16 from a short */
__BF16_HOST_DEVICE__ __hip_bfloat16(short val)
: __x(float_2_bfloatraw(static_cast<float>(val))) {}
__BF16_HOST_DEVICE__ __hip_bfloat16(short val) : __x_bf16(static_cast<__bf16>(val)) {}
/*! \brief create __hip_bfloat16 from a double */
__BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x(double_2_bfloatraw(val)) {}
__BF16_HOST_DEVICE__ __hip_bfloat16(const double val) : __x_bf16(static_cast<__bf16>(val)) {}
/*! \brief create __hip_bfloat16 from a float */
__BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x(float_2_bfloatraw(val)) {}
__BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x_bf16(static_cast<__bf16>(val)) {}
/*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */
__BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {}
/*! \brief create __hip_bfloat16 from __bf16 */
__BF16_HOST_DEVICE__ __hip_bfloat16(const __bf16 val) : __x_bf16(val) {}
/*! \brief default constructor */
__BF16_HOST_DEVICE__ __hip_bfloat16() = default;
@@ -311,96 +222,89 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
}
/*! \brief return false if bfloat value is +0.0 or -0.0, returns true otherwise */
__BF16_HOST_DEVICE__ operator bool() const {
auto val = bfloatraw_2_float(__x);
return val != 0.0f && val != -0.0f;
}
__BF16_HOST_DEVICE__ operator bool() const { return __x_bf16 != 0.0f; }
/*! \brief return a casted char from underlying float val */
__BF16_HOST_DEVICE__ operator char() const { return static_cast<char>(bfloatraw_2_float(__x)); }
__BF16_HOST_DEVICE__ operator char() const { return static_cast<char>(__x_bf16); }
/*! \brief return a float */
__BF16_HOST_DEVICE__ operator float() const { return bfloatraw_2_float(__x); }
__BF16_HOST_DEVICE__ operator float() const { return static_cast<float>(__x_bf16); }
/*! \brief return a casted int casted from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator int() const { return static_cast<int>(bfloatraw_2_float(__x)); }
__BF16_HOST_DEVICE__ operator int() const { return static_cast<int>(__x_bf16); }
/*! \brief return a casted long casted from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator long() const { return static_cast<long>(bfloatraw_2_float(__x)); }
__BF16_HOST_DEVICE__ operator long() const { return static_cast<long>(__x_bf16); }
/*! \brief return a casted long long casted from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator long long() const {
return static_cast<long long>(bfloatraw_2_float(__x));
}
__BF16_HOST_DEVICE__ operator long long() const { return static_cast<long long>(__x_bf16); }
/*! \brief return a casted short casted from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator short() const { return static_cast<short>(bfloatraw_2_float(__x)); }
__BF16_HOST_DEVICE__ operator short() const { return static_cast<short>(__x_bf16); }
/*! \brief return a casted signed char from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator signed char() const {
return static_cast<signed char>(bfloatraw_2_float(__x));
}
__BF16_HOST_DEVICE__ operator signed char() const { return static_cast<signed char>(__x_bf16); }
/*! \brief return a casted unsigned char casted from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator unsigned char() const {
return static_cast<unsigned char>(bfloatraw_2_float(__x));
return static_cast<unsigned char>(__x_bf16);
}
/*! \brief return a casted unsigned int casted from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator unsigned int() const {
return static_cast<unsigned int>(bfloatraw_2_float(__x));
}
__BF16_HOST_DEVICE__ operator unsigned int() const { return static_cast<unsigned int>(__x_bf16); }
/*! \brief return a casted unsigned from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator unsigned long() const {
return static_cast<unsigned long>(bfloatraw_2_float(__x));
return static_cast<unsigned long>(__x_bf16);
}
/*! \brief return a casted unsigned long long from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator unsigned long long() const {
return static_cast<unsigned long long>(bfloatraw_2_float(__x));
return static_cast<unsigned long long>(__x_bf16);
}
/*! \brief return a casted unsigned short from float of underlying bfloat16 value */
__BF16_HOST_DEVICE__ operator unsigned short() const {
return static_cast<unsigned short>(bfloatraw_2_float(__x));
return static_cast<unsigned short>(__x_bf16);
}
__BF16_HOST_DEVICE__ operator __bf16() const { return __x_bf16; }
// TODO: SWDEV-452411 add operator which converts unsigned long long and long long to bfloat
/*! \brief assign value from an unsigned int */
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned int val) {
__x = float_2_bfloatraw(static_cast<float>(val));
__x_bf16 = static_cast<__bf16>(val);
return *this;
}
/*! \brief assign value from a int */
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(int val) {
__x = float_2_bfloatraw(static_cast<float>(val));
__x_bf16 = static_cast<__bf16>(val);
return *this;
}
/*! \brief assign value from an unsigned short */
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(unsigned short val) {
__x = float_2_bfloatraw(static_cast<float>(val));
__x_bf16 = static_cast<__bf16>(val);
return *this;
}
/*! \brief assign value from a short int */
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(short val) {
__x = float_2_bfloatraw(static_cast<float>(val));
__x_bf16 = static_cast<__bf16>(val);
return *this;
}
/*! \brief assign value from a double */
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const double f) {
__x = float_2_bfloatraw(static_cast<float>(f));
__x_bf16 = static_cast<__bf16>(f);
return *this;
}
/*! \brief assign value from a float */
__BF16_HOST_DEVICE__ __hip_bfloat16& operator=(const float f) {
__x = float_2_bfloatraw(f);
__x_bf16 = static_cast<__bf16>(f);
return *this;
}
@@ -425,6 +329,8 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
};
/**@}*/
typedef __bf16 __bf16_2 __attribute__((ext_vector_type(2)));
/**
* \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT
* \ingroup HIP_INTRINSIC_BFLOAT16
@@ -432,9 +338,16 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
* @{
*/
struct __attribute__((aligned(4))) __hip_bfloat162 {
static_assert(sizeof(__hip_bfloat16[2]) == sizeof(__bf16_2));
public:
__hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
__hip_bfloat16 y; /*! \brief raw representation of bfloat16 */
union {
struct {
__hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
__hip_bfloat16 y; /*! \brief raw representation of bfloat16 */
};
__bf16_2 __xy_bf162;
};
public:
@@ -450,6 +363,9 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b)
: x(a), y(b) {}
/*! \brief create __hip_bfloat162 from vector of __bf16_2 */
__BF16_HOST_DEVICE__ __hip_bfloat162(const __bf16_2 in) : __xy_bf162(in) {}
/*! \brief default constructor of __hip_bfloat162 */
__BF16_HOST_DEVICE__ __hip_bfloat162() = default;
@@ -462,22 +378,19 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
/*! \brief return a float2 */
__BF16_HOST_DEVICE__ operator float2() const {
#if HIP_BF16_AVX512_OP
union {
__hip_bfloat162_raw raw2;
__bf16 bf162[2];
static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw));
} u;
u.raw2 = *this;
__m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0};
__m128 pf32 = _mm_cvtpbh_ps(pbf16);
float2 ret(pf32[0], pf32[1]);
#else
float2 ret(x, y);
#endif
return ret;
}
/*! \brief return a vector of bf16 */
__BF16_HOST_DEVICE__ operator __bf16_2() const { return __xy_bf162; }
/*! \brief return a vector of bf16 */
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __bf16_2 in) {
__xy_bf162 = in;
return *this;
}
/*! \brief assign value from __hip_bfloat162_raw */
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) {
x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x});
@@ -785,14 +698,14 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_xor_sync(const unsigned long long
u.ui = __shfl_xor_sync(mask, u.ui, delta, width);
return u.bf162;
}
#endif
#endif // HIP_DISABLE_WARP_SYNC_BUILTINS
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Adds two bfloat16 values
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b));
return (__bf16)a + (__bf16)b;
}
/**
@@ -800,7 +713,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const
* \brief Subtracts two bfloat16 values
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b));
return (__bf16)a - (__bf16)b;
}
/**
@@ -808,7 +721,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const
* \brief Divides two bfloat16 values
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b));
return (__bf16)a / (__bf16)b;
}
/**
@@ -817,8 +730,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const
*/
__BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip_bfloat16 b,
const __hip_bfloat16 c) {
return __float2bfloat16(
__ocml_fma_f32(__bfloat162float(a), __bfloat162float(b), __bfloat162float(c)));
return __hip_bfloat16(__builtin_elementwise_fma(__bf16(a), __bf16(b), __bf16(c)));
;
}
/**
@@ -826,7 +739,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, const __hip
* \brief Multiplies two bfloat16 values
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b));
return (__bf16)a * (__bf16)b;
}
/**
@@ -855,8 +768,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __habs(const __hip_bfloat16 a) {
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
return __hip_bfloat162(__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)),
__float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y)));
return __hip_bfloat162{__bf16_2(a) / __bf16_2(b)};
}
/**
@@ -864,7 +776,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a,
* \brief Returns absolute of a bfloat162
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
return __hip_bfloat162(__habs(a.x), __habs(a.y));
return __hip_bfloat162{__habs(a.x), __habs(a.y)};
}
/**
@@ -873,7 +785,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) {
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
return __hip_bfloat162(__hadd(a.x, b.x), __hadd(a.y, b.y));
return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)};
}
/**
@@ -882,7 +794,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a,
*/
__BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __hip_bfloat162 b,
const __hip_bfloat162 c) {
return __hip_bfloat162(__hfma(a.x, b.x, c.x), __hfma(a.y, b.y, c.y));
return __hip_bfloat162{__builtin_elementwise_fma(__bf16_2(a), __bf16_2(b), __bf16_2(c))};
}
/**
@@ -891,7 +803,7 @@ __BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, const __
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
return __hip_bfloat162(__hmul(a.x, b.x), __hmul(a.y, b.y));
return __hip_bfloat162{__bf16_2(a) * __bf16_2(b)};
}
/**
@@ -899,7 +811,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a,
* \brief Converts a bfloat162 into negative
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
return __hip_bfloat162(__hneg(a.x), __hneg(a.y));
return __hip_bfloat162{__hneg(a.x), __hneg(a.y)};
}
/**
@@ -908,7 +820,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) {
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
return __hip_bfloat162(__hsub(a.x, b.x), __hsub(a.y, b.y));
return __hip_bfloat162{__bf16_2(a) - __bf16_2(b)};
}
/**
@@ -1166,7 +1078,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162& operator/=(__hip_bfloat162& l,
* \brief Compare two bfloat162 values
*/
__BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __bfloat162float(a) == __bfloat162float(b);
return (__bf16)a == (__bf16)b;
}
/**
@@ -1174,8 +1086,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, const __hip_bfloa
* \brief Compare two bfloat162 values - unordered equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return !(__bfloat162float(a) < __bfloat162float(b)) &&
!(__bfloat162float(a) > __bfloat162float(b));
return !((__bf16)a < (__bf16)b) && !((__bf16)a > (__bf16)b);
}
/**
@@ -1183,7 +1094,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, const __hip_bflo
* \brief Compare two bfloat162 values - greater than
*/
__BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __bfloat162float(a) > __bfloat162float(b);
return (__bf16)a > (__bf16)b;
}
/**
@@ -1191,7 +1102,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloa
* \brief Compare two bfloat162 values - unordered greater than
*/
__BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return !(__bfloat162float(a) <= __bfloat162float(b));
return !((__bf16)a <= (__bf16)b);
}
/**
@@ -1199,7 +1110,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, const __hip_bflo
* \brief Compare two bfloat162 values - greater than equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __bfloat162float(a) >= __bfloat162float(b);
return (__bf16)a >= (__bf16)b;
}
/**
@@ -1207,7 +1118,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, const __hip_bfloa
* \brief Compare two bfloat162 values - unordered greater than equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return !(__bfloat162float(a) < __bfloat162float(b));
return !((__bf16)a < (__bf16)b);
}
/**
@@ -1215,7 +1126,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, const __hip_bflo
* \brief Compare two bfloat162 values - not equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __bfloat162float(a) != __bfloat162float(b);
return (__bf16)a != (__bf16)b;
}
/**
@@ -1223,7 +1134,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, const __hip_bfloa
* \brief Compare two bfloat162 values - unordered not equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return !(__bfloat162float(a) == __bfloat162float(b));
return !((__bf16)a == (__bf16)b);
}
/**
@@ -1231,11 +1142,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, const __hip_bflo
* \brief Compare two bfloat162 values - return max
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) {
#if __HIP_DEVICE_COMPILE__
return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b)));
#else
return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b)));
#endif
return (__bf16)a > (__bf16)b ? a : b;
}
/**
@@ -1243,11 +1150,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const
* \brief Compare two bfloat162 values - return min
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) {
#if __HIP_DEVICE_COMPILE__
return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b)));
#else
return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b)));
#endif
return (__bf16)a < (__bf16)b ? a : b;
}
/**
@@ -1255,7 +1158,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const
* \brief Compare two bfloat162 values - less than operator
*/
__BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __bfloat162float(a) < __bfloat162float(b);
return (__bf16)a < (__bf16)b;
}
/**
@@ -1263,7 +1166,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloa
* \brief Compare two bfloat162 values - unordered less than
*/
__BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return !(__bfloat162float(a) >= __bfloat162float(b));
return !((__bf16)a >= (__bf16)b);
}
/**
@@ -1271,7 +1174,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, const __hip_bflo
* \brief Compare two bfloat162 values - less than equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return __bfloat162float(a) <= __bfloat162float(b);
return (__bf16)a <= (__bf16)b;
}
/**
@@ -1279,7 +1182,7 @@ __BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, const __hip_bfloa
* \brief Compare two bfloat162 values - unordered less than equal
*/
__BF16_HOST_DEVICE_STATIC__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) {
return !(__bfloat162float(a) > __bfloat162float(b));
return !((__bf16)a > (__bf16)b);
}
/**