EXSWHTEC-338 - Implement tests for half2 type casting intrinsics #422
Change-Id: I5492fa7d54573d45bfdb9320e74ccc6ca7640d2d
This commit is contained in:
committed by
Rakesh Roy
parent
91cff794b8
commit
50031b5c44
@@ -37,7 +37,3 @@ struct CmdOptions {
|
||||
};
|
||||
|
||||
extern CmdOptions cmd_options;
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
|
||||
>>>>>>> c08a2a5d (Merge branch 'develop' into casting_int_tests)
|
||||
|
||||
@@ -311,6 +311,13 @@ TEST_CASE("Unit_atomicDec_Negative_Parameters") {}
|
||||
* @}
|
||||
*/
|
||||
|
||||
/**
|
||||
* @defgroup MathTest Math Device Functions
|
||||
* @{
|
||||
* This section describes tests for device math functions of HIP runtime API.
|
||||
* @}
|
||||
*/
|
||||
|
||||
/**
|
||||
* @defgroup PrintfTest Printf API Management
|
||||
* @{
|
||||
|
||||
@@ -112,6 +112,9 @@ template <typename T> class LinearAllocGuard {
|
||||
T* host_ptr_ = nullptr;
|
||||
|
||||
void dealloc() {
|
||||
if (ptr_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
// No Catch macros, don't want to possibly throw in the destructor
|
||||
if (ptr_ != nullptr) {
|
||||
switch (allocation_type_) {
|
||||
|
||||
@@ -32,6 +32,7 @@ set(TEST_SRC
|
||||
casting_double_funcs.cc
|
||||
casting_float_funcs.cc
|
||||
casting_int_funcs.cc
|
||||
casting_half2_funcs.cc
|
||||
)
|
||||
|
||||
if(HIP_PLATFORM MATCHES "nvidia")
|
||||
@@ -116,3 +117,8 @@ add_test(NAME Unit_Device_casting_int_Negative
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../compileAndCaptureOutput.py
|
||||
${CMAKE_CURRENT_SOURCE_DIR} ${HIP_PLATFORM} ${HIP_PATH}
|
||||
casting_int_negative_kernels.cc 92)
|
||||
|
||||
add_test(NAME Unit_Device_casting_half2_Negative
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../compileAndCaptureOutput.py
|
||||
${CMAKE_CURRENT_SOURCE_DIR} ${HIP_PLATFORM} ${HIP_PATH}
|
||||
casting_half2_negative_kernels.cc 53)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#define FLOAT16_MAX 65504.0f
|
||||
|
||||
class Float16 {
|
||||
public:
|
||||
__host__ __device__ Float16() = default;
|
||||
__host__ __device__ Float16(__half x) : x_{x} {}
|
||||
__host__ __device__ Float16(__half2 x) : x_{__low2half(x)} {}
|
||||
__host__ __device__ Float16(float x) : x_{__float2half(x)} {}
|
||||
|
||||
__host__ __device__ bool operator==(Float16 other) const {
|
||||
return static_cast<__half_raw>(x_).data == static_cast<__half_raw>(other.x_).data;
|
||||
}
|
||||
__host__ __device__ bool operator!=(Float16 other) const { return !(*this == other); }
|
||||
|
||||
__host__ __device__ operator __half() const { return x_; }
|
||||
__host__ __device__ operator __half2() const { return __half2half2(x_); }
|
||||
__host__ __device__ operator float() const { return __half2float(x_); }
|
||||
|
||||
private:
|
||||
__half x_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& o, Float16 x) {
|
||||
o << static_cast<float>(x);
|
||||
return o;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -73,8 +73,13 @@ void BinaryFloatingPointBruteForceTest(kernel_sig<T, TArg, TArg> kernel,
|
||||
thread_pool.Post([=, &x1s, &x2s] {
|
||||
const auto generator = [=] {
|
||||
static thread_local std::mt19937 rng(std::random_device{}());
|
||||
std::uniform_real_distribution<RefType_t<TArg>> unif_dist(a, b);
|
||||
return static_cast<TArg>(unif_dist(rng));
|
||||
if constexpr (std::is_same_v<TArg, Float16>) {
|
||||
std::uniform_real_distribution<RefType_t<Float16>> unif_dist(-FLOAT16_MAX, FLOAT16_MAX);
|
||||
return static_cast<Float16>(unif_dist(rng));
|
||||
} else {
|
||||
std::uniform_real_distribution<RefType_t<TArg>> unif_dist(a, b);
|
||||
return static_cast<TArg>(unif_dist(rng));
|
||||
}
|
||||
};
|
||||
std::generate(x1s.ptr() + base_idx, x1s.ptr() + base_idx + sub_batch_size, generator);
|
||||
std::generate(x2s.ptr() + base_idx, x2s.ptr() + base_idx + sub_batch_size, generator);
|
||||
@@ -94,7 +99,8 @@ void BinaryFloatingPointSpecialValuesTest(kernel_sig<T, TArg, TArg> kernel,
|
||||
ref_sig<RT, RTArg, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
const auto [grid_size, block_size] = GetOccupancyMaxPotentialBlockSize(kernel);
|
||||
const auto values = std::get<SpecialVals<TArg>>(kSpecialValRegistry);
|
||||
using SpecialValsType = std::conditional_t<std::is_same_v<TArg, Float16>, float, TArg>;
|
||||
const auto values = std::get<SpecialVals<SpecialValsType>>(kSpecialValRegistry);
|
||||
|
||||
const auto size = values.size * values.size;
|
||||
LinearAllocGuard<TArg> x1s{LinearAllocs::hipHostMalloc, size * sizeof(TArg)};
|
||||
@@ -122,7 +128,6 @@ void BinaryFloatingPointTest(kernel_sig<T, TArg, TArg> kernel, ref_sig<RT, RTArg
|
||||
SECTION("Brute force") { BinaryFloatingPointBruteForceTest(kernel, ref_func, validator_builder); }
|
||||
}
|
||||
|
||||
|
||||
#define MATH_BINARY_WITHIN_ULP_TEST_DEF(kern_name, ref_func, sp_ulp, dp_ulp) \
|
||||
MATH_BINARY_KERNEL_DEF(kern_name) \
|
||||
\
|
||||
|
||||
@@ -36,6 +36,17 @@ namespace cg = cooperative_groups;
|
||||
} \
|
||||
}
|
||||
|
||||
#define CAST_BINARY_KERNEL_DEF(func_name, T1, T2) \
|
||||
__global__ void func_name##_kernel(T1* const ys, const size_t num_xs, T2* const x1s, \
|
||||
T2* const x2s) { \
|
||||
const auto tid = cg::this_grid().thread_rank(); \
|
||||
const auto stride = cg::this_grid().size(); \
|
||||
\
|
||||
for (auto i = tid; i < num_xs; i += stride) { \
|
||||
ys[i] = func_name(x1s[i], x2s[i]); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CAST_F2I_REF_DEF(func_name, T1, T2, ref_func) \
|
||||
T1 func_name##_ref(T2 arg) { \
|
||||
if (arg >= static_cast<T2>(std::numeric_limits<T1>::max())) \
|
||||
@@ -71,13 +82,66 @@ namespace cg = cooperative_groups;
|
||||
return result; \
|
||||
}
|
||||
|
||||
|
||||
template <typename T1, typename T2> T1 type2_as_type1_ref(T2 arg) {
|
||||
T1 tmp;
|
||||
memcpy(&tmp, &arg, sizeof(tmp));
|
||||
return tmp;
|
||||
}
|
||||
|
||||
template <typename T, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void CastUnaryHalfPrecisionBruteForceTest(kernel_sig<T, Float16> kernel,
|
||||
ref_sig<RT, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
const auto [grid_size, block_size] = GetOccupancyMaxPotentialBlockSize(kernel);
|
||||
uint64_t stop = std::numeric_limits<uint16_t>::max() + 1ul;
|
||||
const auto max_batch_size =
|
||||
std::min(GetMaxAllowedDeviceMemoryUsage() / (sizeof(Float16) + sizeof(T)), stop);
|
||||
LinearAllocGuard<Float16> values{LinearAllocs::hipHostMalloc, max_batch_size * sizeof(Float16)};
|
||||
|
||||
MathTest math_test(kernel, max_batch_size);
|
||||
|
||||
auto batch_size = max_batch_size;
|
||||
const auto num_threads = thread_pool.thread_count();
|
||||
|
||||
for (uint64_t v = 0u; v < stop;) {
|
||||
batch_size = std::min<uint64_t>(max_batch_size, stop - v);
|
||||
|
||||
const auto min_sub_batch_size = batch_size / num_threads;
|
||||
const auto tail = batch_size % num_threads;
|
||||
|
||||
auto base_idx = 0u;
|
||||
for (auto i = 0u; i < num_threads; ++i) {
|
||||
const auto sub_batch_size = min_sub_batch_size + (i < tail);
|
||||
|
||||
thread_pool.Post([=, &values] {
|
||||
auto t = v;
|
||||
uint16_t val;
|
||||
for (auto j = 0u; j < sub_batch_size; ++j) {
|
||||
val = static_cast<uint16_t>(t++);
|
||||
values.ptr()[base_idx + j] = *reinterpret_cast<Float16*>(&val);
|
||||
if (std::isnan(values.ptr()[base_idx + j]) || std::isinf(values.ptr()[base_idx + j])) {
|
||||
values.ptr()[base_idx + j] = 0;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
v += sub_batch_size;
|
||||
base_idx += sub_batch_size;
|
||||
}
|
||||
|
||||
thread_pool.Wait();
|
||||
|
||||
math_test.Run(validator_builder, grid_size, block_size, ref_func, batch_size, values.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void CastUnaryHalfPrecisionTest(kernel_sig<T, Float16> kernel, ref_sig<RT, RTArg> ref,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
SECTION("Brute force") { CastUnaryHalfPrecisionBruteForceTest(kernel, ref, validator_builder); }
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename ValidatorBuilder>
|
||||
void CastDoublePrecisionSpecialValuesTest(kernel_sig<T, double> kernel, ref_sig<T, double> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "math_common.hh"
|
||||
#include "validators.hh"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#define CAST_HALF2_KERNEL_DEF(func_name, T) \
|
||||
__global__ void func_name##_kernel(T* const ys, const size_t num_xs, Float16* const xs) { \
|
||||
const auto tid = cg::this_grid().thread_rank(); \
|
||||
const auto stride = cg::this_grid().size(); \
|
||||
\
|
||||
for (auto i = tid; i < num_xs; i += stride) { \
|
||||
ys[i] = func_name(__half2{xs[i], -xs[i]}); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CAST_BINARY_HALF2_KERNEL_DEF(func_name, T) \
|
||||
__global__ void func_name##_kernel(T* const ys, const size_t num_xs, Float16* const x1s, \
|
||||
Float16* const x2s) { \
|
||||
const auto tid = cg::this_grid().thread_rank(); \
|
||||
const auto stride = cg::this_grid().size(); \
|
||||
\
|
||||
for (auto i = tid; i < num_xs; i += stride) { \
|
||||
ys[i] = func_name(__half2{x1s[i], -x1s[i]}, __half2{x2s[i], -x2s[i]}); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename VB> class Float2Validator : public MatcherBase<float2> {
|
||||
public:
|
||||
Float2Validator(const float2& target, const VB& vb)
|
||||
: first_matcher_{vb(target.x)}, second_matcher_{vb(target.y)} {}
|
||||
|
||||
bool match(const float2& val) const override {
|
||||
return first_matcher_->match(val.x) && second_matcher_->match(val.y);
|
||||
}
|
||||
|
||||
std::string describe() const override {
|
||||
return "<" + first_matcher_->describe() + ", " + second_matcher_->describe() + ">";
|
||||
}
|
||||
|
||||
private:
|
||||
decltype(std::declval<VB>()(float())) first_matcher_;
|
||||
decltype(std::declval<VB>()(float())) second_matcher_;
|
||||
};
|
||||
|
||||
template <typename ValidatorBuilder>
|
||||
auto Float2ValidatorBuilderFactory(const ValidatorBuilder& vb) {
|
||||
return [=](const float2& t, auto&&...) {
|
||||
return std::make_unique<Float2Validator<ValidatorBuilder>>(t, vb);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename VB> class Half2Validator : public MatcherBase<__half2> {
|
||||
public:
|
||||
Half2Validator(const __half2& target, const VB& vb)
|
||||
: first_matcher_{vb(target.data.x)}, second_matcher_{vb(target.data.y)} {}
|
||||
|
||||
bool match(const __half2& val) const override {
|
||||
return first_matcher_->match(val.data.x) && second_matcher_->match(val.data.y);
|
||||
}
|
||||
|
||||
std::string describe() const override {
|
||||
return "<" + first_matcher_->describe() + ", " + second_matcher_->describe() + ">";
|
||||
}
|
||||
|
||||
private:
|
||||
decltype(std::declval<VB>()(Float16())) first_matcher_;
|
||||
decltype(std::declval<VB>()(Float16())) second_matcher_;
|
||||
};
|
||||
|
||||
template <typename ValidatorBuilder> auto Half2ValidatorBuilderFactory(const ValidatorBuilder& vb) {
|
||||
return [=](const __half2& t, auto&&...) {
|
||||
return std::make_unique<Half2Validator<ValidatorBuilder>>(t, vb);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,419 @@
|
||||
/*
|
||||
Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "half_precision_common.hh"
|
||||
#include "casting_common.hh"
|
||||
#include "casting_half2_common.hh"
|
||||
|
||||
/**
|
||||
* @addtogroup HalfPrecisionCastingHalf2 HalfPrecisionCastingHalf2
|
||||
* @{
|
||||
* @ingroup MathTest
|
||||
*/
|
||||
|
||||
/********** half -> half2 **********/
|
||||
|
||||
CAST_KERNEL_DEF(__half2half2, __half2, Float16)
|
||||
|
||||
static __half2 __half2half2_ref(Float16 x) { return __half2{x, x}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__half2half2` for all possible inputs. The results are compared against
|
||||
* reference function which returns __half2 value created from one __half value.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___half2half2_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__half2half2_kernel, __half2half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_BINARY_KERNEL_DEF(make_half2, __half2, Float16)
|
||||
|
||||
static __half2 make_half2_ref(Float16 x, Float16 y) { return __half2{x, y}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `make_half2` against a table of difficult values, followed by a large
|
||||
* number of randomly generated values. The results are compared against reference function which
|
||||
* returns __half2 value created from two __half values.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device_make_half2_Accuracy_Positive") {
|
||||
BinaryFloatingPointTest(make_half2_kernel, make_half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_BINARY_KERNEL_DEF(__halves2half2, __half2, Float16)
|
||||
|
||||
static __half2 __halves2half2_ref(Float16 x, Float16 y) { return __half2{x, y}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__halves2half2` against a table of difficult values, followed by a large
|
||||
* number of randomly generated values. The results are compared against reference function which
|
||||
* returns __half2 value created from two __half values.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___halves2half2_Accuracy_Positive") {
|
||||
BinaryFloatingPointTest(__halves2half2_kernel, __halves2half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
/********** half2 -> half **********/
|
||||
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__low2half, Float16)
|
||||
|
||||
static Float16 __low2half_ref(Float16 x) { return x; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__low2half` for all possible inputs. The results are compared against
|
||||
* reference function which returns __half value created from lower __half2 element.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___low2half_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__low2half_kernel, __low2half_ref, EqValidatorBuilderFactory<Float16>());
|
||||
}
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__high2half, Float16)
|
||||
|
||||
static Float16 __high2half_ref(Float16 x) { return -x; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__high2half` for all possible inputs. The results are compared against
|
||||
* reference function which returns __half value created from higher __half2 element.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___high2half_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__high2half_kernel, __high2half_ref, EqValidatorBuilderFactory<Float16>());
|
||||
}
|
||||
|
||||
/********** half2 -> half2 **********/
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__low2half2, __half2)
|
||||
|
||||
static __half2 __low2half2_ref(Float16 x) { return __half2{x, x}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__low2half2` for all possible inputs. The results are compared against
|
||||
* reference function which returns __half2 value created from two lower __half2 elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___low2half2_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__low2half2_kernel, __low2half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__high2half2, __half2)
|
||||
|
||||
static __half2 __high2half2_ref(Float16 x) { return __half2{-x, -x}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__high2half2` for all possible inputs. The results are compared against
|
||||
* reference function which returns __half2 value created from two higher __half2 elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___high2half2_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__high2half2_kernel, __high2half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__lowhigh2highlow, __half2)
|
||||
|
||||
static __half2 __lowhigh2highlow_ref(Float16 x) { return __half2{-x, x}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__lowhigh2highlow` for all possible inputs. The results are compared
|
||||
* against reference function which returns __half2 value created from higher and lower __half2
|
||||
* elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___lowhigh2highlow_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__lowhigh2highlow_kernel, __lowhigh2highlow_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_BINARY_HALF2_KERNEL_DEF(__lows2half2, __half2)
|
||||
|
||||
static __half2 __lows2half2_ref(Float16 x, Float16 y) { return __half2{x, y}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__lows2half2` against a table of difficult values, followed by a large
|
||||
* number of randomly generated values. The results are compared against reference function which
|
||||
* returns __half2 value created from lower elements of two __half2 values.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___lows2half2_Accuracy_Positive") {
|
||||
BinaryFloatingPointTest(__lows2half2_kernel, __lows2half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_BINARY_HALF2_KERNEL_DEF(__highs2half2, __half2)
|
||||
|
||||
static __half2 __highs2half2_ref(Float16 x, Float16 y) { return __half2{-x, -y}; }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__highs2half2` against a table of difficult values, followed by a large
|
||||
* number of randomly generated values. The results are compared against reference function which
|
||||
* returns __half2 value created from higher elements of two __half2 values.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___highs2half2_Accuracy_Positive") {
|
||||
BinaryFloatingPointTest(__highs2half2_kernel, __highs2half2_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
/********** float -> half2 **********/
|
||||
|
||||
CAST_KERNEL_DEF(__float2half2_rn, __half2, float)
|
||||
|
||||
static __half2 __float2half2_rn_ref(float x) {
|
||||
return __half2{static_cast<Float16>(x), static_cast<Float16>(x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__float2half2_rn` for all possible inputs. The results are compared
|
||||
* against reference function which returns __half2 value created from one casted float value.
|
||||
* elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___float2half2_rn_Accuracy_Positive") {
|
||||
UnarySinglePrecisionTest(__float2half2_rn_kernel, __float2half2_rn_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
CAST_BINARY_KERNEL_DEF(__floats2half2_rn, __half2, float)
|
||||
|
||||
static __half2 __floats2half2_rn_ref(float x, float y) {
|
||||
return __half2{static_cast<Float16>(x), static_cast<Float16>(y)};
|
||||
}
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__floats2half2_rn` against a table of difficult values, followed by a
|
||||
* large number of randomly generated values. The results are compared against reference function
|
||||
* which returns __half2 value created from two casted float values.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___floats2half2_rn_Accuracy_Positive") {
|
||||
BinaryFloatingPointTest(__floats2half2_rn_kernel, __floats2half2_rn_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
/********** float2 -> half2 **********/
|
||||
|
||||
__global__ void __float22half2_rn_kernel(__half2* const ys, const size_t num_xs, float* const xs) {
|
||||
const auto tid = cg::this_grid().thread_rank();
|
||||
const auto stride = cg::this_grid().size();
|
||||
|
||||
for (auto i = tid; i < num_xs; i += stride) {
|
||||
ys[i] = __float22half2_rn(make_float2(xs[i], -xs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
static __half2 __float22half2_rn_ref(float x) {
|
||||
return __half2{static_cast<Float16>(x), static_cast<Float16>(-x)};
|
||||
}
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__float22half2_rn` for all possible inputs. The results are compared
|
||||
* against reference function which returns __half2 value created from two casted float values.
|
||||
* elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___float22half2_rn_Accuracy_Positive") {
|
||||
UnarySinglePrecisionTest(__float22half2_rn_kernel, __float22half2_rn_ref,
|
||||
Half2ValidatorBuilderFactory(EqValidatorBuilderFactory<Float16>()));
|
||||
}
|
||||
|
||||
/********** half2 -> float **********/
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__low2float, float)
|
||||
|
||||
static float __low2float_ref(Float16 x) { return static_cast<float>(x); }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__low2float` for all possible inputs. The results are compared
|
||||
* against reference function which returns float value created from lower __half2 element.
|
||||
* elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___low2float_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__low2float_kernel, __low2float_ref, EqValidatorBuilderFactory<float>());
|
||||
}
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__high2float, float)
|
||||
|
||||
static float __high2float_ref(Float16 x) { return static_cast<float>(-x); }
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__high2float` for all possible inputs. The results are compared
|
||||
* against reference function which returns float value created from higher __half2 element.
|
||||
* elements.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___high2float_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__high2float_kernel, __high2float_ref, EqValidatorBuilderFactory<float>());
|
||||
}
|
||||
|
||||
/********** half2 -> float2 **********/
|
||||
|
||||
CAST_HALF2_KERNEL_DEF(__half22float2, float2)
|
||||
|
||||
static float2 __half22float2_ref(Float16 x) {
|
||||
return make_float2(static_cast<float>(x), static_cast<float>(-x));
|
||||
}
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Tests that checks `__half22float2` for all possible inputs. The results are compared against
|
||||
* reference function which returns float2 value created from casted elements of one __half2 value.
|
||||
*
|
||||
* Test source
|
||||
* ------------------------
|
||||
* - unit/math/casting_half2_funcs.cc
|
||||
* Test requirements
|
||||
* ------------------------
|
||||
* - HIP_VERSION >= 5.2
|
||||
*/
|
||||
TEST_CASE("Unit_Device___half22float2_Accuracy_Positive") {
|
||||
UnaryHalfPrecisionTest(__half22float2_kernel, __half22float2_ref,
|
||||
Float2ValidatorBuilderFactory(EqValidatorBuilderFactory<float>()));
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <hip_test_common.hh>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
class Dummy {
|
||||
public:
|
||||
__device__ Dummy() {}
|
||||
__device__ ~Dummy() {}
|
||||
};
|
||||
|
||||
#define NEGATIVE_UNARY_KERNELS_SHELL(func_name, T1, T2) \
|
||||
__global__ void func_name##_kernel_v1(T1* result, T2* x) { *result = func_name(x); } \
|
||||
__global__ void func_name##_kernel_v2(T1* result, Dummy x) { *result = func_name(x); } \
|
||||
__global__ void func_name##_kernel_v3(Dummy* result, T2 x) { *result = func_name(x); }
|
||||
|
||||
|
||||
#define NEGATIVE_BINARY_KERNELS_SHELL(func_name, T1, T2) \
|
||||
__global__ void func_name##_kernel_v1(T2* x, T2 y) { T1 result = func_name(x, y); } \
|
||||
__global__ void func_name##_kernel_v2(T2 x, T2* y) { T1 result = func_name(x, y); } \
|
||||
__global__ void func_name##_kernel_v3(Dummy x, T2 y) { T1 result = func_name(x, y); } \
|
||||
__global__ void func_name##_kernel_v4(T2 x, Dummy y) { T1 result = func_name(x, y); }
|
||||
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__half2half2, __half2, __half)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__low2half, __half, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__high2half, __half, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__low2half2, __half2, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__high2half2, __half2, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__lowhigh2highlow, __half2, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__float2half2_rn, __half2, float)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__float22half2_rn, __half2, float2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__low2float, float, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__high2float, float, __half2)
|
||||
NEGATIVE_UNARY_KERNELS_SHELL(__half22float2, float2, __half2)
|
||||
|
||||
NEGATIVE_BINARY_KERNELS_SHELL(make_half2, __half2, __half)
|
||||
NEGATIVE_BINARY_KERNELS_SHELL(__halves2half2, __half2, __half)
|
||||
NEGATIVE_BINARY_KERNELS_SHELL(__lows2half2, __half2, __half2)
|
||||
NEGATIVE_BINARY_KERNELS_SHELL(__highs2half2, __half2, __half2)
|
||||
NEGATIVE_BINARY_KERNELS_SHELL(__floats2half2_rn, __half2, float)
|
||||
@@ -0,0 +1,103 @@
|
||||
/*
|
||||
Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "unary_common.hh"
|
||||
#include "binary_common.hh"
|
||||
#include "ternary_common.hh"
|
||||
|
||||
|
||||
/********** Unary **********/
|
||||
|
||||
#define MATH_UNARY_HP_KERNEL_DEF(func_name) \
|
||||
__global__ void func_name##_kernel(Float16* const ys, const size_t num_xs, Float16* const xs) { \
|
||||
const auto tid = cg::this_grid().thread_rank(); \
|
||||
const auto stride = cg::this_grid().size(); \
|
||||
\
|
||||
for (auto i = tid; i < num_xs; i += stride) { \
|
||||
ys[i] = func_name(xs[i]); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MATH_UNARY_HP_TEST_DEF_IMPL(func_name, ref_func, validator_builder) \
|
||||
TEST_CASE("Unit_Device_" #func_name "_Accuracy_Positive") { \
|
||||
UnaryHalfPrecisionTest(func_name##_kernel, ref_func, validator_builder); \
|
||||
}
|
||||
|
||||
#define MATH_UNARY_HP_TEST_DEF(func_name, ref_func) \
|
||||
MATH_UNARY_HP_TEST_DEF_IMPL(func_name, ref_func, func_name##_validator_builder)
|
||||
|
||||
#define MATH_UNARY_HP_VALIDATOR_BUILDER_DEF(func_name) \
|
||||
static std::unique_ptr<MatcherBase<float>> func_name##_validator_builder(float target, float x)
|
||||
|
||||
|
||||
/********** Binary **********/
|
||||
|
||||
#define MATH_BINARY_HP_KERNEL_DEF(func_name) \
|
||||
__global__ void func_name##_kernel(Float16* const ys, const size_t num_xs, Float16* const x1s, \
|
||||
Float16* const x2s) { \
|
||||
const auto tid = cg::this_grid().thread_rank(); \
|
||||
const auto stride = cg::this_grid().size(); \
|
||||
\
|
||||
for (auto i = tid; i < num_xs; i += stride) { \
|
||||
ys[i] = func_name(x1s[i], x2s[i]); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MATH_BINARY_HP_TEST_DEF_IMPL(func_name, ref_func, validator_builder) \
|
||||
TEST_CASE("Unit_Device_" #func_name "_Accuracy_Positive") { \
|
||||
BinaryFloatingPointTest(func_name##_kernel, ref_func, validator_builder); \
|
||||
}
|
||||
|
||||
#define MATH_BINARY_HP_TEST_DEF(func_name, ref_func) \
|
||||
MATH_BINARY_HP_TEST_IMPL(func_name, ref_func, func_name##_validator_builder)
|
||||
|
||||
#define MATH_BINARY_HP_VALIDATOR_BUILDER_DEF(func_name) \
|
||||
static std::unique_ptr<MatcherBase<float>> func_name##_validator_builder(float target, float x1, \
|
||||
float x2)
|
||||
|
||||
|
||||
/********** Ternary **********/
|
||||
|
||||
#define MATH_TERNARY_HP_KERNEL_DEF(func_name) \
|
||||
__global__ void func_name##_kernel(Float16* const ys, const size_t num_xs, Float16* const x1s, \
|
||||
Float16* const x2s, Float16* const x3s) { \
|
||||
const auto tid = cg::this_grid().thread_rank(); \
|
||||
const auto stride = cg::this_grid().size(); \
|
||||
\
|
||||
for (auto i = tid; i < num_xs; i += stride) { \
|
||||
ys[i] = func_name(x1s[i], x2s[i], x3s[i]); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MATH_TERNARY_HP_TEST_DEF_IMPL(func_name, ref_func, validator_builder) \
|
||||
TEST_CASE("Unit_Device_" #func_name "_Accuracy_Positive") { \
|
||||
TernaryFloatingPointTest(func_name##_kernel, ref_func, validator_builder); \
|
||||
}
|
||||
|
||||
#define MATH_TERNARY_HP_TEST_DEF(func_name, ref_func, validator_builder) \
|
||||
MATH_TERNARY_HP_TEST_DEF_IMPL(func_name, ref_func, func_name##_validator_builder)
|
||||
|
||||
#define MATH_TERNARY_HP_VALIDATOR_BUILDER_DEF(func_name) \
|
||||
static std::unique_ptr<MatcherBase<float>> func_name##_validator_builder(float target, float x1, \
|
||||
float x2, float x3)
|
||||
@@ -7,15 +7,8 @@ in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
<<<<<<< HEAD
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
=======
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
>>>>>>> c08a2a5d (Merge branch 'develop' into casting_int_tests)
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
@@ -33,6 +26,7 @@ THE SOFTWARE.
|
||||
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
|
||||
#include "Float16.hh"
|
||||
#include "thread_pool.hh"
|
||||
#include "validators.hh"
|
||||
|
||||
@@ -47,6 +41,15 @@ operator<<(std::ostream& os, const std::pair<T, U>& p) {
|
||||
<< std::setprecision(default_prec);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<sizeof(T) / sizeof(decltype(T().x)) == 2 && !std::is_same_v<T, __half2>, std::ostream&>
|
||||
operator<<(std::ostream& os, const T& p) {
|
||||
const auto default_prec = os.precision();
|
||||
return os << "<" << std::setprecision(std::numeric_limits<decltype(T().x)>::max_digits10 - 1) << p.x << ", "
|
||||
<< std::setprecision(std::numeric_limits<decltype(T().x)>::max_digits10 - 1) << p.y << ">"
|
||||
<< std::setprecision(default_prec);
|
||||
}
|
||||
|
||||
// This class represents a generic numerical accuracy math test. Template parameter T is the output
|
||||
// type of the function being tested, and template parameter pack Ts represents the input types. The
|
||||
// constructor takes a kernel with the signature void(T*, const size_t, Ts*...). The first kernel
|
||||
@@ -107,11 +110,7 @@ template <typename T, typename... Ts> class MathTest {
|
||||
template <bool parallel, typename RT, typename ValidatorBuilder, typename... RTs, size_t... I>
|
||||
void RunImpl(const ValidatorBuilder& validator_builder, const size_t grid_dim,
|
||||
const size_t block_dim, RT (*const ref_func)(RTs...), const size_t num_args,
|
||||
<<<<<<< HEAD
|
||||
std::index_sequence<I...> is, const Ts*... xss) {
|
||||
=======
|
||||
std::index_sequence<I...>, const Ts*... xss) {
|
||||
>>>>>>> c08a2a5d (Merge branch 'develop' into casting_int_tests)
|
||||
const auto xss_tup = std::make_tuple(xss...);
|
||||
|
||||
constexpr auto f = [](auto dst, auto src, size_t size) {
|
||||
@@ -196,6 +195,8 @@ template <typename T, typename... Ts> class MathTest {
|
||||
|
||||
template <typename T> struct RefType {};
|
||||
|
||||
template <> struct RefType<Float16> { using type = float; };
|
||||
|
||||
template <> struct RefType<float> { using type = double; };
|
||||
|
||||
template <> struct RefType<double> { using type = long double; };
|
||||
|
||||
@@ -74,8 +74,13 @@ void TernaryFloatingPointBruteForceTest(kernel_sig<T, TArg, TArg, TArg> kernel,
|
||||
thread_pool.Post([=, &x1s, &x2s, &x3s] {
|
||||
const auto generator = [=] {
|
||||
static thread_local std::mt19937 rng(std::random_device{}());
|
||||
std::uniform_real_distribution<RefType_t<TArg>> unif_dist(a, b);
|
||||
return static_cast<TArg>(unif_dist(rng));
|
||||
if constexpr (std::is_same_v<TArg, Float16>) {
|
||||
std::uniform_real_distribution<RefType_t<Float16>> unif_dist(-FLOAT16_MAX, FLOAT16_MAX);
|
||||
return static_cast<Float16>(unif_dist(rng));
|
||||
} else {
|
||||
std::uniform_real_distribution<RefType_t<TArg>> unif_dist(a, b);
|
||||
return static_cast<TArg>(unif_dist(rng));
|
||||
}
|
||||
};
|
||||
std::generate(x1s.ptr() + base_idx, x1s.ptr() + base_idx + sub_batch_size, generator);
|
||||
std::generate(x2s.ptr() + base_idx, x2s.ptr() + base_idx + sub_batch_size, generator);
|
||||
@@ -93,10 +98,11 @@ void TernaryFloatingPointBruteForceTest(kernel_sig<T, TArg, TArg, TArg> kernel,
|
||||
|
||||
template <typename T, typename TArg, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void TernaryFloatingPointSpecialValuesTest(kernel_sig<T, TArg, TArg, TArg> kernel,
|
||||
ref_sig<RT, RTArg, RTArg, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
ref_sig<RT, RTArg, RTArg, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
const auto [grid_size, block_size] = GetOccupancyMaxPotentialBlockSize(kernel);
|
||||
const auto values = std::get<SpecialVals<TArg>>(kSpecialValRegistry);
|
||||
using SpecialValsType = std::conditional_t<std::is_same_v<TArg, Float16>, float, TArg>;
|
||||
const auto values = std::get<SpecialVals<SpecialValsType>>(kSpecialValRegistry);
|
||||
|
||||
const auto size = values.size * values.size * values.size;
|
||||
LinearAllocGuard<TArg> x1s{LinearAllocs::hipHostMalloc, size * sizeof(TArg)};
|
||||
@@ -119,13 +125,16 @@ void TernaryFloatingPointSpecialValuesTest(kernel_sig<T, TArg, TArg, TArg> kerne
|
||||
}
|
||||
|
||||
template <typename T, typename TArg, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void TernaryFloatingPointTest(kernel_sig<T, TArg, TArg, TArg> kernel, ref_sig<RT, RTArg, RTArg, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
void TernaryFloatingPointTest(kernel_sig<T, TArg, TArg, TArg> kernel,
|
||||
ref_sig<RT, RTArg, RTArg, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
SECTION("Special values") {
|
||||
TernaryFloatingPointSpecialValuesTest(kernel, ref_func, validator_builder);
|
||||
}
|
||||
|
||||
SECTION("Brute force") { TernaryFloatingPointBruteForceTest(kernel, ref_func, validator_builder); }
|
||||
SECTION("Brute force") {
|
||||
TernaryFloatingPointBruteForceTest(kernel, ref_func, validator_builder);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +147,5 @@ void TernaryFloatingPointTest(kernel_sig<T, TArg, TArg, TArg> kernel, ref_sig<RT
|
||||
const auto ulp = std::is_same_v<float, TestType> ? sp_ulp : dp_ulp; \
|
||||
\
|
||||
TernaryFloatingPointTest(kern_name##_kernel<TestType>, ref, \
|
||||
ULPValidatorBuilderFactory<TestType>(ulp)); \
|
||||
\
|
||||
ULPValidatorBuilderFactory<TestType>(ulp)); \
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,49 @@ namespace cg = cooperative_groups;
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void UnaryHalfPrecisionBruteForceTest(kernel_sig<T, Float16> kernel, ref_sig<RT, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
const auto [grid_size, block_size] = GetOccupancyMaxPotentialBlockSize(kernel);
|
||||
uint64_t stop = std::numeric_limits<uint16_t>::max() + 1ul;
|
||||
const auto max_batch_size =
|
||||
std::min(GetMaxAllowedDeviceMemoryUsage() / (sizeof(Float16) + sizeof(T)), stop);
|
||||
LinearAllocGuard<Float16> values{LinearAllocs::hipHostMalloc, max_batch_size * sizeof(Float16)};
|
||||
|
||||
MathTest math_test(kernel, max_batch_size);
|
||||
|
||||
auto batch_size = max_batch_size;
|
||||
const auto num_threads = thread_pool.thread_count();
|
||||
|
||||
for (uint64_t v = 0u; v < stop;) {
|
||||
batch_size = std::min<uint64_t>(max_batch_size, stop - v);
|
||||
|
||||
const auto min_sub_batch_size = batch_size / num_threads;
|
||||
const auto tail = batch_size % num_threads;
|
||||
|
||||
auto base_idx = 0u;
|
||||
for (auto i = 0u; i < num_threads; ++i) {
|
||||
const auto sub_batch_size = min_sub_batch_size + (i < tail);
|
||||
|
||||
thread_pool.Post([=, &values] {
|
||||
auto t = v;
|
||||
uint16_t val;
|
||||
for (auto j = 0u; j < sub_batch_size; ++j) {
|
||||
val = static_cast<uint16_t>(t++);
|
||||
values.ptr()[base_idx + j] = *reinterpret_cast<Float16*>(&val);
|
||||
}
|
||||
});
|
||||
|
||||
v += sub_batch_size;
|
||||
base_idx += sub_batch_size;
|
||||
}
|
||||
|
||||
thread_pool.Wait();
|
||||
|
||||
math_test.Run(validator_builder, grid_size, block_size, ref_func, batch_size, values.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void UnarySinglePrecisionBruteForceTest(kernel_sig<T, float> kernel, ref_sig<RT, RTArg> ref_func,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
@@ -163,6 +206,12 @@ void UnaryDoublePrecisionSpecialValuesTest(kernel_sig<T, double> kernel,
|
||||
values.data);
|
||||
}
|
||||
|
||||
template <typename T, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void UnaryHalfPrecisionTest(kernel_sig<T, Float16> kernel, ref_sig<RT, RTArg> ref,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
SECTION("Brute force") { UnaryHalfPrecisionBruteForceTest(kernel, ref, validator_builder); }
|
||||
}
|
||||
|
||||
template <typename T, typename RT, typename RTArg, typename ValidatorBuilder>
|
||||
void UnarySinglePrecisionTest(kernel_sig<T, float> kernel, ref_sig<RT, RTArg> ref,
|
||||
const ValidatorBuilder& validator_builder) {
|
||||
|
||||
@@ -61,21 +61,21 @@ template <typename T, typename Matcher> class ValidatorBase : public MatcherBase
|
||||
};
|
||||
|
||||
template <typename T> auto ULPValidatorBuilderFactory(int64_t ulps) {
|
||||
return [=](T target, auto&&... args) {
|
||||
return [=](T target, auto&&...) {
|
||||
return std::make_unique<ValidatorBase<T, Catch::Matchers::Floating::WithinUlpsMatcher>>(
|
||||
target, Catch::WithinULP(target, ulps));
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T> auto AbsValidatorBuilderFactory(double margin) {
|
||||
return [=](T target, auto&&... args) {
|
||||
return [=](T target, auto&&...) {
|
||||
return std::make_unique<ValidatorBase<T, Catch::Matchers::Floating::WithinAbsMatcher>>(
|
||||
target, Catch::WithinAbs(target, margin));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T> auto RelValidatorBuilderFactory(T margin) {
|
||||
return [=](T target, auto&&... args) {
|
||||
return [=](T target, auto&&...) {
|
||||
return std::make_unique<ValidatorBase<T, Catch::Matchers::Floating::WithinRelMatcher>>(
|
||||
target, Catch::WithinRel(target, margin));
|
||||
};
|
||||
@@ -104,7 +104,7 @@ template <typename T> class EqValidator : public MatcherBase<T> {
|
||||
};
|
||||
|
||||
template <typename T> auto EqValidatorBuilderFactory() {
|
||||
return [](T val, auto&&... args) { return std::make_unique<EqValidator<T>>(val); };
|
||||
return [](T val, auto&&...) { return std::make_unique<EqValidator<T>>(val); };
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename VBF, typename VBS>
|
||||
@@ -128,25 +128,25 @@ class PairValidator : public MatcherBase<std::pair<T, U>> {
|
||||
|
||||
template <typename T, typename ValidatorBuilder>
|
||||
auto PairValidatorBuilderFactory(const ValidatorBuilder& vb) {
|
||||
return [=](const std::pair<T, T>& t, auto&&... args) {
|
||||
return [=](const std::pair<T, T>& t, auto&&...) {
|
||||
return std::make_unique<PairValidator<T, T, ValidatorBuilder, ValidatorBuilder>>(t, vb, vb);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename VBF, typename VBS>
|
||||
auto PairValidatorBuilderFactory(const VBF& vbf, const VBS& vbs) {
|
||||
return [=](const std::pair<T, U>& t, auto&&... args) {
|
||||
return [=](const std::pair<T, U>& t, auto&&...) {
|
||||
return std::make_unique<PairValidator<T, U, VBF, VBS>>(t, vbf, vbs);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T> class NopValidator : public MatcherBase<T> {
|
||||
public:
|
||||
bool match(const T& val) const override { return true; }
|
||||
bool match(const T&) const override { return true; }
|
||||
|
||||
std::string describe() const override { return ""; }
|
||||
};
|
||||
|
||||
template <typename T> auto NopValidatorBuilderFactory() {
|
||||
return [](auto&&... args) { return std::make_unique<NopValidator<T>>(); };
|
||||
return [](auto&&...) { return std::make_unique<NopValidator<T>>(); };
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user