From 0193d6667918c8cc830ff7a3ce2dd24a4a0a6959 Mon Sep 17 00:00:00 2001 From: Jatin Chaudhary Date: Mon, 15 Jul 2024 14:05:52 +0100 Subject: [PATCH] SWDEV-466747 - add shfl functions in bfloat16 Change-Id: Ide7d7e1d449783cced8867abf43ff45f5bce113a [ROCm/clr commit: e43176bde93ee82ba402d3102c18b6acddc566ed] --- .../include/hip/amd_detail/amd_hip_bf16.h | 163 ++++++++++++++++-- 1 file changed, 151 insertions(+), 12 deletions(-) diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h index 42e91ee4f2..6deaab77fa 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h @@ -69,6 +69,18 @@ * To use these functions, include the header file \p hip_bf16.h in your program. */ +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_MOVE Bfloat16 Data Movement Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_MOVE Bfloat162 Data Movement Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your program. + */ + /** * \defgroup HIP_INTRINSIC_BFLOAT16_MATH Bfloat16 Math Functions * \ingroup HIP_INTRINSIC_BFLOAT16 @@ -98,7 +110,8 @@ #if !defined(__HIPCC_RTC__) #include -#endif // !defined(__HIPCC_RTC__) +#include // Sync functions +#endif // !defined(__HIPCC_RTC__) #include "amd_hip_vector_types.h" // float2 etc #include "device_library_decls.h" // ocml conversion functions @@ -133,8 +146,8 @@ static_assert(sizeof(__bf16) == sizeof(unsigned short), #define HIP_BF16_AVX512_OP 0 #endif -#define HIPRT_ONE_BF16 __float2bfloat16(1.0f) -#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f) +#define HIPRT_ONE_BF16 __ushort_as_bfloat16((unsigned short)0x3F80U) +#define HIPRT_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x0000U) #define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) #define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU) #define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U) @@ -525,8 +538,12 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfl * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer */ __BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h) { - short ret = h; - return ret; + static_assert(sizeof(__hip_bfloat16) == sizeof(short int)); + union { + __hip_bfloat16 bf16; + short int si; + } u{h}; + return u.si; } /** @@ -534,8 +551,12 @@ __BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer */ __BF16_HOST_DEVICE_STATIC__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { - unsigned short ret = h; - return ret; + static_assert(sizeof(__hip_bfloat16) == sizeof(unsigned short int)); + union { + __hip_bfloat16 bf16; + unsigned short int usi; + } u{h}; + return u.usi; } /** @@ -599,9 +620,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Returns low 16 bits of __hip_bfloat162 */ -__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { - return __hip_bfloat16(a.x); -} +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV @@ -641,7 +660,12 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat1 * \brief Reinterprets short int into a bfloat16 */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a) { - return __hip_bfloat16(a); + static_assert(sizeof(__hip_bfloat16) == sizeof(short int)); + union { + short int si; + __hip_bfloat16 bf16; + } u{a}; + return u.bf16; } /** @@ -649,9 +673,124 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a * \brief Reinterprets unsigned short int into a bfloat16 */ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { - return __hip_bfloat16(a); + static_assert(sizeof(__hip_bfloat16) == sizeof(unsigned short int)); + union { + unsigned short int usi; + __hip_bfloat16 bf16; + } u{a}; + return u.bf16; } +#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MOVE + * \brief shfl down warp intrinsic for bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_down_sync(const unsigned long long mask, + const __hip_bfloat16 in, + const unsigned int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + return __ushort_as_bfloat16(__shfl_down_sync(mask, __bfloat16_as_ushort(in), delta, width)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MOVE + * \brief shfl down warp intrinsic for bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_down_sync(const unsigned long long mask, + const __hip_bfloat162 in, + const unsigned int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int)); + union { + __hip_bfloat162 bf162; + unsigned int ui; + } u{in}; + u.ui = __shfl_down_sync(mask, u.ui, delta, width); + return u.bf162; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MOVE + * \brief shfl sync warp intrinsic for bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_sync(const unsigned long long mask, + const __hip_bfloat16 in, const int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + return __ushort_as_bfloat16(__shfl_sync(mask, __bfloat16_as_ushort(in), delta, width)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MOVE + * \brief shfl sync warp intrinsic for bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_sync(const unsigned long long mask, + const __hip_bfloat162 in, const int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int)); + union { + __hip_bfloat162 bf162; + unsigned int ui; + } u{in}; + u.ui = __shfl_sync(mask, u.ui, delta, width); + return u.bf162; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MOVE + * \brief shfl up sync warp intrinsic for bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_up_sync(const unsigned long long mask, + const __hip_bfloat16 in, + const unsigned int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + return __ushort_as_bfloat16(__shfl_up_sync(mask, __bfloat16_as_ushort(in), delta, width)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MOVE + * \brief shfl up sync warp intrinsic for bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_up_sync(const unsigned long long mask, + const __hip_bfloat162 in, + const unsigned int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int)); + union { + __hip_bfloat162 bf162; + unsigned int ui; + } u{in}; + u.ui = __shfl_up_sync(mask, u.ui, delta, width); + return u.bf162; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MOVE + * \brief shfl xor sync warp intrinsic for bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_xor_sync(const unsigned long long mask, + const __hip_bfloat16 in, const int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + return __ushort_as_bfloat16(__shfl_xor_sync(mask, __bfloat16_as_ushort(in), delta, width)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MOVE + * \brief shfl xor sync warp intrinsic for bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_xor_sync(const unsigned long long mask, + const __hip_bfloat162 in, const int delta, + const int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int)); + union { + __hip_bfloat162 bf162; + unsigned int ui; + } u{in}; + u.ui = __shfl_xor_sync(mask, u.ui, delta, width); + return u.bf162; +} +#endif + /** * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH * \brief Adds two bfloat16 values