diff --git a/include/hip/hcc_detail/hip_vector_types.h b/include/hip/hcc_detail/hip_vector_types.h index 3a10a06372..706b9aa3a7 100644 --- a/include/hip/hcc_detail/hip_vector_types.h +++ b/include/hip/hcc_detail/hip_vector_types.h @@ -40,6 +40,7 @@ THE SOFTWARE. #endif #if defined(__cplusplus) && defined(__clang__) + #include #include namespace hip_impl { @@ -68,6 +69,23 @@ THE SOFTWARE. } }; + friend + inline + std::ostream& operator<<(std::ostream& os, + const Scalar_accessor& x) noexcept { + return os << x.data[idx]; + } + friend + inline + std::istream& operator>>(std::istream& is, + Scalar_accessor& x) noexcept { + T tmp; + is >> tmp; + x.data[idx] = tmp; + + return is; + } + // Idea from https://t0rakka.silvrback.com/simd-scalar-accessor Vector data; @@ -76,6 +94,17 @@ THE SOFTWARE. __host__ __device__ operator T() const volatile noexcept { return data[idx]; } + __host__ __device__ + operator T&() noexcept { + return reinterpret_cast< + T (&)[sizeof(Vector) / sizeof(T)]>(data)[idx]; + } + __host__ __device__ + operator volatile T&() volatile noexcept { + return reinterpret_cast< + volatile T (&)[sizeof(Vector) / sizeof(T)]>(data)[idx]; + } + __host__ __device__ Address operator&() const noexcept { return Address{this}; } @@ -198,6 +227,8 @@ THE SOFTWARE. Native_vec_ data; hip_impl::Scalar_accessor x; }; + + using value_type = T; }; template @@ -209,6 +240,8 @@ THE SOFTWARE. hip_impl::Scalar_accessor x; hip_impl::Scalar_accessor y; }; + + using value_type = T; }; template @@ -367,6 +400,8 @@ THE SOFTWARE. T z; }; }; + + using value_type = T; }; template @@ -380,6 +415,8 @@ THE SOFTWARE. hip_impl::Scalar_accessor z; hip_impl::Scalar_accessor w; }; + + using value_type = T; }; template diff --git a/tests/src/deviceLib/hipVectorTypes.cpp b/tests/src/deviceLib/hipVectorTypes.cpp index 14479881ff..70c8320073 100644 --- a/tests/src/deviceLib/hipVectorTypes.cpp +++ b/tests/src/deviceLib/hipVectorTypes.cpp @@ -34,6 +34,7 @@ THE SOFTWARE. #include #include #include +#include #include using namespace std; @@ -157,6 +158,24 @@ bool TestVectorType() { if (f1 == f2) return false; if (!(f1 != f2)) return false; + using T = typename V::value_type; + + const T& x = f1.x; + T& y = f2.x; + const volatile T& z = f3.x; + volatile T& w = f2.x; + + if (x != T{3}) return false; + if (y != T{4}) return false; + if (z != T{3}) return false; + if (w != T{4}) return false; + + stringstream str; + str << f1.x; + str >> f2.x; + + if (f1.x != f2.x) return false; + return true; } diff --git a/tests/src/deviceLib/hipVectorTypesDevice.cpp b/tests/src/deviceLib/hipVectorTypesDevice.cpp index ba23931ee5..4bf5d2c87d 100644 --- a/tests/src/deviceLib/hipVectorTypesDevice.cpp +++ b/tests/src/deviceLib/hipVectorTypesDevice.cpp @@ -149,6 +149,20 @@ bool TestVectorType() { if (f1 == f2) return false; if (!(f1 != f2)) return false; + #if 0 // TODO: investigate on GFX8 + using T = typename V::value_type; + + const T& x = f1.x; + T& y = f2.x; + const volatile T& z = f3.x; + volatile T& w = f2.x; + + if (x != T{3}) return false; + if (y != T{4}) return false; + if (z != T{3}) return false; + if (w != T{4}) return false; + #endif + return true; }