Optimise the gridDim.n * blockDim.m idiom (#1468)

This commit is contained in:
Alex Voicu
2019-09-30 06:09:23 +01:00
committed by Maneesh Gupta
orang tua b3e6ba50c3
melakukan ab8fe8a3d8
+84 -25
Melihat File
@@ -37,6 +37,7 @@ THE SOFTWARE.
//#include <cstring>
#if __cplusplus
#include <cmath>
#include <cstdint>
#else
#include <math.h>
#include <string.h>
@@ -198,35 +199,93 @@ __device__ int __hip_move_dpp_N(int src);
#if defined __HCC__
template <
typename std::common_type<decltype(hc_get_group_id), decltype(hc_get_group_size),
decltype(hc_get_num_groups), decltype(hc_get_workitem_id)>::type f>
class Coordinates {
using R = decltype(f(0));
namespace hip_impl {
struct GroupId {
using R = decltype(hc_get_group_id(0));
struct X {
__device__ operator R() const { return f(0); }
__device__ uint32_t operator=(R _) { return f(0); }
};
struct Y {
__device__ operator R() const { return f(1); }
__device__ uint32_t operator=(R _) { return f(1); }
};
struct Z {
__device__ operator R() const { return f(2); }
__device__ uint32_t operator=(R _) { return f(2); }
};
__device__
R operator()(std::uint32_t x) const noexcept { return hc_get_group_id(x); }
};
struct GroupSize {
using R = decltype(hc_get_group_size(0));
public:
static constexpr X x{};
static constexpr Y y{};
static constexpr Z z{};
__device__
R operator()(std::uint32_t x) const noexcept {
return hc_get_group_size(x);
}
};
struct NumGroups {
using R = decltype(hc_get_num_groups(0));
__device__
R operator()(std::uint32_t x) const noexcept {
return hc_get_num_groups(x);
}
};
struct WorkitemId {
using R = decltype(hc_get_workitem_id(0));
__device__
R operator()(std::uint32_t x) const noexcept {
return hc_get_workitem_id(x);
}
};
} // Namespace hip_impl.
template <typename F>
struct Coordinates {
using R = decltype(F{}(0));
struct X { __device__ operator R() const noexcept { return F{}(0); } };
struct Y { __device__ operator R() const noexcept { return F{}(1); } };
struct Z { __device__ operator R() const noexcept { return F{}(2); } };
static constexpr X x{};
static constexpr Y y{};
static constexpr Z z{};
};
static constexpr Coordinates<hc_get_group_size> blockDim;
static constexpr Coordinates<hc_get_group_id> blockIdx;
static constexpr Coordinates<hc_get_num_groups> gridDim;
static constexpr Coordinates<hc_get_workitem_id> threadIdx;
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::NumGroups>::X,
Coordinates<hip_impl::GroupSize>::X) noexcept {
return hc_get_grid_size(0);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::GroupSize>::X,
Coordinates<hip_impl::NumGroups>::X) noexcept {
return hc_get_grid_size(0);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::NumGroups>::Y,
Coordinates<hip_impl::GroupSize>::Y) noexcept {
return hc_get_grid_size(1);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::GroupSize>::Y,
Coordinates<hip_impl::NumGroups>::Y) noexcept {
return hc_get_grid_size(1);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::NumGroups>::Z,
Coordinates<hip_impl::GroupSize>::Z) noexcept {
return hc_get_grid_size(2);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::GroupSize>::Z,
Coordinates<hip_impl::NumGroups>::Z) noexcept {
return hc_get_grid_size(2);
}
static constexpr Coordinates<hip_impl::GroupSize> blockDim{};
static constexpr Coordinates<hip_impl::GroupId> blockIdx{};
static constexpr Coordinates<hip_impl::NumGroups> gridDim{};
static constexpr Coordinates<hip_impl::WorkitemId> threadIdx{};
#define hipThreadIdx_x (hc_get_workitem_id(0))
#define hipThreadIdx_y (hc_get_workitem_id(1))