diff --git a/CMakeLists.txt b/CMakeLists.txt index bab23ee786..b6ff4847ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,11 @@ cmake_minimum_required(VERSION 2.8.12) +# We use C++14 features, this will add compile option: -std=c++14 +set( CMAKE_CXX_STANDARD 14 ) +# Without this line, it will add -std=gnu++14 instead, which has some issues. +set( CMAKE_CXX_EXTENSIONS OFF ) + set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") project(rccl CXX) diff --git a/src/collectives/collectives.h b/src/collectives/collectives.h index 63fcfd2017..8e5f6eb83d 100644 --- a/src/collectives/collectives.h +++ b/src/collectives/collectives.h @@ -39,7 +39,8 @@ DECL_COLL3(coll, op, u64) \ DECL_COLL3(coll, op, f16) \ DECL_COLL3(coll, op, f32) \ - DECL_COLL3(coll, op, f64) + DECL_COLL3(coll, op, f64) \ + DECL_COLL3(coll, op, b16) #define DECL_COLL(coll) \ DECL_COLL2(coll, sum) \ diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 863198180e..3df81415bc 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -53,7 +53,8 @@ static inline __device__ void exitIfAbortBarrier(int abort) { NCCL_FUNC4(coll, op, u64), \ NCCL_FUNC4(coll, op, f16), \ NCCL_FUNC4(coll, op, f32), \ - NCCL_FUNC4(coll, op, f64) + NCCL_FUNC4(coll, op, f64), \ + NCCL_FUNC4(coll, op, b16) #define NCCL_FUNCS3B(coll, op) \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ @@ -63,6 +64,7 @@ static inline __device__ void exitIfAbortBarrier(int abort) { NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8) // Must be consistent with ncclRedOp_t @@ -121,20 +123,20 @@ struct Caller{ inline __device__ void NCCL_CALL_FUNCTIONS(struct ncclColl* const c) noexcept { - if (c->funcIndex < 144) { + if (c->funcIndex < 160) { if (c->funcIndex % 4 == 0) ncclBroadcastRing_copy_i8(&c->args); else if (c->funcIndex % 4 == 1) ncclBroadcastRingLL_copy_i8(&c->args); else if (c->funcIndex % 4 == 2) ncclBroadcastTree_copy_i8(&c->args); else ncclBroadcastTreeLL_copy_i8(&c->args); } - else if (c->funcIndex < 288) Caller<144, 288>::call(c); - else if (c->funcIndex < 432) { + else if (c->funcIndex < 320) Caller<160, 320>::call(c); + else if (c->funcIndex < 480) { if (c->funcIndex % 4 == 0) ncclAllGatherRing_copy_i8(&c->args); else if (c->funcIndex % 4 == 1) ncclAllGatherRingLL_copy_i8(&c->args); else if (c->funcIndex % 4 == 2) ncclAllGatherTree_copy_i8(&c->args); else ncclAllGatherTreeLL_copy_i8(&c->args); } - else Caller<432, 720>::call(c); + else Caller<480, 800>::call(c); } static __device__ void load_parallel(void* dst, void* src, size_t size, int tid, uint32_t* abortCount) { @@ -227,7 +229,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64) \ IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16) \ IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32) \ - IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) + IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) \ + IMPL_COLL3(coll, op, ncclFunc, b16, rccl_bfloat16, ncclColl, ncclOp, ncclBfloat16) #define COLL_UNROLL 2 diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index b05e06e7d0..4c2c20bd23 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -241,6 +241,18 @@ template<> inline __device__ void vStore(volatile half* ptr, const half val) { ((half*)ptr)[0] = val; } + +template<> inline __device__ +rccl_bfloat16 vFetch(const volatile rccl_bfloat16* ptr) { + rccl_bfloat16 r; + r.data = ptr->data; + return r; +} + +template<> inline __device__ +void vStore(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) { + ptr->data = val.data; +} #endif typedef ulong2 Pack128; diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h index 4c5caa9f28..b347c5e26d 100644 --- a/src/collectives/device/reduce_kernel.h +++ b/src/collectives/device/reduce_kernel.h @@ -134,6 +134,86 @@ struct FuncMin : private FuncBase { } }; +template<> +struct FuncSum { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] + cy.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x + y; + } +}; + +template<> +struct FuncProd { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] * cy.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x * y; + } +}; + +template<> +struct FuncMax { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] < cy.vec[i] ? cy.vec[i] : cx.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x < y ? y : x; + } +}; + +template<> +struct FuncMin { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] < cy.vec[i] ? cx.vec[i] : cy.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x < y ? x : y; + } +}; + #else template diff --git a/src/enqueue.cc b/src/enqueue.cc index a6eb484d4e..3e72f4ca7d 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -30,7 +30,8 @@ NCCL_FUNC4(coll, op, u64), \ NCCL_FUNC4(coll, op, f16), \ NCCL_FUNC4(coll, op, f32), \ - NCCL_FUNC4(coll, op, f64) + NCCL_FUNC4(coll, op, f64), \ + NCCL_FUNC4(coll, op, b16) #define NCCL_FUNCS3B(coll, op) \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ @@ -40,6 +41,7 @@ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8) // Must be consistent with ncclRedOp_t -- but we only generate kernel for sums. diff --git a/src/include/core.h b/src/include/core.h index 8a08b914b0..1257c4e31a 100644 --- a/src/include/core.h +++ b/src/include/core.h @@ -48,6 +48,7 @@ static __inline__ int ncclTypeSize(ncclDataType_t type) { case ncclUint8: return 1; case ncclFloat16: + case ncclBfloat16: return 2; case ncclInt32: case ncclUint32: diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 30eccab7b8..3f145cddd1 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -9,6 +9,7 @@ #define NCCL_DEVICE_H_ #include "nccl.h" +#include "rccl_bfloat16.h" #include // Convert volatile access to atomic diff --git a/src/include/rccl_bfloat16.h b/src/include/rccl_bfloat16.h new file mode 100644 index 0000000000..06b053a626 --- /dev/null +++ b/src/include/rccl_bfloat16.h @@ -0,0 +1,253 @@ +/** + * MIT License + * + * Copyright 2019 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/*!\file + * \brief rccl_bfloat16.h provides struct for rccl_bfloat16 typedef + */ + +#ifndef _RCCL_BFLOAT16_H_ +#define _RCCL_BFLOAT16_H_ + +#if __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) + +// If this is a C compiler, C++ compiler below C++14, or a host-only compiler, we only +// include a minimal definition of rccl_bfloat16 + +#include +/*! \brief Struct to represent a 16 bit brain floating point number. */ +typedef struct +{ + uint16_t data; +} rccl_bfloat16; + +#else // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) + +#include +#include +#include +#include +#include +#include + +struct rccl_bfloat16 +{ + uint16_t data; + + __host__ __device__ rccl_bfloat16() = default; + + // round upper 16 bits of IEEE float to convert to bfloat16 + explicit constexpr __host__ __device__ rccl_bfloat16(float f) + : data(float_to_bfloat16(f)) + { + } + + // zero extend lower 16 bits of bfloat16 to convert to IEEE float + constexpr __host__ __device__ operator float() const + { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(data) << 16}; + return u.fp32; + } + +private: + static constexpr __host__ __device__ uint16_t float_to_bfloat16(float f) + { + union + { + float fp32; + uint32_t int32; + } u = {f}; + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return uint16_t(u.int32 >> 16); + } +}; + +typedef struct +{ + uint16_t data; +} rccl_bfloat16_public; + +static_assert(std::is_standard_layout{}, + "rccl_bfloat16 is not a standard layout type, and thus is " + "incompatible with C."); + +static_assert(std::is_trivial{}, + "rccl_bfloat16 is not a trivial type, and thus is " + "incompatible with C."); + +static_assert(sizeof(rccl_bfloat16) == sizeof(rccl_bfloat16_public) + && offsetof(rccl_bfloat16, data) == offsetof(rccl_bfloat16_public, data), + "internal rccl_bfloat16 does not match public rccl_bfloat16"); + +inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat16& bf16) +{ + return os << float(bf16); +} +constexpr __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a) +{ + return a; +} +constexpr __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a) +{ + a.data ^= 0x8000; + return a; +} +constexpr __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) + float(b)); +} +constexpr __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) - float(b)); +} +constexpr __host__ __device__ rccl_bfloat16 operator*(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) * float(b)); +} +constexpr __host__ __device__ rccl_bfloat16 operator/(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) / float(b)); +} +constexpr __host__ __device__ bool operator<(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return float(a) < float(b); +} +constexpr __host__ __device__ bool operator==(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return float(a) == float(b); +} +constexpr __host__ __device__ bool operator>(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return b < a; +} +constexpr __host__ __device__ bool operator<=(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return !(a > b); +} +constexpr __host__ __device__ bool operator!=(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return !(a == b); +} +constexpr __host__ __device__ bool operator>=(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return !(a < b); +} +constexpr __host__ __device__ rccl_bfloat16& operator+=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a + b; +} +constexpr __host__ __device__ rccl_bfloat16& operator-=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a - b; +} +constexpr __host__ __device__ rccl_bfloat16& operator*=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a * b; +} +constexpr __host__ __device__ rccl_bfloat16& operator/=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a / b; +} +constexpr __host__ __device__ rccl_bfloat16& operator++(rccl_bfloat16& a) +{ + return a += rccl_bfloat16(1.0f); +} +constexpr __host__ __device__ rccl_bfloat16& operator--(rccl_bfloat16& a) +{ + return a -= rccl_bfloat16(1.0f); +} +constexpr __host__ __device__ rccl_bfloat16 operator++(rccl_bfloat16& a, int) +{ + rccl_bfloat16 orig = a; + ++a; + return orig; +} +constexpr __host__ __device__ rccl_bfloat16 operator--(rccl_bfloat16& a, int) +{ + rccl_bfloat16 orig = a; + --a; + return orig; +} + +namespace std +{ + constexpr __host__ __device__ bool isinf(rccl_bfloat16 a) + { + return !(~a.data & 0x7f80) && !(a.data & 0x7f); + } + constexpr __host__ __device__ bool isnan(rccl_bfloat16 a) + { + return !(~a.data & 0x7f80) && +(a.data & 0x7f); + } + constexpr __host__ __device__ bool iszero(rccl_bfloat16 a) + { + return !(a.data & 0x7fff); + } + inline rccl_bfloat16 sin(rccl_bfloat16 a) + { + return rccl_bfloat16(sinf(float(a))); + } + inline rccl_bfloat16 cos(rccl_bfloat16 a) + { + return rccl_bfloat16(cosf(float(a))); + } +} + +#endif // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) + +#endif // _RCCL_BFLOAT16_H_ diff --git a/src/nccl.h.in b/src/nccl.h.in index 686ed42406..f555bcb495 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -19,6 +19,8 @@ #define NCCL_VERSION_CODE ${NCCL_VERSION} #define NCCL_VERSION(X,Y,Z) ((X) * 1000 + (Y) * 100 + (Z)) +#define RCCL_BFLOAT16 1 + #ifdef __cplusplus extern "C" { #endif @@ -116,7 +118,8 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat16 = 6, ncclHalf = 6, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, - ncclNumTypes = 9 } ncclDataType_t; + ncclBfloat16 = 9, + ncclNumTypes = 10 } ncclDataType_t; /* * Collective communication operations diff --git a/test/CorrectnessTest.hpp b/test/CorrectnessTest.hpp index e7d3c75382..7440bd5ecc 100644 --- a/test/CorrectnessTest.hpp +++ b/test/CorrectnessTest.hpp @@ -11,6 +11,7 @@ #include #include #include "rccl.h" +#include "../include/rccl_bfloat16.h" #define HIP_CALL(x) ASSERT_EQ(x, hipSuccess) #define NCCL_CALL(x) ASSERT_EQ(x, ncclSuccess) @@ -47,6 +48,7 @@ namespace CorrectnessTests case ncclFloat16: return 2; case ncclFloat32: return 4; case ncclFloat64: return 8; + case ncclBfloat16: return 2; default: fprintf(stderr, "[ERROR] Unsupported datatype (%d)\n", dataType); exit(0); @@ -217,6 +219,7 @@ namespace CorrectnessTests uint64_t* arrayU8 = (uint64_t *)arrayI1; float* arrayF4 = (float *)arrayI1; double* arrayF8 = (double *)arrayI1; + rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1; // NOTE: Currently half-precision float tests are unsupported due to half being supported // on GPU only and not host @@ -241,6 +244,7 @@ namespace CorrectnessTests case ncclUint64: arrayU8[j] = valueI; break; case ncclFloat32: arrayF4[j] = valueF; break; case ncclFloat64: arrayF8[j] = valueF; break; + case ncclBfloat16: arrayB2[j] = rccl_bfloat16(valueF); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0); @@ -278,6 +282,7 @@ namespace CorrectnessTests uint64_t* outputU8 = (uint64_t *)outputI1; float* outputF4 = (float *)outputI1; double* outputF8 = (double *)outputI1; + rccl_bfloat16* outputB2 = (rccl_bfloat16 *)outputI1; bool isMatch = true; @@ -295,6 +300,7 @@ namespace CorrectnessTests uint64_t* expectedU8 = (uint64_t *)expectedI1; float* expectedF4 = (float *)expectedI1; double* expectedF8 = (double *)expectedI1; + rccl_bfloat16* expectedB2 = (rccl_bfloat16 *)expectedI1; for (int j = 0; j < dataset.numElements && isMatch; j++) { @@ -308,6 +314,7 @@ namespace CorrectnessTests case ncclUint64: isMatch &= (outputU8[j] == expectedU8[j]); break; case ncclFloat32: isMatch &= (outputF4[j] == expectedF4[j]); break; case ncclFloat64: isMatch &= (outputF8[j] == expectedF8[j]); break; + case ncclBfloat16: isMatch &= (outputB2[j] == expectedB2[j]); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0); @@ -333,6 +340,8 @@ namespace CorrectnessTests printf("Expected %f. Output %f on device %d[%d]\n", outputF4[j], expectedF4[j], i, j); break; case ncclFloat64: printf("Expected %lf. Output %lf on device %d[%d]\n", outputF8[j], expectedF8[j], i, j); break; + case ncclBfloat16: + printf("Expected %f. Output %f on device %d[%d]\n", (float)outputB2[j], (float)expectedB2[j], i, j); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0); diff --git a/test/test_AllGather.cpp b/test/test_AllGather.cpp index 2727514186..b28df48e87 100644 --- a/test/test_AllGather.cpp +++ b/test/test_AllGather.cpp @@ -101,7 +101,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(3072, 3145728), // Number of devices diff --git a/test/test_AllReduce.cpp b/test/test_AllReduce.cpp index 0fb7474d0e..0fd5eedf91 100644 --- a/test/test_AllReduce.cpp +++ b/test/test_AllReduce.cpp @@ -50,7 +50,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(1024, 1048576), // Number of devices diff --git a/test/test_AllReduce.hpp b/test/test_AllReduce.hpp index d8867cb649..a056be95a7 100644 --- a/test/test_AllReduce.hpp +++ b/test/test_AllReduce.hpp @@ -29,6 +29,7 @@ namespace CorrectnessTests uint64_t* resultU8 = (uint64_t *)resultI1; float* resultF4 = (float *)resultI1; double* resultF8 = (double *)resultI1; + rccl_bfloat16* resultB2 = (rccl_bfloat16 *)resultI1; // Initialize the result with the first device's array memcpy(resultI1, dataset.expected[0], dataset.NumBytes()); @@ -44,6 +45,7 @@ namespace CorrectnessTests uint64_t* arrayU8 = (uint64_t *)arrayI1; float* arrayF4 = (float *)arrayI1; double* arrayF8 = (double *)arrayI1; + rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1; for (int j = 0; j < dataset.numElements; j++) { @@ -57,6 +59,7 @@ namespace CorrectnessTests case ncclUint64: resultU8[j] = ReduceOp(op, resultU8[j], arrayU8[j]); break; case ncclFloat32: resultF4[j] = ReduceOp(op, resultF4[j], arrayF4[j]); break; case ncclFloat64: resultF8[j] = ReduceOp(op, resultF8[j], arrayF8[j]); break; + case ncclBfloat16: resultB2[j] = ReduceOp(op, resultB2[j], arrayB2[j]); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0); diff --git a/test/test_Broadcast.cpp b/test/test_Broadcast.cpp index 3ed6964785..0e728b0153 100644 --- a/test/test_Broadcast.cpp +++ b/test/test_Broadcast.cpp @@ -59,7 +59,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(1024, 1048576), // Number of devices diff --git a/test/test_CombinedCalls.cpp b/test/test_CombinedCalls.cpp index bdbf55bc20..4b51ab5375 100644 --- a/test/test_CombinedCalls.cpp +++ b/test/test_CombinedCalls.cpp @@ -89,7 +89,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(3072, 3145728), // Number of devices diff --git a/test/test_GroupCalls.cpp b/test/test_GroupCalls.cpp index de1ad0bd76..77780b633d 100644 --- a/test/test_GroupCalls.cpp +++ b/test/test_GroupCalls.cpp @@ -110,7 +110,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(3072, 3145728), // Number of devices diff --git a/test/test_Reduce.cpp b/test/test_Reduce.cpp index dfca79ccf4..9844e928c1 100644 --- a/test/test_Reduce.cpp +++ b/test/test_Reduce.cpp @@ -58,7 +58,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(1024, 1048576), // Number of devices diff --git a/test/test_Reduce.hpp b/test/test_Reduce.hpp index 3ab9d66b44..0303596560 100644 --- a/test/test_Reduce.hpp +++ b/test/test_Reduce.hpp @@ -29,6 +29,7 @@ namespace CorrectnessTests uint64_t* resultU8 = (uint64_t *)resultI1; float* resultF4 = (float *)resultI1; double* resultF8 = (double *)resultI1; + rccl_bfloat16* resultB2 = (rccl_bfloat16 *)resultI1; // Initialize the result with the first device's array memcpy(resultI1, dataset.expected[0], dataset.NumBytes()); @@ -44,6 +45,7 @@ namespace CorrectnessTests uint64_t* arrayU8 = (uint64_t *)arrayI1; float* arrayF4 = (float *)arrayI1; double* arrayF8 = (double *)arrayI1; + rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1; for (int j = 0; j < dataset.numElements; j++) { @@ -57,6 +59,7 @@ namespace CorrectnessTests case ncclUint64: resultU8[j] = ReduceOp(op, resultU8[j], arrayU8[j]); break; case ncclFloat32: resultF4[j] = ReduceOp(op, resultF4[j], arrayF4[j]); break; case ncclFloat64: resultF8[j] = ReduceOp(op, resultF8[j], arrayF8[j]); break; + case ncclBfloat16: resultB2[j] = ReduceOp(op, resultB2[j], arrayB2[j]); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0); diff --git a/test/test_ReduceScatter.cpp b/test/test_ReduceScatter.cpp index d55b514689..11007732b5 100644 --- a/test/test_ReduceScatter.cpp +++ b/test/test_ReduceScatter.cpp @@ -57,7 +57,8 @@ namespace CorrectnessTests ncclUint64, //ncclFloat16, ncclFloat32, - ncclFloat64), + ncclFloat64, + ncclBfloat16), // Number of elements testing::Values(3072, 3145728), // Number of devices diff --git a/test/test_ReduceScatter.hpp b/test/test_ReduceScatter.hpp index a1731f13b2..3a7843e2c1 100644 --- a/test/test_ReduceScatter.hpp +++ b/test/test_ReduceScatter.hpp @@ -29,6 +29,7 @@ namespace CorrectnessTests uint64_t* resultU8 = (uint64_t *)resultI1; float* resultF4 = (float *)resultI1; double* resultF8 = (double *)resultI1; + rccl_bfloat16* resultB2 = (rccl_bfloat16 *)resultI1; // Initialize the result with the first device's array memcpy(resultI1, dataset.expected[0], dataset.NumBytes()); @@ -44,6 +45,7 @@ namespace CorrectnessTests uint64_t* arrayU8 = (uint64_t *)arrayI1; float* arrayF4 = (float *)arrayI1; double* arrayF8 = (double *)arrayI1; + rccl_bfloat16* arrayB2 = (rccl_bfloat16 *)arrayI1; for (int j = 0; j < dataset.numElements; j++) { @@ -57,6 +59,7 @@ namespace CorrectnessTests case ncclUint64: resultU8[j] = ReduceOp(op, resultU8[j], arrayU8[j]); break; case ncclFloat32: resultF4[j] = ReduceOp(op, resultF4[j], arrayF4[j]); break; case ncclFloat64: resultF8[j] = ReduceOp(op, resultF8[j], arrayF8[j]); break; + case ncclBfloat16: resultB2[j] = ReduceOp(op, resultB2[j], arrayB2[j]); break; default: fprintf(stderr, "[ERROR] Unsupported datatype\n"); exit(0);