SWDEV-379007 - use avx instruction for bf16 cvt

AMD CPUs have had avx512_bf16 support for quite some time now (from
consumer Ryzen 7000 series to enterprise grade CPUs). This
patch should allow users to use the hardware bf16 unit when running the
__host__ variants of the function. This can be enabled via `hipcc ...
-mavx512vl -mavx512bf16`.

Change-Id: I67c377afc95ddfe8d45a048dce078a247d4a1878


[ROCm/clr commit: 49349f168c]
This commit is contained in:
Jatin Chaudhary
2024-04-12 00:16:11 +01:00
کامیت شده توسط Jatin Jaikishan Chaudhary
والد 7f195e2996
کامیت 80f02a1534
@@ -116,6 +116,20 @@
#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline
#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline
#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__)
// Enable with -mavx512vl -mavx512bf16
#if defined(__MINGW64__)
#include <intrin.h>
#else
#include <immintrin.h>
#endif
#define HIP_BF16_AVX512_OP 1
static_assert(sizeof(__bf16) == sizeof(unsigned short),
"sizeof __bf16 should match sizeof unsigned short");
#else
#define HIP_BF16_AVX512_OP 0
#endif
#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f)
#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
@@ -140,7 +154,6 @@ typedef struct __attribute__((aligned(2))) {
unsigned short x;
} __hip_bfloat16_raw;
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_RAW
* \brief represents raw bfloat16x2 vector type
@@ -159,14 +172,29 @@ typedef struct __attribute__((aligned(4))) {
struct __attribute__((aligned(2))) __hip_bfloat16 {
private:
__BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) {
#if HIP_BF16_AVX512_OP
union {
unsigned short us;
__bf16 bf16;
} u = {val};
return _mm_cvtsbh_ss(u.bf16);
#else
unsigned int uval = val << 16;
union {
unsigned int u32;
float fp32;
} u = {uval};
return u.fp32;
#endif
}
__BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) {
#if HIP_BF16_AVX512_OP
union {
__bf16 bf16;
unsigned short us;
} u = {_mm_cvtness_sbh(f)};
return u.us;
#else
union {
float fp32;
unsigned int u32;
@@ -201,7 +229,9 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
u.u32 |= 0x10000; // Preserve signaling NaN
}
return static_cast<unsigned short>(u.u32 >> 16);
#endif
}
__BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) {
union {
float fp32;
@@ -393,6 +423,7 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
__hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
__hip_bfloat16 y; /*! \brief raw representation of bfloat16 */
public:
/*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw& h2r)
@@ -422,22 +453,34 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
/*! \brief return a float2 */
__BF16_HOST_DEVICE__ operator float2() const {
#if HIP_BF16_AVX512_OP
union {
__hip_bfloat162_raw raw2;
__bf16 bf162[2];
static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw));
} u;
u.raw2 = *this;
__m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0};
__m128 pf32 = _mm_cvtpbh_ps(pbf16);
float2 ret(pf32[0], pf32[1]);
#else
float2 ret(x, y);
#endif
return ret;
}
/*! \brief assign value from __hip_bfloat162_raw */
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) {
x = __hip_bfloat16(__hip_bfloat16{h2r.x});
y = __hip_bfloat16(__hip_bfloat16{h2r.y});
x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x});
y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y});
return *this;
}
/*! \brief assign value from __hip_bfloat162 */
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162& src) {
__hip_bfloat162_raw hr = src;
x = __hip_bfloat16(__hip_bfloat16{hr.x});
y = __hip_bfloat16(__hip_bfloat16{hr.y});
x = __hip_bfloat16(__hip_bfloat16_raw{hr.x});
y = __hip_bfloat16(__hip_bfloat16_raw{hr.y});
return *this;
}
};
@@ -466,9 +509,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) {
* \brief Converts and moves bfloat162 to float2
*/
__BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
__hip_bfloat162_raw hr = a;
return float2{__bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x})),
__bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y}))};
float2 ret = a;
return ret;
}
/**
@@ -576,7 +618,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
__hip_bfloat162_raw hr = a;
return __hip_bfloat162(__hip_bfloat162{hr.x, hr.x});
return __hip_bfloat162(hr.x, hr.x);
}
/**