From 80f02a1534fbc2d46faaab4bd0db0f373d1fc6c3 Mon Sep 17 00:00:00 2001 From: Jatin Chaudhary Date: Fri, 12 Apr 2024 00:16:11 +0100 Subject: [PATCH] SWDEV-379007 - use avx instruction for bf16 cvt AMD CPUs have had avx512_bf16 support for quite some time now (from consumer Ryzen 7000 series to enterprise grade CPUs). This patch should allow users to use the hardware bf16 unit when running the __host__ variants of the function. This can be enabled via `hipcc ... -mavx512vl -mavx512bf16`. Change-Id: I67c377afc95ddfe8d45a048dce078a247d4a1878 [ROCm/clr commit: 49349f168cba7ab7876cce9669747d5bfab32496] --- .../include/hip/amd_detail/amd_hip_bf16.h | 60 ++++++++++++++++--- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h index 4ddafeafe3..193ca9174b 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h @@ -116,6 +116,20 @@ #define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline #define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline +#if defined(__AVX512VL__) and defined(__AVX512BF16__) and not defined(__HIP_DEVICE_COMPILE__) +// Enable with -mavx512vl -mavx512bf16 +#if defined(__MINGW64__) +#include +#else +#include +#endif +#define HIP_BF16_AVX512_OP 1 +static_assert(sizeof(__bf16) == sizeof(unsigned short), + "sizeof __bf16 should match sizeof unsigned short"); +#else +#define HIP_BF16_AVX512_OP 0 +#endif + #define HIPRT_ONE_BF16 __float2bfloat16(1.0f) #define HIPRT_ZERO_BF16 __float2bfloat16(0.0f) #define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) @@ -140,7 +154,6 @@ typedef struct __attribute__((aligned(2))) { unsigned short x; } __hip_bfloat16_raw; - /** * \ingroup HIP_INTRINSIC_BFLOAT162_RAW * \brief represents raw bfloat16x2 vector type @@ -159,14 +172,29 @@ typedef struct __attribute__((aligned(4))) { struct __attribute__((aligned(2))) __hip_bfloat16 { private: __BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) { +#if HIP_BF16_AVX512_OP + union { + unsigned short us; + __bf16 bf16; + } u = {val}; + return _mm_cvtsbh_ss(u.bf16); +#else unsigned int uval = val << 16; union { unsigned int u32; float fp32; } u = {uval}; return u.fp32; +#endif } __BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) { +#if HIP_BF16_AVX512_OP + union { + __bf16 bf16; + unsigned short us; + } u = {_mm_cvtness_sbh(f)}; + return u.us; +#else union { float fp32; unsigned int u32; @@ -201,7 +229,9 @@ struct __attribute__((aligned(2))) __hip_bfloat16 { u.u32 |= 0x10000; // Preserve signaling NaN } return static_cast(u.u32 >> 16); +#endif } + __BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) { union { float fp32; @@ -393,6 +423,7 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */ __hip_bfloat16 y; /*! \brief raw representation of bfloat16 */ + public: /*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */ __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw& h2r) @@ -422,22 +453,34 @@ struct __attribute__((aligned(4))) __hip_bfloat162 { /*! \brief return a float2 */ __BF16_HOST_DEVICE__ operator float2() const { +#if HIP_BF16_AVX512_OP + union { + __hip_bfloat162_raw raw2; + __bf16 bf162[2]; + static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw)); + } u; + u.raw2 = *this; + __m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0}; + __m128 pf32 = _mm_cvtpbh_ps(pbf16); + float2 ret(pf32[0], pf32[1]); +#else float2 ret(x, y); +#endif return ret; } /*! \brief assign value from __hip_bfloat162_raw */ __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162_raw& h2r) { - x = __hip_bfloat16(__hip_bfloat16{h2r.x}); - y = __hip_bfloat16(__hip_bfloat16{h2r.y}); + x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x}); + y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y}); return *this; } /*! \brief assign value from __hip_bfloat162 */ __BF16_HOST_DEVICE__ __hip_bfloat162& operator=(const __hip_bfloat162& src) { __hip_bfloat162_raw hr = src; - x = __hip_bfloat16(__hip_bfloat16{hr.x}); - y = __hip_bfloat16(__hip_bfloat16{hr.y}); + x = __hip_bfloat16(__hip_bfloat16_raw{hr.x}); + y = __hip_bfloat16(__hip_bfloat16_raw{hr.y}); return *this; } }; @@ -466,9 +509,8 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) { * \brief Converts and moves bfloat162 to float2 */ __BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) { - __hip_bfloat162_raw hr = a; - return float2{__bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x})), - __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y}))}; + float2 ret = a; + return ret; } /** @@ -576,7 +618,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { __hip_bfloat162_raw hr = a; - return __hip_bfloat162(__hip_bfloat162{hr.x, hr.x}); + return __hip_bfloat162(hr.x, hr.x); } /**