From 91cff794b812cb5f87bc1a286383ccb57c0ec56c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mirza=20Halil=C4=8Devi=C4=87?= <109971222+mirza-halilcevic@users.noreply.github.com> Date: Thu, 28 Dec 2023 15:09:55 +0100 Subject: [PATCH] EXSWHTEC-333 - Extend tests for warp shlf and shfl_xor functions to support half-precision types #420 Change-Id: I1da47a0a4b8d15b0d2d569eb4769aa40207aade2 --- catch/unit/warp/warp_common.hh | 10 ++++++++++ catch/unit/warp/warp_shfl.cc | 10 +++------- catch/unit/warp/warp_shfl_common.hh | 10 ++++++++++ catch/unit/warp/warp_shfl_xor.cc | 2 +- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/catch/unit/warp/warp_common.hh b/catch/unit/warp/warp_common.hh index d09e96837e..875af88e54 100644 --- a/catch/unit/warp/warp_common.hh +++ b/catch/unit/warp/warp_common.hh @@ -21,6 +21,16 @@ THE SOFTWARE. #include #include +#include + +static bool operator==(__half x, __half y) { + // __heq doesn't have a __host__ version + return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data; +} +static bool operator!=(__half x, __half y) { return !(x == y); } + +static bool operator==(__half2 x, __half2 y) { return __hbeq2(x, y); } +static bool operator!=(__half2 x, __half2 y) { return !(x == y); } static __device__ bool deactivate_thread(const uint64_t* const active_masks) { const auto warp = diff --git a/catch/unit/warp/warp_shfl.cc b/catch/unit/warp/warp_shfl.cc index babb814fe4..73913ef672 100644 --- a/catch/unit/warp/warp_shfl.cc +++ b/catch/unit/warp/warp_shfl.cc @@ -100,7 +100,7 @@ template class WarpShfl : public WarpShflTest, T> { * - Device supports warp shuffle */ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long, unsigned long, - long long, unsigned long long, float, double) { + long long, unsigned long long, float, double, __half, __half2) { int device; hipDeviceProp_t device_properties; HIP_CHECK(hipGetDevice(&device)); @@ -111,11 +111,7 @@ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long, return; } - SECTION("Shfl with specified active mask and input values") { - WarpShfl().run(false); - } + SECTION("Shfl with specified active mask and input values") { WarpShfl().run(false); } - SECTION("Shfl with random active mask and input values") { - WarpShfl().run(true); - } + SECTION("Shfl with random active mask and input values") { WarpShfl().run(true); } } diff --git a/catch/unit/warp/warp_shfl_common.hh b/catch/unit/warp/warp_shfl_common.hh index 97b2677f31..44097c8f0a 100644 --- a/catch/unit/warp/warp_shfl_common.hh +++ b/catch/unit/warp/warp_shfl_common.hh @@ -82,6 +82,16 @@ template class WarpShflTest { return static_cast( GenerateRandomReal(std::numeric_limits().min(), std::numeric_limits().max())); }); + } else if constexpr (std::is_same_v<__half, T>) { + std::generate_n(input, grid_.thread_count_, [] { + return __float2half(GenerateRandomReal(std::numeric_limits().min(), + std::numeric_limits().max())); + }); + } else if constexpr (std::is_same_v<__half2, T>) { + std::generate_n(input, grid_.thread_count_, [] { + return __float2half2_rn(GenerateRandomReal(std::numeric_limits().min(), + std::numeric_limits().max())); + }); } else { std::generate_n(input, grid_.thread_count_, [] { return static_cast(GenerateRandomInteger(std::numeric_limits().min(), diff --git a/catch/unit/warp/warp_shfl_xor.cc b/catch/unit/warp/warp_shfl_xor.cc index 3edbca1b3a..267bc91119 100644 --- a/catch/unit/warp/warp_shfl_xor.cc +++ b/catch/unit/warp/warp_shfl_xor.cc @@ -97,7 +97,7 @@ template class WarpShflXOR : public WarpShflTest, T> * - Device supports warp shuffle */ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_XOR_Positive_Basic", "", int, unsigned int, long, unsigned long, - long long, unsigned long long, float, double) { + long long, unsigned long long, float, double, __half, __half2) { int device; hipDeviceProp_t device_properties; HIP_CHECK(hipGetDevice(&device));