diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h index 52fab34a42..a731db3e6e 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h @@ -1748,7 +1748,9 @@ __BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) { * \brief Calculate sqrt of bfloat16 */ __BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) { - return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h))); + // FIXME: Just directly use elementwise sqrt on the bfloat value + // and don't promote + return __float2bfloat16(__builtin_elementwise_sqrt(__bfloat162float(h))); } /** diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h index cdb75affba..8869a44723 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h @@ -980,7 +980,7 @@ inline __device__ __half hrsqrt(__half x) { return __half_raw{__ocml_rsqrt_f16(static_cast<__half_raw>(x).data)}; } inline __device__ __half hsqrt(__half x) { - return __half_raw{__ocml_sqrt_f16(static_cast<__half_raw>(x).data)}; + return __half_raw{__builtin_elementwise_sqrt(static_cast<__half_raw>(x).data)}; } inline __HOST_DEVICE__ bool __hisinf(__half x) { __half_raw hr = x; @@ -1015,7 +1015,9 @@ inline __device__ __half2 h2rcp(__half2 x) { return _Float16_2{_Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data}; } inline __device__ __half2 h2rsqrt(__half2 x) { return __ocml_rsqrt_2f16(x); } -inline __device__ __half2 h2sqrt(__half2 x) { return __ocml_sqrt_2f16(x); } +inline __device__ __half2 h2sqrt(__half2 x) { + return __half2{__builtin_elementwise_sqrt(static_cast<__half2_raw>(x).data)}; +} inline __device__ __half2 __hisinf2(__half2 x) { auto r = __ocml_isinf_2f16(x); return __half2{_Float16_2{static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}};