SWDEV-525933 - add constexpr operators for fp16/bf16 (#199)
[ROCm/clr commit: 2f73e1385b]
This commit is contained in:
committed by
GitHub
orang tua
e1d2194b75
melakukan
a8630e866d
@@ -123,7 +123,6 @@
|
||||
#if defined(__HIPCC_RTC__)
|
||||
#define __BF16_HOST_DEVICE__ __BF16_DEVICE__
|
||||
#else
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#define __BF16_HOST_DEVICE__ __host__ __BF16_DEVICE__
|
||||
@@ -215,7 +214,7 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat16(const float val) : __x_bf16(static_cast<__bf16>(val)) {}
|
||||
|
||||
/*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {}
|
||||
__BF16_HOST_DEVICE__ constexpr __hip_bfloat16(const __hip_bfloat16_raw& val) : __x(val.x) {}
|
||||
|
||||
/*! \brief create __hip_bfloat16 from __bf16 */
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat16(const __bf16 val) : __x_bf16(val) {}
|
||||
@@ -232,7 +231,7 @@ struct __attribute__((aligned(2))) __hip_bfloat16 {
|
||||
}
|
||||
|
||||
/*! \brief return false if bfloat value is +0.0 or -0.0, returns true otherwise */
|
||||
__BF16_HOST_DEVICE__ operator bool() const { return __x_bf16 != 0.0f; }
|
||||
__BF16_HOST_DEVICE__ constexpr operator bool() const { return __x_bf16 != 0.0f; }
|
||||
|
||||
/*! \brief return a casted char from underlying float val */
|
||||
__BF16_HOST_DEVICE__ operator char() const { return static_cast<char>(__x_bf16); }
|
||||
@@ -370,7 +369,7 @@ struct __attribute__((aligned(4))) __hip_bfloat162 {
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162& val) : x(val.x), y(val.y) {}
|
||||
|
||||
/*! \brief create __hip_bfloat162 from two __hip_bfloat16 */
|
||||
__BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b)
|
||||
__BF16_HOST_DEVICE__ constexpr __hip_bfloat162(const __hip_bfloat16& a, const __hip_bfloat16& b)
|
||||
: x(a), y(b) {}
|
||||
|
||||
/*! \brief create __hip_bfloat162 from vector of __bf16_2 */
|
||||
|
||||
@@ -99,7 +99,7 @@ THE SOFTWARE.
|
||||
// CREATORS
|
||||
__HOST_DEVICE__
|
||||
__half() = default;
|
||||
__HOST_DEVICE__
|
||||
__HOST_DEVICE__ constexpr
|
||||
__half(const __half_raw& x) : data{x.data} {}
|
||||
#if !defined(__HIP_NO_HALF_CONVERSIONS__)
|
||||
__HOST_DEVICE__
|
||||
@@ -363,12 +363,9 @@ THE SOFTWARE.
|
||||
__half2(const __half2_raw& xx) : data{xx.data} {}
|
||||
__HOST_DEVICE__
|
||||
__half2(decltype(data) xx) : data{xx} {}
|
||||
__HOST_DEVICE__
|
||||
__HOST_DEVICE__ constexpr
|
||||
__half2(const __half& xx, const __half& yy)
|
||||
:
|
||||
data{static_cast<__half_raw>(xx).data,
|
||||
static_cast<__half_raw>(yy).data}
|
||||
{}
|
||||
: x(xx), y(yy) {}
|
||||
__HOST_DEVICE__
|
||||
__half2(const __half2&) = default;
|
||||
__HOST_DEVICE__
|
||||
|
||||
Reference in New Issue
Block a user