EXSWHTEC-333 - Extend tests for warp shlf and shfl_xor functions to support half-precision types #420

Change-Id: I1da47a0a4b8d15b0d2d569eb4769aa40207aade2
This commit is contained in:
Mirza Halilčević
2023-12-28 15:09:55 +01:00
کامیت شده توسط Rakesh Roy
والد 7659470dbc
کامیت 91cff794b8
4فایلهای تغییر یافته به همراه24 افزوده شده و 8 حذف شده
@@ -21,6 +21,16 @@ THE SOFTWARE.
#include <hip_test_common.hh>
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_fp16.h>
static bool operator==(__half x, __half y) {
// __heq doesn't have a __host__ version
return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data;
}
static bool operator!=(__half x, __half y) { return !(x == y); }
static bool operator==(__half2 x, __half2 y) { return __hbeq2(x, y); }
static bool operator!=(__half2 x, __half2 y) { return !(x == y); }
static __device__ bool deactivate_thread(const uint64_t* const active_masks) {
const auto warp =
+3 -7
مشاهده پرونده
@@ -100,7 +100,7 @@ template <typename T> class WarpShfl : public WarpShflTest<WarpShfl<T>, T> {
* - Device supports warp shuffle
*/
TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long, unsigned long,
long long, unsigned long long, float, double) {
long long, unsigned long long, float, double, __half, __half2) {
int device;
hipDeviceProp_t device_properties;
HIP_CHECK(hipGetDevice(&device));
@@ -111,11 +111,7 @@ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long,
return;
}
SECTION("Shfl with specified active mask and input values") {
WarpShfl<TestType>().run(false);
}
SECTION("Shfl with specified active mask and input values") { WarpShfl<TestType>().run(false); }
SECTION("Shfl with random active mask and input values") {
WarpShfl<TestType>().run(true);
}
SECTION("Shfl with random active mask and input values") { WarpShfl<TestType>().run(true); }
}
@@ -82,6 +82,16 @@ template <typename Derived, typename T> class WarpShflTest {
return static_cast<T>(
GenerateRandomReal(std::numeric_limits<T>().min(), std::numeric_limits<T>().max()));
});
} else if constexpr (std::is_same_v<__half, T>) {
std::generate_n(input, grid_.thread_count_, [] {
return __float2half(GenerateRandomReal(std::numeric_limits<float>().min(),
std::numeric_limits<float>().max()));
});
} else if constexpr (std::is_same_v<__half2, T>) {
std::generate_n(input, grid_.thread_count_, [] {
return __float2half2_rn(GenerateRandomReal(std::numeric_limits<float>().min(),
std::numeric_limits<float>().max()));
});
} else {
std::generate_n(input, grid_.thread_count_, [] {
return static_cast<T>(GenerateRandomInteger(std::numeric_limits<T>().min(),
@@ -97,7 +97,7 @@ template <typename T> class WarpShflXOR : public WarpShflTest<WarpShflXOR<T>, T>
* - Device supports warp shuffle
*/
TEMPLATE_TEST_CASE("Unit_Warp_Shfl_XOR_Positive_Basic", "", int, unsigned int, long, unsigned long,
long long, unsigned long long, float, double) {
long long, unsigned long long, float, double, __half, __half2) {
int device;
hipDeviceProp_t device_properties;
HIP_CHECK(hipGetDevice(&device));