Upcast FP8 to Half (FP16) for Sum Operation (#1775)

* adding hadd and hadd2 support using builtin functions.

---------

Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>
Этот коммит содержится в:
mberenjk
2025-07-29 11:33:06 -05:00
коммит произвёл GitHub
родитель 9843adaab2
Коммит c84ee3d298
2 изменённых файлов: 187 добавлений и 2 удалений
Обычный файл → Исполняемый файл
+4 -2
Просмотреть файл
@@ -304,11 +304,13 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h
SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 1, __nv_fp8_e5m2, __nv_fp8_e5m2(fn.isMinNotMax ? __hmin(__half(x), __half(y)) : __hmax(__half(x), __half(y))))
SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 2, __nv_fp8x2_e5m2, __nv_fp8x2_e5m2(fn.isMinNotMax ? __hmin2(__half2(x), __half2(y)) : __hmax2(__half2(x), __half2(y))))
#else
SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, rccl_float8(float(x) + float(y)))
SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, hadd(x,y))
SPECIALIZE_REDUCE(FuncSum, rccl_float8, 2, fp8x2_storage_t, hadd2(x,y))
SPECIALIZE_REDUCE(FuncProd, rccl_float8, 1, rccl_float8, rccl_float8(float(x) * float(y)))
SPECIALIZE_REDUCE(FuncMinMax, rccl_float8, 1, rccl_float8, rccl_float8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y))))
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) + float(y)))
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, hadd_b(x,y))
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 2, fp8x2_storage_t, hadd2_b(x,y))
SPECIALIZE_REDUCE(FuncProd, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) * float(y)))
SPECIALIZE_REDUCE(FuncMinMax, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y))))
#endif
Обычный файл → Исполняемый файл
+183
Просмотреть файл
@@ -26,6 +26,7 @@
#include <stdint.h>
#include <hip/hip_version.h>
typedef uint16_t fp8x2_storage_t;
#if __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__))
/*! \brief Struct to represent a 8 bit floating-point number. */
@@ -52,6 +53,151 @@ typedef __hip_fp8_e4m3 rccl_float8;
typedef __hip_fp8_e5m2 rccl_bfloat8;
#endif
typedef _Float16 half_t;
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
typedef short shortx2_t __attribute__((ext_vector_type(2)));
typedef short __attribute__((ext_vector_type(2))) __amd_shortx2_storage_t;
typedef float float2_t __attribute__((ext_vector_type(2)));
inline __device__ rccl_float8 hadd(rccl_float8 x, rccl_float8 y)
{
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
half2_t v1;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(x.__x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(y.__x, 1.f, 0)));
union {
shortx2_t i16_vec;
rccl_float8 fp8[4];
} u{0};
u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(v1, v1, /* scale */ 1.f, 0);
return u.fp8[0];
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
union
{
uint32_t i32val;
rccl_float8 i8val[4];
} val;
float2_t v;
uint32_t ival = 0;
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(x.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(y.__x, 0)));
return __builtin_amdgcn_cvt_pk_fp8_f32(v[0], v[0], ival, false);
#else
return rccl_float8(float(x) + float(y));
#endif
}
inline __device__ rccl_bfloat8 hadd_b(rccl_bfloat8 x, rccl_bfloat8 y)
{
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
half2_t v1;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(x.__x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(y.__x, 1.f, 0)));
union {
shortx2_t i16_vec;
rccl_bfloat8 fp8[4];
} u1{0};
u1.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(v1, v1, /* scale */ 1.f, 0);
return u1.fp8[0];
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
float2_t v;
uint32_t ival = 0;
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(x.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(y.__x, 0)));
return __builtin_amdgcn_cvt_pk_bf8_f32(v[0], v[0], ival, false);
#else
return rccl_bfloat8(float(x) + float(y));
#endif
}
inline __device__ fp8x2_storage_t hadd2(fp8x2_storage_t x, fp8x2_storage_t y)
{
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
half2_t v1;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(y, 1.f, 0)));
union {
shortx2_t i16_vec;
fp8x2_storage_t fp8;
} u{0};
u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(v1, v1, /* scale */ 1.f, 0);
return u.fp8;
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
float2_t v;
uint32_t ival = 0;
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(y, 0)));
return __builtin_amdgcn_cvt_pk_fp8_f32(v[0], v[1], ival, false);
#else
union {
rccl_float8 fp8[2];
fp8x2_storage_t fp8x2;
} u, v, w;
u.fp8x2 = x;
v.fp8x2 = y;
w.fp8[0] = hadd(u.fp8[0], v.fp8[0]);
w.fp8[1] = hadd(u.fp8[1], v.fp8[1]);
return w.fp8x2;
#endif
}
inline __device__ fp8x2_storage_t hadd2_b(fp8x2_storage_t x, fp8x2_storage_t y)
{
#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__)
half2_t v1;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(y, 1.f, 0)));
union {
shortx2_t i16_vec;
fp8x2_storage_t fp8;
} u{0};
u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(v1, v1, /* scale */ 1.f, 0);
return u.fp8;
#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__)
float2_t v;
uint32_t ival = 0;
asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(y, 0)));
return __builtin_amdgcn_cvt_pk_bf8_f32(v[0], v[1], ival, false);
#else
union {
rccl_bfloat8 bfp8[2];
fp8x2_storage_t bfp8x2;
} u, v, w;
u.bfp8x2 = x;
v.bfp8x2 = y;
w.bfp8[0] = hadd_b(u.bfp8[0], v.bfp8[0]);
w.bfp8[1] = hadd_b(u.bfp8[1], v.bfp8[1]);
return w.bfp8x2;
#endif
}
inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
{
return os << float(f8);
}
inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8)
{
return os << float(bf8);
}
inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(rccl_float8 a, float b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(rccl_bfloat8 a, float b)
{
return float(a) * float(b);
}
// For older versions of ROCm that do not include hip_fp8.h,
// we provide a local version of the header file as a fallback.
#else
@@ -714,6 +860,43 @@ namespace std
}
}
inline __device__ rccl_float8 hadd(rccl_float8 x, rccl_float8 y)
{
return rccl_float8(float(x) + float(y));
}
inline __device__ fp8x2_storage_t hadd2(fp8x2_storage_t x, fp8x2_storage_t y)
{
union {
rccl_float8 fp8[2];
fp8x2_storage_t fp8x2;
} u, v, w;
u.fp8x2 = x;
v.fp8x2 = y;
w.fp8[0] = hadd(u.fp8[0], v.fp8[0]);
w.fp8[1] = hadd(u.fp8[1], v.fp8[1]);
return w.fp8x2;
}
inline __device__ rccl_bfloat8 hadd_b(rccl_bfloat8 x, rccl_bfloat8 y)
{
return rccl_bfloat8(float(x) + float(y));
}
inline __device__ fp8x2_storage_t hadd2_b(fp8x2_storage_t x, fp8x2_storage_t y) {
union {
rccl_bfloat8 fp8[2];
fp8x2_storage_t fp8x2;
} u, v, w;
u.fp8x2 = x;
v.fp8x2 = y;
w.fp8[0] = hadd_b(u.fp8[0], v.fp8[0]);
w.fp8[1] = hadd_b(u.fp8[1], v.fp8[1]);
return w.fp8x2;
}
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
{