From eb65dadfc5e054c53a7aafa6af750bd7f6fec4db Mon Sep 17 00:00:00 2001 From: mberenjk <146776561+mberenjk@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:00:20 -0500 Subject: [PATCH] replacing rccl_bfloat16 with hip_bfloat16 (#70) Co-authored-by: Marzieh Berenjkoub --- src/CMakeLists.txt | 1 - src/common.cu | 6 +- src/rccl_bfloat16.h | 274 -------------------------------------------- 3 files changed, 3 insertions(+), 278 deletions(-) delete mode 100644 src/rccl_bfloat16.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f0aae7cc10..12aeb641a7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -73,7 +73,6 @@ set(COMMON_FILES common.cu nccl1_compat.h rccl_float8.h - rccl_bfloat16.h timer.h timer.cc ../verifiable/verifiable.h diff --git a/src/common.cu b/src/common.cu index b034ca768c..6e1306830f 100644 --- a/src/common.cu +++ b/src/common.cu @@ -9,7 +9,7 @@ #include "cuda_runtime.h" #include "rccl_float8.h" -#include "rccl_bfloat16.h" +#include #include "common.h" #include #include @@ -408,7 +408,7 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t int8_t i8; uint8_t u8; int32_t i32; uint32_t u32; int64_t i64; uint64_t u64; half f16; float f32; double f64; #if defined(RCCL_BFLOAT16) - rccl_bfloat16 bf16; + hip_bfloat16 bf16; #endif #if defined(RCCL_FLOAT8) rccl_float8 fp8_e4m3; rccl_bfloat8 fp8_e5m2; @@ -425,7 +425,7 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t case ncclFloat32: f32 = ncclVerifiablePremulScalar(rank); break; case ncclFloat64: f64 = ncclVerifiablePremulScalar(rank); break; #if defined(RCCL_BFLOAT16) - case ncclBfloat16: bf16 = ncclVerifiablePremulScalar(rank); break; + case ncclBfloat16: bf16 = ncclVerifiablePremulScalar(rank); break; #endif #if defined(RCCL_FLOAT8) case ncclFp8E4M3: fp8_e4m3 = ncclVerifiablePremulScalar(rank); break; diff --git a/src/rccl_bfloat16.h b/src/rccl_bfloat16.h deleted file mode 100644 index cbc6e059a5..0000000000 --- a/src/rccl_bfloat16.h +++ /dev/null @@ -1,274 +0,0 @@ -/** - * MIT License - * - * Copyright 2019-2020 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/*!\file - * \brief rccl_bfloat16.h provides struct for rccl_bfloat16 typedef - */ - -#ifndef _RCCL_BFLOAT16_H_ -#define _RCCL_BFLOAT16_H_ - -#if __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__)) - -// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only -// include a minimal definition of rccl_bfloat16 - -#include -/*! \brief Struct to represent a 16 bit brain floating point number. */ -typedef struct -{ - uint16_t data; -} rccl_bfloat16; - -#else // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__)) - -#include -#include -#include -#include -#include -#include - -struct rccl_bfloat16 -{ - uint16_t data; - - enum truncate_t - { - truncate - }; - - __host__ __device__ rccl_bfloat16() = default; - - // round upper 16 bits of IEEE float to convert to bfloat16 - explicit __host__ __device__ rccl_bfloat16(float f) - : data(float_to_bfloat16(f)) - { - } - - explicit __host__ __device__ rccl_bfloat16(float f, truncate_t) - : data(truncate_float_to_bfloat16(f)) - { - } - - // zero extend lower 16 bits of bfloat16 to convert to IEEE float - __host__ __device__ operator float() const - { - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(data) << 16}; - return u.fp32; - } - -private: - static __host__ __device__ uint16_t float_to_bfloat16(float f) - { - union - { - float fp32; - uint32_t int32; - } u = {f}; - if(~u.int32 & 0x7f800000) - { - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even - } - else if(u.int32 & 0xffff) - { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - u.int32 |= 0x10000; // Preserve signaling NaN - } - return uint16_t(u.int32 >> 16); - } - - // Truncate instead of rounding, preserving SNaN - static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f) - { - union - { - float fp32; - uint32_t int32; - } u = {f}; - return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); - } -}; - -typedef struct -{ - uint16_t data; -} rccl_bfloat16_public; - -static_assert(std::is_standard_layout{}, - "rccl_bfloat16 is not a standard layout type, and thus is " - "incompatible with C."); - -static_assert(std::is_trivial{}, - "rccl_bfloat16 is not a trivial type, and thus is " - "incompatible with C."); - -static_assert(sizeof(rccl_bfloat16) == sizeof(rccl_bfloat16_public) - && offsetof(rccl_bfloat16, data) == offsetof(rccl_bfloat16_public, data), - "internal rccl_bfloat16 does not match public rccl_bfloat16"); - -inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat16& bf16) -{ - return os << float(bf16); -} -inline __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a) -{ - return a; -} -inline __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a) -{ - a.data ^= 0x8000; - return a; -} -inline __host__ __device__ rccl_bfloat16 operator+(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) + float(b)); -} -inline __host__ __device__ rccl_bfloat16 operator-(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) - float(b)); -} -inline __host__ __device__ rccl_bfloat16 operator*(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) * float(b)); -} -inline __host__ __device__ rccl_bfloat16 operator/(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return rccl_bfloat16(float(a) / float(b)); -} -inline __host__ __device__ bool operator<(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return float(a) < float(b); -} -inline __host__ __device__ bool operator==(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return float(a) == float(b); -} -inline __host__ __device__ bool operator>(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return b < a; -} -inline __host__ __device__ bool operator<=(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return !(a > b); -} -inline __host__ __device__ bool operator!=(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return !(a == b); -} -inline __host__ __device__ bool operator>=(rccl_bfloat16 a, rccl_bfloat16 b) -{ - return !(a < b); -} -inline __host__ __device__ rccl_bfloat16& operator+=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a + b; -} -inline __host__ __device__ rccl_bfloat16& operator-=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a - b; -} -inline __host__ __device__ rccl_bfloat16& operator*=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a * b; -} -inline __host__ __device__ rccl_bfloat16& operator/=(rccl_bfloat16& a, rccl_bfloat16 b) -{ - return a = a / b; -} -inline __host__ __device__ rccl_bfloat16& operator++(rccl_bfloat16& a) -{ - return a += rccl_bfloat16(1.0f); -} -inline __host__ __device__ rccl_bfloat16& operator--(rccl_bfloat16& a) -{ - return a -= rccl_bfloat16(1.0f); -} -inline __host__ __device__ rccl_bfloat16 operator++(rccl_bfloat16& a, int) -{ - rccl_bfloat16 orig = a; - ++a; - return orig; -} -inline __host__ __device__ rccl_bfloat16 operator--(rccl_bfloat16& a, int) -{ - rccl_bfloat16 orig = a; - --a; - return orig; -} - -namespace std -{ - constexpr __host__ __device__ bool isinf(rccl_bfloat16 a) - { - return !(~a.data & 0x7f80) && !(a.data & 0x7f); - } - constexpr __host__ __device__ bool isnan(rccl_bfloat16 a) - { - return !(~a.data & 0x7f80) && +(a.data & 0x7f); - } - constexpr __host__ __device__ bool iszero(rccl_bfloat16 a) - { - return !(a.data & 0x7fff); - } - inline rccl_bfloat16 sin(rccl_bfloat16 a) - { - return rccl_bfloat16(sinf(float(a))); - } - inline rccl_bfloat16 cos(rccl_bfloat16 a) - { - return rccl_bfloat16(cosf(float(a))); - } -} - -#endif // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) - -#endif // _RCCL_BFLOAT16_H_