SWDEV-379007 - use avx instruction for bf16 cvt
AMD CPUs have had avx512_bf16 support for quite some time now (from
consumer Ryzen 7000 series to enterprise grade CPUs). This
patch should allow users to use the hardware bf16 unit when running the
__host__ variants of the function. This can be enabled via `hipcc ...
-mavx512vl -mavx512bf16`.
Change-Id: I67c377afc95ddfe8d45a048dce078a247d4a1878
[ROCm/clr commit: 49349f168c]
This commit is contained in:
کامیت شده توسط
Jatin Jaikishan Chaudhary
والد
7f195e2996
کامیت
80f02a1534
@@ -116,6 +116,20 @@
|
||||
#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline
|
||||
#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline
|
||||
|
||||
#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__)
|
||||
// Enable with -mavx512vl -mavx512bf16
|
||||
#if defined(__MINGW64__)
|
||||
#include <intrin.h>
|
||||
#else
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
#define HIP_BF16_AVX512_OP 1
|
||||
static_assert(sizeof(__bf16) == sizeof(unsigned short),
|
||||
"sizeof __bf16 should match sizeof unsigned short");
|
||||
#else
|
||||
#define HIP_BF16_AVX512_OP 0
|
||||
#endif
|
||||
|
||||
#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
|
||||
#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f)
|
||||
#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
|
||||
@@ -140,7 +154,6 @@ typedef struct __attribute__((aligned(2))) {
|
||||
unsigned short x;
|
||||
} __hip_bfloat16_raw;
|
||||
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_RAW
|
||||
* \brief represents raw bfloat16x2 vector type
|
||||
@@ -159,14 +172,29 @@ typedef struct __attribute__((aligned(4))) {
|
||||
struct __attribute__((aligned(2))) __hip_bfloat16 {
|
||||
private:
|
||||
__BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) {
|
||||
#if HIP_BF16_AVX512_OP
|
||||
union {
|
||||
unsigned short us;
|
||||
__bf16 bf16;
|
||||
} u = {val};
|
||||
return _mm_cvtsbh_ss(u.bf16);
|
||||
#else
|
||||
unsigned int uval = val << 16;
|
||||
union {
|
||||
unsigned int u32;
|
||||
float fp32;
|
||||
} u = {uval};
|
||||
return u.fp32;
|
||||
#endif
|
||||
}
|
||||
__BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) {
|
||||
#if HIP_BF16_AVX512_OP
|
||||
union {
|
||||
__bf16 bf16;
|
||||
unsigned short us;
|
||||
} u = {_mm_cvtness_sbh(f)};
|
||||
return u.us;
|
||||
#else
|
||||
union {
|
||||
float fp32;
|
||||
unsigned int u32;
|
||||
@@ -201,7 +229,9 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
|
||||
u.u32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return static_cast<unsigned short>(u.u32 >> 16);
|
||||
#endif
|
||||
}
|
||||
|
||||
__BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) {
|
||||
union {
|
||||
float fp32;
|
||||
@@ -393,6 +423,7 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
|
||||
__hip_bfloat16 x; /*! \brief raw representation of bfloat16 */
|
||||
__hip_bfloat16 y; /*! \brief raw representation of bfloat16 */
|
||||
|
||||
|
||||
public:
|
||||
/*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw& h2r)
|
||||
@@ -422,22 +453,34 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
|
||||
|
||||
/*! \brief return a float2 */
|
||||
__BF16_HOST_DEVICE__ operator float2() const {
|
||||
#if HIP_BF16_AVX512_OP
|
||||
union {
|
||||
__hip_bfloat162_raw raw2;
|
||||
__bf16 bf162[2];
|
||||
static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw));
|
||||
} u;
|
||||
u.raw2 = *this;
|
||||
__m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0};
|
||||
__m128 pf32 = _mm_cvtpbh_ps(pbf16);
|
||||
float2 ret(pf32[0], pf32[1]);
|
||||
#else
|
||||
float2 ret(x, y);
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*! \brief assign value from __hip_bfloat162_raw */
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) {
|
||||
x = __hip_bfloat16(__hip_bfloat16{h2r.x});
|
||||
y = __hip_bfloat16(__hip_bfloat16{h2r.y});
|
||||
x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x});
|
||||
y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y});
|
||||
return *this;
|
||||
}
|
||||
|
||||
/*! \brief assign value from __hip_bfloat162 */
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162& src) {
|
||||
__hip_bfloat162_raw hr = src;
|
||||
x = __hip_bfloat16(__hip_bfloat16{hr.x});
|
||||
y = __hip_bfloat16(__hip_bfloat16{hr.y});
|
||||
x = __hip_bfloat16(__hip_bfloat16_raw{hr.x});
|
||||
y = __hip_bfloat16(__hip_bfloat16_raw{hr.y});
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
@@ -466,9 +509,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) {
|
||||
* \brief Converts and moves bfloat162 to float2
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
||||
__hip_bfloat162_raw hr = a;
|
||||
return float2{__bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x})),
|
||||
__bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y}))};
|
||||
float2 ret = a;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -576,7 +618,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) {
|
||||
__hip_bfloat162_raw hr = a;
|
||||
return __hip_bfloat162(__hip_bfloat162{hr.x, hr.x});
|
||||
return __hip_bfloat162(hr.x, hr.x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
مرجع در شماره جدید
Block a user