From 16f9dbff6c1227a8781fc54094fdab0654d54fbc Mon Sep 17 00:00:00 2001 From: Jatin Chaudhary Date: Tue, 4 Feb 2025 00:20:33 +0000 Subject: [PATCH] SWDEV-511239 - make fp8 standalone host compileable - Use correct header in device_library_decl - use std:: instead of __hip_internal:: for host compilation - hide device specific stuff behind __clang__ and __HIP__ check Change-Id: I2f3647e00555ed0e79f9954a459c41394c3cd49b [ROCm/clr commit: c3f49c878883c687d5eb7d283b78e5f8a87c5d7b] --- .../include/hip/amd_detail/amd_hip_bf16.h | 4 +++- .../include/hip/amd_detail/amd_hip_fp16.h | 6 +++--- .../include/hip/amd_detail/amd_hip_fp8.h | 18 ++++++++++++++++++ .../hip/amd_detail/device_library_decls.h | 5 +++++ .../hipamd/include/hip/amd_detail/hip_assert.h | 4 ++++ .../include/hip/amd_detail/hip_fp16_gcc.h | 6 ++++++ 6 files changed, 39 insertions(+), 4 deletions(-) diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h index 6426cc1610..f865bee996 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_bf16.h @@ -110,10 +110,12 @@ #if !defined(__HIPCC_RTC__) #include -#include // Sync functions #include "amd_hip_vector_types.h" // float2 etc #include "device_library_decls.h" // ocml conversion functions #include "math_fwd.h" // ocml device functions +#if defined(__clang__) and defined(__HIP__) +#include // Sync functions +#endif #endif // !defined(__HIPCC_RTC__) #define __BF16_DEVICE__ __device__ diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h index c68abb0e9f..8dbc7a997c 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp16.h @@ -28,8 +28,9 @@ THE SOFTWARE. #define __HOST_DEVICE__ __device__ #else #define __HOST_DEVICE__ __host__ __device__ - #include - #include "hip/amd_detail/host_defines.h" + #include "amd_hip_common.h" + #include "host_defines.h" + #include "amd_hip_vector_types.h" #include #if defined(__cplusplus) #include @@ -73,7 +74,6 @@ THE SOFTWARE. #if defined(__cplusplus) #if !defined(__HIPCC_RTC__) #include "hip_fp16_math_fwd.h" - #include "amd_hip_vector_types.h" #include "host_defines.h" #include "amd_device_functions.h" #include "amd_warp_functions.h" diff --git a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp8.h b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp8.h index 5ed745b668..e730ff3e16 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp8.h +++ b/projects/clr/hipamd/include/hip/amd_detail/amd_hip_fp8.h @@ -188,9 +188,15 @@ template __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false, bool stoch = false, unsigned int rng = 0) { +#if defined(__clang__) and defined(__HIP__) constexpr bool is_half = __hip_internal::is_same::value; constexpr bool is_float = __hip_internal::is_same::value; constexpr bool is_double = __hip_internal::is_same::value; +#else // compiling for host + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + constexpr bool is_double = std::is_same::value; +#endif // defined(__clang__) and defined(__HIP__) static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8"); const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); @@ -395,9 +401,15 @@ after shift right by 4 bits, it would look like midpoint. template __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, bool clip = false) { +#if defined(__clang__) and defined(__HIP__) constexpr bool is_half = __hip_internal::is_same::value; constexpr bool is_float = __hip_internal::is_same::value; constexpr bool is_double = __hip_internal::is_same::value; +#else + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + constexpr bool is_double = std::is_same::value; +#endif // defined(__clang__) and defined(__HIP__) static_assert(is_half || is_float || is_double, "only half, float and double are supported"); constexpr int weo = is_half ? 5 : (is_float ? 8 : 11); @@ -478,10 +490,16 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, } } +#if defined(__clang__) and defined(__HIP__) typename __hip_internal::conditional< sizeof(T) == 2, unsigned short int, typename __hip_internal::conditional::type>::type retval; +#else + typename std::conditional::type>::type retval; +#endif if (we == 5 && is_half && !is_fnuz) { retval = x << 8; diff --git a/projects/clr/hipamd/include/hip/amd_detail/device_library_decls.h b/projects/clr/hipamd/include/hip/amd_detail/device_library_decls.h index edc4692c83..81f645052c 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/device_library_decls.h +++ b/projects/clr/hipamd/include/hip/amd_detail/device_library_decls.h @@ -33,6 +33,11 @@ THE SOFTWARE. #if !defined(__HIPCC_RTC__) #include "hip/amd_detail/host_defines.h" +#if __cplusplus +#include +#else +#include +#endif #endif typedef unsigned char uchar; diff --git a/projects/clr/hipamd/include/hip/amd_detail/hip_assert.h b/projects/clr/hipamd/include/hip/amd_detail/hip_assert.h index 7d634eae0d..716a9f228d 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/hip_assert.h +++ b/projects/clr/hipamd/include/hip/amd_detail/hip_assert.h @@ -22,6 +22,8 @@ THE SOFTWARE. #pragma once +#if defined(__clang__) and defined(__HIP__) + // abort extern "C" __device__ inline __attribute__((weak)) void abort() { @@ -99,3 +101,5 @@ void __assertfail() __builtin_trap(); \ } while (0) #endif + +#endif // defined(__clang__) and defined(__HIP__) diff --git a/projects/clr/hipamd/include/hip/amd_detail/hip_fp16_gcc.h b/projects/clr/hipamd/include/hip/amd_detail/hip_fp16_gcc.h index e76a7fff3a..8fac7a6660 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/hip_fp16_gcc.h +++ b/projects/clr/hipamd/include/hip/amd_detail/hip_fp16_gcc.h @@ -234,6 +234,12 @@ struct __half2_raw { { return __internal_half2float(static_cast<__half_raw>(x).x); } + inline + float2 __half22float2(__half2 x) + { + return float2{__internal_half2float(static_cast<__half2_raw>(x).x), + __internal_half2float(static_cast<__half2_raw>(x).x)}; + } inline float __low2float(__half2 x)