SWDEV-511239 - make fp8 standalone host compileable
- Use correct header in device_library_decl - use std:: instead of __hip_internal:: for host compilation - hide device specific stuff behind __clang__ and __HIP__ check Change-Id: I2f3647e00555ed0e79f9954a459c41394c3cd49b
Этот коммит содержится в:
коммит произвёл
Jatin Jaikishan Chaudhary
родитель
0391aec14a
Коммит
c3f49c8788
@@ -110,10 +110,12 @@
|
||||
|
||||
#if !defined(__HIPCC_RTC__)
|
||||
#include <hip/amd_detail/amd_hip_common.h>
|
||||
#include <hip/amd_detail/amd_warp_functions.h> // Sync functions
|
||||
#include "amd_hip_vector_types.h" // float2 etc
|
||||
#include "device_library_decls.h" // ocml conversion functions
|
||||
#include "math_fwd.h" // ocml device functions
|
||||
#if defined(__clang__) and defined(__HIP__)
|
||||
#include <hip/amd_detail/amd_warp_functions.h> // Sync functions
|
||||
#endif
|
||||
#endif // !defined(__HIPCC_RTC__)
|
||||
|
||||
#define __BF16_DEVICE__ __device__
|
||||
|
||||
@@ -28,8 +28,9 @@ THE SOFTWARE.
|
||||
#define __HOST_DEVICE__ __device__
|
||||
#else
|
||||
#define __HOST_DEVICE__ __host__ __device__
|
||||
#include <hip/amd_detail/amd_hip_common.h>
|
||||
#include "hip/amd_detail/host_defines.h"
|
||||
#include "amd_hip_common.h"
|
||||
#include "host_defines.h"
|
||||
#include "amd_hip_vector_types.h"
|
||||
#include <assert.h>
|
||||
#if defined(__cplusplus)
|
||||
#include <algorithm>
|
||||
@@ -73,7 +74,6 @@ THE SOFTWARE.
|
||||
#if defined(__cplusplus)
|
||||
#if !defined(__HIPCC_RTC__)
|
||||
#include "hip_fp16_math_fwd.h"
|
||||
#include "amd_hip_vector_types.h"
|
||||
#include "host_defines.h"
|
||||
#include "amd_device_functions.h"
|
||||
#include "amd_warp_functions.h"
|
||||
|
||||
@@ -188,9 +188,15 @@ template <typename T, bool is_fnuz>
|
||||
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
|
||||
bool stoch = false,
|
||||
unsigned int rng = 0) {
|
||||
#if defined(__clang__) and defined(__HIP__)
|
||||
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
|
||||
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
|
||||
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
|
||||
#else // compiling for host
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
constexpr bool is_double = std::is_same<T, double>::value;
|
||||
#endif // defined(__clang__) and defined(__HIP__)
|
||||
static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
|
||||
@@ -395,9 +401,15 @@ after shift right by 4 bits, it would look like midpoint.
|
||||
template <typename T, bool is_fnuz>
|
||||
__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
|
||||
bool clip = false) {
|
||||
#if defined(__clang__) and defined(__HIP__)
|
||||
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
|
||||
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
|
||||
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
|
||||
#else
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
constexpr bool is_double = std::is_same<T, double>::value;
|
||||
#endif // defined(__clang__) and defined(__HIP__)
|
||||
static_assert(is_half || is_float || is_double, "only half, float and double are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
|
||||
@@ -478,10 +490,16 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__clang__) and defined(__HIP__)
|
||||
typename __hip_internal::conditional<
|
||||
sizeof(T) == 2, unsigned short int,
|
||||
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
|
||||
unsigned long long>::type>::type retval;
|
||||
#else
|
||||
typename std::conditional<sizeof(T) == 2, unsigned short int,
|
||||
typename std::conditional<sizeof(T) == 4, unsigned int,
|
||||
unsigned long long>::type>::type retval;
|
||||
#endif
|
||||
|
||||
if (we == 5 && is_half && !is_fnuz) {
|
||||
retval = x << 8;
|
||||
|
||||
@@ -33,6 +33,11 @@ THE SOFTWARE.
|
||||
|
||||
#if !defined(__HIPCC_RTC__)
|
||||
#include "hip/amd_detail/host_defines.h"
|
||||
#if __cplusplus
|
||||
#include <cstdint>
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
typedef unsigned char uchar;
|
||||
|
||||
@@ -22,6 +22,8 @@ THE SOFTWARE.
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__clang__) and defined(__HIP__)
|
||||
|
||||
// abort
|
||||
extern "C" __device__ inline __attribute__((weak))
|
||||
void abort() {
|
||||
@@ -99,3 +101,5 @@ void __assertfail()
|
||||
__builtin_trap(); \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#endif // defined(__clang__) and defined(__HIP__)
|
||||
|
||||
@@ -234,6 +234,12 @@ struct __half2_raw {
|
||||
{
|
||||
return __internal_half2float(static_cast<__half_raw>(x).x);
|
||||
}
|
||||
inline
|
||||
float2 __half22float2(__half2 x)
|
||||
{
|
||||
return float2{__internal_half2float(static_cast<__half2_raw>(x).x),
|
||||
__internal_half2float(static_cast<__half2_raw>(x).x)};
|
||||
}
|
||||
|
||||
inline
|
||||
float __low2float(__half2 x)
|
||||
|
||||
Ссылка в новой задаче
Block a user