EXSWHTEC-333 - Extend tests for warp shlf and shfl_xor functions to support half-precision types #420
Change-Id: I1da47a0a4b8d15b0d2d569eb4769aa40207aade2
This commit is contained in:
کامیت شده توسط
Rakesh Roy
والد
7659470dbc
کامیت
91cff794b8
@@ -21,6 +21,16 @@ THE SOFTWARE.
|
||||
|
||||
#include <hip_test_common.hh>
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
static bool operator==(__half x, __half y) {
|
||||
// __heq doesn't have a __host__ version
|
||||
return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data;
|
||||
}
|
||||
static bool operator!=(__half x, __half y) { return !(x == y); }
|
||||
|
||||
static bool operator==(__half2 x, __half2 y) { return __hbeq2(x, y); }
|
||||
static bool operator!=(__half2 x, __half2 y) { return !(x == y); }
|
||||
|
||||
static __device__ bool deactivate_thread(const uint64_t* const active_masks) {
|
||||
const auto warp =
|
||||
|
||||
@@ -100,7 +100,7 @@ template <typename T> class WarpShfl : public WarpShflTest<WarpShfl<T>, T> {
|
||||
* - Device supports warp shuffle
|
||||
*/
|
||||
TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long, unsigned long,
|
||||
long long, unsigned long long, float, double) {
|
||||
long long, unsigned long long, float, double, __half, __half2) {
|
||||
int device;
|
||||
hipDeviceProp_t device_properties;
|
||||
HIP_CHECK(hipGetDevice(&device));
|
||||
@@ -111,11 +111,7 @@ TEMPLATE_TEST_CASE("Unit_Warp_Shfl_Positive_Basic", "", int, unsigned int, long,
|
||||
return;
|
||||
}
|
||||
|
||||
SECTION("Shfl with specified active mask and input values") {
|
||||
WarpShfl<TestType>().run(false);
|
||||
}
|
||||
SECTION("Shfl with specified active mask and input values") { WarpShfl<TestType>().run(false); }
|
||||
|
||||
SECTION("Shfl with random active mask and input values") {
|
||||
WarpShfl<TestType>().run(true);
|
||||
}
|
||||
SECTION("Shfl with random active mask and input values") { WarpShfl<TestType>().run(true); }
|
||||
}
|
||||
|
||||
@@ -82,6 +82,16 @@ template <typename Derived, typename T> class WarpShflTest {
|
||||
return static_cast<T>(
|
||||
GenerateRandomReal(std::numeric_limits<T>().min(), std::numeric_limits<T>().max()));
|
||||
});
|
||||
} else if constexpr (std::is_same_v<__half, T>) {
|
||||
std::generate_n(input, grid_.thread_count_, [] {
|
||||
return __float2half(GenerateRandomReal(std::numeric_limits<float>().min(),
|
||||
std::numeric_limits<float>().max()));
|
||||
});
|
||||
} else if constexpr (std::is_same_v<__half2, T>) {
|
||||
std::generate_n(input, grid_.thread_count_, [] {
|
||||
return __float2half2_rn(GenerateRandomReal(std::numeric_limits<float>().min(),
|
||||
std::numeric_limits<float>().max()));
|
||||
});
|
||||
} else {
|
||||
std::generate_n(input, grid_.thread_count_, [] {
|
||||
return static_cast<T>(GenerateRandomInteger(std::numeric_limits<T>().min(),
|
||||
|
||||
@@ -97,7 +97,7 @@ template <typename T> class WarpShflXOR : public WarpShflTest<WarpShflXOR<T>, T>
|
||||
* - Device supports warp shuffle
|
||||
*/
|
||||
TEMPLATE_TEST_CASE("Unit_Warp_Shfl_XOR_Positive_Basic", "", int, unsigned int, long, unsigned long,
|
||||
long long, unsigned long long, float, double) {
|
||||
long long, unsigned long long, float, double, __half, __half2) {
|
||||
int device;
|
||||
hipDeviceProp_t device_properties;
|
||||
HIP_CHECK(hipGetDevice(&device));
|
||||
|
||||
مرجع در شماره جدید
Block a user