SWDEV-466747 - add shfl functions in bfloat16
Change-Id: Ide7d7e1d449783cced8867abf43ff45f5bce113a
[ROCm/clr commit: e43176bde9]
Этот коммит содержится в:
коммит произвёл
Jatin Jaikishan Chaudhary
родитель
f33637b1c6
Коммит
0193d66679
@@ -69,6 +69,18 @@
|
||||
* To use these functions, include the header file \p hip_bf16.h in your program.
|
||||
*/
|
||||
|
||||
/**
|
||||
* \defgroup HIP_INTRINSIC_BFLOAT16_MOVE Bfloat16 Data Movement Functions
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16
|
||||
* To use these functions, include the header file \p hip_bf16.h in your program.
|
||||
*/
|
||||
|
||||
/**
|
||||
* \defgroup HIP_INTRINSIC_BFLOAT162_MOVE Bfloat162 Data Movement Functions
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16
|
||||
* To use these functions, include the header file \p hip_bf16.h in your program.
|
||||
*/
|
||||
|
||||
/**
|
||||
* \defgroup HIP_INTRINSIC_BFLOAT16_MATH Bfloat16 Math Functions
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16
|
||||
@@ -98,7 +110,8 @@
|
||||
|
||||
#if !defined(__HIPCC_RTC__)
|
||||
#include <hip/amd_detail/amd_hip_common.h>
|
||||
#endif // !defined(__HIPCC_RTC__)
|
||||
#include <hip/amd_detail/amd_warp_functions.h> // Sync functions
|
||||
#endif // !defined(__HIPCC_RTC__)
|
||||
|
||||
#include "amd_hip_vector_types.h" // float2 etc
|
||||
#include "device_library_decls.h" // ocml conversion functions
|
||||
@@ -133,8 +146,8 @@ static_assert(sizeof(__bf16) == sizeof(unsigned short),
|
||||
#define HIP_BF16_AVX512_OP 0
|
||||
#endif
|
||||
|
||||
#define HIPRT_ONE_BF16 __float2bfloat16(1.0f)
|
||||
#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f)
|
||||
#define HIPRT_ONE_BF16 __ushort_as_bfloat16((unsigned short)0x3F80U)
|
||||
#define HIPRT_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x0000U)
|
||||
#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
|
||||
#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU)
|
||||
#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U)
|
||||
@@ -525,8 +538,12 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfl
|
||||
* \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h) {
|
||||
short ret = h;
|
||||
return ret;
|
||||
static_assert(sizeof(__hip_bfloat16) == sizeof(short int));
|
||||
union {
|
||||
__hip_bfloat16 bf16;
|
||||
short int si;
|
||||
} u{h};
|
||||
return u.si;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -534,8 +551,12 @@ __BF16_HOST_DEVICE_STATIC__ short int __bfloat16_as_short(const __hip_bfloat16 h
|
||||
* \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) {
|
||||
unsigned short ret = h;
|
||||
return ret;
|
||||
static_assert(sizeof(__hip_bfloat16) == sizeof(unsigned short int));
|
||||
union {
|
||||
__hip_bfloat16 bf16;
|
||||
unsigned short int usi;
|
||||
} u{h};
|
||||
return u.usi;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -599,9 +620,7 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
* \brief Returns low 16 bits of __hip_bfloat162
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) {
|
||||
return __hip_bfloat16(a.x);
|
||||
}
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; }
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
@@ -641,7 +660,12 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat1
|
||||
* \brief Reinterprets short int into a bfloat16
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a) {
|
||||
return __hip_bfloat16(a);
|
||||
static_assert(sizeof(__hip_bfloat16) == sizeof(short int));
|
||||
union {
|
||||
short int si;
|
||||
__hip_bfloat16 bf16;
|
||||
} u{a};
|
||||
return u.bf16;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -649,9 +673,124 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __short_as_bfloat16(const short int a
|
||||
* \brief Reinterprets unsigned short int into a bfloat16
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) {
|
||||
return __hip_bfloat16(a);
|
||||
static_assert(sizeof(__hip_bfloat16) == sizeof(unsigned short int));
|
||||
union {
|
||||
unsigned short int usi;
|
||||
__hip_bfloat16 bf16;
|
||||
} u{a};
|
||||
return u.bf16;
|
||||
}
|
||||
|
||||
#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
|
||||
* \brief shfl down warp intrinsic for bfloat16
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_down_sync(const unsigned long long mask,
|
||||
const __hip_bfloat16 in,
|
||||
const unsigned int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
return __ushort_as_bfloat16(__shfl_down_sync(mask, __bfloat16_as_ushort(in), delta, width));
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
|
||||
* \brief shfl down warp intrinsic for bfloat16
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_down_sync(const unsigned long long mask,
|
||||
const __hip_bfloat162 in,
|
||||
const unsigned int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
|
||||
union {
|
||||
__hip_bfloat162 bf162;
|
||||
unsigned int ui;
|
||||
} u{in};
|
||||
u.ui = __shfl_down_sync<unsigned long long, unsigned int>(mask, u.ui, delta, width);
|
||||
return u.bf162;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
|
||||
* \brief shfl sync warp intrinsic for bfloat16
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_sync(const unsigned long long mask,
|
||||
const __hip_bfloat16 in, const int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
return __ushort_as_bfloat16(__shfl_sync(mask, __bfloat16_as_ushort(in), delta, width));
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
|
||||
* \brief shfl sync warp intrinsic for bfloat162
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_sync(const unsigned long long mask,
|
||||
const __hip_bfloat162 in, const int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
|
||||
union {
|
||||
__hip_bfloat162 bf162;
|
||||
unsigned int ui;
|
||||
} u{in};
|
||||
u.ui = __shfl_sync(mask, u.ui, delta, width);
|
||||
return u.bf162;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
|
||||
* \brief shfl up sync warp intrinsic for bfloat16
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_up_sync(const unsigned long long mask,
|
||||
const __hip_bfloat16 in,
|
||||
const unsigned int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
return __ushort_as_bfloat16(__shfl_up_sync(mask, __bfloat16_as_ushort(in), delta, width));
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
|
||||
* \brief shfl up sync warp intrinsic for bfloat162
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_up_sync(const unsigned long long mask,
|
||||
const __hip_bfloat162 in,
|
||||
const unsigned int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
|
||||
union {
|
||||
__hip_bfloat162 bf162;
|
||||
unsigned int ui;
|
||||
} u{in};
|
||||
u.ui = __shfl_up_sync(mask, u.ui, delta, width);
|
||||
return u.bf162;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_MOVE
|
||||
* \brief shfl xor sync warp intrinsic for bfloat16
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat16 __shfl_xor_sync(const unsigned long long mask,
|
||||
const __hip_bfloat16 in, const int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
return __ushort_as_bfloat16(__shfl_xor_sync(mask, __bfloat16_as_ushort(in), delta, width));
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_MOVE
|
||||
* \brief shfl xor sync warp intrinsic for bfloat162
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat162 __shfl_xor_sync(const unsigned long long mask,
|
||||
const __hip_bfloat162 in, const int delta,
|
||||
const int width = __AMDGCN_WAVEFRONT_SIZE) {
|
||||
static_assert(sizeof(__hip_bfloat162) == sizeof(unsigned int));
|
||||
union {
|
||||
__hip_bfloat162 bf162;
|
||||
unsigned int ui;
|
||||
} u{in};
|
||||
u.ui = __shfl_xor_sync(mask, u.ui, delta, width);
|
||||
return u.bf162;
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Adds two bfloat16 values
|
||||
|
||||
Ссылка в новой задаче
Block a user