Add Dot functions as amd_mixed_dot function
Introduce the Dot functions which are available in the device library. Forward those prototypes, and introduce HIP API to expose the usage of the dot functions.
This commit is contained in:
@@ -23,6 +23,7 @@ THE SOFTWARE.
|
||||
#pragma once
|
||||
|
||||
#include "hip_fp16_math_fwd.h"
|
||||
#include "hip_vector_types.h"
|
||||
#include "math_fwd.h"
|
||||
|
||||
#include <hip/hcc_detail/host_defines.h>
|
||||
@@ -119,6 +120,43 @@ uint64_t __make_mantissa(const char* tagp)
|
||||
return __make_mantissa_base10(tagp);
|
||||
}
|
||||
|
||||
// DOT FUNCTIONS
|
||||
__DEVICE__
|
||||
inline
|
||||
float amd_mixed_dot(__2f16 a, __2f16 b, float c, bool saturate) {
|
||||
return __ockl_fdot2(a, b, c, saturate);
|
||||
}
|
||||
__DEVICE__
|
||||
inline
|
||||
int amd_mixed_dot(short2 a, short2 b, int c, bool saturate) {
|
||||
return __ockl_sdot2(a, b, c, saturate);
|
||||
}
|
||||
__DEVICE__
|
||||
inline
|
||||
uint amd_mixed_dot(ushort2 a, ushort2 b, uint c, bool saturate) {
|
||||
return __ockl_udot2(a, b, c, saturate);
|
||||
}
|
||||
__DEVICE__
|
||||
inline
|
||||
int amd_mixed_dot(char4 a, char4 b, int c, bool saturate) {
|
||||
return __ockl_sdot4(a, b, c, saturate);
|
||||
}
|
||||
__DEVICE__
|
||||
inline
|
||||
uint amd_mixed_dot(uchar4 a, uchar4 b, uint c, bool saturate) {
|
||||
return __ockl_udot4(a, b, c, saturate);
|
||||
}
|
||||
__DEVICE__
|
||||
inline
|
||||
int amd_mixed_dot(int a, int b, int c, bool saturate) {
|
||||
return __ockl_sdot8(a, b, c, saturate);
|
||||
}
|
||||
__DEVICE__
|
||||
inline
|
||||
uint amd_mixed_dot(uint a, uint b, uint c, bool saturate) {
|
||||
return __ockl_udot8(a, b, c, saturate);
|
||||
}
|
||||
|
||||
// BEGIN FLOAT
|
||||
__DEVICE__
|
||||
inline
|
||||
|
||||
@@ -28,6 +28,30 @@ THE SOFTWARE.
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// DOT FUNCTIONS
|
||||
typedef _Float16 __2f16 __attribute__((ext_vector_type(2)));
|
||||
__device__
|
||||
__attribute__((const))
|
||||
float __ockl_fdot2(__2f16 a, __2f16 b, float c, bool s);
|
||||
__device__
|
||||
__attribute__((const))
|
||||
int __ockl_sdot2(short2 a, short2 b, int c, bool s);
|
||||
__device__
|
||||
__attribute__((const))
|
||||
unsigned int __ockl_udot2(ushort2 a, ushort2 b, unsigned int c, bool s);
|
||||
__device__
|
||||
__attribute__((const))
|
||||
int __ockl_sdot4(char4 a, char4 b, int c, bool s);
|
||||
__device__
|
||||
__attribute__((const))
|
||||
unsigned int __ockl_udot4(uchar4 a, uchar4 b, unsigned int c, bool s);
|
||||
__device__
|
||||
__attribute__((const))
|
||||
int __ockl_sdot8(int a, int b, int c, bool s);
|
||||
__device__
|
||||
__attribute__((const))
|
||||
unsigned int __ockl_udot8(unsigned int a, unsigned int b, unsigned int c, bool s);
|
||||
|
||||
// BEGIN FLOAT
|
||||
__device__
|
||||
__attribute__((const))
|
||||
|
||||
Verwijs in nieuw issue
Block a user