SWDEV-466747 - add shfl functions in bfloat16

Change-Id: Ide7d7e1d449783cced8867abf43ff45f5bce113a


[ROCm/clr commit: e43176bde9]
Этот коммит содержится в:
Jatin Chaudhary
2024-07-15 14:05:52 +01:00
коммит произвёл Jatin Jaikishan Chaudhary
родитель f33637b1c6
Коммит 0193d66679
+151 -12
Просмотреть файл
@@ -69,6 +69,18 @@
* To use these functions, include the header file \p hip_bf16.h in your program.
*/
/**
* \defgroup HIP_INTRINSIC_BFLOAT16_MOVE Bfloat16 Data Movement Functions
* \ingroup HIP_INTRINSIC_BFLOAT16
* To use these functions, include the header file \p hip_bf16.h in your program.
*/
/**
* \defgroup HIP_INTRINSIC_BFLOAT162_MOVE Bfloat162 Data Movement Functions
* \ingroup HIP_INTRINSIC_BFLOAT16
* To use these functions, include the header file \p hip_bf16.h in your program.
*/
/**
* \defgroup HIP_INTRINSIC_BFLOAT16_MATH Bfloat16 Math Functions
* \ingroup HIP_INTRINSIC_BFLOAT16
@@ -98,7 +110,8 @@
#if !defined(__HIPCC_RTC__)
#include <hip/amd_detail/amd_hip_common.h>
#endif // !defined(__HIPCC_RTC__)
#include <hip/amd_detail/amd_warp_functions.h> // Sync functions
#endif // !defined(__HIPCC_RTC__)
#include "amd_hip_vector_types.h" // float2 etc
#include "device_library_decls.h" // ocml conversion functions
@@ -133,8 +146,8 @@ static_assert(sizeof(__bf16) == sizeof(unsigned short),
#define HIP_BF16_AVX512_OP 0
#endif
#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f)
#define HIPRT_ONE_BF16 __ushort_as_bfloat16((unsigned short)0x3F80U)
#define HIPRT_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x0000U)
#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU)
#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U)
@@ -525,8 +538,12 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfl
* \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer
*/
__BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h) {
short ret = h;
return ret;
static_assert(sizeof(__hip_bfloat16) == sizeof(short int));
union {
__hip_bfloat16 bf16;
short int si;
} u{h};
return u.si;
}
/**
@@ -534,8 +551,12 @@ __BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h
* \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer
*/
__BF16_HOST_DEVICE_STATIC__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) {
unsigned short ret = h;
return ret;
static_assert(sizeof(__hip_bfloat16) == sizeof(unsigned short int));
union {
__hip_bfloat16 bf16;
unsigned short int usi;
} u{h};
return u.usi;
}
/**
@@ -599,9 +620,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
* \brief Returns low 16 bits of __hip_bfloat162
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) {
return __hip_bfloat16(a.x);
}
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; }
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
@@ -641,7 +660,12 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat1
* \brief Reinterprets short int into a bfloat16
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
return __hip_bfloat16(a);
static_assert(sizeof(__hip_bfloat16) == sizeof(short int));
union {
short int si;
__hip_bfloat16 bf16;
} u{a};
return u.bf16;
}
/**
@@ -649,9 +673,124 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a
* \brief Reinterprets unsigned short int into a bfloat16
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
return __hip_bfloat16(a);
static_assert(sizeof(__hip_bfloat16) == sizeof(unsigned short int));
union {
unsigned short int usi;
__hip_bfloat16 bf16;
} u{a};
return u.bf16;
}
#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
* \brief shfl down warp intrinsic for bfloat16
*/
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_down_sync(const unsigned long long mask,
const __hip_bfloat16 in,
const unsigned int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
return __ushort_as_bfloat16(__shfl_down_sync(mask, __bfloat16_as_ushort(in), delta, width));
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
* \brief shfl down warp intrinsic for bfloat16
*/
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_down_sync(const unsigned long long mask,
const __hip_bfloat162 in,
const unsigned int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
union {
__hip_bfloat162 bf162;
unsigned int ui;
} u{in};
u.ui = __shfl_down_sync<unsigned long long, unsigned int>(mask, u.ui, delta, width);
return u.bf162;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
* \brief shfl sync warp intrinsic for bfloat16
*/
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_sync(const unsigned long long mask,
const __hip_bfloat16 in, const int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
return __ushort_as_bfloat16(__shfl_sync(mask, __bfloat16_as_ushort(in), delta, width));
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
* \brief shfl sync warp intrinsic for bfloat162
*/
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_sync(const unsigned long long mask,
const __hip_bfloat162 in, const int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
union {
__hip_bfloat162 bf162;
unsigned int ui;
} u{in};
u.ui = __shfl_sync(mask, u.ui, delta, width);
return u.bf162;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
* \brief shfl up sync warp intrinsic for bfloat16
*/
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_up_sync(const unsigned long long mask,
const __hip_bfloat16 in,
const unsigned int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
return __ushort_as_bfloat16(__shfl_up_sync(mask, __bfloat16_as_ushort(in), delta, width));
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
* \brief shfl up sync warp intrinsic for bfloat162
*/
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_up_sync(const unsigned long long mask,
const __hip_bfloat162 in,
const unsigned int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
union {
__hip_bfloat162 bf162;
unsigned int ui;
} u{in};
u.ui = __shfl_up_sync(mask, u.ui, delta, width);
return u.bf162;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
* \brief shfl xor sync warp intrinsic for bfloat16
*/
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_xor_sync(const unsigned long long mask,
const __hip_bfloat16 in, const int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
return __ushort_as_bfloat16(__shfl_xor_sync(mask, __bfloat16_as_ushort(in), delta, width));
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
* \brief shfl xor sync warp intrinsic for bfloat162
*/
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_xor_sync(const unsigned long long mask,
const __hip_bfloat162 in, const int delta,
const int width = __AMDGCN_WAVEFRONT_SIZE) {
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
union {
__hip_bfloat162 bf162;
unsigned int ui;
} u{in};
u.ui = __shfl_xor_sync(mask, u.ui, delta, width);
return u.bf162;
}
#endif
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Adds two bfloat16 values