SWDEV-365299 - update implementation __float2half_{rd,ru,rz} to call __ocml_cvt{rtn, rtp, rtz}_f16_f32

Change-Id: I6cd711fbeb0e02a1caa03ac7f3dd9c8f8fdbac01
This commit is contained in:
Anusha GodavarthySurya
2022-11-01 09:34:05 +00:00
parent 766c40aab8
commit f7ca4b8fb9
2 changed files with 32 additions and 4 deletions
+24 -4
View File
@@ -640,7 +640,6 @@ THE SOFTWARE.
return r;
}
// TODO: rounding behaviour is not correct.
// float -> half | half2
inline
__HOST_DEVICE__
@@ -654,24 +653,45 @@ THE SOFTWARE.
{
return __half_raw{static_cast<_Float16>(x)};
}
#if !defined(__HIPCC_RTC__)
// TODO: rounding behaviour is not correct for host functions.
inline
__HOST_DEVICE__
__host__
__half __float2half_rz(float x)
{
return __half_raw{static_cast<_Float16>(x)};
}
inline
__HOST_DEVICE__
__host__
__half __float2half_rd(float x)
{
return __half_raw{static_cast<_Float16>(x)};
}
inline
__HOST_DEVICE__
__host__
__half __float2half_ru(float x)
{
return __half_raw{static_cast<_Float16>(x)};
}
#endif
inline
__device__
__half __float2half_rz(float x)
{
return __half_raw{__ocml_cvtrtz_f16_f32(x)};
}
inline
__device__
__half __float2half_rd(float x)
{
return __half_raw{__ocml_cvtrtn_f16_f32(x)};
}
inline
__device__
__half __float2half_ru(float x)
{
return __half_raw{__ocml_cvtrtp_f16_f32(x)};
}
inline
__HOST_DEVICE__
__half2 __float2half2_rn(float x)
@@ -83,10 +83,18 @@ extern "C"
__device__ __2f16 __ocml_sin_2f16(__2f16);
__device__ __attribute__((const)) __2f16 __ocml_sqrt_2f16(__2f16);
__device__ __attribute__((const)) __2f16 __ocml_trunc_2f16(__2f16);
__device__ __attribute__((const)) _Float16 __ocml_cvtrtn_f16_f32(float);
__device__ __attribute__((const)) _Float16 __ocml_cvtrtp_f16_f32(float);
__device__ __attribute__((const)) _Float16 __ocml_cvtrtz_f16_f32(float);
}
#endif // !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__
//TODO: remove these after they get into clang header __clang_hip_libdevice_declares.h'
extern "C" {
__device__ __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
__device__ __attribute__((const)) _Float16 __ocml_fmin_f16(_Float16, _Float16);
__device__ __attribute__((const)) _Float16 __ocml_cvtrtn_f16_f32(float);
__device__ __attribute__((const)) _Float16 __ocml_cvtrtp_f16_f32(float);
__device__ __attribute__((const)) _Float16 __ocml_cvtrtz_f16_f32(float);
}