Upcast FP8 to Half (FP16) for Sum Operation (#1775)
* adding hadd and hadd2 support using builtin functions. --------- Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>
Этот коммит содержится в:
Обычный файл → Исполняемый файл
+4
-2
@@ -304,11 +304,13 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h
|
||||
SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 1, __nv_fp8_e5m2, __nv_fp8_e5m2(fn.isMinNotMax ? __hmin(__half(x), __half(y)) : __hmax(__half(x), __half(y))))
|
||||
SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 2, __nv_fp8x2_e5m2, __nv_fp8x2_e5m2(fn.isMinNotMax ? __hmin2(__half2(x), __half2(y)) : __hmax2(__half2(x), __half2(y))))
|
||||
#else
|
||||
SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, rccl_float8(float(x) + float(y)))
|
||||
SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, hadd(x,y))
|
||||
SPECIALIZE_REDUCE(FuncSum, rccl_float8, 2, fp8x2_storage_t, hadd2(x,y))
|
||||
SPECIALIZE_REDUCE(FuncProd, rccl_float8, 1, rccl_float8, rccl_float8(float(x) * float(y)))
|
||||
SPECIALIZE_REDUCE(FuncMinMax, rccl_float8, 1, rccl_float8, rccl_float8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y))))
|
||||
|
||||
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) + float(y)))
|
||||
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, hadd_b(x,y))
|
||||
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 2, fp8x2_storage_t, hadd2_b(x,y))
|
||||
SPECIALIZE_REDUCE(FuncProd, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) * float(y)))
|
||||
SPECIALIZE_REDUCE(FuncMinMax, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y))))
|
||||
#endif
|
||||
|
||||
Обычный файл → Исполняемый файл
+183
@@ -26,6 +26,7 @@
|
||||
#include <stdint.h>
|
||||
#include <hip/hip_version.h>
|
||||
|
||||
typedef uint16_t fp8x2_storage_t;
|
||||
#if __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__))
|
||||
/*! \brief Struct to represent a 8 bit floating-point number. */
|
||||
|
||||
@@ -52,6 +53,151 @@ typedef __hip_fp8_e4m3 rccl_float8;
|
||||
typedef __hip_fp8_e5m2 rccl_bfloat8;
|
||||
#endif
|
||||
|
||||
typedef _Float16 half_t;
|
||||
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
typedef short shortx2_t __attribute__((ext_vector_type(2)));
|
||||
typedef short __attribute__((ext_vector_type(2))) __amd_shortx2_storage_t;
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
|
||||
inline __device__ rccl_float8 hadd(rccl_float8 x, rccl_float8 y)
|
||||
{
|
||||
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
|
||||
half2_t v1;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(x.__x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(y.__x, 1.f, 0)));
|
||||
union {
|
||||
shortx2_t i16_vec;
|
||||
rccl_float8 fp8[4];
|
||||
} u{0};
|
||||
u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(v1, v1, /* scale */ 1.f, 0);
|
||||
return u.fp8[0];
|
||||
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
uint32_t i32val;
|
||||
rccl_float8 i8val[4];
|
||||
} val;
|
||||
|
||||
float2_t v;
|
||||
uint32_t ival = 0;
|
||||
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(x.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(y.__x, 0)));
|
||||
return __builtin_amdgcn_cvt_pk_fp8_f32(v[0], v[0], ival, false);
|
||||
#else
|
||||
return rccl_float8(float(x) + float(y));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ rccl_bfloat8 hadd_b(rccl_bfloat8 x, rccl_bfloat8 y)
|
||||
{
|
||||
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
|
||||
half2_t v1;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(x.__x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(y.__x, 1.f, 0)));
|
||||
union {
|
||||
shortx2_t i16_vec;
|
||||
rccl_bfloat8 fp8[4];
|
||||
} u1{0};
|
||||
u1.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(v1, v1, /* scale */ 1.f, 0);
|
||||
return u1.fp8[0];
|
||||
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
|
||||
|
||||
float2_t v;
|
||||
uint32_t ival = 0;
|
||||
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(x.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(y.__x, 0)));
|
||||
return __builtin_amdgcn_cvt_pk_bf8_f32(v[0], v[0], ival, false);
|
||||
#else
|
||||
return rccl_bfloat8(float(x) + float(y));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ fp8x2_storage_t hadd2(fp8x2_storage_t x, fp8x2_storage_t y)
|
||||
{
|
||||
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
|
||||
half2_t v1;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(y, 1.f, 0)));
|
||||
union {
|
||||
shortx2_t i16_vec;
|
||||
fp8x2_storage_t fp8;
|
||||
} u{0};
|
||||
u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(v1, v1, /* scale */ 1.f, 0);
|
||||
return u.fp8;
|
||||
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
|
||||
float2_t v;
|
||||
uint32_t ival = 0;
|
||||
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(y, 0)));
|
||||
return __builtin_amdgcn_cvt_pk_fp8_f32(v[0], v[1], ival, false);
|
||||
#else
|
||||
union {
|
||||
rccl_float8 fp8[2];
|
||||
fp8x2_storage_t fp8x2;
|
||||
} u, v, w;
|
||||
u.fp8x2 = x;
|
||||
v.fp8x2 = y;
|
||||
w.fp8[0] = hadd(u.fp8[0], v.fp8[0]);
|
||||
w.fp8[1] = hadd(u.fp8[1], v.fp8[1]);
|
||||
return w.fp8x2;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ fp8x2_storage_t hadd2_b(fp8x2_storage_t x, fp8x2_storage_t y)
|
||||
{
|
||||
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
|
||||
half2_t v1;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(y, 1.f, 0)));
|
||||
union {
|
||||
shortx2_t i16_vec;
|
||||
fp8x2_storage_t fp8;
|
||||
} u{0};
|
||||
u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(v1, v1, /* scale */ 1.f, 0);
|
||||
return u.fp8;
|
||||
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
|
||||
float2_t v;
|
||||
uint32_t ival = 0;
|
||||
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(y, 0)));
|
||||
return __builtin_amdgcn_cvt_pk_bf8_f32(v[0], v[1], ival, false);
|
||||
#else
|
||||
union {
|
||||
rccl_bfloat8 bfp8[2];
|
||||
fp8x2_storage_t bfp8x2;
|
||||
} u, v, w;
|
||||
u.bfp8x2 = x;
|
||||
v.bfp8x2 = y;
|
||||
w.bfp8[0] = hadd_b(u.bfp8[0], v.bfp8[0]);
|
||||
w.bfp8[1] = hadd_b(u.bfp8[1], v.bfp8[1]);
|
||||
return w.bfp8x2;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
|
||||
{
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8)
|
||||
{
|
||||
return os << float(bf8);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_float8 a, float b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_bfloat8 a, float b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
// For older versions of ROCm that do not include hip_fp8.h,
|
||||
// we provide a local version of the header file as a fallback.
|
||||
#else
|
||||
@@ -714,6 +860,43 @@ namespace std
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ rccl_float8 hadd(rccl_float8 x, rccl_float8 y)
|
||||
{
|
||||
return rccl_float8(float(x) + float(y));
|
||||
}
|
||||
|
||||
inline __device__ fp8x2_storage_t hadd2(fp8x2_storage_t x, fp8x2_storage_t y)
|
||||
{
|
||||
union {
|
||||
rccl_float8 fp8[2];
|
||||
fp8x2_storage_t fp8x2;
|
||||
} u, v, w;
|
||||
u.fp8x2 = x;
|
||||
v.fp8x2 = y;
|
||||
w.fp8[0] = hadd(u.fp8[0], v.fp8[0]);
|
||||
w.fp8[1] = hadd(u.fp8[1], v.fp8[1]);
|
||||
|
||||
return w.fp8x2;
|
||||
}
|
||||
|
||||
inline __device__ rccl_bfloat8 hadd_b(rccl_bfloat8 x, rccl_bfloat8 y)
|
||||
{
|
||||
return rccl_bfloat8(float(x) + float(y));
|
||||
}
|
||||
|
||||
inline __device__ fp8x2_storage_t hadd2_b(fp8x2_storage_t x, fp8x2_storage_t y) {
|
||||
union {
|
||||
rccl_bfloat8 fp8[2];
|
||||
fp8x2_storage_t fp8x2;
|
||||
} u, v, w;
|
||||
u.fp8x2 = x;
|
||||
v.fp8x2 = y;
|
||||
w.fp8[0] = hadd_b(u.fp8[0], v.fp8[0]);
|
||||
w.fp8[1] = hadd_b(u.fp8[1], v.fp8[1]);
|
||||
|
||||
return w.fp8x2;
|
||||
}
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
|
||||
{
|
||||
|
||||
Ссылка в новой задаче
Block a user