diff --git a/include/hip/hcc_detail/hip_runtime.h b/include/hip/hcc_detail/hip_runtime.h index a614e3599d..0b70570352 100644 --- a/include/hip/hcc_detail/hip_runtime.h +++ b/include/hip/hcc_detail/hip_runtime.h @@ -37,6 +37,7 @@ THE SOFTWARE. //#include #if __cplusplus #include +#include #else #include #include @@ -198,35 +199,93 @@ __device__ int __hip_move_dpp_N(int src); #if defined __HCC__ -template < - typename std::common_type::type f> -class Coordinates { - using R = decltype(f(0)); +namespace hip_impl { + struct GroupId { + using R = decltype(hc_get_group_id(0)); - struct X { - __device__ operator R() const { return f(0); } - __device__ uint32_t operator=(R _) { return f(0); } - }; - struct Y { - __device__ operator R() const { return f(1); } - __device__ uint32_t operator=(R _) { return f(1); } - }; - struct Z { - __device__ operator R() const { return f(2); } - __device__ uint32_t operator=(R _) { return f(2); } - }; + __device__ + R operator()(std::uint32_t x) const noexcept { return hc_get_group_id(x); } + }; + struct GroupSize { + using R = decltype(hc_get_group_size(0)); - public: - static constexpr X x{}; - static constexpr Y y{}; - static constexpr Z z{}; + __device__ + R operator()(std::uint32_t x) const noexcept { + return hc_get_group_size(x); + } + }; + struct NumGroups { + using R = decltype(hc_get_num_groups(0)); + + __device__ + R operator()(std::uint32_t x) const noexcept { + return hc_get_num_groups(x); + } + }; + struct WorkitemId { + using R = decltype(hc_get_workitem_id(0)); + + __device__ + R operator()(std::uint32_t x) const noexcept { + return hc_get_workitem_id(x); + } + }; +} // Namespace hip_impl. + +template +struct Coordinates { + using R = decltype(F{}(0)); + + struct X { __device__ operator R() const noexcept { return F{}(0); } }; + struct Y { __device__ operator R() const noexcept { return F{}(1); } }; + struct Z { __device__ operator R() const noexcept { return F{}(2); } }; + + static constexpr X x{}; + static constexpr Y y{}; + static constexpr Z z{}; }; -static constexpr Coordinates blockDim; -static constexpr Coordinates blockIdx; -static constexpr Coordinates gridDim; -static constexpr Coordinates threadIdx; +inline +__device__ +std::uint32_t operator*(Coordinates::X, + Coordinates::X) noexcept { + return hc_get_grid_size(0); +} +inline +__device__ +std::uint32_t operator*(Coordinates::X, + Coordinates::X) noexcept { + return hc_get_grid_size(0); +} +inline +__device__ +std::uint32_t operator*(Coordinates::Y, + Coordinates::Y) noexcept { + return hc_get_grid_size(1); +} +inline +__device__ +std::uint32_t operator*(Coordinates::Y, + Coordinates::Y) noexcept { + return hc_get_grid_size(1); +} +inline +__device__ +std::uint32_t operator*(Coordinates::Z, + Coordinates::Z) noexcept { + return hc_get_grid_size(2); +} +inline +__device__ +std::uint32_t operator*(Coordinates::Z, + Coordinates::Z) noexcept { + return hc_get_grid_size(2); +} + +static constexpr Coordinates blockDim{}; +static constexpr Coordinates blockIdx{}; +static constexpr Coordinates gridDim{}; +static constexpr Coordinates threadIdx{}; #define hipThreadIdx_x (hc_get_workitem_id(0)) #define hipThreadIdx_y (hc_get_workitem_id(1))