Merge "added erfinv software implementation" into amd-master
[ROCm/clr commit: 1dbb2f5205]
This commit is contained in:
@@ -220,6 +220,7 @@ __device__ double erf(double x);
|
||||
__device__ double erfc(double x);
|
||||
__device__ double erfcinv(double y);
|
||||
__device__ double erfcx(double x);
|
||||
__device__ double erfinv(double x);
|
||||
__device__ double exp(double x);
|
||||
__device__ double exp10(double x);
|
||||
__device__ double exp2(double x);
|
||||
|
||||
@@ -27,6 +27,117 @@ THE SOFTWARE.
|
||||
using namespace hc::precise_math;
|
||||
#endif
|
||||
|
||||
#define __hip_erfinva3 -0.140543331
|
||||
#define __hip_erfinva2 0.914624893
|
||||
#define __hip_erfinva1 -1.645349621
|
||||
#define __hip_erfinva0 0.886226899
|
||||
|
||||
#define __hip_erfinvb4 0.012229801
|
||||
#define __hip_erfinvb3 -0.329097515
|
||||
#define __hip_erfinvb2 1.442710462
|
||||
#define __hip_erfinvb1 -2.118377725
|
||||
#define __hip_erfinvb0 1
|
||||
|
||||
#define __hip_erfinvc3 1.641345311
|
||||
#define __hip_erfinvc2 3.429567803
|
||||
#define __hip_erfinvc1 -1.62490649
|
||||
#define __hip_erfinvc0 -1.970840454
|
||||
|
||||
#define __hip_erfinvd2 1.637067800
|
||||
#define __hip_erfinvd1 3.543889200
|
||||
#define __hip_erfinvd0 1
|
||||
|
||||
#define HIP_PI 3.14159265358979323846
|
||||
|
||||
__device__ float __hip_erfinvf(float x){
|
||||
float ret;
|
||||
int sign;
|
||||
if (x < -1 || x > 1){
|
||||
return NAN;
|
||||
}
|
||||
if (x == 0){
|
||||
return 0;
|
||||
}
|
||||
if (x > 0){
|
||||
sign = 1;
|
||||
} else {
|
||||
sign = -1;
|
||||
x = -x;
|
||||
}
|
||||
if (x <= 0.7) {
|
||||
float x1 = x * x;
|
||||
float x2 = hc::precise_math::fmaf(__hip_erfinva3, x1, __hip_erfinva2);
|
||||
float x3 = hc::precise_math::fmaf(x2, x1, __hip_erfinva1);
|
||||
float x4 = x * hc::precise_math::fmaf(x3, x1, __hip_erfinva0);
|
||||
|
||||
float r1 = hc::precise_math::fmaf(__hip_erfinvb4, x1, __hip_erfinvb3);
|
||||
float r2 = hc::precise_math::fmaf(r1, x1, __hip_erfinvb2);
|
||||
float r3 = hc::precise_math::fmaf(r2, x1, __hip_erfinvb1);
|
||||
ret = x4 / hc::precise_math::fmaf(r3, x1, __hip_erfinvb0);
|
||||
} else {
|
||||
float x1 = hc::precise_math::sqrtf(-hc::precise_math::logf((1 - x) / 2));
|
||||
float x2 = hc::precise_math::fmaf(__hip_erfinvc3, x1, __hip_erfinvc2);
|
||||
float x3 = hc::precise_math::fmaf(x2, x1, __hip_erfinvc1);
|
||||
float x4 = hc::precise_math::fmaf(x3, x1, __hip_erfinvc0);
|
||||
|
||||
float r1 = hc::precise_math::fmaf(__hip_erfinvd2, x1, __hip_erfinvd1);
|
||||
ret = x4 / hc::precise_math::fmaf(r1, x1, __hip_erfinvd0);
|
||||
}
|
||||
|
||||
ret = ret * sign;
|
||||
x = x * sign;
|
||||
|
||||
ret -= (hc::precise_math::erff(ret) - x) / (2 / hc::precise_math::sqrtf(HIP_PI) * hc::precise_math::expf(-ret * ret));
|
||||
ret -= (hc::precise_math::erff(ret) - x) / (2 / hc::precise_math::sqrtf(HIP_PI) * hc::precise_math::expf(-ret * ret));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ double __hip_erfinv(double x){
|
||||
double ret;
|
||||
int sign;
|
||||
if (x < -1 || x > 1){
|
||||
return NAN;
|
||||
}
|
||||
if (x == 0){
|
||||
return 0;
|
||||
}
|
||||
if (x > 0){
|
||||
sign = 1;
|
||||
} else {
|
||||
sign = -1;
|
||||
x = -x;
|
||||
}
|
||||
if (x <= 0.7) {
|
||||
double x1 = x * x;
|
||||
double x2 = hc::precise_math::fma(__hip_erfinva3, x1, __hip_erfinva2);
|
||||
double x3 = hc::precise_math::fma(x2, x1, __hip_erfinva1);
|
||||
double x4 = x * hc::precise_math::fma(x3, x1, __hip_erfinva0);
|
||||
|
||||
double r1 = hc::precise_math::fma(__hip_erfinvb4, x1, __hip_erfinvb3);
|
||||
double r2 = hc::precise_math::fma(r1, x1, __hip_erfinvb2);
|
||||
double r3 = hc::precise_math::fma(r2, x1, __hip_erfinvb1);
|
||||
ret = x4 / hc::precise_math::fma(r3, x1, __hip_erfinvb0);
|
||||
} else {
|
||||
double x1 = hc::precise_math::sqrt(-hc::precise_math::log((1 - x) / 2));
|
||||
double x2 = hc::precise_math::fma(__hip_erfinvc3, x1, __hip_erfinvc2);
|
||||
double x3 = hc::precise_math::fma(x2, x1, __hip_erfinvc1);
|
||||
double x4 = hc::precise_math::fma(x3, x1, __hip_erfinvc0);
|
||||
|
||||
double r1 = hc::precise_math::fma(__hip_erfinvd2, x1, __hip_erfinvd1);
|
||||
ret = x4 / hc::precise_math::fma(r1, x1, __hip_erfinvd0);
|
||||
}
|
||||
|
||||
ret = ret * sign;
|
||||
x = x * sign;
|
||||
|
||||
ret -= (hc::precise_math::erf(ret) - x) / (2 / hc::precise_math::sqrt(HIP_PI) * hc::precise_math::exp(-ret * ret));
|
||||
ret -= (hc::precise_math::erf(ret) - x) / (2 / hc::precise_math::sqrt(HIP_PI) * hc::precise_math::exp(-ret * ret));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
__device__ float acosf(float x)
|
||||
{
|
||||
return hc::precise_math::acosf(x);
|
||||
@@ -87,7 +198,10 @@ __device__ float erff(float x)
|
||||
{
|
||||
return hc::precise_math::erff(x);
|
||||
}
|
||||
__device__ float erfinvf(float y);
|
||||
__device__ float erfinvf(float y)
|
||||
{
|
||||
return __hip_erfinvf(y);
|
||||
}
|
||||
__device__ float exp10f(float x)
|
||||
{
|
||||
return hc::precise_math::exp10f(x);
|
||||
@@ -442,6 +556,10 @@ __device__ double erfc(double x)
|
||||
{
|
||||
return hc::precise_math::erfc(x);
|
||||
}
|
||||
__device__ double erfinv(double x)
|
||||
{
|
||||
return __hip_erfinv(x);
|
||||
}
|
||||
__device__ double exp(double x)
|
||||
{
|
||||
return hc::precise_math::exp(x);
|
||||
|
||||
@@ -98,6 +98,10 @@ __global__ void test_rnormf(hipLaunchParm lp, float *a, float *b){
|
||||
b[tid] = rnormf(N, a);
|
||||
}
|
||||
|
||||
__global__ void test_erfinvf(hipLaunchParm lp, float *a, float *b){
|
||||
int tid = hipThreadIdx_x;
|
||||
b[tid] = erff(erfinvf(a[tid]));
|
||||
}
|
||||
|
||||
|
||||
bool run_sincosf(){
|
||||
@@ -591,13 +595,39 @@ assert(passed == 1);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool run_erfinvf(){
|
||||
float *A, *Ad, *B, *Bd;
|
||||
A = new float[N];
|
||||
B = new float[N];
|
||||
for(int i=0;i<N;i++){
|
||||
A[i] = -0.6f;
|
||||
B[i] = 0.0f;
|
||||
}
|
||||
hipMalloc((void**)&Ad, SIZE);
|
||||
hipMalloc((void**)&Bd, SIZE);
|
||||
hipMemcpy(Ad, A, SIZE, hipMemcpyHostToDevice);
|
||||
hipLaunchKernel(test_erfinvf, dim3(1), dim3(N), 0, 0, Ad, Bd);
|
||||
hipMemcpy(B, Bd, SIZE, hipMemcpyDeviceToHost);
|
||||
int passed = 0;
|
||||
for(int i=0;i<512;i++){
|
||||
if(B[i] - A[i] < 0.000001){
|
||||
passed = 1;
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
if(passed == 1){
|
||||
return true;
|
||||
}
|
||||
assert(passed == 1);
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(){
|
||||
if(run_sincosf() && run_sincospif() && run_fdividef() &&
|
||||
run_llrintf() && run_norm3df() && run_norm4df() &&
|
||||
run_normf() && run_rnorm3df() && run_rnorm4df() &&
|
||||
run_rnormf() && run_lroundf() && run_llroundf() &&
|
||||
run_rintf() && run_rhypotf()
|
||||
run_rintf() && run_rhypotf() && run_erfinvf()
|
||||
){
|
||||
passed();
|
||||
}
|
||||
|
||||
@@ -88,6 +88,11 @@ __global__ void test_rnorm(hipLaunchParm lp, double *a, double *b){
|
||||
b[tid] = rnorm(N, a);
|
||||
}
|
||||
|
||||
__global__ void test_erfinv(hipLaunchParm lp, double *a, double *b){
|
||||
int tid = hipThreadIdx_x;
|
||||
b[tid] = erf(erfinv(a[tid]));
|
||||
}
|
||||
|
||||
bool run_sincos(){
|
||||
double *A, *Ad, *B, *C, *Bd, *Cd;
|
||||
A = new double[N];
|
||||
@@ -517,12 +522,38 @@ assert(passed == 1);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool run_erfinv(){
|
||||
double *A, *Ad, *B, *Bd;
|
||||
A = new double[N];
|
||||
B = new double[N];
|
||||
for(int i=0;i<N;i++){
|
||||
A[i] = -0.6;
|
||||
B[i] = 0.0;
|
||||
}
|
||||
hipMalloc((void**)&Ad, SIZE);
|
||||
hipMalloc((void**)&Bd, SIZE);
|
||||
hipMemcpy(Ad, A, SIZE, hipMemcpyHostToDevice);
|
||||
hipLaunchKernel(test_erfinv, dim3(1), dim3(N), 0, 0, Ad, Bd);
|
||||
hipMemcpy(B, Bd, SIZE, hipMemcpyDeviceToHost);
|
||||
int passed = 0;
|
||||
for(int i=0;i<512;i++){
|
||||
if(B[i] - A[i] < 0.000001){
|
||||
passed = 1;
|
||||
}
|
||||
}
|
||||
free(A);
|
||||
if(passed == 1){
|
||||
return true;
|
||||
}
|
||||
assert(passed == 1);
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(){
|
||||
if(run_sincos() && run_sincospi() && run_llrint() && run_norm3d() && run_norm4d() &&
|
||||
run_rnorm3d() && run_rnorm4d() &&
|
||||
run_rnorm() && run_lround() && run_llround() &&
|
||||
run_rint() && run_rhypot()
|
||||
run_rint() && run_rhypot() && run_erfinv()
|
||||
){
|
||||
passed();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user