SWDEV-468039,SWDEV-482579 - Enable FP8 SW Conversions on pre gfx940 archs
1) SW Conversions for ocp and fnuz are enabled on pre mi300 archs 2) for mi300 only fnuz is enabled 3) for gfx1200 only ocp is enabled Change-Id: I90373752a2d15eff20d5deec874ed396ba4e1788
This commit is contained in:
zatwierdzone przez
Rahul Manocha
rodzic
8657a77029
commit
e729f08704
@@ -44,14 +44,25 @@
|
||||
#elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
|
||||
#define HIP_FP8_TYPE_OCP 1
|
||||
#define HIP_FP8_TYPE_FNUZ 0
|
||||
#elif __HIP_DEVICE_COMPILE__
|
||||
#define HIP_FP8_TYPE_FNUZ 0
|
||||
#define HIP_FP8_TYPE_OCP 0
|
||||
#else // Host
|
||||
#else
|
||||
#define HIP_FP8_TYPE_FNUZ 1
|
||||
#define HIP_FP8_TYPE_OCP 1
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC_RTC__)
|
||||
#if HIP_FP8_TYPE_FNUZ
|
||||
#define ENABLE_FNUZ_HIPRTC 1
|
||||
#else
|
||||
#define ENABLE_FNUZ_HIPRTC 0
|
||||
#endif
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
#define ENABLE_OCP_HIPRTC 1
|
||||
#else
|
||||
#define ENABLE_OCP_HIPRTC 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
#if !defined(__HIPCC_RTC__)
|
||||
#include <hip/amd_detail/amd_hip_common.h>
|
||||
#include <climits>
|
||||
@@ -474,7 +485,7 @@ __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa |= 1ull << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
@@ -540,8 +551,8 @@ static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate
|
||||
static __device__ __hip_fp8x2_storage_t
|
||||
cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
|
||||
union {
|
||||
static_assert(sizeof(float2) == sizeof(unsigned int[2]));
|
||||
static_assert(sizeof(float2) == sizeof(unsigned short[4]));
|
||||
static_assert(sizeof(float2) == sizeof(unsigned int[2]), "size mismatch");
|
||||
static_assert(sizeof(float2) == sizeof(unsigned short[4]), "size mismatch");
|
||||
float2 fval;
|
||||
unsigned int i32val[2];
|
||||
unsigned short i16val[4];
|
||||
@@ -656,8 +667,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(
|
||||
int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
|
||||
int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
|
||||
return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
|
||||
}
|
||||
if (interp == __HIP_E4M3 || interp == __HIP_E5M2) {
|
||||
} else {
|
||||
int we = interp == __HIP_E4M3 ? 4 : 5;
|
||||
int wm = interp == __HIP_E4M3 ? 3 : 2;
|
||||
return internal::cast_to_f8<float, false>(f, wm, we, sat == __HIP_SATFINITE);
|
||||
@@ -716,8 +726,7 @@ __FP8_HOST_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(
|
||||
int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
|
||||
int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
|
||||
return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
|
||||
}
|
||||
if (interp == __HIP_E4M3 || interp == __HIP_E5M2) {
|
||||
} else {
|
||||
int we = interp == __HIP_E4M3 ? 4 : 5;
|
||||
int wm = interp == __HIP_E4M3 ? 3 : 2;
|
||||
return internal::cast_to_f8<double, false>(d, wm, we, sat == __HIP_SATFINITE);
|
||||
@@ -822,8 +831,7 @@ __FP8_HOST_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_
|
||||
unsigned int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
|
||||
unsigned int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
|
||||
return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
|
||||
}
|
||||
if (interp == __HIP_E4M3 || interp == __HIP_E5M2) {
|
||||
} else {
|
||||
unsigned int we = interp == __HIP_E4M3 ? 4 : 5;
|
||||
unsigned int wm = interp == __HIP_E4M3 ? 3 : 2;
|
||||
return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)};
|
||||
@@ -903,6 +911,8 @@ __FP8_HOST_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(
|
||||
* \brief struct representing single fp8 number with e4m3 interpretation
|
||||
*
|
||||
*/
|
||||
|
||||
#if !defined(ENABLE_FNUZ_HIPRTC) || ENABLE_FNUZ_HIPRTC
|
||||
struct __hip_fp8_e4m3_fnuz {
|
||||
__hip_fp8_storage_t __x; //! raw storage of fp8 number
|
||||
constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
||||
@@ -2014,10 +2024,15 @@ struct __hip_fp8x4_e5m2_fnuz {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // ENABLE_FNUZ_HIPRTC
|
||||
|
||||
/**
|
||||
* \brief struct representing ocp fp8 numbers with e4m3 interpretation
|
||||
*
|
||||
* */
|
||||
|
||||
#if !defined(ENABLE_OCP_HIPRTC) || ENABLE_OCP_HIPRTC
|
||||
|
||||
struct __hip_fp8_e4m3 {
|
||||
__hip_fp8_storage_t __x; //! raw storage of fp8 number
|
||||
constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
|
||||
@@ -3135,4 +3150,5 @@ struct __hip_fp8x4_e5m2 {
|
||||
return float4(low.x, low.y, high.x, high.y);
|
||||
}
|
||||
};
|
||||
#endif // ENABLE_OCP_HIPRTC
|
||||
#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
|
||||
Reference in New Issue
Block a user