diff --git a/hipamd/include/hip/amd_detail/amd_hip_fp8.h b/hipamd/include/hip/amd_detail/amd_hip_fp8.h index 47419cc61b..be8e63a545 100644 --- a/hipamd/include/hip/amd_detail/amd_hip_fp8.h +++ b/hipamd/include/hip/amd_detail/amd_hip_fp8.h @@ -61,6 +61,7 @@ #include "amd_hip_fp16.h" // __half_raw #include "amd_hip_bf16.h" // bf16 #include "math_fwd.h" // ocml device functions +#include "hip_assert.h" // hip assertions #endif // !defined(__HIPCC_RTC__) #if defined(__HIPCC_RTC__) @@ -72,6 +73,8 @@ #endif // __HIPCC_RTC__ #define __FP8_HOST__ __host__ +#define __FP8_HOST_STATIC__ __FP8_HOST__ static inline + #if !defined(__HIPCC_RTC__) static_assert(CHAR_BIT == 8, "byte size should be of 8 bits"); @@ -118,14 +121,41 @@ typedef unsigned short int __hip_fp8x2_storage_t; */ typedef unsigned int __hip_fp8x4_storage_t; + namespace internal { +// Assertions to check for supported conversion types +#define __assert_ocp_support(interp) \ + { \ + if (interp != __HIP_E4M3 && interp != __HIP_E5M2) { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } +#define __assert_fnuz_support(interp) \ + { \ + if (interp != __HIP_E4M3_FNUZ && interp != __HIP_E5M2_FNUZ) { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } + +__FP8_HOST_DEVICE_STATIC__ void __is_interpret_supported(__hip_fp8_interpretation_t interp) { +#if __HIP_DEVICE_COMPILE__ +#if HIP_FP8_TYPE_OCP + __assert_ocp_support(interp); +#endif +#if HIP_FP8_TYPE_FNUZ + __assert_fnuz_support(interp); +#endif +#endif +} + // The conversion function is from rocblas // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 // This has been modified to add double types conversion as well template __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false, - bool stoch = false, unsigned int rng = 0) { + bool stoch = false, + unsigned int rng = 0) { constexpr bool is_half = __hip_internal::is_same::value; constexpr bool is_float = __hip_internal::is_same::value; constexpr bool is_double = __hip_internal::is_same::value; @@ -330,7 +360,8 @@ after shift right by 4 bits, it would look like midpoint. // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 // This has been modified to handle double types as well template -__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, bool clip = false) { +__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, + bool clip = false) { constexpr bool is_half = __hip_internal::is_same::value; constexpr bool is_float = __hip_internal::is_same::value; constexpr bool is_double = __hip_internal::is_same::value; @@ -605,22 +636,30 @@ __FP8_HOST_DEVICE_STATIC__ bool hip_fp8_ocp_is_inf(__hip_fp8_storage_t a, * * \param f float number * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8_storage_t */ -__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( - const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { #if HIP_FP8_CVT_FAST_PATH - return internal::cast_to_f8_from_f32(f, sat == __HIP_SATFINITE, type); -#else // HIP_FP8_CVT_FAST_PATH - if (type == __HIP_E4M3_FNUZ || type == __HIP_E5M2_FNUZ) { - int we = type == __HIP_E4M3_FNUZ ? 4 : 5; - int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( + const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); + return internal::cast_to_f8_from_f32(f, sat == __HIP_SATFINITE, interp); +#else +#if HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( + const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( + const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#endif + if (interp == __HIP_E4M3_FNUZ || interp == __HIP_E5M2_FNUZ) { + int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; + int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); } - if (type == __HIP_E4M3 || type == __HIP_E5M2) { - int we = type == __HIP_E4M3 ? 4 : 5; - int wm = type == __HIP_E4M3 ? 3 : 2; + if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + int we = interp == __HIP_E4M3 ? 4 : 5; + int wm = interp == __HIP_E4M3 ? 3 : 2; return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); } #endif // HIP_FP8_CVT_FAST_PATH @@ -632,18 +671,26 @@ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8( * * \param f2 float2 number * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8x2_storage_t */ -__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2( - const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { #if HIP_FP8_CVT_FAST_PATH - return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type); +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2( + const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); + return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, interp); #else - return static_cast<__hip_fp8x2_storage_t>( - static_cast(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 | - static_cast(__hip_cvt_float_to_fp8(f2.x, sat, type))); +#if HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2( + const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2( + const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { #endif + return static_cast<__hip_fp8x2_storage_t>( + static_cast(__hip_cvt_float_to_fp8(f2.y, sat, interp)) << 8 | + static_cast(__hip_cvt_float_to_fp8(f2.x, sat, interp))); +#endif // HIP_FP8_CVT_FAST_PATH } /** @@ -651,19 +698,28 @@ __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2( * * \param d double val * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8_storage_t */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( - const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { - if (type == __HIP_E4M3_FNUZ || type == __HIP_E5M2_FNUZ) { - int we = type == __HIP_E4M3_FNUZ ? 4 : 5; - int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( + const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( + const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#endif + if (interp == __HIP_E4M3_FNUZ || interp == __HIP_E5M2_FNUZ) { + int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; + int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); } - if (type == __HIP_E4M3 || type == __HIP_E5M2) { - int we = type == __HIP_E4M3 ? 4 : 5; - int wm = type == __HIP_E4M3 ? 3 : 2; + if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + int we = interp == __HIP_E4M3 ? 4 : 5; + int wm = interp == __HIP_E4M3 ? 3 : 2; return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); } } @@ -673,14 +729,23 @@ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8( * * \param d2 double2 val * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8x2_storage_t */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2( - const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { + const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2( + const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2( + const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#endif return static_cast<__hip_fp8x2_storage_t>( - static_cast(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 | - static_cast(__hip_cvt_double_to_fp8(d2.x, sat, type))); + static_cast(__hip_cvt_double_to_fp8(d2.y, sat, interp)) << 8 | + static_cast(__hip_cvt_double_to_fp8(d2.x, sat, interp))); } /** @@ -688,14 +753,25 @@ __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2( * * \param hr __hip_bfloat16_raw val * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8_storage_t */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, - const __hip_fp8_interpretation_t type) { + const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t +__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8_storage_t +__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t interp) { +#endif float fval = __hip_bfloat16(hr); - return __hip_cvt_float_to_fp8(fval, sat, type); + return __hip_cvt_float_to_fp8(fval, sat, interp); } /** @@ -703,33 +779,53 @@ __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation * * \param hr __hip_bfloat162_raw value * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8x2_storage_t */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, - const __hip_fp8_interpretation_t type) { + const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t interp) { +#endif float2 f2 = __hip_bfloat162(hr); - return __hip_cvt_float2_to_fp8x2(f2, sat, type); + return __hip_cvt_float2_to_fp8x2(f2, sat, interp); } /** * \brief convert @p __hip_fp8_storage_t to __half_raw * * \param x __hip_fp8_storage_t val - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __half_raw */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __half_raw -__hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type) { - if (type == __HIP_E4M3_FNUZ || type == __HIP_E5M2_FNUZ) { - unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5; - unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; +__hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __half_raw +__hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, + const __hip_fp8_interpretation_t interp) { +#endif + if (interp == __HIP_E4M3_FNUZ || interp == __HIP_E5M2_FNUZ) { + unsigned int we = interp == __HIP_E4M3_FNUZ ? 4 : 5; + unsigned int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2; return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)}; } - if (type == __HIP_E4M3 || type == __HIP_E5M2) { - unsigned int we = type == __HIP_E4M3 ? 4 : 5; - unsigned int wm = type == __HIP_E4M3 ? 3 : 2; + if (interp == __HIP_E4M3 || interp == __HIP_E5M2) { + unsigned int we = interp == __HIP_E4M3 ? 4 : 5; + unsigned int wm = interp == __HIP_E4M3 ? 3 : 2; return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)}; } } @@ -738,15 +834,24 @@ __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpreta * \brief convert @p __hip_fp8x2_storage_t to __half2_raw * * \param x __hip_fp8x2_storage_t val - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __half2_raw */ -__FP8_HOST_DEVICE_STATIC__ __half2_raw -__hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type) { +#if HIP_FP8_CVT_FAST_PATH +__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2( + const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2( + const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2( + const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t interp) { +#endif __half2 ret(static_cast<__half>( - __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)), + __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), interp)), static_cast<__half>( - __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type))); + __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), interp))); return static_cast<__half2_raw>(ret); } @@ -755,12 +860,21 @@ __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_inter * * \param x __half_raw value * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8_storage_t */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8( - const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { - return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type); + const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8( + const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8( + const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#endif + return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, interp); } /** @@ -768,12 +882,21 @@ __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8( * * \param x __half2_raw value * \param sat saturation of fp8 - * \param type interpretation of fp8 + * \param interp interpretation of fp8 * \return __hip_fp8x2_storage_t */ +#if HIP_FP8_CVT_FAST_PATH __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2( - const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) { - return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type); + const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { + internal::__is_interpret_supported(interp); +#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2( + const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#else +__FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2( + const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) { +#endif + return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, interp); } /** @@ -797,7 +920,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from int */ #if HIP_FP8_TYPE_FNUZ @@ -806,7 +930,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from short int */ #if HIP_FP8_TYPE_FNUZ @@ -815,7 +940,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from unsigned long */ #if HIP_FP8_TYPE_FNUZ @@ -824,7 +950,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from unsigned int */ #if HIP_FP8_TYPE_FNUZ @@ -833,7 +960,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from unsigned short */ #if HIP_FP8_TYPE_FNUZ @@ -842,7 +970,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from double */ #if HIP_FP8_TYPE_FNUZ @@ -850,7 +979,8 @@ struct __hip_fp8_e4m3_fnuz { #else __FP8_HOST__ __hip_fp8_e4m3_fnuz(const double f) #endif - : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e4m3 from float */ #if HIP_FP8_TYPE_FNUZ @@ -858,7 +988,8 @@ struct __hip_fp8_e4m3_fnuz { #else __FP8_HOST__ __hip_fp8_e4m3_fnuz(const float f) #endif - : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e4m3 from __hip_bfloat16 */ #if HIP_FP8_TYPE_FNUZ @@ -867,7 +998,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f) #endif : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from __half */ #if HIP_FP8_TYPE_FNUZ @@ -876,7 +1008,8 @@ struct __hip_fp8_e4m3_fnuz { __FP8_HOST__ __hip_fp8_e4m3_fnuz(const __half f) #endif : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! default construct fp8 e4m3 */ #if HIP_FP8_TYPE_FNUZ @@ -1118,17 +1251,17 @@ struct __hip_fp8_e4m3_fnuz { #else __FP8_HOST__ operator unsigned short int() const { #endif - if (internal::hip_fp8_fnuz_is_nan(__x)) { - return 0; - } + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } - float fval = *this; - auto llval = static_cast(fval); - if (llval <= 0) { - return 0; - } - return static_cast(fval); - } + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } }; /** @@ -1148,7 +1281,8 @@ struct __hip_fp8x2_e4m3_fnuz { #else __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const double2 val) #endif - : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e4m3 type from float2 */ #if HIP_FP8_TYPE_FNUZ @@ -1156,7 +1290,8 @@ struct __hip_fp8x2_e4m3_fnuz { #else __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const float2 val) #endif - : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e4m3 type from __hip_bfloat162 */ #if HIP_FP8_TYPE_FNUZ @@ -1164,7 +1299,8 @@ struct __hip_fp8x2_e4m3_fnuz { #else __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val) #endif - : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e4m3 type from __half2 */ #if HIP_FP8_TYPE_FNUZ @@ -1172,7 +1308,8 @@ struct __hip_fp8x2_e4m3_fnuz { #else __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const __half2 val) #endif - : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! Default construct of fp8x2 e4m3 */ #if HIP_FP8_TYPE_FNUZ @@ -1197,7 +1334,7 @@ struct __hip_fp8x2_e4m3_fnuz { __FP8_HOST__ operator float2() const { #endif #if HIP_FP8_CVT_FAST_PATH - return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); + return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); #else return float2(internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we), @@ -1235,7 +1372,8 @@ struct __hip_fp8x4_e4m3_fnuz { << 16 | reinterpret_cast(__hip_cvt_double_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))} {} + << 24))} { + } /*! create fp8x4 e4m3 type from float4 */ #if HIP_FP8_TYPE_FNUZ @@ -1254,7 +1392,8 @@ struct __hip_fp8x4_e4m3_fnuz { << 16 | reinterpret_cast(__hip_cvt_float_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))} {} + << 24))} { + } /*! create fp8x4 e4m3 type from two __hip_bfloat162 */ #if HIP_FP8_TYPE_FNUZ @@ -1267,7 +1406,8 @@ struct __hip_fp8x4_e4m3_fnuz { __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) | reinterpret_cast( __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /*! create fp8x4 e4m3 type from two __half2 */ #if HIP_FP8_TYPE_FNUZ @@ -1280,7 +1420,8 @@ struct __hip_fp8x4_e4m3_fnuz { high, __default_saturation, __default_interpret)) | reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /*! Default construct fp8x4 e4m3 */ #if HIP_FP8_TYPE_FNUZ @@ -1295,7 +1436,7 @@ struct __hip_fp8x4_e4m3_fnuz { #else __FP8_HOST__ operator float4() const { #endif - auto x = __x; // bypass const + auto x = __x; // bypass const auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1); #if HIP_FP8_CVT_FAST_PATH @@ -1337,7 +1478,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from int */ #if HIP_FP8_TYPE_FNUZ @@ -1346,7 +1488,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from short int */ #if HIP_FP8_TYPE_FNUZ @@ -1355,7 +1498,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from unsigned long */ #if HIP_FP8_TYPE_FNUZ @@ -1364,7 +1508,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from unsigned int */ #if HIP_FP8_TYPE_FNUZ @@ -1373,7 +1518,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from unsigned short */ #if HIP_FP8_TYPE_FNUZ @@ -1382,7 +1528,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from double */ #if HIP_FP8_TYPE_FNUZ @@ -1390,7 +1537,8 @@ struct __hip_fp8_e5m2_fnuz { #else __FP8_HOST__ __hip_fp8_e5m2_fnuz(const double f) #endif - : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e5m2 type from float */ #if HIP_FP8_TYPE_FNUZ @@ -1398,7 +1546,8 @@ struct __hip_fp8_e5m2_fnuz { #else __FP8_HOST__ __hip_fp8_e5m2_fnuz(const float f) #endif - : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e5m2 type from __hip_bfloat16 */ #if HIP_FP8_TYPE_FNUZ @@ -1407,7 +1556,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f) #endif : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from __hip_bfloat16 */ #if HIP_FP8_TYPE_FNUZ @@ -1416,7 +1566,8 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ __hip_fp8_e5m2_fnuz(const __half f) #endif : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! default construct fp8 e5m2 */ #if HIP_FP8_TYPE_FNUZ @@ -1659,7 +1810,7 @@ struct __hip_fp8_e5m2_fnuz { __FP8_HOST__ operator unsigned short int() const { #endif if (internal::hip_fp8_fnuz_is_nan(__x)) { - return 0; + return 0; } float fval = *this; @@ -1688,7 +1839,8 @@ struct __hip_fp8x2_e5m2_fnuz { #else __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const double2 val) #endif - : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e5m2 type from float2 */ #if HIP_FP8_TYPE_FNUZ @@ -1696,7 +1848,8 @@ struct __hip_fp8x2_e5m2_fnuz { #else __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const float2 val) #endif - : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e5m2 type from __hip_bfloat162 */ #if HIP_FP8_TYPE_FNUZ @@ -1704,7 +1857,8 @@ struct __hip_fp8x2_e5m2_fnuz { #else __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val) #endif - : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e5m2 type from __half2 */ #if HIP_FP8_TYPE_FNUZ @@ -1712,7 +1866,8 @@ struct __hip_fp8x2_e5m2_fnuz { #else __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const __half2 val) #endif - : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! default construct fp8x2 e5m2 */ #if HIP_FP8_TYPE_FNUZ @@ -1775,7 +1930,8 @@ struct __hip_fp8x4_e5m2_fnuz { << 16 | reinterpret_cast(__hip_cvt_double_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))) {} + << 24))) { + } /*! create fp8x4 e5m2 type from float4 */ #if HIP_FP8_TYPE_FNUZ @@ -1794,7 +1950,8 @@ struct __hip_fp8x4_e5m2_fnuz { << 16 | reinterpret_cast(__hip_cvt_float_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))) {} + << 24))) { + } /*! create fp8x4 e5m2 type from two __hip_bfloat162 */ #if HIP_FP8_TYPE_FNUZ @@ -1807,7 +1964,8 @@ struct __hip_fp8x4_e5m2_fnuz { __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) | reinterpret_cast( __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /*! create fp8x4 e5m2 type from two __half2 */ #if HIP_FP8_TYPE_FNUZ @@ -1820,7 +1978,8 @@ struct __hip_fp8x4_e5m2_fnuz { high, __default_saturation, __default_interpret)) | reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /* default construct fp8x4 e5m2 */ #if HIP_FP8_TYPE_FNUZ @@ -1835,7 +1994,7 @@ struct __hip_fp8x4_e5m2_fnuz { #else __FP8_HOST__ operator float4() const { #endif - auto x = __x; // bypass const + auto x = __x; // bypass const auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1); #if HIP_FP8_CVT_FAST_PATH @@ -1860,7 +2019,7 @@ struct __hip_fp8x4_e5m2_fnuz { * * */ struct __hip_fp8_e4m3 { - __hip_fp8_storage_t __x; //! raw storage of fp8 number + __hip_fp8_storage_t __x; //! raw storage of fp8 number constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3; constexpr static unsigned int __we = 4; @@ -1871,21 +2030,23 @@ struct __hip_fp8_e4m3 { /*! create fp8 e4m3 from long */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val) #else -__FP8_HOST__ __hip_fp8_e4m3(const long int val) + __FP8_HOST__ __hip_fp8_e4m3(const long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from int */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val) #else -__FP8_HOST__ __hip_fp8_e4m3(const int val) + __FP8_HOST__ __hip_fp8_e4m3(const int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from short int */ __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const short int val) @@ -1894,87 +2055,94 @@ __FP8_HOST__ __hip_fp8_e4m3(const int val) /*! create fp8 e4m3 from unsigned long */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val) #else -__FP8_HOST__ __hip_fp8_e4m3(const unsigned long int val) + __FP8_HOST__ __hip_fp8_e4m3(const unsigned long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from unsigned int */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val) #else -__FP8_HOST__ __hip_fp8_e4m3(const unsigned int val) + __FP8_HOST__ __hip_fp8_e4m3(const unsigned int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from unsigned short */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val) #else -__FP8_HOST__ __hip_fp8_e4m3(const unsigned short int val) + __FP8_HOST__ __hip_fp8_e4m3(const unsigned short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from double */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f) #else -__FP8_HOST__ __hip_fp8_e4m3(const double f) + __FP8_HOST__ __hip_fp8_e4m3(const double f) #endif - : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e4m3 from float */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f) #else -__FP8_HOST__ __hip_fp8_e4m3(const float f) + __FP8_HOST__ __hip_fp8_e4m3(const float f) #endif - : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e4m3 from __hip_bfloat16 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f) #else -__FP8_HOST__ __hip_fp8_e4m3(const __hip_bfloat16 f) + __FP8_HOST__ __hip_fp8_e4m3(const __hip_bfloat16 f) #endif : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e4m3 from __half */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f) + __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f) #else -__FP8_HOST__ __hip_fp8_e4m3(const __half f) + __FP8_HOST__ __hip_fp8_e4m3(const __half f) #endif : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! default construct fp8 e4m3 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e4m3() = default; + __FP8_HOST_DEVICE__ __hip_fp8_e4m3() = default; #else -__FP8_HOST__ __hip_fp8_e4m3() = default; + __FP8_HOST__ __hip_fp8_e4m3() = default; #endif /*! convert fp8 e4m3 to __half */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator __half() const { + __FP8_HOST_DEVICE__ operator __half() const { #else -__FP8_HOST__ operator __half() const { + __FP8_HOST__ operator __half() const { #endif return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret)); } /*! convert fp8 e4m3 to __hip_bfloat16 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator __hip_bfloat16() const { + __FP8_HOST_DEVICE__ operator __hip_bfloat16() const { #else -__FP8_HOST__ operator __hip_bfloat16() const { + __FP8_HOST__ operator __hip_bfloat16() const { #endif float f = *this; return __hip_bfloat16(f); @@ -1982,9 +2150,9 @@ __FP8_HOST__ operator __hip_bfloat16() const { /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator bool() const { + __FP8_HOST_DEVICE__ operator bool() const { #else -__FP8_HOST__ operator bool() const { + __FP8_HOST__ operator bool() const { #endif // it can be 0x00 (+0.0) since 0x80 will be nan return !(static_cast(__x) == 0 || static_cast(__x) == 0x80); @@ -1992,11 +2160,11 @@ __FP8_HOST__ operator bool() const { /*! convert fp8 e4m3 to char, clamp number to CHAR_MIN/CHAR_MAX if its out of range */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator char() const { + __FP8_HOST_DEVICE__ operator char() const { #else -__FP8_HOST__ operator char() const { + __FP8_HOST__ operator char() const { #endif - if (internal::hip_fp8_ocp_is_nan(__x,__default_interpret)) { + if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; } @@ -2012,18 +2180,18 @@ __FP8_HOST__ operator char() const { /*! convert fp8 e4m3 to double */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator double() const { + __FP8_HOST_DEVICE__ operator double() const { #else -__FP8_HOST__ operator double() const { + __FP8_HOST__ operator double() const { #endif return internal::cast_from_f8(__x, __wm, __we); } /*! convert fp8 e4m3 to float */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator float() const { + __FP8_HOST_DEVICE__ operator float() const { #else -__FP8_HOST__ operator float() const { + __FP8_HOST__ operator float() const { #endif #if HIP_FP8_CVT_FAST_PATH return internal::cast_to_f32_from_f8(__x, __default_interpret); @@ -2034,9 +2202,9 @@ __FP8_HOST__ operator float() const { /*! convert fp8 e4m3 to int, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator int() const { + __FP8_HOST_DEVICE__ operator int() const { #else -__FP8_HOST__ operator int() const { + __FP8_HOST__ operator int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2048,9 +2216,9 @@ __FP8_HOST__ operator int() const { /*! convert fp8 e4m3 to long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator long int() const { + __FP8_HOST_DEVICE__ operator long int() const { #else -__FP8_HOST__ operator long int() const { + __FP8_HOST__ operator long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2062,9 +2230,9 @@ __FP8_HOST__ operator long int() const { /*! convert fp8 e4m3 to long long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator long long int() const { + __FP8_HOST_DEVICE__ operator long long int() const { #else -__FP8_HOST__ operator long long int() const { + __FP8_HOST__ operator long long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2076,9 +2244,9 @@ __FP8_HOST__ operator long long int() const { /*! convert fp8 e4m3 to short int, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator short int() const { + __FP8_HOST_DEVICE__ operator short int() const { #else -__FP8_HOST__ operator short int() const { + __FP8_HOST__ operator short int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2096,9 +2264,9 @@ __FP8_HOST__ operator short int() const { /*! convert fp8 e4m3 to signed char, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator signed char() const { + __FP8_HOST_DEVICE__ operator signed char() const { #else -__FP8_HOST__ operator signed char() const { + __FP8_HOST__ operator signed char() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2116,9 +2284,9 @@ __FP8_HOST__ operator signed char() const { /*! convert fp8 e4m3 to unsigned char, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned char() const { + __FP8_HOST_DEVICE__ operator unsigned char() const { #else -__FP8_HOST__ operator unsigned char() const { + __FP8_HOST__ operator unsigned char() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2136,9 +2304,9 @@ __FP8_HOST__ operator unsigned char() const { /*! convert fp8 e4m3 to unsigned int, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned int() const { + __FP8_HOST_DEVICE__ operator unsigned int() const { #else -__FP8_HOST__ operator unsigned int() const { + __FP8_HOST__ operator unsigned int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2154,9 +2322,9 @@ __FP8_HOST__ operator unsigned int() const { /*! convert fp8 e4m3 to unsigned long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned long int() const { + __FP8_HOST_DEVICE__ operator unsigned long int() const { #else -__FP8_HOST__ operator unsigned long int() const { + __FP8_HOST__ operator unsigned long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2172,9 +2340,9 @@ __FP8_HOST__ operator unsigned long int() const { /*! convert fp8 e4m3 to long long int, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned long long int() const { + __FP8_HOST_DEVICE__ operator unsigned long long int() const { #else -__FP8_HOST__ operator unsigned long long int() const { + __FP8_HOST__ operator unsigned long long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2190,11 +2358,11 @@ __FP8_HOST__ operator unsigned long long int() const { /*! convert fp8 e4m3 to unsigned short, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned short int() const { + __FP8_HOST_DEVICE__ operator unsigned short int() const { #else -__FP8_HOST__ operator unsigned short int() const { + __FP8_HOST__ operator unsigned short int() const { #endif - if (internal::hip_fp8_ocp_is_nan(__x,__default_interpret)) { + if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; } @@ -2221,63 +2389,69 @@ struct __hip_fp8x2_e4m3 { /*! create fp8x2 e4m3 type from double2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val) #else -__FP8_HOST__ __hip_fp8x2_e4m3(const double2 val) + __FP8_HOST__ __hip_fp8x2_e4m3(const double2 val) #endif - : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e4m3 type from float2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val) #else -__FP8_HOST__ __hip_fp8x2_e4m3(const float2 val) + __FP8_HOST__ __hip_fp8x2_e4m3(const float2 val) #endif - : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e4m3 type from __hip_bfloat162 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val) #else -__FP8_HOST__ __hip_fp8x2_e4m3(const __hip_bfloat162 val) + __FP8_HOST__ __hip_fp8x2_e4m3(const __hip_bfloat162 val) #endif - : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e4m3 type from __half2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val) #else -__FP8_HOST__ __hip_fp8x2_e4m3(const __half2 val) + __FP8_HOST__ __hip_fp8x2_e4m3(const __half2 val) #endif - : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! Default construct of fp8x2 e4m3 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3() = default; + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3() = default; #else -__FP8_HOST__ __hip_fp8x2_e4m3() = default; + __FP8_HOST__ __hip_fp8x2_e4m3() = default; #endif /*! convert fp8x2 e4m3 to __half2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator __half2() const { + __FP8_HOST_DEVICE__ operator __half2() const { #else -__FP8_HOST__ operator __half2() const { + __FP8_HOST__ operator __half2() const { #endif return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret)); } /*! convert fp8x2 e4m3 to float2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator float2() const { + __FP8_HOST_DEVICE__ operator float2() const { #else -__FP8_HOST__ operator float2() const { + __FP8_HOST__ operator float2() const { #endif #if HIP_FP8_CVT_FAST_PATH - return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); + return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); #else - return float2(internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we), - internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we)); + return float2(internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), + __wm, __we), + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x >> 8), + __wm, __we)); #endif } }; @@ -2296,11 +2470,11 @@ struct __hip_fp8x4_e4m3 { /*! create fp8x4 e4m3 type from double4 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val) + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val) #else -__FP8_HOST__ __hip_fp8x4_e4m3(const double4 val) + __FP8_HOST__ __hip_fp8x4_e4m3(const double4 val) #endif - : __x{reinterpret_cast<__hip_fp8x4_storage_t>( + : __x{reinterpret_cast<__hip_fp8x4_storage_t>( static_cast(reinterpret_cast(__hip_cvt_double_to_fp8( val.x, __default_saturation, __default_interpret)) | reinterpret_cast(__hip_cvt_double_to_fp8( @@ -2311,13 +2485,14 @@ __FP8_HOST__ __hip_fp8x4_e4m3(const double4 val) << 16 | reinterpret_cast(__hip_cvt_double_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))} {} + << 24))} { + } /*! create fp8x4 e4m3 type from float4 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val) + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val) #else -__FP8_HOST__ __hip_fp8x4_e4m3(const float4 val) + __FP8_HOST__ __hip_fp8x4_e4m3(const float4 val) #endif : __x{reinterpret_cast<__hip_fp8x4_storage_t>( static_cast(reinterpret_cast(__hip_cvt_float_to_fp8( @@ -2330,48 +2505,51 @@ __FP8_HOST__ __hip_fp8x4_e4m3(const float4 val) << 16 | reinterpret_cast(__hip_cvt_float_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))} {} + << 24))} { + } /*! create fp8x4 e4m3 type from two __hip_bfloat162 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high) + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high) #else -__FP8_HOST__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high) + __FP8_HOST__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high) #endif : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( reinterpret_cast( __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) | reinterpret_cast( __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /*! create fp8x4 e4m3 type from two __half2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high) + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high) #else -__FP8_HOST__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high) + __FP8_HOST__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high) #endif : __x(reinterpret_cast<__hip_fp8x4_storage_t>( static_cast(reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( high, __default_saturation, __default_interpret)) | reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /*! Default construct fp8x4 e4m3 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3() = default; + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3() = default; #else -__FP8_HOST__ __hip_fp8x4_e4m3() = default; + __FP8_HOST__ __hip_fp8x4_e4m3() = default; #endif /*! convert fp8x4 e4m3 to float4 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator float4() const { + __FP8_HOST_DEVICE__ operator float4() const { #else -__FP8_HOST__ operator float4() const { + __FP8_HOST__ operator float4() const { #endif - auto x = __x; // bypass const + auto x = __x; // bypass const auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1); #if HIP_FP8_CVT_FAST_PATH @@ -2409,126 +2587,137 @@ struct __hip_fp8_e5m2 { /*! create fp8 e5m2 type from long */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val) #else -__FP8_HOST__ __hip_fp8_e5m2(const long int val) + __FP8_HOST__ __hip_fp8_e5m2(const long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from int */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val) #else -__FP8_HOST__ __hip_fp8_e5m2(const int val) + __FP8_HOST__ __hip_fp8_e5m2(const int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from short int */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val) #else -__FP8_HOST__ __hip_fp8_e5m2(const short int val) + __FP8_HOST__ __hip_fp8_e5m2(const short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from unsigned long */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val) #else -__FP8_HOST__ __hip_fp8_e5m2(const unsigned long int val) + __FP8_HOST__ __hip_fp8_e5m2(const unsigned long int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from unsigned int */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val) #else -__FP8_HOST__ __hip_fp8_e5m2(const unsigned int val) + __FP8_HOST__ __hip_fp8_e5m2(const unsigned int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from unsigned short */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val) #else -__FP8_HOST__ __hip_fp8_e5m2(const unsigned short int val) + __FP8_HOST__ __hip_fp8_e5m2(const unsigned short int val) #endif : __x(__hip_cvt_float_to_fp8(static_cast(val), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from double */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f) #else -__FP8_HOST__ __hip_fp8_e5m2(const double f) + __FP8_HOST__ __hip_fp8_e5m2(const double f) #endif - : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e5m2 type from float */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f) #else -__FP8_HOST__ __hip_fp8_e5m2(const float f) + __FP8_HOST__ __hip_fp8_e5m2(const float f) #endif - : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) { + } /*! create fp8 e5m2 type from __hip_bfloat16 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f) #else -__FP8_HOST__ __hip_fp8_e5m2(const __hip_bfloat16 f) + __FP8_HOST__ __hip_fp8_e5m2(const __hip_bfloat16 f) #endif : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! create fp8 e5m2 type from __hip_bfloat16 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f) + __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f) #else -__FP8_HOST__ __hip_fp8_e5m2(const __half f) + __FP8_HOST__ __hip_fp8_e5m2(const __half f) #endif : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation, - __default_interpret)) {} + __default_interpret)) { + } /*! default construct fp8 e5m2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8_e5m2() = default; + __FP8_HOST_DEVICE__ __hip_fp8_e5m2() = default; #else -__FP8_HOST__ __hip_fp8_e5m2() = default; + __FP8_HOST__ __hip_fp8_e5m2() = default; #endif /*! convert fp8 e5m2 to float */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator float() const { + __FP8_HOST_DEVICE__ operator float() const { #else -__FP8_HOST__ operator float() const { + __FP8_HOST__ operator float() const { #endif #if HIP_FP8_CVT_FAST_PATH return internal::cast_to_f32_from_f8(__x, __default_interpret); #else - return internal::cast_from_f8(__x, __wm, __we, __default_saturation == __HIP_SATFINITE); + return internal::cast_from_f8(__x, __wm, __we, + __default_saturation == __HIP_SATFINITE); #endif } /*! convert fp8 e5m2 to __half */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator __half() const { + __FP8_HOST_DEVICE__ operator __half() const { #else -__FP8_HOST__ operator __half() const { + __FP8_HOST__ operator __half() const { #endif return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret)); } /*! convert fp8 e5m2 to __hip_bfloat16 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator __hip_bfloat16() const { + __FP8_HOST_DEVICE__ operator __hip_bfloat16() const { #else -__FP8_HOST__ operator __hip_bfloat16() const { + __FP8_HOST__ operator __hip_bfloat16() const { #endif float f = *this; return __hip_bfloat16(f); @@ -2536,9 +2725,9 @@ __FP8_HOST__ operator __hip_bfloat16() const { /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator bool() const { + __FP8_HOST_DEVICE__ operator bool() const { #else -__FP8_HOST__ operator bool() const { + __FP8_HOST__ operator bool() const { #endif // it can be 0x00 (+0.0) since 0x80 will be nan return !(static_cast(__x) == 0 || static_cast(__x) == 0x80); @@ -2546,9 +2735,9 @@ __FP8_HOST__ operator bool() const { /*! convert fp8 e5m2 to char, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator char() const { + __FP8_HOST_DEVICE__ operator char() const { #else -__FP8_HOST__ operator char() const { + __FP8_HOST__ operator char() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2566,18 +2755,19 @@ __FP8_HOST__ operator char() const { /*! convert fp8 e5m2 to double */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator double() const { + __FP8_HOST_DEVICE__ operator double() const { #else -__FP8_HOST__ operator double() const { + __FP8_HOST__ operator double() const { #endif - return internal::cast_from_f8(__x, __wm, __we, __default_saturation == __HIP_SATFINITE); + return internal::cast_from_f8(__x, __wm, __we, + __default_saturation == __HIP_SATFINITE); } /*! convert fp8 e5m2 to int, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator int() const { + __FP8_HOST_DEVICE__ operator int() const { #else -__FP8_HOST__ operator int() const { + __FP8_HOST__ operator int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2589,9 +2779,9 @@ __FP8_HOST__ operator int() const { /*! convert fp8 e5m2 to long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator long int() const { + __FP8_HOST_DEVICE__ operator long int() const { #else -__FP8_HOST__ operator long int() const { + __FP8_HOST__ operator long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2603,9 +2793,9 @@ __FP8_HOST__ operator long int() const { /*! convert fp8 e5m2 to long long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator long long int() const { + __FP8_HOST_DEVICE__ operator long long int() const { #else -__FP8_HOST__ operator long long int() const { + __FP8_HOST__ operator long long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2617,9 +2807,9 @@ __FP8_HOST__ operator long long int() const { /*! convert fp8 e5m2 to short, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator short int() const { + __FP8_HOST_DEVICE__ operator short int() const { #else -__FP8_HOST__ operator short int() const { + __FP8_HOST__ operator short int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2637,9 +2827,9 @@ __FP8_HOST__ operator short int() const { /*! convert fp8 e5m2 to signed char, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator signed char() const { + __FP8_HOST_DEVICE__ operator signed char() const { #else -__FP8_HOST__ operator signed char() const { + __FP8_HOST__ operator signed char() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2657,9 +2847,9 @@ __FP8_HOST__ operator signed char() const { /*! convert fp8 e5m2 to unsigned char, clamp out of bound values, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned char() const { + __FP8_HOST_DEVICE__ operator unsigned char() const { #else -__FP8_HOST__ operator unsigned char() const { + __FP8_HOST__ operator unsigned char() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2677,9 +2867,9 @@ __FP8_HOST__ operator unsigned char() const { /*! convert fp8 e5m2 to unsigned int, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned int() const { + __FP8_HOST_DEVICE__ operator unsigned int() const { #else -__FP8_HOST__ operator unsigned int() const { + __FP8_HOST__ operator unsigned int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2695,9 +2885,9 @@ __FP8_HOST__ operator unsigned int() const { /*! convert fp8 e5m2 to unsigned long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned long int() const { + __FP8_HOST_DEVICE__ operator unsigned long int() const { #else -__FP8_HOST__ operator unsigned long int() const { + __FP8_HOST__ operator unsigned long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2713,9 +2903,9 @@ __FP8_HOST__ operator unsigned long int() const { /*! convert fp8 e5m2 to unsigned long long, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned long long int() const { + __FP8_HOST_DEVICE__ operator unsigned long long int() const { #else -__FP8_HOST__ operator unsigned long long int() const { + __FP8_HOST__ operator unsigned long long int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; @@ -2731,16 +2921,16 @@ __FP8_HOST__ operator unsigned long long int() const { /*! convert fp8 e5m2 to unsigned short, return 0 if value is NaN */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator unsigned short int() const { + __FP8_HOST_DEVICE__ operator unsigned short int() const { #else -__FP8_HOST__ operator unsigned short int() const { + __FP8_HOST__ operator unsigned short int() const { #endif if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) { return 0; - } +} float fval = *this; - auto llval = static_cast(fval); + auto llval = static_cast(fval); if (llval <= 0) { return 0; } @@ -2762,63 +2952,70 @@ struct __hip_fp8x2_e5m2 { /*! create fp8x2 e5m2 type from double2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val) #else -__FP8_HOST__ __hip_fp8x2_e5m2(const double2 val) + __FP8_HOST__ __hip_fp8x2_e5m2(const double2 val) #endif - : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e5m2 type from float2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val) #else -__FP8_HOST__ __hip_fp8x2_e5m2(const float2 val) + __FP8_HOST__ __hip_fp8x2_e5m2(const float2 val) #endif - : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e5m2 type from __hip_bfloat162 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val) #else -__FP8_HOST__ __hip_fp8x2_e5m2(const __hip_bfloat162 val) + __FP8_HOST__ __hip_fp8x2_e5m2(const __hip_bfloat162 val) #endif - : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! create fp8x2 e5m2 type from __half2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val) + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val) #else -__FP8_HOST__ __hip_fp8x2_e5m2(const __half2 val) + __FP8_HOST__ __hip_fp8x2_e5m2(const __half2 val) #endif - : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {} + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) { + } /*! default construct fp8x2 e5m2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2() = default; + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2() = default; #else -__FP8_HOST__ __hip_fp8x2_e5m2() = default; + __FP8_HOST__ __hip_fp8x2_e5m2() = default; #endif /*! convert fp8x2 e5m2 to __half2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator __half2() const { + __FP8_HOST_DEVICE__ operator __half2() const { #else -__FP8_HOST__ operator __half2() const { + __FP8_HOST__ operator __half2() const { #endif return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret)); } /*! convert fp8x2 e5m2 to float2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator float2() const { + __FP8_HOST_DEVICE__ operator float2() const { #else -__FP8_HOST__ operator float2() const { + __FP8_HOST__ operator float2() const { #endif #if HIP_FP8_CVT_FAST_PATH return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); #else - return float2(internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we, __default_saturation == __HIP_SATFINITE), - internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE)); + return float2( + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, + __we, __default_saturation == __HIP_SATFINITE), + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we, + __default_saturation == __HIP_SATFINITE)); #endif } }; @@ -2836,9 +3033,9 @@ struct __hip_fp8x4_e5m2 { /*! create fp8x4 e5m2 type from double4 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val) + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val) #else -__FP8_HOST__ __hip_fp8x4_e5m2(const double4 val) + __FP8_HOST__ __hip_fp8x4_e5m2(const double4 val) #endif : __x(reinterpret_cast<__hip_fp8x4_storage_t>( static_cast(reinterpret_cast(__hip_cvt_double_to_fp8( @@ -2851,13 +3048,14 @@ __FP8_HOST__ __hip_fp8x4_e5m2(const double4 val) << 16 | reinterpret_cast(__hip_cvt_double_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))) {} + << 24))) { + } /*! create fp8x4 e5m2 type from float4 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val) + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val) #else -__FP8_HOST__ __hip_fp8x4_e5m2(const float4 val) + __FP8_HOST__ __hip_fp8x4_e5m2(const float4 val) #endif : __x(reinterpret_cast<__hip_fp8x4_storage_t>( static_cast(reinterpret_cast(__hip_cvt_float_to_fp8( @@ -2870,62 +3068,69 @@ __FP8_HOST__ __hip_fp8x4_e5m2(const float4 val) << 16 | reinterpret_cast(__hip_cvt_float_to_fp8( val.w, __default_saturation, __default_interpret)) - << 24))) {} + << 24))) { + } /*! create fp8x4 e5m2 type from two __hip_bfloat162 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high) + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high) #else -__FP8_HOST__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high) + __FP8_HOST__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high) #endif : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( reinterpret_cast( __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) | reinterpret_cast( __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /*! create fp8x4 e5m2 type from two __half2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high) + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high) #else -__FP8_HOST__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high) + __FP8_HOST__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high) #endif : __x(reinterpret_cast<__hip_fp8x4_storage_t>( static_cast(reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( high, __default_saturation, __default_interpret)) | reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( low, __default_saturation, __default_interpret)) - << 16))) {} + << 16))) { + } /* default construct fp8x4 e5m2 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2() = default; + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2() = default; #else -__FP8_HOST__ __hip_fp8x4_e5m2() = default; + __FP8_HOST__ __hip_fp8x4_e5m2() = default; #endif /*! convert fp8x4 e5m2 to float4 */ #if HIP_FP8_TYPE_OCP -__FP8_HOST_DEVICE__ operator float4() const { + __FP8_HOST_DEVICE__ operator float4() const { #else -__FP8_HOST__ operator float4() const { + __FP8_HOST__ operator float4() const { #endif - auto x = __x; // bypass const + auto x = __x; // bypass const auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1); #if HIP_FP8_CVT_FAST_PATH float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret); float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret); #else - float2 high = float2(internal::cast_from_f8( - static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE), - internal::cast_from_f8( - static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE)); - float2 low = float2(internal::cast_from_f8( - static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE), - internal::cast_from_f8( - static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE)); + float2 high = float2( + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we, + __default_saturation == __HIP_SATFINITE), + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), + __wm, __we, __default_saturation == __HIP_SATFINITE)); + float2 low = float2( + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we, + __default_saturation == __HIP_SATFINITE), + internal::cast_from_f8(static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, + __we, __default_saturation == __HIP_SATFINITE)); #endif return float4(low.x, low.y, high.x, high.y); }