SWDEV-511239 - make fp8 standalone host compileable

- Use correct header in device_library_decl
- use std:: instead of __hip_internal:: for host compilation
- hide device specific stuff behind __clang__ and __HIP__ check

Change-Id: I2f3647e00555ed0e79f9954a459c41394c3cd49b


[ROCm/clr commit: c3f49c8788]
Этот коммит содержится в:
Jatin Chaudhary
2025-02-04 00:20:33 +00:00
коммит произвёл Jatin Jaikishan Chaudhary
родитель 508d043176
Коммит 16f9dbff6c
6 изменённых файлов: 39 добавлений и 4 удалений
+3 -1
Просмотреть файл
@@ -110,10 +110,12 @@
#if !defined(__HIPCC_RTC__)
#include <hip/amd_detail/amd_hip_common.h>
#include <hip/amd_detail/amd_warp_functions.h> // Sync functions
#include "amd_hip_vector_types.h" // float2 etc
#include "device_library_decls.h" // ocml conversion functions
#include "math_fwd.h" // ocml device functions
#if defined(__clang__) and defined(__HIP__)
#include <hip/amd_detail/amd_warp_functions.h> // Sync functions
#endif
#endif // !defined(__HIPCC_RTC__)
#define __BF16_DEVICE__ __device__
+3 -3
Просмотреть файл
@@ -28,8 +28,9 @@ THE SOFTWARE.
#define __HOST_DEVICE__ __device__
#else
#define __HOST_DEVICE__ __host__ __device__
#include <hip/amd_detail/amd_hip_common.h>
#include "hip/amd_detail/host_defines.h"
#include "amd_hip_common.h"
#include "host_defines.h"
#include "amd_hip_vector_types.h"
#include <assert.h>
#if defined(__cplusplus)
#include <algorithm>
@@ -73,7 +74,6 @@ THE SOFTWARE.
#if defined(__cplusplus)
#if !defined(__HIPCC_RTC__)
#include "hip_fp16_math_fwd.h"
#include "amd_hip_vector_types.h"
#include "host_defines.h"
#include "amd_device_functions.h"
#include "amd_warp_functions.h"
+18
Просмотреть файл
@@ -188,9 +188,15 @@ template <typename T, bool is_fnuz>
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
bool stoch = false,
unsigned int rng = 0) {
#if defined(__clang__) and defined(__HIP__)
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
#else // compiling for host
constexpr bool is_half = std::is_same<T, _Float16>::value;
constexpr bool is_float = std::is_same<T, float>::value;
constexpr bool is_double = std::is_same<T, double>::value;
#endif // defined(__clang__) and defined(__HIP__)
static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
@@ -395,9 +401,15 @@ after shift right by 4 bits, it would look like midpoint.
template <typename T, bool is_fnuz>
__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
bool clip = false) {
#if defined(__clang__) and defined(__HIP__)
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
#else
constexpr bool is_half = std::is_same<T, _Float16>::value;
constexpr bool is_float = std::is_same<T, float>::value;
constexpr bool is_double = std::is_same<T, double>::value;
#endif // defined(__clang__) and defined(__HIP__)
static_assert(is_half || is_float || is_double, "only half, float and double are supported");
constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
@@ -478,10 +490,16 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
}
}
#if defined(__clang__) and defined(__HIP__)
typename __hip_internal::conditional<
sizeof(T) == 2, unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
unsigned long long>::type>::type retval;
#else
typename std::conditional<sizeof(T) == 2, unsigned short int,
typename std::conditional<sizeof(T) == 4, unsigned int,
unsigned long long>::type>::type retval;
#endif
if (we == 5 && is_half && !is_fnuz) {
retval = x << 8;
+5
Просмотреть файл
@@ -33,6 +33,11 @@ THE SOFTWARE.
#if !defined(__HIPCC_RTC__)
#include "hip/amd_detail/host_defines.h"
#if __cplusplus
#include <cstdint>
#else
#include <stdint.h>
#endif
#endif
typedef unsigned char uchar;
+4
Просмотреть файл
@@ -22,6 +22,8 @@ THE SOFTWARE.
#pragma once
#if defined(__clang__) and defined(__HIP__)
// abort
extern "C" __device__ inline __attribute__((weak))
void abort() {
@@ -99,3 +101,5 @@ void __assertfail()
__builtin_trap(); \
} while (0)
#endif
#endif // defined(__clang__) and defined(__HIP__)
+6
Просмотреть файл
@@ -234,6 +234,12 @@ struct __half2_raw {
{
return __internal_half2float(static_cast<__half_raw>(x).x);
}
inline
float2 __half22float2(__half2 x)
{
return float2{__internal_half2float(static_cast<__half2_raw>(x).x),
__internal_half2float(static_cast<__half2_raw>(x).x)};
}
inline
float __low2float(__half2 x)