SWDEV-539481 - Add _rn variants of fp16 APIs (#582)

* Add _rn variants of fp16 APIs
* cover bf16 as well
This commit is contained in:
Chaudhary, Jatin Jaikishan
2025-08-12 17:28:38 +01:00
gecommit door GitHub
bovenliggende c60888b0cd
commit 4d5fe2206d
2 gewijzigde bestanden met toevoegingen van 115 en 0 verwijderingen
@@ -781,6 +781,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const
return (__bf16)a + (__bf16)b;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Adds two bfloat16 values, will not fuse into fma
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd_rn(const __hip_bfloat16 a,
const __hip_bfloat16 b) {
#pragma clang fp contract(off)
return (__bf16)a + (__bf16)b;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Subtracts two bfloat16 values
@@ -789,6 +799,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const
return (__bf16)a - (__bf16)b;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Subtracts two bfloat16 values, will not fuse into fma
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub_rn(const __hip_bfloat16 a,
const __hip_bfloat16 b) {
#pragma clang fp contract(off)
return (__bf16)a - (__bf16)b;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Divides two bfloat16 values
@@ -815,6 +835,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const
return (__bf16)a * (__bf16)b;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Multiplies two bfloat16 values, will not fuse into fma
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul_rn(const __hip_bfloat16 a,
const __hip_bfloat16 b) {
#pragma clang fp contract(off)
return (__bf16)a * (__bf16)b;
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Negate a bfloat16 value
@@ -861,6 +891,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a,
return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)};
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
* \brief Adds two bfloat162 values, will not fuse into fma
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2_rn(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
#pragma clang fp contract(off)
return __hip_bfloat162{__bf16_2(a) + __bf16_2(b)};
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
* \brief Performs FMA of given bfloat162 values
@@ -879,6 +919,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a,
return __hip_bfloat162{__bf16_2(a) * __bf16_2(b)};
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
* \brief Multiplies two bfloat162 values, will not fuse into fma
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2_rn(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
#pragma clang fp contract(off)
return __hip_bfloat162{__bf16_2(a) * __bf16_2(b)};
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
* \brief Converts a bfloat162 into negative
@@ -896,6 +946,16 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a,
return __hip_bfloat162{__bf16_2(a) - __bf16_2(b)};
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT162_ARITH
* \brief Subtracts two bfloat162 values, will not fuse into fma
*/
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2_rn(const __hip_bfloat162 a,
const __hip_bfloat162 b) {
#pragma clang fp contract(off)
return __hip_bfloat162{__bf16_2(a) - __bf16_2(b)};
}
/**
* \ingroup HIP_INTRINSIC_BFLOAT16_ARITH
* \brief Operator to multiply two __hip_bfloat16 numbers
@@ -1367,6 +1367,15 @@ THE SOFTWARE.
return __half_raw{
static_cast<__half_raw>(x).data +
static_cast<__half_raw>(y).data};
}
inline
__HOST_DEVICE__
__half __hadd_rn(__half x, __half y)
{
#pragma clang fp contract(off)
return __half_raw{
static_cast<__half_raw>(x).data +
static_cast<__half_raw>(y).data};
}
inline
__HOST_DEVICE__
@@ -1390,6 +1399,15 @@ THE SOFTWARE.
}
inline
__HOST_DEVICE__
__half __hsub_rn(__half x, __half y)
{
#pragma clang fp contract(off)
return __half_raw{
static_cast<__half_raw>(x).data -
static_cast<__half_raw>(y).data};
}
inline
__HOST_DEVICE__
__half __hmul(__half x, __half y)
{
return __half_raw{
@@ -1398,6 +1416,15 @@ THE SOFTWARE.
}
inline
__HOST_DEVICE__
__half __hmul_rn(__half x, __half y)
{
#pragma clang fp contract(off)
return __half_raw{
static_cast<__half_raw>(x).data *
static_cast<__half_raw>(y).data};
}
inline
__HOST_DEVICE__
__half __hadd_sat(__half x, __half y)
{
return __clamp_01(__hadd(x, y));
@@ -1446,6 +1473,16 @@ THE SOFTWARE.
static_cast<__half2_raw>(x).data +
static_cast<__half2_raw>(y).data};
}
inline
__HOST_DEVICE__
__half2 __hadd2_rn(__half2 x, __half2 y)
{
#pragma clang fp contract(off)
return __half2{
static_cast<__half2_raw>(x).data +
static_cast<__half2_raw>(y).data};
}
inline
__HOST_DEVICE__
__half2 __habs2(__half2 x)
@@ -1462,6 +1499,15 @@ THE SOFTWARE.
}
inline
__HOST_DEVICE__
__half2 __hsub2_rn(__half2 x, __half2 y)
{
#pragma clang fp contract(off)
return __half2{
static_cast<__half2_raw>(x).data -
static_cast<__half2_raw>(y).data};
}
inline
__HOST_DEVICE__
__half2 __hmul2(__half2 x, __half2 y)
{
return __half2{
@@ -1470,6 +1516,15 @@ THE SOFTWARE.
}
inline
__HOST_DEVICE__
__half2 __hmul2_rn(__half2 x, __half2 y)
{
#pragma clang fp contract(off)
return __half2{
static_cast<__half2_raw>(x).data *
static_cast<__half2_raw>(y).data};
}
inline
__HOST_DEVICE__
__half2 __hadd2_sat(__half2 x, __half2 y)
{
auto r = static_cast<__half2_raw>(__hadd2(x, y));