SWDEV-468039,SWDEV-482579 - Enable FP8 SW Conversions on pre gfx940 archs

1) SW Conversions for ocp and fnuz are enabled on pre mi300 archs
2) for mi300 only fnuz is enabled
3) for gfx1200 only ocp is enabled

Change-Id: I90373752a2d15eff20d5deec874ed396ba4e1788
This commit is contained in:
Rahul Manocha
2024-09-18 12:05:16 -07:00
zatwierdzone przez Rahul Manocha
rodzic 8657a77029
commit e729f08704
@@ -44,14 +44,25 @@
#elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#define HIP_FP8_TYPE_OCP 1
#define HIP_FP8_TYPE_FNUZ 0
#elif __HIP_DEVICE_COMPILE__
#define HIP_FP8_TYPE_FNUZ 0
#define HIP_FP8_TYPE_OCP 0
#else // Host
#else
#define HIP_FP8_TYPE_FNUZ 1
#define HIP_FP8_TYPE_OCP 1
#endif
#if defined(__HIPCC_RTC__)
#if HIP_FP8_TYPE_FNUZ
#define ENABLE_FNUZ_HIPRTC 1
#else
#define ENABLE_FNUZ_HIPRTC 0
#endif
#if HIP_FP8_TYPE_OCP
#define ENABLE_OCP_HIPRTC 1
#else
#define ENABLE_OCP_HIPRTC 0
#endif
#endif
#if !defined(__HIPCC_RTC__)
#include <hip/amd_detail/amd_hip_common.h>
#include <climits>
@@ -474,7 +485,7 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if (exponent <= 0) {
mantissa |= 1 << wmo;
mantissa |= 1ull << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
@@ -540,8 +551,8 @@ static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate
static __device__ __hip_fp8x2_storage_t
cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
union {
static_assert(sizeof(float2) == sizeof(unsigned int[2]));
static_assert(sizeof(float2) == sizeof(unsigned short[4]));
static_assert(sizeof(float2) == sizeof(unsigned int[2]), "size mismatch");
static_assert(sizeof(float2) == sizeof(unsigned short[4]), "size mismatch");
float2 fval;
unsigned int i32val[2];
unsigned short i16val[4];
@@ -656,8 +667,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(
int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
}
if (interp == __HIP_E4M3 || interp == __HIP_E5M2) {
} else {
int we = interp == __HIP_E4M3 ? 4 : 5;
int wm = interp == __HIP_E4M3 ? 3 : 2;
return internal::cast_to_f8<float, false>(f, wm, we, sat == __HIP_SATFINITE);
@@ -716,8 +726,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(
int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
}
if (interp == __HIP_E4M3 || interp == __HIP_E5M2) {
} else {
int we = interp == __HIP_E4M3 ? 4 : 5;
int wm = interp == __HIP_E4M3 ? 3 : 2;
return internal::cast_to_f8<double, false>(d, wm, we, sat == __HIP_SATFINITE);
@@ -822,8 +831,7 @@ __FP8_HOST_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_
unsigned int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
unsigned int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
}
if (interp == __HIP_E4M3 || interp == __HIP_E5M2) {
} else {
unsigned int we = interp == __HIP_E4M3 ? 4 : 5;
unsigned int wm = interp == __HIP_E4M3 ? 3 : 2;
return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)};
@@ -903,6 +911,8 @@ __FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(
* \brief struct representing single fp8 number with e4m3 interpretation
*
*/
#if !defined(ENABLE_FNUZ_HIPRTC) || ENABLE_FNUZ_HIPRTC
struct __hip_fp8_e4m3_fnuz {
__hip_fp8_storage_t __x; //! raw storage of fp8 number
constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
@@ -2014,10 +2024,15 @@ struct __hip_fp8x4_e5m2_fnuz {
}
};
#endif // ENABLE_FNUZ_HIPRTC
/**
* \brief struct representing ocp fp8 numbers with e4m3 interpretation
*
* */
#if !defined(ENABLE_OCP_HIPRTC) || ENABLE_OCP_HIPRTC
struct __hip_fp8_e4m3 {
__hip_fp8_storage_t __x; //! raw storage of fp8 number
constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
@@ -3135,4 +3150,5 @@ struct __hip_fp8x4_e5m2 {
return float4(low.x, low.y, high.x, high.y);
}
};
#endif // ENABLE_OCP_HIPRTC
#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_