SWDEV-548892 - Stop using ocml sqrt wrappers (#1716)
This commit is contained in:
@@ -1748,7 +1748,9 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
|
||||
* \brief Calculate sqrt of bfloat16
|
||||
*/
|
||||
__BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
|
||||
return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h)));
|
||||
// FIXME: Just directly use elementwise sqrt on the bfloat value
|
||||
// and don't promote
|
||||
return __float2bfloat16(__builtin_elementwise_sqrt(__bfloat162float(h)));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -980,7 +980,7 @@ inline __device__ __half hrsqrt(__half x) {
|
||||
return __half_raw{__ocml_rsqrt_f16(static_cast<__half_raw>(x).data)};
|
||||
}
|
||||
inline __device__ __half hsqrt(__half x) {
|
||||
return __half_raw{__ocml_sqrt_f16(static_cast<__half_raw>(x).data)};
|
||||
return __half_raw{__builtin_elementwise_sqrt(static_cast<__half_raw>(x).data)};
|
||||
}
|
||||
inline __HOST_DEVICE__ bool __hisinf(__half x) {
|
||||
__half_raw hr = x;
|
||||
@@ -1015,7 +1015,9 @@ inline __device__ __half2 h2rcp(__half2 x) {
|
||||
return _Float16_2{_Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data};
|
||||
}
|
||||
inline __device__ __half2 h2rsqrt(__half2 x) { return __ocml_rsqrt_2f16(x); }
|
||||
inline __device__ __half2 h2sqrt(__half2 x) { return __ocml_sqrt_2f16(x); }
|
||||
inline __device__ __half2 h2sqrt(__half2 x) {
|
||||
return __half2{__builtin_elementwise_sqrt(static_cast<__half2_raw>(x).data)};
|
||||
}
|
||||
inline __device__ __half2 __hisinf2(__half2 x) {
|
||||
auto r = __ocml_isinf_2f16(x);
|
||||
return __half2{_Float16_2{static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}};
|
||||
|
||||
Reference in New Issue
Block a user