diff --git a/catch/unit/warp/warp_common.hh b/catch/unit/warp/warp_common.hh index d09e96837e..875af88e54 100644 --- a/catch/unit/warp/warp_common.hh +++ b/catch/unit/warp/warp_common.hh @@ -21,6 +21,16 @@ THE SOFTWARE. #include #include +#include + +static bool operator==(__half x, __half y) { + // __heq doesn't have a __host__ version + return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data; +} +static bool operator!=(__half x, __half y) { return !(x == y); } + +static bool operator==(__half2 x, __half2 y) { return __hbeq2(x, y); } +static bool operator!=(__half2 x, __half2 y) { return !(x == y); } static __device__ bool deactivate_thread(const uint64_t* const active_masks) { const auto warp = diff --git a/catch/unit/warp/warp_shfl.cc b/catch/unit/warp/warp_shfl.cc index babb814fe4..73913ef672 100644 --- a/catch/unit/warp/warp_shfl.cc +++ b/catch/unit/warp/warp_shfl.cc @@ -100,7 +100,7 @@ template class WarpShfl : public WarpShflTest, T> { * - Device supports warp shuffle */ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long, unsigned long, - long long, unsigned long long, float, double) { + long long, unsigned long long, float, double, __half, __half2) { int device; hipDeviceProp_t device_properties; HIP_CHECK(hipGetDevice(&device)); @@ -111,11 +111,7 @@ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long, return; } - SECTION("Shfl with specified active mask and input values") { - WarpShfl().run(false); - } + SECTION("Shfl with specified active mask and input values") { WarpShfl().run(false); } - SECTION("Shfl with random active mask and input values") { - WarpShfl().run(true); - } + SECTION("Shfl with random active mask and input values") { WarpShfl().run(true); } } diff --git a/catch/unit/warp/warp_shfl_common.hh b/catch/unit/warp/warp_shfl_common.hh index 97b2677f31..44097c8f0a 100644 --- a/catch/unit/warp/warp_shfl_common.hh +++ b/catch/unit/warp/warp_shfl_common.hh @@ -82,6 +82,16 @@ template class WarpShflTest { return static_cast( GenerateRandomReal(std::numeric_limits().min(), std::numeric_limits().max())); }); + } else if constexpr (std::is_same_v<__half, T>) { + std::generate_n(input, grid_.thread_count_, [] { + return __float2half(GenerateRandomReal(std::numeric_limits().min(), + std::numeric_limits().max())); + }); + } else if constexpr (std::is_same_v<__half2, T>) { + std::generate_n(input, grid_.thread_count_, [] { + return __float2half2_rn(GenerateRandomReal(std::numeric_limits().min(), + std::numeric_limits().max())); + }); } else { std::generate_n(input, grid_.thread_count_, [] { return static_cast(GenerateRandomInteger(std::numeric_limits().min(), diff --git a/catch/unit/warp/warp_shfl_xor.cc b/catch/unit/warp/warp_shfl_xor.cc index 3edbca1b3a..267bc91119 100644 --- a/catch/unit/warp/warp_shfl_xor.cc +++ b/catch/unit/warp/warp_shfl_xor.cc @@ -97,7 +97,7 @@ template class WarpShflXOR : public WarpShflTest, T> * - Device supports warp shuffle */ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_XOR_Positive_Basic", "", int, unsigned int, long, unsigned long, - long long, unsigned long long, float, double) { + long long, unsigned long long, float, double, __half, __half2) { int device; hipDeviceProp_t device_properties; HIP_CHECK(hipGetDevice(&device));