Improving build time by removing the gfx11xx and host code from rccl_float8.h (#1789)

* removing extra build time by removing the gfx11xx arch from using hip_fp8

---------

Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>

[ROCm/rccl commit: 697bee4ee8]
このコミットが含まれているのは:
mberenjk
2025-07-09 14:03:47 -05:00
committed by GitHub
コミット 1623fcc7a1
3個のファイルの変更6行の追加39行の削除
+2 -35
ファイルの表示
@@ -40,50 +40,17 @@ typedef struct
} rccl_bfloat8;
// __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__))
#elif HIP_VERSION >= 60300000
#elif HIP_VERSION >= 60300000 && !(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1030__))
#include <hip/hip_fp8.h>
#if __HIP_DEVICE_COMPILE__ && (defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) || (defined(__gfx1100__) || defined(__gfx1101__)))//HIP_FP8_TYPE_OCP is enabled.
typedef __hip_fp8_e4m3 rccl_float8;
typedef __hip_fp8_e5m2 rccl_bfloat8;
#elif __HIP_DEVICE_COMPILE__ && (defined(__gfx942__))
#if __HIP_DEVICE_COMPILE__ && (defined(__gfx942__))
typedef __hip_fp8_e4m3_fnuz rccl_float8;
typedef __hip_fp8_e5m2_fnuz rccl_bfloat8;
#else
typedef __hip_fp8_e4m3 rccl_float8;
typedef __hip_fp8_e5m2 rccl_bfloat8;
#endif
inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
{
return os << float(f8);
}
inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8)
{
return os << float(bf8);
}
inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(rccl_float8 a, float b)
{
return float(a) * float(b);
}
inline __host__ __device__ float operator*(rccl_bfloat8 a, float b)
{
return float(a) * float(b);
}
// For older versions of ROCm that do not include hip_fp8.h,
// we provide a local version of the header file as a fallback.
+2 -2
ファイルの表示
@@ -201,10 +201,10 @@ namespace RcclUnitTesting
case ncclUint32: ss << scalarsPerRank.U4[this->globalRank]; break;
case ncclInt64: ss << scalarsPerRank.I8[this->globalRank]; break;
case ncclUint64: ss << scalarsPerRank.U8[this->globalRank]; break;
case ncclFloat8e4m3: ss << scalarsPerRank.F1[this->globalRank]; break;
case ncclFloat8e4m3: ss << (float)scalarsPerRank.F1[this->globalRank]; break;
case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break;
case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break;
case ncclFloat8e5m2: ss << scalarsPerRank.B1[this->globalRank]; break;
case ncclFloat8e5m2: ss << (float)scalarsPerRank.B1[this->globalRank]; break;
case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break;
default: ss << "(UNKNOWN)";
}
+2 -2
ファイルの表示
@@ -234,11 +234,11 @@ namespace RcclUnitTesting
case ncclUint32: U4[idx] *= scalarsPerRank.U4[rank]; break;
case ncclInt64: I8[idx] *= scalarsPerRank.I8[rank]; break;
case ncclUint64: U8[idx] *= scalarsPerRank.U8[rank]; break;
case ncclFloat8e4m3: F1[idx] = rccl_float8(F1[idx] * scalarsPerRank.F1[rank]); break;
case ncclFloat8e4m3: F1[idx] = rccl_float8((float)F1[idx] * (float)scalarsPerRank.F1[rank]); break;
case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx]) * __half2float(scalarsPerRank.F2[rank])); break;
case ncclFloat32: F4[idx] *= scalarsPerRank.F4[rank]; break;
case ncclFloat64: F8[idx] *= scalarsPerRank.F8[rank]; break;
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(B1[idx] * scalarsPerRank.B1[rank]); break;
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8((float)B1[idx] * (float)scalarsPerRank.B1[rank]); break;
case ncclBfloat16: B2[idx] *= scalarsPerRank.B2[rank]; break;
default:
ERROR("Unsupported datatype\n");