Revert __nv_bfloat16 back to hip_bfloat16 (#64)
[ROCm/rccl-tests commit: 7a7a5969d0]
This commit is contained in:
@@ -28,7 +28,7 @@
|
||||
|
||||
using std::uint64_t;
|
||||
using std::uint32_t;
|
||||
using bfloat16 = __nv_bfloat16;
|
||||
using bfloat16 = hip_bfloat16;
|
||||
|
||||
template<typename T>
|
||||
struct float_traits;
|
||||
|
||||
@@ -91,7 +91,7 @@ template<>
|
||||
struct IsIntegral<__half>: std::false_type {};
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
template<>
|
||||
struct IsIntegral<__nv_bfloat16>: std::false_type {};
|
||||
struct IsIntegral<hip_bfloat16>: std::false_type {};
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ namespace {
|
||||
}
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
template<>
|
||||
__host__ __device__ __nv_bfloat16 castTo<__nv_bfloat16>(float x) {
|
||||
__host__ __device__ hip_bfloat16 castTo<hip_bfloat16>(float x) {
|
||||
return hip_bfloat16(x);
|
||||
}
|
||||
#endif
|
||||
@@ -153,7 +153,7 @@ struct ReduceSum {
|
||||
return __float2half(__half2float(a) + __half2float(b));
|
||||
}
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
|
||||
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
|
||||
return hip_bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
||||
}
|
||||
#endif
|
||||
@@ -169,7 +169,7 @@ struct ReduceProd {
|
||||
return __float2half(__half2float(a) * __half2float(b));
|
||||
}
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
|
||||
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
|
||||
return hip_bfloat16(static_cast<float>(a) * static_cast<float>(b));
|
||||
}
|
||||
#endif
|
||||
@@ -185,7 +185,7 @@ struct ReduceMin {
|
||||
return __half2float(a) < __half2float(b) ? a : b;
|
||||
}
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
|
||||
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
|
||||
return static_cast<float>(a) < static_cast<float>(b) ? a : b;
|
||||
}
|
||||
#endif
|
||||
@@ -201,7 +201,7 @@ struct ReduceMax {
|
||||
return __half2float(a) > __half2float(b) ? a : b;
|
||||
}
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
__host__ __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const {
|
||||
__host__ __device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b) const {
|
||||
return static_cast<float>(a) > static_cast<float>(b) ? a : b;
|
||||
}
|
||||
#endif
|
||||
@@ -280,7 +280,7 @@ struct FloatLayout<__half> {
|
||||
};
|
||||
#if RCCL_BFLOAT16 == 1
|
||||
template<>
|
||||
struct FloatLayout<__nv_bfloat16> {
|
||||
struct FloatLayout<hip_bfloat16> {
|
||||
static constexpr int exponent_bits = 8, mantissa_bits = 7;
|
||||
static constexpr int exponent_bias = (1<<(exponent_bits-1))-1;
|
||||
};
|
||||
@@ -814,7 +814,7 @@ void prepareInput1(
|
||||
case ncclUint64: CASE_TY(uint64_t)
|
||||
case ncclFloat16: CASE_TY(__half)
|
||||
#if HAVE_ncclBfloat16
|
||||
case ncclBfloat16: CASE_TY(__nv_bfloat16)
|
||||
case ncclBfloat16: CASE_TY(hip_bfloat16)
|
||||
#endif
|
||||
case ncclFloat32: CASE_TY(float)
|
||||
case ncclFloat64: CASE_TY(double)
|
||||
@@ -890,7 +890,7 @@ void prepareExpected1(
|
||||
case ncclUint64: CASE_TY(uint64_t)
|
||||
case ncclFloat16: CASE_TY(__half)
|
||||
#if HAVE_ncclBfloat16
|
||||
case ncclBfloat16: CASE_TY(__nv_bfloat16)
|
||||
case ncclBfloat16: CASE_TY(hip_bfloat16)
|
||||
#endif
|
||||
case ncclFloat32: CASE_TY(float)
|
||||
case ncclFloat64: CASE_TY(double)
|
||||
@@ -1112,7 +1112,7 @@ void ncclVerifiableVerify(
|
||||
case ncclUint64: CASE_TY(uint64_t, uint64_t)
|
||||
case ncclFloat16: CASE_TY(__half, uint16_t)
|
||||
#if HAVE_ncclBfloat16
|
||||
case ncclBfloat16: CASE_TY(__nv_bfloat16, uint16_t)
|
||||
case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t)
|
||||
#endif
|
||||
case ncclFloat32: CASE_TY(float, uint32_t)
|
||||
case ncclFloat64: CASE_TY(double, uint64_t)
|
||||
@@ -1179,7 +1179,7 @@ __global__ void sweep() {
|
||||
sweep1<uint64_t>(ncclUint64, "uint64");
|
||||
sweep1<__half>(ncclFloat16, "half");
|
||||
#if HAVE_ncclBfloat16
|
||||
sweep1<__nv_bfloat16>(ncclBfloat16, "bfloat16");
|
||||
sweep1<hip_bfloat16>(ncclBfloat16, "bfloat16");
|
||||
#endif
|
||||
sweep1<float>(ncclFloat32, "float");
|
||||
sweep1<double>(ncclFloat64, "double");
|
||||
|
||||
Reference in New Issue
Block a user