diff --git a/hipamd/include/hip/amd_detail/amd_hip_fp8.h b/hipamd/include/hip/amd_detail/amd_hip_fp8.h index be8e63a545..e9b1f4a335 100644 --- a/hipamd/include/hip/amd_detail/amd_hip_fp8.h +++ b/hipamd/include/hip/amd_detail/amd_hip_fp8.h @@ -44,14 +44,25 @@ #elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ #define HIP_FP8_TYPE_OCP 1 #define HIP_FP8_TYPE_FNUZ 0 -#elif __HIP_DEVICE_COMPILE__ -#define HIP_FP8_TYPE_FNUZ 0 -#define HIP_FP8_TYPE_OCP 0 -#else // Host +#else #define HIP_FP8_TYPE_FNUZ 1 #define HIP_FP8_TYPE_OCP 1 #endif +#if defined(__HIPCC_RTC__) + #if HIP_FP8_TYPE_FNUZ + #define ENABLE_FNUZ_HIPRTC 1 + #else + #define ENABLE_FNUZ_HIPRTC 0 + #endif + #if HIP_FP8_TYPE_OCP + #define ENABLE_OCP_HIPRTC 1 + #else + #define ENABLE_OCP_HIPRTC 0 + #endif +#endif + + #if !defined(__HIPCC_RTC__) #include #include @@ -474,7 +485,7 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) if (exponent <= 0) { - mantissa |= 1 << wmo; + mantissa |= 1ull << wmo; mantissa >>= 1 - exponent; exponent = 0; } @@ -540,8 +551,8 @@ static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate static __device__ __hip_fp8x2_storage_t cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) { union { - static_assert(sizeof(float2) == sizeof(unsigned int[2])); - static_assert(sizeof(float2) == sizeof(unsigned short[4])); + static_assert(sizeof(float2) == sizeof(unsigned int[2]), "size mismatch"); + static_assert(sizeof(float2) == sizeof(unsigned short[4]), "size mismatch"); float2 fval; unsigned int i32val[2]; unsigned short i16val[4]; @@ -656,8 +667,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); - } - if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + } else { int we = interp == __HIP_E4M3 ? 4 : 5; int wm = interp == __HIP_E4M3 ? 3 : 2; return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); @@ -716,8 +726,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); - } - if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + } else { int we = interp == __HIP_E4M3 ? 4 : 5; int wm = interp == __HIP_E4M3 ? 3 : 2; return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); @@ -822,8 +831,7 @@ __FP8_HOST_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_ unsigned int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; unsigned int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)}; - } - if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + } else { unsigned int we = interp == __HIP_E4M3 ? 4 : 5; unsigned int wm = interp == __HIP_E4M3 ? 3 : 2; return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)}; @@ -903,6 +911,8 @@ __FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2( * \brief struct representing single fp8 number with e4m3 interpretation * */ + +#if !defined(ENABLE_FNUZ_HIPRTC) || ENABLE_FNUZ_HIPRTC struct __hip_fp8_e4m3_fnuz { __hip_fp8_storage_t __x; //! raw storage of fp8 number constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; @@ -2014,10 +2024,15 @@ struct __hip_fp8x4_e5m2_fnuz { } }; +#endif // ENABLE_FNUZ_HIPRTC + /** * \brief struct representing ocp fp8 numbers with e4m3 interpretation * * */ + +#if !defined(ENABLE_OCP_HIPRTC) || ENABLE_OCP_HIPRTC + struct __hip_fp8_e4m3 { __hip_fp8_storage_t __x; //! raw storage of fp8 number constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; @@ -3135,4 +3150,5 @@ struct __hip_fp8x4_e5m2 { return float4(low.x, low.y, high.x, high.y); } }; +#endif // ENABLE_OCP_HIPRTC #endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ \ No newline at end of file