diff --git a/projects/clr/hipamd/include/hip/hcc_detail/math_functions.h b/projects/clr/hipamd/include/hip/hcc_detail/math_functions.h index 63e48fab29..69c9f358c0 100644 --- a/projects/clr/hipamd/include/hip/hcc_detail/math_functions.h +++ b/projects/clr/hipamd/include/hip/hcc_detail/math_functions.h @@ -23,6 +23,7 @@ THE SOFTWARE. #pragma once #include "hip_fp16_math_fwd.h" +#include "hip_vector_types.h" #include "math_fwd.h" #include @@ -119,6 +120,43 @@ uint64_t __make_mantissa(const char* tagp) return __make_mantissa_base10(tagp); } +// DOT FUNCTIONS +__DEVICE__ +inline +float amd_mixed_dot(__2f16 a, __2f16 b, float c, bool saturate) { + return __ockl_fdot2(a, b, c, saturate); +} +__DEVICE__ +inline +int amd_mixed_dot(short2 a, short2 b, int c, bool saturate) { + return __ockl_sdot2(a, b, c, saturate); +} +__DEVICE__ +inline +uint amd_mixed_dot(ushort2 a, ushort2 b, uint c, bool saturate) { + return __ockl_udot2(a, b, c, saturate); +} +__DEVICE__ +inline +int amd_mixed_dot(char4 a, char4 b, int c, bool saturate) { + return __ockl_sdot4(a, b, c, saturate); +} +__DEVICE__ +inline +uint amd_mixed_dot(uchar4 a, uchar4 b, uint c, bool saturate) { + return __ockl_udot4(a, b, c, saturate); +} +__DEVICE__ +inline +int amd_mixed_dot(int a, int b, int c, bool saturate) { + return __ockl_sdot8(a, b, c, saturate); +} +__DEVICE__ +inline +uint amd_mixed_dot(uint a, uint b, uint c, bool saturate) { + return __ockl_udot8(a, b, c, saturate); +} + // BEGIN FLOAT __DEVICE__ inline diff --git a/projects/clr/hipamd/include/hip/hcc_detail/math_fwd.h b/projects/clr/hipamd/include/hip/hcc_detail/math_fwd.h index e5594924ba..df611dfe80 100644 --- a/projects/clr/hipamd/include/hip/hcc_detail/math_fwd.h +++ b/projects/clr/hipamd/include/hip/hcc_detail/math_fwd.h @@ -28,6 +28,30 @@ THE SOFTWARE. extern "C" { #endif +// DOT FUNCTIONS +typedef _Float16 __2f16 __attribute__((ext_vector_type(2))); +__device__ +__attribute__((const)) +float __ockl_fdot2(__2f16 a, __2f16 b, float c, bool s); +__device__ +__attribute__((const)) +int __ockl_sdot2(short2 a, short2 b, int c, bool s); +__device__ +__attribute__((const)) +unsigned int __ockl_udot2(ushort2 a, ushort2 b, unsigned int c, bool s); +__device__ +__attribute__((const)) +int __ockl_sdot4(char4 a, char4 b, int c, bool s); +__device__ +__attribute__((const)) +unsigned int __ockl_udot4(uchar4 a, uchar4 b, unsigned int c, bool s); +__device__ +__attribute__((const)) +int __ockl_sdot8(int a, int b, int c, bool s); +__device__ +__attribute__((const)) +unsigned int __ockl_udot8(unsigned int a, unsigned int b, unsigned int c, bool s); + // BEGIN FLOAT __device__ __attribute__((const))