From 5e109ed400d74d87bf73e573b5bc14ddafba55a3 Mon Sep 17 00:00:00 2001 From: Wenkai Du Date: Fri, 15 Nov 2019 10:39:48 -0800 Subject: [PATCH] Add bfloat16 support in RCCL Preprocessor symbol RCCL_BFLOAT16 is used as feature indicator --- CMakeLists.txt | 5 + src/collectives/collectives.h | 3 +- src/collectives/device/common.h | 15 +- src/collectives/device/common_kernel.h | 12 ++ src/collectives/device/reduce_kernel.h | 80 ++++++++ src/enqueue.cc | 4 +- src/include/core.h | 1 + src/include/devcomm.h | 1 + src/include/rccl_bfloat16.h | 253 +++++++++++++++++++++++++ src/nccl.h.in | 5 +- 10 files changed, 370 insertions(+), 9 deletions(-) create mode 100644 src/include/rccl_bfloat16.h diff --git a/CMakeLists.txt b/CMakeLists.txt index bab23ee786..b6ff4847ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,11 @@ cmake_minimum_required(VERSION 2.8.12) +# We use C++14 features, this will add compile option: -std=c++14 +set( CMAKE_CXX_STANDARD 14 ) +# Without this line, it will add -std=gnu++14 instead, which has some issues. +set( CMAKE_CXX_EXTENSIONS OFF ) + set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") project(rccl CXX) diff --git a/src/collectives/collectives.h b/src/collectives/collectives.h index 63fcfd2017..8e5f6eb83d 100644 --- a/src/collectives/collectives.h +++ b/src/collectives/collectives.h @@ -39,7 +39,8 @@ DECL_COLL3(coll, op, u64) \ DECL_COLL3(coll, op, f16) \ DECL_COLL3(coll, op, f32) \ - DECL_COLL3(coll, op, f64) + DECL_COLL3(coll, op, f64) \ + DECL_COLL3(coll, op, b16) #define DECL_COLL(coll) \ DECL_COLL2(coll, sum) \ diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 863198180e..3df81415bc 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -53,7 +53,8 @@ static inline __device__ void exitIfAbortBarrier(int abort) { NCCL_FUNC4(coll, op, u64), \ NCCL_FUNC4(coll, op, f16), \ NCCL_FUNC4(coll, op, f32), \ - NCCL_FUNC4(coll, op, f64) + NCCL_FUNC4(coll, op, f64), \ + NCCL_FUNC4(coll, op, b16) #define NCCL_FUNCS3B(coll, op) \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ @@ -63,6 +64,7 @@ static inline __device__ void exitIfAbortBarrier(int abort) { NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8) // Must be consistent with ncclRedOp_t @@ -121,20 +123,20 @@ struct Caller{ inline __device__ void NCCL_CALL_FUNCTIONS(struct ncclColl* const c) noexcept { - if (c->funcIndex < 144) { + if (c->funcIndex < 160) { if (c->funcIndex % 4 == 0) ncclBroadcastRing_copy_i8(&c->args); else if (c->funcIndex % 4 == 1) ncclBroadcastRingLL_copy_i8(&c->args); else if (c->funcIndex % 4 == 2) ncclBroadcastTree_copy_i8(&c->args); else ncclBroadcastTreeLL_copy_i8(&c->args); } - else if (c->funcIndex < 288) Caller<144, 288>::call(c); - else if (c->funcIndex < 432) { + else if (c->funcIndex < 320) Caller<160, 320>::call(c); + else if (c->funcIndex < 480) { if (c->funcIndex % 4 == 0) ncclAllGatherRing_copy_i8(&c->args); else if (c->funcIndex % 4 == 1) ncclAllGatherRingLL_copy_i8(&c->args); else if (c->funcIndex % 4 == 2) ncclAllGatherTree_copy_i8(&c->args); else ncclAllGatherTreeLL_copy_i8(&c->args); } - else Caller<432, 720>::call(c); + else Caller<480, 800>::call(c); } static __device__ void load_parallel(void* dst, void* src, size_t size, int tid, uint32_t* abortCount) { @@ -227,7 +229,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64) \ IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16) \ IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32) \ - IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) + IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) \ + IMPL_COLL3(coll, op, ncclFunc, b16, rccl_bfloat16, ncclColl, ncclOp, ncclBfloat16) #define COLL_UNROLL 2 diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index b05e06e7d0..4c2c20bd23 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -241,6 +241,18 @@ template<> inline __device__ void vStore(volatile half* ptr, const half val) { ((half*)ptr)[0] = val; } + +template<> inline __device__ +rccl_bfloat16 vFetch(const volatile rccl_bfloat16* ptr) { + rccl_bfloat16 r; + r.data = ptr->data; + return r; +} + +template<> inline __device__ +void vStore(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) { + ptr->data = val.data; +} #endif typedef ulong2 Pack128; diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h index 4c5caa9f28..b347c5e26d 100644 --- a/src/collectives/device/reduce_kernel.h +++ b/src/collectives/device/reduce_kernel.h @@ -134,6 +134,86 @@ struct FuncMin : private FuncBase { } }; +template<> +struct FuncSum { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] + cy.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x + y; + } +}; + +template<> +struct FuncProd { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] * cy.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x * y; + } +}; + +template<> +struct FuncMax { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] < cy.vec[i] ? cy.vec[i] : cx.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x < y ? y : x; + } +}; + +template<> +struct FuncMin { + static constexpr auto n = sizeof(PackType) / sizeof(rccl_bfloat16); + __device__ PackType operator()(PackType x, PackType y) const + { + union converter { PackType storage; rccl_bfloat16 vec[n]; }; + static_assert(sizeof(PackType) == sizeof(converter), "PackType must be the same size of converter."); + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + for (auto i = 0u; i != n; ++i) { + cr.vec[i] = cx.vec[i] < cy.vec[i] ? cx.vec[i] : cy.vec[i]; + } + return cr.storage; + } + __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { + return x < y ? x : y; + } +}; + #else template diff --git a/src/enqueue.cc b/src/enqueue.cc index a6eb484d4e..3e72f4ca7d 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -30,7 +30,8 @@ NCCL_FUNC4(coll, op, u64), \ NCCL_FUNC4(coll, op, f16), \ NCCL_FUNC4(coll, op, f32), \ - NCCL_FUNC4(coll, op, f64) + NCCL_FUNC4(coll, op, f64), \ + NCCL_FUNC4(coll, op, b16) #define NCCL_FUNCS3B(coll, op) \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ @@ -40,6 +41,7 @@ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ NCCL_FUNC4(coll, op, i8) // Must be consistent with ncclRedOp_t -- but we only generate kernel for sums. diff --git a/src/include/core.h b/src/include/core.h index 8a08b914b0..1257c4e31a 100644 --- a/src/include/core.h +++ b/src/include/core.h @@ -48,6 +48,7 @@ static __inline__ int ncclTypeSize(ncclDataType_t type) { case ncclUint8: return 1; case ncclFloat16: + case ncclBfloat16: return 2; case ncclInt32: case ncclUint32: diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 30eccab7b8..3f145cddd1 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -9,6 +9,7 @@ #define NCCL_DEVICE_H_ #include "nccl.h" +#include "rccl_bfloat16.h" #include // Convert volatile access to atomic diff --git a/src/include/rccl_bfloat16.h b/src/include/rccl_bfloat16.h new file mode 100644 index 0000000000..06b053a626 --- /dev/null +++ b/src/include/rccl_bfloat16.h @@ -0,0 +1,253 @@ +/** + * MIT License + * + * Copyright 2019 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/*!\file + * \brief rccl_bfloat16.h provides struct for rccl_bfloat16 typedef + */ + +#ifndef _RCCL_BFLOAT16_H_ +#define _RCCL_BFLOAT16_H_ + +#if __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) + +// If this is a C compiler, C++ compiler below C++14, or a host-only compiler, we only +// include a minimal definition of rccl_bfloat16 + +#include +/*! \brief Struct to represent a 16 bit brain floating point number. */ +typedef struct +{ + uint16_t data; +} rccl_bfloat16; + +#else // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) + +#include +#include +#include +#include +#include +#include + +struct rccl_bfloat16 +{ + uint16_t data; + + __host__ __device__ rccl_bfloat16() = default; + + // round upper 16 bits of IEEE float to convert to bfloat16 + explicit constexpr __host__ __device__ rccl_bfloat16(float f) + : data(float_to_bfloat16(f)) + { + } + + // zero extend lower 16 bits of bfloat16 to convert to IEEE float + constexpr __host__ __device__ operator float() const + { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(data) << 16}; + return u.fp32; + } + +private: + static constexpr __host__ __device__ uint16_t float_to_bfloat16(float f) + { + union + { + float fp32; + uint32_t int32; + } u = {f}; + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return uint16_t(u.int32 >> 16); + } +}; + +typedef struct +{ + uint16_t data; +} rccl_bfloat16_public; + +static_assert(std::is_standard_layout{}, + "rccl_bfloat16 is not a standard layout type, and thus is " + "incompatible with C."); + +static_assert(std::is_trivial{}, + "rccl_bfloat16 is not a trivial type, and thus is " + "incompatible with C."); + +static_assert(sizeof(rccl_bfloat16) == sizeof(rccl_bfloat16_public) + && offsetof(rccl_bfloat16, data) == offsetof(rccl_bfloat16_public, data), + "internal rccl_bfloat16 does not match public rccl_bfloat16"); + +inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat16& bf16) +{ + return os << float(bf16); +} +constexpr __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a) +{ + return a; +} +constexpr __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a) +{ + a.data ^= 0x8000; + return a; +} +constexpr __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) + float(b)); +} +constexpr __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) - float(b)); +} +constexpr __host__ __device__ rccl_bfloat16 operator*(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) * float(b)); +} +constexpr __host__ __device__ rccl_bfloat16 operator/(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return rccl_bfloat16(float(a) / float(b)); +} +constexpr __host__ __device__ bool operator<(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return float(a) < float(b); +} +constexpr __host__ __device__ bool operator==(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return float(a) == float(b); +} +constexpr __host__ __device__ bool operator>(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return b < a; +} +constexpr __host__ __device__ bool operator<=(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return !(a > b); +} +constexpr __host__ __device__ bool operator!=(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return !(a == b); +} +constexpr __host__ __device__ bool operator>=(rccl_bfloat16 a, rccl_bfloat16 b) +{ + return !(a < b); +} +constexpr __host__ __device__ rccl_bfloat16& operator+=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a + b; +} +constexpr __host__ __device__ rccl_bfloat16& operator-=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a - b; +} +constexpr __host__ __device__ rccl_bfloat16& operator*=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a * b; +} +constexpr __host__ __device__ rccl_bfloat16& operator/=(rccl_bfloat16& a, rccl_bfloat16 b) +{ + return a = a / b; +} +constexpr __host__ __device__ rccl_bfloat16& operator++(rccl_bfloat16& a) +{ + return a += rccl_bfloat16(1.0f); +} +constexpr __host__ __device__ rccl_bfloat16& operator--(rccl_bfloat16& a) +{ + return a -= rccl_bfloat16(1.0f); +} +constexpr __host__ __device__ rccl_bfloat16 operator++(rccl_bfloat16& a, int) +{ + rccl_bfloat16 orig = a; + ++a; + return orig; +} +constexpr __host__ __device__ rccl_bfloat16 operator--(rccl_bfloat16& a, int) +{ + rccl_bfloat16 orig = a; + --a; + return orig; +} + +namespace std +{ + constexpr __host__ __device__ bool isinf(rccl_bfloat16 a) + { + return !(~a.data & 0x7f80) && !(a.data & 0x7f); + } + constexpr __host__ __device__ bool isnan(rccl_bfloat16 a) + { + return !(~a.data & 0x7f80) && +(a.data & 0x7f); + } + constexpr __host__ __device__ bool iszero(rccl_bfloat16 a) + { + return !(a.data & 0x7fff); + } + inline rccl_bfloat16 sin(rccl_bfloat16 a) + { + return rccl_bfloat16(sinf(float(a))); + } + inline rccl_bfloat16 cos(rccl_bfloat16 a) + { + return rccl_bfloat16(cosf(float(a))); + } +} + +#endif // __cplusplus < 201402L || (!defined(__HCC__) && !defined(__HIPCC__)) + +#endif // _RCCL_BFLOAT16_H_ diff --git a/src/nccl.h.in b/src/nccl.h.in index 686ed42406..f555bcb495 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -19,6 +19,8 @@ #define NCCL_VERSION_CODE ${NCCL_VERSION} #define NCCL_VERSION(X,Y,Z) ((X) * 1000 + (Y) * 100 + (Z)) +#define RCCL_BFLOAT16 1 + #ifdef __cplusplus extern "C" { #endif @@ -116,7 +118,8 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat16 = 6, ncclHalf = 6, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, - ncclNumTypes = 9 } ncclDataType_t; + ncclBfloat16 = 9, + ncclNumTypes = 10 } ncclDataType_t; /* * Collective communication operations