From e729f08704a90c49bf00a4a6261ce10d11218b71 Mon Sep 17 00:00:00 2001 From: Rahul Manocha Date: Wed, 18 Sep 2024 12:05:16 -0700 Subject: [PATCH] SWDEV-468039,SWDEV-482579 - Enable FP8 SW Conversions on pre gfx940 archs 1) SW Conversions for ocp and fnuz are enabled on pre mi300 archs 2) for mi300 only fnuz is enabled 3) for gfx1200 only ocp is enabled Change-Id: I90373752a2d15eff20d5deec874ed396ba4e1788 --- hipamd/include/hip/amd_detail/amd_hip_fp8.h | 42 ++++++++++++++------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/hipamd/include/hip/amd_detail/amd_hip_fp8.h b/hipamd/include/hip/amd_detail/amd_hip_fp8.h index be8e63a545..e9b1f4a335 100644 --- a/hipamd/include/hip/amd_detail/amd_hip_fp8.h +++ b/hipamd/include/hip/amd_detail/amd_hip_fp8.h @@ -44,14 +44,25 @@ #elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ #define HIP_FP8_TYPE_OCP 1 #define HIP_FP8_TYPE_FNUZ 0 -#elif __HIP_DEVICE_COMPILE__ -#define HIP_FP8_TYPE_FNUZ 0 -#define HIP_FP8_TYPE_OCP 0 -#else // Host +#else #define HIP_FP8_TYPE_FNUZ 1 #define HIP_FP8_TYPE_OCP 1 #endif +#if defined(__HIPCC_RTC__) + #if HIP_FP8_TYPE_FNUZ + #define ENABLE_FNUZ_HIPRTC 1 + #else + #define ENABLE_FNUZ_HIPRTC 0 + #endif + #if HIP_FP8_TYPE_OCP + #define ENABLE_OCP_HIPRTC 1 + #else + #define ENABLE_OCP_HIPRTC 0 + #endif +#endif + + #if !defined(__HIPCC_RTC__) #include #include @@ -474,7 +485,7 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) if (exponent <= 0) { - mantissa |= 1 << wmo; + mantissa |= 1ull << wmo; mantissa >>= 1 - exponent; exponent = 0; } @@ -540,8 +551,8 @@ static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate static __device__ __hip_fp8x2_storage_t cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) { union { - static_assert(sizeof(float2) == sizeof(unsigned int[2])); - static_assert(sizeof(float2) == sizeof(unsigned short[4])); + static_assert(sizeof(float2) == sizeof(unsigned int[2]), "size mismatch"); + static_assert(sizeof(float2) == sizeof(unsigned short[4]), "size mismatch"); float2 fval; unsigned int i32val[2]; unsigned short i16val[4]; @@ -656,8 +667,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); - } - if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + } else { int we = interp == __HIP_E4M3 ? 4 : 5; int wm = interp == __HIP_E4M3 ? 3 : 2; return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); @@ -716,8 +726,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); - } - if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + } else { int we = interp == __HIP_E4M3 ? 4 : 5; int wm = interp == __HIP_E4M3 ? 3 : 2; return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); @@ -822,8 +831,7 @@ __FP8_HOST_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_ unsigned int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; unsigned int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)}; - } - if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + } else { unsigned int we = interp == __HIP_E4M3 ? 4 : 5; unsigned int wm = interp == __HIP_E4M3 ? 3 : 2; return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)}; @@ -903,6 +911,8 @@ __FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2( * \brief struct representing single fp8 number with e4m3 interpretation * */ + +#if !defined(ENABLE_FNUZ_HIPRTC) || ENABLE_FNUZ_HIPRTC struct __hip_fp8_e4m3_fnuz { __hip_fp8_storage_t __x; //! raw storage of fp8 number constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; @@ -2014,10 +2024,15 @@ struct __hip_fp8x4_e5m2_fnuz { } }; +#endif // ENABLE_FNUZ_HIPRTC + /** * \brief struct representing ocp fp8 numbers with e4m3 interpretation * * */ + +#if !defined(ENABLE_OCP_HIPRTC) || ENABLE_OCP_HIPRTC + struct __hip_fp8_e4m3 { __hip_fp8_storage_t __x; //! raw storage of fp8 number constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; @@ -3135,4 +3150,5 @@ struct __hip_fp8x4_e5m2 { return float4(low.x, low.y, high.x, high.y); } }; +#endif // ENABLE_OCP_HIPRTC #endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ \ No newline at end of file