diff --git a/include/hip/hcc_detail/math_functions.h b/include/hip/hcc_detail/math_functions.h index 5482f34093..e717df07c1 100644 --- a/include/hip/hcc_detail/math_functions.h +++ b/include/hip/hcc_detail/math_functions.h @@ -57,6 +57,7 @@ __device__ float exp2f(float x); __device__ float expf(float x); __device__ float expm1f(float x); __device__ int abs(int x); +__device__ long long abs(long long x); __device__ float fabsf(float x); __device__ float fdimf(float x, float y); __device__ float fdividef(float x, float y); diff --git a/src/hip_fp16.cpp b/src/hip_fp16.cpp index e4c1b43786..3004b7805d 100644 --- a/src/hip_fp16.cpp +++ b/src/hip_fp16.cpp @@ -29,10 +29,6 @@ struct hipHalfHolder { }; }; -#define HINF 65504 - -__device__ static struct hipHalfHolder __hInfValue = {HINF}; - __device__ __half __hadd(__half a, __half b) { return a + b; } __device__ __half __hadd_sat(__half a, __half b) { return a + b; } @@ -63,9 +59,21 @@ __device__ bool __hge(__half a, __half b) { return a >= b ? true : false; } __device__ bool __hgt(__half a, __half b) { return a > b ? true : false; } -__device__ bool __hisinf(__half a) { return a == HINF ? true : false; } +__device__ bool __hisinf(__half a) { + hipHalfHolder hH; + hH.h = a; + // mask with 0x7fff to drop the sign bit + // 0x7c00 is bit pattern for inf (exp = 11111, significand = 0) + return ((hH.s & 0x7fff) == 0x7c00) ? true : false; +} -__device__ bool __hisnan(__half a) { return a > HINF ? true : false; } +__device__ bool __hisnan(__half a) { + hipHalfHolder hH; + hH.h = a; + // mask with 0x7fff to drop the sign bit + // 0x7cXX is bit pattern for inf (exp = 11111, significand = 0) + return ((hH.s & 0x7fff) > 0x7c00) ? true : false; +} __device__ bool __hle(__half a, __half b) { return a <= b ? true : false; } @@ -124,8 +132,8 @@ __device__ __half2 __hgt2(__half2 a, __half2 b) { __device__ __half2 __hisnan2(__half2 a) { __half2 c; - c.x = (a.x > HINF) ? (__half)1 : (__half)0; - c.y = (a.y > HINF) ? (__half)1 : (__half)0; + c.x = (__hisnan(a.x)) ? (__half)1 : (__half)0; + c.y = (__hisnan(a.y)) ? (__half)1 : (__half)0; return c; } diff --git a/src/math_functions.cpp b/src/math_functions.cpp index 3c0a7f6541..dedc40f2ae 100644 --- a/src/math_functions.cpp +++ b/src/math_functions.cpp @@ -56,6 +56,9 @@ __device__ float expm1f(float x) { return hc::precise_math::expm1f(x); } __device__ int abs(int x) { return x >= 0 ? x : -x; // TODO - optimize with OCML } +__device__ long long abs(long long x) { + return x >= 0 ? x : -x; +} __device__ float fabsf(float x) { return hc::precise_math::fabsf(x); } __device__ float fdimf(float x, float y) { return hc::precise_math::fdimf(x, y); } __device__ float fdividef(float x, float y) { return x / y; } @@ -220,14 +223,7 @@ __device__ double j0(double x) { return __hip_j0(x); } __device__ double j1(double x) { return __hip_j1(x); } __device__ double jn(int n, double x) { return __hip_jn(n, x); } __device__ double ldexp(double x, int exp) { return hc::precise_math::ldexp(x, exp); } -__device__ double lgamma(double x) { - double val = 0.0; - double y = x - 1; - while (y > 0) { - val += log(y--); - } - return val; -} +__device__ double lgamma(double x) { return hc::precise_math::lgamma(x); } __device__ long long int llrint(double x) { long long int y = hc::precise_math::round(x); return y;