Improving build time by removing the gfx11xx and host code from rccl_float8.h (#1789)
* removing extra build time by removing the gfx11xx arch from using hip_fp8
---------
Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>
[ROCm/rccl commit: 697bee4ee8]
このコミットが含まれているのは:
@@ -40,50 +40,17 @@ typedef struct
|
||||
} rccl_bfloat8;
|
||||
|
||||
// __cplusplus < 201103L || (!defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__))
|
||||
#elif HIP_VERSION >= 60300000
|
||||
#elif HIP_VERSION >= 60300000 && !(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1030__))
|
||||
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#if __HIP_DEVICE_COMPILE__ && (defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) || (defined(__gfx1100__) || defined(__gfx1101__)))//HIP_FP8_TYPE_OCP is enabled.
|
||||
typedef __hip_fp8_e4m3 rccl_float8;
|
||||
typedef __hip_fp8_e5m2 rccl_bfloat8;
|
||||
#elif __HIP_DEVICE_COMPILE__ && (defined(__gfx942__))
|
||||
#if __HIP_DEVICE_COMPILE__ && (defined(__gfx942__))
|
||||
typedef __hip_fp8_e4m3_fnuz rccl_float8;
|
||||
typedef __hip_fp8_e5m2_fnuz rccl_bfloat8;
|
||||
#else
|
||||
typedef __hip_fp8_e4m3 rccl_float8;
|
||||
typedef __hip_fp8_e5m2 rccl_bfloat8;
|
||||
#endif
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8)
|
||||
{
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8)
|
||||
{
|
||||
return os << float(bf8);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_float8 a, float b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float operator*(rccl_bfloat8 a, float b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
// For older versions of ROCm that do not include hip_fp8.h,
|
||||
// we provide a local version of the header file as a fallback.
|
||||
|
||||
@@ -201,10 +201,10 @@ namespace RcclUnitTesting
|
||||
case ncclUint32: ss << scalarsPerRank.U4[this->globalRank]; break;
|
||||
case ncclInt64: ss << scalarsPerRank.I8[this->globalRank]; break;
|
||||
case ncclUint64: ss << scalarsPerRank.U8[this->globalRank]; break;
|
||||
case ncclFloat8e4m3: ss << scalarsPerRank.F1[this->globalRank]; break;
|
||||
case ncclFloat8e4m3: ss << (float)scalarsPerRank.F1[this->globalRank]; break;
|
||||
case ncclFloat32: ss << scalarsPerRank.F4[this->globalRank]; break;
|
||||
case ncclFloat64: ss << scalarsPerRank.F8[this->globalRank]; break;
|
||||
case ncclFloat8e5m2: ss << scalarsPerRank.B1[this->globalRank]; break;
|
||||
case ncclFloat8e5m2: ss << (float)scalarsPerRank.B1[this->globalRank]; break;
|
||||
case ncclBfloat16: ss << scalarsPerRank.B2[this->globalRank]; break;
|
||||
default: ss << "(UNKNOWN)";
|
||||
}
|
||||
|
||||
@@ -234,11 +234,11 @@ namespace RcclUnitTesting
|
||||
case ncclUint32: U4[idx] *= scalarsPerRank.U4[rank]; break;
|
||||
case ncclInt64: I8[idx] *= scalarsPerRank.I8[rank]; break;
|
||||
case ncclUint64: U8[idx] *= scalarsPerRank.U8[rank]; break;
|
||||
case ncclFloat8e4m3: F1[idx] = rccl_float8(F1[idx] * scalarsPerRank.F1[rank]); break;
|
||||
case ncclFloat8e4m3: F1[idx] = rccl_float8((float)F1[idx] * (float)scalarsPerRank.F1[rank]); break;
|
||||
case ncclFloat16: F2[idx] = __float2half(__half2float(F2[idx]) * __half2float(scalarsPerRank.F2[rank])); break;
|
||||
case ncclFloat32: F4[idx] *= scalarsPerRank.F4[rank]; break;
|
||||
case ncclFloat64: F8[idx] *= scalarsPerRank.F8[rank]; break;
|
||||
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8(B1[idx] * scalarsPerRank.B1[rank]); break;
|
||||
case ncclFloat8e5m2: B1[idx] = rccl_bfloat8((float)B1[idx] * (float)scalarsPerRank.B1[rank]); break;
|
||||
case ncclBfloat16: B2[idx] *= scalarsPerRank.B2[rank]; break;
|
||||
default:
|
||||
ERROR("Unsupported datatype\n");
|
||||
|
||||
新しいイシューから参照
ユーザーをブロックする