SWDEV-548892 - Stop using ocml sqrt wrappers (#1716)

This commit is contained in:
Matt Arsenault
2025-11-13 16:19:44 -08:00
committed by GitHub
parent 65b607b0bd
commit 42e91b8934
2 changed files with 7 additions and 3 deletions
@@ -1748,7 +1748,9 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) {
* \brief Calculate sqrt of bfloat16
*/
__BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) {
return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h)));
// FIXME: Just directly use elementwise sqrt on the bfloat value
// and don't promote
return __float2bfloat16(__builtin_elementwise_sqrt(__bfloat162float(h)));
}
/**
@@ -980,7 +980,7 @@ inline __device__ __half hrsqrt(__half x) {
return __half_raw{__ocml_rsqrt_f16(static_cast<__half_raw>(x).data)};
}
inline __device__ __half hsqrt(__half x) {
return __half_raw{__ocml_sqrt_f16(static_cast<__half_raw>(x).data)};
return __half_raw{__builtin_elementwise_sqrt(static_cast<__half_raw>(x).data)};
}
inline __HOST_DEVICE__ bool __hisinf(__half x) {
__half_raw hr = x;
@@ -1015,7 +1015,9 @@ inline __device__ __half2 h2rcp(__half2 x) {
return _Float16_2{_Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data};
}
inline __device__ __half2 h2rsqrt(__half2 x) { return __ocml_rsqrt_2f16(x); }
inline __device__ __half2 h2sqrt(__half2 x) { return __ocml_sqrt_2f16(x); }
inline __device__ __half2 h2sqrt(__half2 x) {
return __half2{__builtin_elementwise_sqrt(static_cast<__half2_raw>(x).data)};
}
inline __device__ __half2 __hisinf2(__half2 x) {
auto r = __ocml_isinf_2f16(x);
return __half2{_Float16_2{static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}};