diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_complex.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_complex.h index 7dc014b6fd..933fd4e165 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_complex.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_complex.h @@ -20,8 +20,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/* The header defines complex numbers and related functions*/ - #ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMPLEX_H #define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMPLEX_H @@ -55,6 +53,100 @@ THE SOFTWARE. #endif #endif // !defined(__HIPCC_RTC__) +#if __cplusplus +#define COMPLEX_NEG_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type operator-(const type& op) { \ + type ret; \ + ret.x = -op.x; \ + ret.y = -op.y; \ + return ret; \ + } + +#define COMPLEX_EQ_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline bool operator==(const type& lhs, const type& rhs) { \ + return lhs.x == rhs.x && lhs.y == rhs.y; \ + } + +#define COMPLEX_NE_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline bool operator!=(const type& lhs, const type& rhs) { \ + return !(lhs == rhs); \ + } + +#define COMPLEX_ADD_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type operator+(const type& lhs, const type& rhs) { \ + type ret; \ + ret.x = lhs.x + rhs.x; \ + ret.y = lhs.y + rhs.y; \ + return ret; \ + } + +#define COMPLEX_SUB_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type operator-(const type& lhs, const type& rhs) { \ + type ret; \ + ret.x = lhs.x - rhs.x; \ + ret.y = lhs.y - rhs.y; \ + return ret; \ + } + +#define COMPLEX_MUL_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type operator*(const type& lhs, const type& rhs) { \ + type ret; \ + ret.x = lhs.x * rhs.x - lhs.y * rhs.y; \ + ret.y = lhs.x * rhs.y + lhs.y * rhs.x; \ + return ret; \ + } + +#define COMPLEX_DIV_OP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type operator/(const type& lhs, const type& rhs) { \ + type ret; \ + ret.x = (lhs.x * rhs.x + lhs.y * rhs.y); \ + ret.y = (rhs.x * lhs.y - lhs.x * rhs.y); \ + ret.x = ret.x / (rhs.x * rhs.x + rhs.y * rhs.y); \ + ret.y = ret.y / (rhs.x * rhs.x + rhs.y * rhs.y); \ + return ret; \ + } + +#define COMPLEX_ADD_PREOP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type& operator+=(type& lhs, const type& rhs) { \ + lhs.x += rhs.x; \ + lhs.y += rhs.y; \ + return lhs; \ + } + +#define COMPLEX_SUB_PREOP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type& operator-=(type& lhs, const type& rhs) { \ + lhs.x -= rhs.x; \ + lhs.y -= rhs.y; \ + return lhs; \ + } + +#define COMPLEX_MUL_PREOP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type& operator*=(type& lhs, const type& rhs) { \ + type temp{lhs}; \ + lhs.x = rhs.x * temp.x - rhs.y * temp.y; \ + lhs.y = rhs.y * temp.x + rhs.x * temp.y; \ + return lhs; \ + } + +#define COMPLEX_DIV_PREOP_OVERLOAD(type) \ + __HOST_DEVICE__ static inline type& operator/=(type& lhs, const type& rhs) { \ + type temp; \ + temp.x = (lhs.x*rhs.x + lhs.y * rhs.y) / (rhs.x*rhs.x + rhs.y*rhs.y); \ + temp.y = (lhs.y * rhs.x - lhs.x * rhs.y) / (rhs.x*rhs.x + rhs.y*rhs.y); \ + lhs = temp; \ + return lhs; \ + } + +#define COMPLEX_SCALAR_PRODUCT(type, type1) \ + __HOST_DEVICE__ static inline type operator*(const type& lhs, type1 rhs) { \ + type ret; \ + ret.x = lhs.x * rhs; \ + ret.y = lhs.y * rhs; \ + return ret; \ + } + +#endif + typedef float2 hipFloatComplex; __HOST_DEVICE__ static inline float hipCrealf(hipFloatComplex z) { return z.x; } @@ -148,6 +240,56 @@ __HOST_DEVICE__ static inline hipDoubleComplex hipCdiv(hipDoubleComplex p, hipDo __HOST_DEVICE__ static inline double hipCabs(hipDoubleComplex z) { return sqrt(hipCsqabs(z)); } + +#if __cplusplus + +COMPLEX_NEG_OP_OVERLOAD(hipFloatComplex) +COMPLEX_EQ_OP_OVERLOAD(hipFloatComplex) +COMPLEX_NE_OP_OVERLOAD(hipFloatComplex) +COMPLEX_ADD_OP_OVERLOAD(hipFloatComplex) +COMPLEX_SUB_OP_OVERLOAD(hipFloatComplex) +COMPLEX_MUL_OP_OVERLOAD(hipFloatComplex) +COMPLEX_DIV_OP_OVERLOAD(hipFloatComplex) +COMPLEX_ADD_PREOP_OVERLOAD(hipFloatComplex) +COMPLEX_SUB_PREOP_OVERLOAD(hipFloatComplex) +COMPLEX_MUL_PREOP_OVERLOAD(hipFloatComplex) +COMPLEX_DIV_PREOP_OVERLOAD(hipFloatComplex) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, unsigned short) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, signed short) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, unsigned int) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, signed int) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, float) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, unsigned long) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, signed long) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, double) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, signed long long) +COMPLEX_SCALAR_PRODUCT(hipFloatComplex, unsigned long long) + +COMPLEX_NEG_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_EQ_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_NE_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_ADD_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_SUB_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_MUL_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_DIV_OP_OVERLOAD(hipDoubleComplex) +COMPLEX_ADD_PREOP_OVERLOAD(hipDoubleComplex) +COMPLEX_SUB_PREOP_OVERLOAD(hipDoubleComplex) +COMPLEX_MUL_PREOP_OVERLOAD(hipDoubleComplex) +COMPLEX_DIV_PREOP_OVERLOAD(hipDoubleComplex) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, unsigned short) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, signed short) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, unsigned int) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, signed int) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, float) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, unsigned long) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, signed long) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, double) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, signed long long) +COMPLEX_SCALAR_PRODUCT(hipDoubleComplex, unsigned long long) + +#endif + + typedef hipFloatComplex hipComplex; __HOST_DEVICE__ static inline hipComplex make_hipComplex(float x, float y) {