2
0

Revert __nv_bfloat16 back to hip_bfloat16 (#64)

Este cometimento está contido em:
Bertan Dogancay
2024-03-06 11:11:44 -07:00
cometido por GitHub
ascendente 88cf7dbf45
cometimento 7a7a5969d0
2 ficheiros modificados com 12 adições e 12 eliminações
+1 -1
Ver ficheiro
@@ -28,7 +28,7 @@
using std::uint64_t;
using std::uint32_t;
using bfloat16 = __nv_bfloat16;
using bfloat16 = hip_bfloat16;
template<typename T>
struct float_traits;
+11 -11
Ver ficheiro
@@ -91,7 +91,7 @@ template<>
struct IsIntegral<__half>: std::false_type {};
#if RCCL_BFLOAT16 == 1
template<>
struct IsIntegral<__nv_bfloat16>: std::false_type {};
struct IsIntegral<hip_bfloat16>: std::false_type {};
#endif
}
@@ -126,7 +126,7 @@ namespace {
}
#if RCCL_BFLOAT16 == 1
template<>
__host__ __device__ __nv_bfloat16 castTo<__nv_bfloat16>(float x) {
__host__ __device__ hip_bfloat16 castTo<hip_bfloat16>(float x) {
return hip_bfloat16(x);
}
#endif
@@ -153,7 +153,7 @@ struct ReduceSum {
return __float2half(__half2float(a) + __half2float(b));
}
#if RCCL_BFLOAT16 == 1
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
return hip_bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
#endif
@@ -169,7 +169,7 @@ struct ReduceProd {
return __float2half(__half2float(a) * __half2float(b));
}
#if RCCL_BFLOAT16 == 1
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
return hip_bfloat16(static_cast<float>(a) * static_cast<float>(b));
}
#endif
@@ -185,7 +185,7 @@ struct ReduceMin {
return __half2float(a) < __half2float(b) ? a : b;
}
#if RCCL_BFLOAT16 == 1
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
return static_cast<float>(a) < static_cast<float>(b) ? a : b;
}
#endif
@@ -201,7 +201,7 @@ struct ReduceMax {
return __half2float(a) > __half2float(b) ? a : b;
}
#if RCCL_BFLOAT16 == 1
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
return static_cast<float>(a) > static_cast<float>(b) ? a : b;
}
#endif
@@ -280,7 +280,7 @@ struct FloatLayout<__half> {
};
#if RCCL_BFLOAT16 == 1
template<>
struct FloatLayout<__nv_bfloat16> {
struct FloatLayout<hip_bfloat16> {
static constexpr int exponent_bits = 8, mantissa_bits = 7;
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
};
@@ -814,7 +814,7 @@ void prepareInput1(
case ncclUint64: CASE_TY(uint64_t)
case ncclFloat16: CASE_TY(__half)
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(__nv_bfloat16)
case ncclBfloat16: CASE_TY(hip_bfloat16)
#endif
case ncclFloat32: CASE_TY(float)
case ncclFloat64: CASE_TY(double)
@@ -890,7 +890,7 @@ void prepareExpected1(
case ncclUint64: CASE_TY(uint64_t)
case ncclFloat16: CASE_TY(__half)
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(__nv_bfloat16)
case ncclBfloat16: CASE_TY(hip_bfloat16)
#endif
case ncclFloat32: CASE_TY(float)
case ncclFloat64: CASE_TY(double)
@@ -1112,7 +1112,7 @@ void ncclVerifiableVerify(
case ncclUint64: CASE_TY(uint64_t, uint64_t)
case ncclFloat16: CASE_TY(__half, uint16_t)
#if HAVE_ncclBfloat16
case ncclBfloat16: CASE_TY(__nv_bfloat16, uint16_t)
case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t)
#endif
case ncclFloat32: CASE_TY(float, uint32_t)
case ncclFloat64: CASE_TY(double, uint64_t)
@@ -1179,7 +1179,7 @@ __global__ void sweep() {
sweep1<uint64_t>(ncclUint64, "uint64");
sweep1<__half>(ncclFloat16, "half");
#if HAVE_ncclBfloat16
sweep1<__nv_bfloat16>(ncclBfloat16, "bfloat16");
sweep1<hip_bfloat16>(ncclBfloat16, "bfloat16");
#endif
sweep1<float>(ncclFloat32, "float");
sweep1<double>(ncclFloat64, "double");