diff --git a/src/device/reduce_kernel.h b/src/device/reduce_kernel.h old mode 100644 new mode 100755 index a4df12c25e..f4e922bc7f --- a/src/device/reduce_kernel.h +++ b/src/device/reduce_kernel.h @@ -304,11 +304,13 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 1, __nv_fp8_e5m2, __nv_fp8_e5m2(fn.isMinNotMax ? __hmin(__half(x), __half(y)) : __hmax(__half(x), __half(y)))) SPECIALIZE_REDUCE(FuncMinMax, __nv_fp8_e5m2, 2, __nv_fp8x2_e5m2, __nv_fp8x2_e5m2(fn.isMinNotMax ? __hmin2(__half2(x), __half2(y)) : __hmax2(__half2(x), __half2(y)))) #else - SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, rccl_float8(float(x) + float(y))) + SPECIALIZE_REDUCE(FuncSum, rccl_float8, 1, rccl_float8, hadd(x,y)) + SPECIALIZE_REDUCE(FuncSum, rccl_float8, 2, fp8x2_storage_t, hadd2(x,y)) SPECIALIZE_REDUCE(FuncProd, rccl_float8, 1, rccl_float8, rccl_float8(float(x) * float(y))) SPECIALIZE_REDUCE(FuncMinMax, rccl_float8, 1, rccl_float8, rccl_float8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y)))) - SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) + float(y))) + SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 1, rccl_bfloat8, hadd_b(x,y)) + SPECIALIZE_REDUCE(FuncSum, rccl_bfloat8, 2, fp8x2_storage_t, hadd2_b(x,y)) SPECIALIZE_REDUCE(FuncProd, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(float(x) * float(y))) SPECIALIZE_REDUCE(FuncMinMax, rccl_bfloat8, 1, rccl_bfloat8, rccl_bfloat8(fn.isMinNotMax ? fminf(float(x), float(y)) : fmaxf(float(x), float(y)))) #endif diff --git a/src/include/rccl_float8.h b/src/include/rccl_float8.h old mode 100644 new mode 100755 index d5dd31ceb1..4dc975918c --- a/src/include/rccl_float8.h +++ b/src/include/rccl_float8.h @@ -26,6 +26,7 @@ #include #include +typedef uint16_t fp8x2_storage_t; #if __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__)) /*! \brief Struct to represent a 8 bit floating-point number. */ @@ -52,6 +53,151 @@ typedef __hip_fp8_e4m3 rccl_float8; typedef __hip_fp8_e5m2 rccl_bfloat8; #endif +typedef _Float16 half_t; +typedef _Float16 half2_t __attribute__((ext_vector_type(2))); + +typedef short shortx2_t __attribute__((ext_vector_type(2))); +typedef short __attribute__((ext_vector_type(2))) __amd_shortx2_storage_t; +typedef float float2_t __attribute__((ext_vector_type(2))); + + +inline __device__ rccl_float8 hadd(rccl_float8 x, rccl_float8 y) +{ +#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__) + half2_t v1; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(x.__x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(y.__x, 1.f, 0))); + union { + shortx2_t i16_vec; + rccl_float8 fp8[4]; + } u{0}; + u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(v1, v1, /* scale */ 1.f, 0); + return u.fp8[0]; +#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__) + union + { + uint32_t i32val; + rccl_float8 i8val[4]; + } val; + + float2_t v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(x.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(y.__x, 0))); + return __builtin_amdgcn_cvt_pk_fp8_f32(v[0], v[0], ival, false); +#else + return rccl_float8(float(x) + float(y)); +#endif +} + +inline __device__ rccl_bfloat8 hadd_b(rccl_bfloat8 x, rccl_bfloat8 y) +{ +#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__) + half2_t v1; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(x.__x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(y.__x, 1.f, 0))); + union { + shortx2_t i16_vec; + rccl_bfloat8 fp8[4]; + } u1{0}; + u1.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(v1, v1, /* scale */ 1.f, 0); + return u1.fp8[0]; +#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__) + + float2_t v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(x.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(y.__x, 0))); + return __builtin_amdgcn_cvt_pk_bf8_f32(v[0], v[0], ival, false); +#else + return rccl_bfloat8(float(x) + float(y)); +#endif +} + +inline __device__ fp8x2_storage_t hadd2(fp8x2_storage_t x, fp8x2_storage_t y) +{ +#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__) + half2_t v1; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_fp8(y, 1.f, 0))); + union { + shortx2_t i16_vec; + fp8x2_storage_t fp8; + } u{0}; + u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(v1, v1, /* scale */ 1.f, 0); + return u.fp8; +#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__) + float2_t v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(y, 0))); + return __builtin_amdgcn_cvt_pk_fp8_f32(v[0], v[1], ival, false); +#else + union { + rccl_float8 fp8[2]; + fp8x2_storage_t fp8x2; + } u, v, w; + u.fp8x2 = x; + v.fp8x2 = y; + w.fp8[0] = hadd(u.fp8[0], v.fp8[0]); + w.fp8[1] = hadd(u.fp8[1], v.fp8[1]); + return w.fp8x2; +#endif +} + +inline __device__ fp8x2_storage_t hadd2_b(fp8x2_storage_t x, fp8x2_storage_t y) +{ +#if __HIP_DEVICE_COMPILE__ && defined(__gfx950__) + half2_t v1; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(v1) : "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(x, 1.f, 0)), "v"(__builtin_amdgcn_cvt_scalef32_pk_f16_bf8(y, 1.f, 0))); + union { + shortx2_t i16_vec; + fp8x2_storage_t fp8; + } u{0}; + u.i16_vec = __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(v1, v1, /* scale */ 1.f, 0); + return u.fp8; +#elif __HIP_DEVICE_COMPILE__ && defined(__gfx942__) + float2_t v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" : "=v"(v) : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(y, 0))); + return __builtin_amdgcn_cvt_pk_bf8_f32(v[0], v[1], ival, false); +#else + union { + rccl_bfloat8 bfp8[2]; + fp8x2_storage_t bfp8x2; + } u, v, w; + u.bfp8x2 = x; + v.bfp8x2 = y; + w.bfp8[0] = hadd_b(u.bfp8[0], v.bfp8[0]); + w.bfp8[1] = hadd_b(u.bfp8[1], v.bfp8[1]); + return w.bfp8x2; +#endif +} + +inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8) +{ + return os << float(f8); +} + +inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8) +{ + return os << float(bf8); +} + +inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b) +{ + return float(a) * float(b); +} + +inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b) +{ + return float(a) * float(b); +} + +inline __host__ __device__ float operator*(rccl_float8 a, float b) +{ + return float(a) * float(b); +} + +inline __host__ __device__ float operator*(rccl_bfloat8 a, float b) +{ + return float(a) * float(b); +} + // For older versions of ROCm that do not include hip_fp8.h, // we provide a local version of the header file as a fallback. #else @@ -714,6 +860,43 @@ namespace std } } +inline __device__ rccl_float8 hadd(rccl_float8 x, rccl_float8 y) +{ + return rccl_float8(float(x) + float(y)); +} + +inline __device__ fp8x2_storage_t hadd2(fp8x2_storage_t x, fp8x2_storage_t y) +{ + union { + rccl_float8 fp8[2]; + fp8x2_storage_t fp8x2; + } u, v, w; + u.fp8x2 = x; + v.fp8x2 = y; + w.fp8[0] = hadd(u.fp8[0], v.fp8[0]); + w.fp8[1] = hadd(u.fp8[1], v.fp8[1]); + + return w.fp8x2; +} + +inline __device__ rccl_bfloat8 hadd_b(rccl_bfloat8 x, rccl_bfloat8 y) +{ + return rccl_bfloat8(float(x) + float(y)); +} + +inline __device__ fp8x2_storage_t hadd2_b(fp8x2_storage_t x, fp8x2_storage_t y) { + union { + rccl_bfloat8 fp8[2]; + fp8x2_storage_t fp8x2; + } u, v, w; + u.fp8x2 = x; + v.fp8x2 = y; + w.fp8[0] = hadd_b(u.fp8[0], v.fp8[0]); + w.fp8[1] = hadd_b(u.fp8[1], v.fp8[1]); + + return w.fp8x2; +} + // Special operator overloading inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8) {