diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f5a1790a7..55e3372ed1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -369,7 +369,6 @@ set(SRC_FILES src/include/param.h src/include/profiler.h src/include/proxy.h - src/include/rccl_bfloat16.h src/include/rccl_vars.h src/include/rccl_float8.h src/include/rocm_smi_wrap.h @@ -422,7 +421,6 @@ set(SRC_FILES src/include/param.h src/include/profiler.h src/include/proxy.h - src/include/rccl_bfloat16.h src/include/rccl_vars.h src/include/rocm_smi_wrap.h src/include/rocmwrap.h diff --git a/cmake/Generator.cmake b/cmake/Generator.cmake index 91fc54d98a..03a2b3fb7a 100644 --- a/cmake/Generator.cmake +++ b/cmake/Generator.cmake @@ -25,9 +25,9 @@ set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "Redu set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN") set(ALL_PROTOS "LL" "LL128" "SIMPLE") set(ALL_REDOPS "Sum" "Prod" "MinMax" "PreMulSum" "SumPostDiv") -set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "rccl_bfloat16" "rccl_float8" "rccl_bfloat8") +set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8") -set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16" "rccl_float8" "rccl_bfloat8") +set(FLOATS_LIST "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8") ################################################################################ # The command line argument is used as a regex to filter the functions @@ -435,4 +435,4 @@ function(gen_functions CONFIG_INPUT) gen_host_table() ## Generate host_table.cpp set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE) -endfunction() \ No newline at end of file +endfunction() diff --git a/src/device/msccl_kernel_impl.h b/src/device/msccl_kernel_impl.h index 3baf4be35c..98af766a93 100644 --- a/src/device/msccl_kernel_impl.h +++ b/src/device/msccl_kernel_impl.h @@ -414,7 +414,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \ - MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) \ + MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_float8, fullOps) \ MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps) diff --git a/src/device/onerank.cu b/src/device/onerank.cu index 62f4540a2a..d016e5efba 100644 --- a/src/device/onerank.cu +++ b/src/device/onerank.cu @@ -65,7 +65,7 @@ ncclResult_t ncclLaunchOneRank(void* dst, void const* src, size_t nElts, struct case ncclUint64: kernel = (void const*)&oneRankReduce>; break; case ncclFloat16: kernel = (void const*)&oneRankReduce>; break; #if defined(RCCL_BFLOAT16) - case ncclBfloat16: kernel = (void const*)&oneRankReduce>; break; + case ncclBfloat16: kernel = (void const*)&oneRankReduce>; break; #endif #if defined(RCCL_FLOAT8) case ncclFp8E4M3: kernel = (void const*)&oneRankReduce>; break; diff --git a/src/device/reduce_kernel.h b/src/device/reduce_kernel.h index 09d9314a9f..9483a4095e 100644 --- a/src/device/reduce_kernel.h +++ b/src/device/reduce_kernel.h @@ -22,7 +22,7 @@ template<> struct IsFloatingPoint: std::true_type {}; #if defined(RCCL_BFLOAT16) template<> -struct IsFloatingPoint: std::true_type {}; +struct IsFloatingPoint: std::true_type {}; #endif #if defined(RCCL_FLOAT8) template<> @@ -257,9 +257,9 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 1, __nv_bfloat16, fn.isMinNotMax ? __hmin(x, y) : __hmax(x, y)) SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 2, __nv_bfloat162, fn.isMinNotMax ? __hmin2(x, y) : __hmax2(x, y)) #else - SPECIALIZE_REDUCE(FuncSum, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)((float)(x) + (float)(y))) - SPECIALIZE_REDUCE(FuncProd, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)((float)(x) * (float)(y))) - SPECIALIZE_REDUCE(FuncMinMax, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y)))) + SPECIALIZE_REDUCE(FuncSum, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) + (float)(y))) + SPECIALIZE_REDUCE(FuncProd, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) * (float)(y))) + SPECIALIZE_REDUCE(FuncMinMax, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y)))) #endif #endif @@ -386,8 +386,8 @@ struct FuncPreMulSum { #if defined(RCCL_BFLOAT16) template<> - struct FuncPreMulSum { - using EltType = rccl_bfloat16; + struct FuncPreMulSum { + using EltType = hip_bfloat16; #if __CUDA_ARCH__ >= 800 __nv_bfloat162 scalar; __device__ FuncPreMulSum(uint64_t opArg=0) { @@ -399,7 +399,7 @@ struct FuncPreMulSum { #else float scalar; __device__ FuncPreMulSum(uint64_t opArg=0) { - union { uint64_t u64; rccl_bfloat16 val; }; + union { uint64_t u64; hip_bfloat16 val; }; u64 = opArg; scalar = (float)(val); } @@ -481,21 +481,21 @@ struct Apply_PreOp, /*EltPerPack=*/1> { #if defined(RCCL_BFLOAT16) template<> - struct Apply_PreOp, /*EltPerPack=*/1> { + struct Apply_PreOp, /*EltPerPack=*/1> { static constexpr bool IsIdentity = false; - __device__ static BytePack preOp( - FuncPreMulSum fn, BytePack a + __device__ static BytePack preOp( + FuncPreMulSum fn, BytePack a ) { #if __CUDA_ARCH__ >= 800 return toPack<__nv_bfloat16>(__hmul(fromPack<__nv_bfloat16>(a), fn.scalar.x)); #else - return toPack((rccl_bfloat16)((float)(fromPack(a)) * fn.scalar)); + return toPack((hip_bfloat16)((float)(fromPack(a)) * fn.scalar)); #endif } }; #if __CUDA_ARCH__ >= 800 template<> - struct Apply_PreOp, /*EltPerPack=*/2> { + struct Apply_PreOp, /*EltPerPack=*/2> { static constexpr bool IsIdentity = false; __device__ static BytePack preOp( FuncPreMulSum<__nv_bfloat16> fn, BytePack a @@ -732,8 +732,8 @@ struct Apply_LoadMultimem { DEFINE_Apply_LoadMultimem_minmax_v4x2_and_subhalf(half, f16x2, u32) #if defined(RCCL_BFLOAT16) - DEFINE_Apply_LoadMultimem_sum_v4x2_and_subhalf(rccl_bfloat16, bf16x2, u32) - DEFINE_Apply_LoadMultimem_minmax_v4x2_and_subhalf(rccl_bfloat16, bf16x2, u32) + DEFINE_Apply_LoadMultimem_sum_v4x2_and_subhalf(hip_bfloat16, bf16x2, u32) + DEFINE_Apply_LoadMultimem_minmax_v4x2_and_subhalf(hip_bfloat16, bf16x2, u32) #endif #else template diff --git a/src/enqueue.cc b/src/enqueue.cc index 5c501796bf..7aae7b8908 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -1558,7 +1558,7 @@ static ncclResult_t hostToDevRedOp( float f32; double f64; #if defined(RCCL_BFLOAT16) - rccl_bfloat16 bf16; + hip_bfloat16 bf16; #endif #if defined(RCCL_FLOAT8) rccl_float8 fp8_e4m3; @@ -1602,7 +1602,7 @@ static ncclResult_t hostToDevRedOp( #if defined(RCCL_BFLOAT16) case ncclBfloat16: opFull->op = ncclDevPreMulSum; - bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks)); + bf16 = (hip_bfloat16)(float(1.0/comm->nRanks)); break; #endif #if defined(RCCL_FLOAT8) diff --git a/src/include/device.h b/src/include/device.h index 0fc2acacfd..772b09571d 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -11,7 +11,7 @@ #include "nccl.h" #include "rccl_float8.h" -#include "rccl_bfloat16.h" +#include #include "nccl_common.h" #include "align.h" #include "collectives.h" diff --git a/src/include/msccl/msccl_kernel.h b/src/include/msccl/msccl_kernel.h index be343a264f..f7b334e97c 100644 --- a/src/include/msccl/msccl_kernel.h +++ b/src/include/msccl/msccl_kernel.h @@ -26,7 +26,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct n MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \ - MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) \ + MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_float8, fullOps) \ MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps) diff --git a/src/include/rccl_bfloat16.h b/src/include/rccl_bfloat16.h deleted file mode 100755 index cbc6e059a5..0000000000 --- a/src/include/rccl_bfloat16.h +++ /dev/null @@ -1,274 +0,0 @@ -/** - * MIT License - * - * Copyright 2019-2020 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/*!\file - * \brief rccl_bfloat16.h provides struct for rccl_bfloat16 typedef - */ - -#ifndef _RCCL_BFLOAT16_H_ -#define _RCCL_BFLOAT16_H_ - -#if __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__)) - -// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only -// include a minimal definition of rccl_bfloat16 - -#include -/*! \brief Struct to represent a 16 bit brain floating point number. */ -typedef struct -{ - uint16_t data; -} rccl_bfloat16; - -#else // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__)) - -#include -#include -#include -#include -#include -#include - -struct rccl_bfloat16 -{ - uint16_t data; - - enum truncate_t - { - truncate - }; - - __host__ __device__ rccl_bfloat16() = default; - - // round upper 16 bits of IEEE float to convert to bfloat16 - explicit __host__ __device__ rccl_bfloat16(float f) - : data(float_to_bfloat16(f)) - { - } - - explicit __host__ __device__ rccl_bfloat16(float f, truncate_t) - : data(truncate_float_to_bfloat16(f)) - { - } - - // zero extend lower 16 bits of bfloat16 to convert to IEEE float - __host__ __device__ operator float() const - { - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(data) << 16}; - return u.fp32; - } - -private: - static __host__ __device__ uint16_t float_to_bfloat16(float f) - { - union - { - float fp32; - uint32_t int32; - } u = {f}; - if(~u.int32 & 0x7f800000) - { - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even - } - else if(u.int32 & 0xffff) - { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - u.int32 |= 0x10000; // Preserve signaling NaN - } - return uint16_t(u.int32 >> 16); - } - - // Truncate instead of rounding, preserving SNaN - static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f) - { - union - { - float fp32; - uint32_t int32; - } u = {f}; - return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); - } -}; - -typedef struct -{ - uint16_t data; -} rccl_bfloat16_public; - -static_assert(std::is_standard_layout{}, - "rccl_bfloat16 is not a standard layout type, and thus is " - "incompatible with C."); - -static_assert(std::is_trivial{}, - "rccl_bfloat16 is not a trivial type, and thus is " - "incompatible with C."); - -static_assert(sizeof(rccl_bfloat16) == sizeof(rccl_bfloat16_public) - && offsetof(rccl_bfloat16, data) == offsetof(rccl_bfloat16_public, data), - "internal rccl_bfloat16 does not match public rccl_bfloat16"); - -inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat16& bf16) -{ - return os << float(bf16); -} -inline __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a) -{ - return a; -} -inline __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a) -{ - a.data ^= 0x8000; - return a; -} -inline __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) + float(b)); -} -inline __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) - float(b)); -} -inline __host__ __device__ rccl_bfloat16 operator*(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) * float(b)); -} -inline __host__ __device__ rccl_bfloat16 operator/(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) / float(b)); -} -inline __host__ __device__ bool operator<(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return float(a) < float(b); -} -inline __host__ __device__ bool operator==(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return float(a) == float(b); -} -inline __host__ __device__ bool operator>(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return b < a; -} -inline __host__ __device__ bool operator<=(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return !(a > b); -} -inline __host__ __device__ bool operator!=(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return !(a == b); -} -inline __host__ __device__ bool operator>=(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return !(a < b); -} -inline __host__ __device__ rccl_bfloat16& operator+=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a + b; -} -inline __host__ __device__ rccl_bfloat16& operator-=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a - b; -} -inline __host__ __device__ rccl_bfloat16& operator*=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a * b; -} -inline __host__ __device__ rccl_bfloat16& operator/=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a / b; -} -inline __host__ __device__ rccl_bfloat16& operator++(rccl_bfloat16& a) -{ - return a += rccl_bfloat16(1.0f); -} -inline __host__ __device__ rccl_bfloat16& operator--(rccl_bfloat16& a) -{ - return a -= rccl_bfloat16(1.0f); -} -inline __host__ __device__ rccl_bfloat16 operator++(rccl_bfloat16& a, int) -{ - rccl_bfloat16 orig = a; - ++a; - return orig; -} -inline __host__ __device__ rccl_bfloat16 operator--(rccl_bfloat16& a, int) -{ - rccl_bfloat16 orig = a; - --a; - return orig; -} - -namespace std -{ - constexpr __host__ __device__ bool isinf(rccl_bfloat16 a) - { - return !(~a.data & 0x7f80) && !(a.data & 0x7f); - } - constexpr __host__ __device__ bool isnan(rccl_bfloat16 a) - { - return !(~a.data & 0x7f80) && +(a.data & 0x7f); - } - constexpr __host__ __device__ bool iszero(rccl_bfloat16 a) - { - return !(a.data & 0x7fff); - } - inline rccl_bfloat16 sin(rccl_bfloat16 a) - { - return rccl_bfloat16(sinf(float(a))); - } - inline rccl_bfloat16 cos(rccl_bfloat16 a) - { - return rccl_bfloat16(cosf(float(a))); - } -} - -#endif // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) - -#endif // _RCCL_BFLOAT16_H_ diff --git a/src/misc/msccl/msccl_setup.cc b/src/misc/msccl/msccl_setup.cc index e755f0f6a0..611371c21c 100644 --- a/src/misc/msccl/msccl_setup.cc +++ b/src/misc/msccl/msccl_setup.cc @@ -220,7 +220,7 @@ static ncclResult_t hostToDevRedOp( uint64_t u64; half f16; #if defined(RCCL_BFLOAT16) - rccl_bfloat16 bf16; + hip_bfloat16 bf16; #endif #if defined(RCCL_FLOAT8) rccl_float8 fp8_e4m3; @@ -266,7 +266,7 @@ static ncclResult_t hostToDevRedOp( #if defined(RCCL_BFLOAT16) case ncclBfloat16: opFull->op = ncclDevPreMulSum; - bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks)); + bf16 = (hip_bfloat16)(float(1.0/comm->nRanks)); break; #endif #if defined(RCCL_FLOAT8) @@ -325,7 +325,7 @@ static ncclResult_t hostToDevRedOp( MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half, fullOps), \ MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float, fullOps), \ MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double, fullOps), \ - MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps), \ + MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps), \ MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_float8, fullOps), \ MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps) diff --git a/test/common/PtrUnion.cpp b/test/common/PtrUnion.cpp index ea3aaf922d..7d67a5cdc0 100644 --- a/test/common/PtrUnion.cpp +++ b/test/common/PtrUnion.cpp @@ -167,7 +167,7 @@ namespace RcclUnitTesting case ncclFloat32: F4[idx] = valueF; break; case ncclFloat64: F8[idx] = valueF; break; case ncclFp8E5M2: B1[idx] = rccl_bfloat8(valueF); break; - case ncclBfloat16: B2[idx] = rccl_bfloat16(static_cast(valueF)); break; + case ncclBfloat16: B2[idx] = hip_bfloat16(static_cast(valueF)); break; default: ERROR("Unsupported datatype\n"); return TEST_FAIL; @@ -286,7 +286,7 @@ namespace RcclUnitTesting case ncclFloat32: F4[idx] /= divisor; break; case ncclFloat64: F8[idx] /= divisor; break; case ncclFp8E5M2: B1[idx] = (rccl_bfloat8((float)(B1[idx]) / divisor)); break; - case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break; + case ncclBfloat16: B2[idx] = (hip_bfloat16((float)(B2[idx]) / divisor)); break; default: ERROR("Unsupported datatype\n"); return TEST_FAIL; diff --git a/test/common/PtrUnion.hpp b/test/common/PtrUnion.hpp index 467ef5a53f..bed042f176 100644 --- a/test/common/PtrUnion.hpp +++ b/test/common/PtrUnion.hpp @@ -8,7 +8,7 @@ #include "ErrCode.hpp" #include "rccl/rccl.h" #include "rccl_float8.h" -#include "rccl_bfloat16.h" +#include #include "hip/hip_fp16.h" namespace RcclUnitTesting @@ -48,7 +48,7 @@ namespace RcclUnitTesting float* F4; // ncclFloat32 double* F8; // ncclFloat64 rccl_bfloat8* B1; // ncclFp8E5M2 - rccl_bfloat16* B2; // ncclBfloat16 + hip_bfloat16* B2; // ncclBfloat16 constexpr PtrUnion() : ptr(nullptr) {}