replacing rccl_bfloat16 with hip_bfloat16 (#1126)

Co-authored-by: mberenjk <mberenjk@amd.com>

[ROCm/rccl commit: 428837ffe4]
Tento commit je obsažen v:
mberenjk
2024-04-11 11:30:37 -05:00
odevzdal GitHub
rodič 165d51b255
revize da835cff9c
12 změnil soubory, kde provedl 30 přidání a 306 odebrání
-2
Zobrazit soubor
@@ -369,7 +369,6 @@ set(SRC_FILES
src/include/param.h
src/include/profiler.h
src/include/proxy.h
src/include/rccl_bfloat16.h
src/include/rccl_vars.h
src/include/rccl_float8.h
src/include/rocm_smi_wrap.h
@@ -422,7 +421,6 @@ set(SRC_FILES
src/include/param.h
src/include/profiler.h
src/include/proxy.h
src/include/rccl_bfloat16.h
src/include/rccl_vars.h
src/include/rocm_smi_wrap.h
src/include/rocmwrap.h
+3 -3
Zobrazit soubor
@@ -25,9 +25,9 @@ set(ALL_COLLS "AllGather" "AllReduce" "AllToAllPivot" "Broadcast" "Reduce" "Redu
set(ALL_ALGOS "TREE" "RING" "COLLNET_DIRECT" "COLLNET_CHAIN")
set(ALL_PROTOS "LL" "LL128" "SIMPLE")
set(ALL_REDOPS "Sum" "Prod" "MinMax" "PreMulSum" "SumPostDiv")
set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "rccl_bfloat16" "rccl_float8" "rccl_bfloat8")
set(ALL_TYPES "int8_t" "uint8_t" "int32_t" "uint32_t" "int64_t" "uint64_t" "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8")
set(FLOATS_LIST "half" "float" "double" "rccl_bfloat16" "rccl_float8" "rccl_bfloat8")
set(FLOATS_LIST "half" "float" "double" "hip_bfloat16" "rccl_float8" "rccl_bfloat8")
################################################################################
# The command line argument is used as a regex to filter the functions
@@ -435,4 +435,4 @@ function(gen_functions CONFIG_INPUT)
gen_host_table() ## Generate host_table.cpp
set(HIP_SOURCES ${HIP_SOURCES} PARENT_SCOPE)
endfunction()
endfunction()
+1 -1
Zobrazit soubor
@@ -414,7 +414,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_float8, fullOps) \
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps)
+1 -1
Zobrazit soubor
@@ -65,7 +65,7 @@ ncclResult_t ncclLaunchOneRank(void* dst, void const* src, size_t nElts, struct
case ncclUint64: kernel = (void const*)&oneRankReduce<FuncPreMulSum<uint64_t>>; break;
case ncclFloat16: kernel = (void const*)&oneRankReduce<FuncPreMulSum<half>>; break;
#if defined(RCCL_BFLOAT16)
case ncclBfloat16: kernel = (void const*)&oneRankReduce<FuncPreMulSum<rccl_bfloat16>>; break;
case ncclBfloat16: kernel = (void const*)&oneRankReduce<FuncPreMulSum<hip_bfloat16>>; break;
#endif
#if defined(RCCL_FLOAT8)
case ncclFp8E4M3: kernel = (void const*)&oneRankReduce<FuncPreMulSum<rccl_float8>>; break;
+14 -14
Zobrazit soubor
@@ -22,7 +22,7 @@ template<>
struct IsFloatingPoint<half>: std::true_type {};
#if defined(RCCL_BFLOAT16)
template<>
struct IsFloatingPoint<rccl_bfloat16>: std::true_type {};
struct IsFloatingPoint<hip_bfloat16>: std::true_type {};
#endif
#if defined(RCCL_FLOAT8)
template<>
@@ -257,9 +257,9 @@ SPECIALIZE_REDUCE(FuncMinMax, half, 1, half, fn.isMinNotMax ? __hmin(x, y) : __h
SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 1, __nv_bfloat16, fn.isMinNotMax ? __hmin(x, y) : __hmax(x, y))
SPECIALIZE_REDUCE(FuncMinMax, __nv_bfloat16, 2, __nv_bfloat162, fn.isMinNotMax ? __hmin2(x, y) : __hmax2(x, y))
#else
SPECIALIZE_REDUCE(FuncSum, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)((float)(x) + (float)(y)))
SPECIALIZE_REDUCE(FuncProd, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)((float)(x) * (float)(y)))
SPECIALIZE_REDUCE(FuncMinMax, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y))))
SPECIALIZE_REDUCE(FuncSum, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) + (float)(y)))
SPECIALIZE_REDUCE(FuncProd, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)((float)(x) * (float)(y)))
SPECIALIZE_REDUCE(FuncMinMax, hip_bfloat16, 1, hip_bfloat16, (hip_bfloat16)(fn.isMinNotMax ? fminf((float)(x), (float)(y)) : fmaxf((float)(x), (float)(y))))
#endif
#endif
@@ -386,8 +386,8 @@ struct FuncPreMulSum<half> {
#if defined(RCCL_BFLOAT16)
template<>
struct FuncPreMulSum<rccl_bfloat16> {
using EltType = rccl_bfloat16;
struct FuncPreMulSum<hip_bfloat16> {
using EltType = hip_bfloat16;
#if __CUDA_ARCH__ >= 800
__nv_bfloat162 scalar;
__device__ FuncPreMulSum(uint64_t opArg=0) {
@@ -399,7 +399,7 @@ struct FuncPreMulSum<half> {
#else
float scalar;
__device__ FuncPreMulSum(uint64_t opArg=0) {
union { uint64_t u64; rccl_bfloat16 val; };
union { uint64_t u64; hip_bfloat16 val; };
u64 = opArg;
scalar = (float)(val);
}
@@ -481,21 +481,21 @@ struct Apply_PreOp<FuncPreMulSum<half>, /*EltPerPack=*/1> {
#if defined(RCCL_BFLOAT16)
template<>
struct Apply_PreOp<FuncPreMulSum<rccl_bfloat16>, /*EltPerPack=*/1> {
struct Apply_PreOp<FuncPreMulSum<hip_bfloat16>, /*EltPerPack=*/1> {
static constexpr bool IsIdentity = false;
__device__ static BytePack<sizeof(rccl_bfloat16)> preOp(
FuncPreMulSum<rccl_bfloat16> fn, BytePack<sizeof(rccl_bfloat16)> a
__device__ static BytePack<sizeof(hip_bfloat16)> preOp(
FuncPreMulSum<hip_bfloat16> fn, BytePack<sizeof(hip_bfloat16)> a
) {
#if __CUDA_ARCH__ >= 800
return toPack<__nv_bfloat16>(__hmul(fromPack<__nv_bfloat16>(a), fn.scalar.x));
#else
return toPack<rccl_bfloat16>((rccl_bfloat16)((float)(fromPack<rccl_bfloat16>(a)) * fn.scalar));
return toPack<hip_bfloat16>((hip_bfloat16)((float)(fromPack<hip_bfloat16>(a)) * fn.scalar));
#endif
}
};
#if __CUDA_ARCH__ >= 800
template<>
struct Apply_PreOp<FuncPreMulSum<rccl_bfloat16>, /*EltPerPack=*/2> {
struct Apply_PreOp<FuncPreMulSum<hip_bfloat16>, /*EltPerPack=*/2> {
static constexpr bool IsIdentity = false;
__device__ static BytePack<sizeof(__nv_bfloat162)> preOp(
FuncPreMulSum<__nv_bfloat16> fn, BytePack<sizeof(__nv_bfloat162)> a
@@ -732,8 +732,8 @@ struct Apply_LoadMultimem {
DEFINE_Apply_LoadMultimem_minmax_v4x2_and_subhalf(half, f16x2, u32)
#if defined(RCCL_BFLOAT16)
DEFINE_Apply_LoadMultimem_sum_v4x2_and_subhalf(rccl_bfloat16, bf16x2, u32)
DEFINE_Apply_LoadMultimem_minmax_v4x2_and_subhalf(rccl_bfloat16, bf16x2, u32)
DEFINE_Apply_LoadMultimem_sum_v4x2_and_subhalf(hip_bfloat16, bf16x2, u32)
DEFINE_Apply_LoadMultimem_minmax_v4x2_and_subhalf(hip_bfloat16, bf16x2, u32)
#endif
#else
template<typename Fn>
+2 -2
Zobrazit soubor
@@ -1558,7 +1558,7 @@ static ncclResult_t hostToDevRedOp(
float f32;
double f64;
#if defined(RCCL_BFLOAT16)
rccl_bfloat16 bf16;
hip_bfloat16 bf16;
#endif
#if defined(RCCL_FLOAT8)
rccl_float8 fp8_e4m3;
@@ -1602,7 +1602,7 @@ static ncclResult_t hostToDevRedOp(
#if defined(RCCL_BFLOAT16)
case ncclBfloat16:
opFull->op = ncclDevPreMulSum;
bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks));
bf16 = (hip_bfloat16)(float(1.0/comm->nRanks));
break;
#endif
#if defined(RCCL_FLOAT8)
+1 -1
Zobrazit soubor
@@ -11,7 +11,7 @@
#include "nccl.h"
#include "rccl_float8.h"
#include "rccl_bfloat16.h"
#include <hip/hip_bfloat16.h>
#include "nccl_common.h"
#include "align.h"
#include "collectives.h"
+1 -1
Zobrazit soubor
@@ -26,7 +26,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, proto, fullOps)(struct n
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_float8, fullOps) \
MSCCL_DECL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps)
-274
Zobrazit soubor
@@ -1,274 +0,0 @@
/**
* MIT License
*
* Copyright 2019-2020 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
/*!\file
* \brief rccl_bfloat16.h provides struct for rccl_bfloat16 typedef
*/
#ifndef _RCCL_BFLOAT16_H_
#define _RCCL_BFLOAT16_H_
#if __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__))
// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
// include a minimal definition of rccl_bfloat16
#include <stdint.h>
/*! \brief Struct to represent a 16 bit brain floating point number. */
typedef struct
{
uint16_t data;
} rccl_bfloat16;
#else // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__))
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <hip/hip_runtime.h>
#include <ostream>
#include <type_traits>
struct rccl_bfloat16
{
uint16_t data;
enum truncate_t
{
truncate
};
__host__ __device__ rccl_bfloat16() = default;
// round upper 16 bits of IEEE float to convert to bfloat16
explicit __host__ __device__ rccl_bfloat16(float f)
: data(float_to_bfloat16(f))
{
}
explicit __host__ __device__ rccl_bfloat16(float f, truncate_t)
: data(truncate_float_to_bfloat16(f))
{
}
// zero extend lower 16 bits of bfloat16 to convert to IEEE float
__host__ __device__ operator float() const
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(data) << 16};
return u.fp32;
}
private:
static __host__ __device__ uint16_t float_to_bfloat16(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
// Truncate instead of rounding, preserving SNaN
static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
};
typedef struct
{
uint16_t data;
} rccl_bfloat16_public;
static_assert(std::is_standard_layout<rccl_bfloat16>{},
"rccl_bfloat16 is not a standard layout type, and thus is "
"incompatible with C.");
static_assert(std::is_trivial<rccl_bfloat16>{},
"rccl_bfloat16 is not a trivial type, and thus is "
"incompatible with C.");
static_assert(sizeof(rccl_bfloat16) == sizeof(rccl_bfloat16_public)
&& offsetof(rccl_bfloat16, data) == offsetof(rccl_bfloat16_public, data),
"internal rccl_bfloat16 does not match public rccl_bfloat16");
inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat16& bf16)
{
return os << float(bf16);
}
inline __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a)
{
return a;
}
inline __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a)
{
a.data ^= 0x8000;
return a;
}
inline __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) + float(b));
}
inline __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) - float(b));
}
inline __host__ __device__ rccl_bfloat16 operator*(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) * float(b));
}
inline __host__ __device__ rccl_bfloat16 operator/(rccl_bfloat16 a, rccl_bfloat16 b)
{
return rccl_bfloat16(float(a) / float(b));
}
inline __host__ __device__ bool operator<(rccl_bfloat16 a, rccl_bfloat16 b)
{
return float(a) < float(b);
}
inline __host__ __device__ bool operator==(rccl_bfloat16 a, rccl_bfloat16 b)
{
return float(a) == float(b);
}
inline __host__ __device__ bool operator>(rccl_bfloat16 a, rccl_bfloat16 b)
{
return b < a;
}
inline __host__ __device__ bool operator<=(rccl_bfloat16 a, rccl_bfloat16 b)
{
return !(a > b);
}
inline __host__ __device__ bool operator!=(rccl_bfloat16 a, rccl_bfloat16 b)
{
return !(a == b);
}
inline __host__ __device__ bool operator>=(rccl_bfloat16 a, rccl_bfloat16 b)
{
return !(a < b);
}
inline __host__ __device__ rccl_bfloat16& operator+=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a + b;
}
inline __host__ __device__ rccl_bfloat16& operator-=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a - b;
}
inline __host__ __device__ rccl_bfloat16& operator*=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a * b;
}
inline __host__ __device__ rccl_bfloat16& operator/=(rccl_bfloat16& a, rccl_bfloat16 b)
{
return a = a / b;
}
inline __host__ __device__ rccl_bfloat16& operator++(rccl_bfloat16& a)
{
return a += rccl_bfloat16(1.0f);
}
inline __host__ __device__ rccl_bfloat16& operator--(rccl_bfloat16& a)
{
return a -= rccl_bfloat16(1.0f);
}
inline __host__ __device__ rccl_bfloat16 operator++(rccl_bfloat16& a, int)
{
rccl_bfloat16 orig = a;
++a;
return orig;
}
inline __host__ __device__ rccl_bfloat16 operator--(rccl_bfloat16& a, int)
{
rccl_bfloat16 orig = a;
--a;
return orig;
}
namespace std
{
constexpr __host__ __device__ bool isinf(rccl_bfloat16 a)
{
return !(~a.data & 0x7f80) && !(a.data & 0x7f);
}
constexpr __host__ __device__ bool isnan(rccl_bfloat16 a)
{
return !(~a.data & 0x7f80) && +(a.data & 0x7f);
}
constexpr __host__ __device__ bool iszero(rccl_bfloat16 a)
{
return !(a.data & 0x7fff);
}
inline rccl_bfloat16 sin(rccl_bfloat16 a)
{
return rccl_bfloat16(sinf(float(a)));
}
inline rccl_bfloat16 cos(rccl_bfloat16 a)
{
return rccl_bfloat16(cosf(float(a)));
}
}
#endif // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__))
#endif // _RCCL_BFLOAT16_H_
+3 -3
Zobrazit soubor
@@ -220,7 +220,7 @@ static ncclResult_t hostToDevRedOp(
uint64_t u64;
half f16;
#if defined(RCCL_BFLOAT16)
rccl_bfloat16 bf16;
hip_bfloat16 bf16;
#endif
#if defined(RCCL_FLOAT8)
rccl_float8 fp8_e4m3;
@@ -266,7 +266,7 @@ static ncclResult_t hostToDevRedOp(
#if defined(RCCL_BFLOAT16)
case ncclBfloat16:
opFull->op = ncclDevPreMulSum;
bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks));
bf16 = (hip_bfloat16)(float(1.0/comm->nRanks));
break;
#endif
#if defined(RCCL_FLOAT8)
@@ -325,7 +325,7 @@ static ncclResult_t hostToDevRedOp(
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_float8, fullOps), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps)
+2 -2
Zobrazit soubor
@@ -167,7 +167,7 @@ namespace RcclUnitTesting
case ncclFloat32: F4[idx] = valueF; break;
case ncclFloat64: F8[idx] = valueF; break;
case ncclFp8E5M2: B1[idx] = rccl_bfloat8(valueF); break;
case ncclBfloat16: B2[idx] = rccl_bfloat16(static_cast<float>(valueF)); break;
case ncclBfloat16: B2[idx] = hip_bfloat16(static_cast<float>(valueF)); break;
default:
ERROR("Unsupported datatype\n");
return TEST_FAIL;
@@ -286,7 +286,7 @@ namespace RcclUnitTesting
case ncclFloat32: F4[idx] /= divisor; break;
case ncclFloat64: F8[idx] /= divisor; break;
case ncclFp8E5M2: B1[idx] = (rccl_bfloat8((float)(B1[idx]) / divisor)); break;
case ncclBfloat16: B2[idx] = (rccl_bfloat16((float)(B2[idx]) / divisor)); break;
case ncclBfloat16: B2[idx] = (hip_bfloat16((float)(B2[idx]) / divisor)); break;
default:
ERROR("Unsupported datatype\n");
return TEST_FAIL;
+2 -2
Zobrazit soubor
@@ -8,7 +8,7 @@
#include "ErrCode.hpp"
#include "rccl/rccl.h"
#include "rccl_float8.h"
#include "rccl_bfloat16.h"
#include <hip/hip_bfloat16.h>
#include "hip/hip_fp16.h"
namespace RcclUnitTesting
@@ -48,7 +48,7 @@ namespace RcclUnitTesting
float* F4; // ncclFloat32
double* F8; // ncclFloat64
rccl_bfloat8* B1; // ncclFp8E5M2
rccl_bfloat16* B2; // ncclBfloat16
hip_bfloat16* B2; // ncclBfloat16
constexpr PtrUnion() : ptr(nullptr) {}