SWDEV-539481 - Add _rn variants of fp16 APIs (#582)
* Add _rn variants of fp16 APIs * cover bf16 as well
This commit is contained in:
gecommit door
GitHub
bovenliggende
c60888b0cd
commit
4d5fe2206d
@@ -781,6 +781,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const
|
||||
return (__bf16)a + (__bf16)b;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Adds two bfloat16 values, will not fuse into fma
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd_rn(const __hip_bfloat16 a,
|
||||
const __hip_bfloat16 b) {
|
||||
#pragma clang fp contract(off)
|
||||
return (__bf16)a + (__bf16)b;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Subtracts two bfloat16 values
|
||||
@@ -789,6 +799,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const
|
||||
return (__bf16)a - (__bf16)b;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Subtracts two bfloat16 values, will not fuse into fma
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub_rn(const __hip_bfloat16 a,
|
||||
const __hip_bfloat16 b) {
|
||||
#pragma clang fp contract(off)
|
||||
return (__bf16)a - (__bf16)b;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Divides two bfloat16 values
|
||||
@@ -815,6 +835,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const
|
||||
return (__bf16)a * (__bf16)b;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Multiplies two bfloat16 values, will not fuse into fma
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul_rn(const __hip_bfloat16 a,
|
||||
const __hip_bfloat16 b) {
|
||||
#pragma clang fp contract(off)
|
||||
return (__bf16)a * (__bf16)b;
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Negate a bfloat16 value
|
||||
@@ -861,6 +891,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a,
|
||||
return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)};
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
||||
* \brief Adds two bfloat162 values, will not fuse into fma
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2_rn(const __hip_bfloat162 a,
|
||||
const __hip_bfloat162 b) {
|
||||
#pragma clang fp contract(off)
|
||||
return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)};
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
||||
* \brief Performs FMA of given bfloat162 values
|
||||
@@ -879,6 +919,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a,
|
||||
return __hip_bfloat162{__bf16_2(a) * __bf16_2(b)};
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
||||
* \brief Multiplies two bfloat162 values, will not fuse into fma
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2_rn(const __hip_bfloat162 a,
|
||||
const __hip_bfloat162 b) {
|
||||
#pragma clang fp contract(off)
|
||||
return __hip_bfloat162{__bf16_2(a) * __bf16_2(b)};
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
||||
* \brief Converts a bfloat162 into negative
|
||||
@@ -896,6 +946,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a,
|
||||
return __hip_bfloat162{__bf16_2(a) - __bf16_2(b)};
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
|
||||
* \brief Subtracts two bfloat162 values, will not fuse into fma
|
||||
*/
|
||||
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2_rn(const __hip_bfloat162 a,
|
||||
const __hip_bfloat162 b) {
|
||||
#pragma clang fp contract(off)
|
||||
return __hip_bfloat162{__bf16_2(a) - __bf16_2(b)};
|
||||
}
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
|
||||
* \brief Operator to multiply two __hip_bfloat16 numbers
|
||||
|
||||
@@ -1367,6 +1367,15 @@ THE SOFTWARE.
|
||||
return __half_raw{
|
||||
static_cast<__half_raw>(x).data +
|
||||
static_cast<__half_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half __hadd_rn(__half x, __half y)
|
||||
{
|
||||
#pragma clang fp contract(off)
|
||||
return __half_raw{
|
||||
static_cast<__half_raw>(x).data +
|
||||
static_cast<__half_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
@@ -1390,6 +1399,15 @@ THE SOFTWARE.
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half __hsub_rn(__half x, __half y)
|
||||
{
|
||||
#pragma clang fp contract(off)
|
||||
return __half_raw{
|
||||
static_cast<__half_raw>(x).data -
|
||||
static_cast<__half_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half __hmul(__half x, __half y)
|
||||
{
|
||||
return __half_raw{
|
||||
@@ -1398,6 +1416,15 @@ THE SOFTWARE.
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half __hmul_rn(__half x, __half y)
|
||||
{
|
||||
#pragma clang fp contract(off)
|
||||
return __half_raw{
|
||||
static_cast<__half_raw>(x).data *
|
||||
static_cast<__half_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half __hadd_sat(__half x, __half y)
|
||||
{
|
||||
return __clamp_01(__hadd(x, y));
|
||||
@@ -1446,6 +1473,16 @@ THE SOFTWARE.
|
||||
static_cast<__half2_raw>(x).data +
|
||||
static_cast<__half2_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half2 __hadd2_rn(__half2 x, __half2 y)
|
||||
{
|
||||
#pragma clang fp contract(off)
|
||||
return __half2{
|
||||
static_cast<__half2_raw>(x).data +
|
||||
static_cast<__half2_raw>(y).data};
|
||||
}
|
||||
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half2 __habs2(__half2 x)
|
||||
@@ -1462,6 +1499,15 @@ THE SOFTWARE.
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half2 __hsub2_rn(__half2 x, __half2 y)
|
||||
{
|
||||
#pragma clang fp contract(off)
|
||||
return __half2{
|
||||
static_cast<__half2_raw>(x).data -
|
||||
static_cast<__half2_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half2 __hmul2(__half2 x, __half2 y)
|
||||
{
|
||||
return __half2{
|
||||
@@ -1470,6 +1516,15 @@ THE SOFTWARE.
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half2 __hmul2_rn(__half2 x, __half2 y)
|
||||
{
|
||||
#pragma clang fp contract(off)
|
||||
return __half2{
|
||||
static_cast<__half2_raw>(x).data *
|
||||
static_cast<__half2_raw>(y).data};
|
||||
}
|
||||
inline
|
||||
__HOST_DEVICE__
|
||||
__half2 __hadd2_sat(__half2 x, __half2 y)
|
||||
{
|
||||
auto r = static_cast<__half2_raw>(__hadd2(x, y));
|
||||
|
||||
Verwijs in nieuw issue
Block a user