Merge pull request #669 from ROCm-Developer-Tools/feature_automatic_cast
Remove potential for mismatch between runtime passed actuals and defined formals
[ROCm/hip commit: aed5ad31ba]
Этот коммит содержится в:
@@ -51,30 +51,53 @@ inline T round_up_to_next_multiple_nonnegative(T x, T y) {
|
||||
return tmp - tmp % y;
|
||||
}
|
||||
|
||||
inline std::vector<std::uint8_t> make_kernarg() { return {}; }
|
||||
|
||||
inline std::vector<std::uint8_t> make_kernarg(std::vector<std::uint8_t> kernarg) { return kernarg; }
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<std::uint8_t> make_kernarg(std::vector<uint8_t> kernarg, T x) {
|
||||
kernarg.resize(round_up_to_next_multiple_nonnegative(kernarg.size(), alignof(T)) + sizeof(T));
|
||||
|
||||
new (kernarg.data() + kernarg.size() - sizeof(T)) T{std::move(x)};
|
||||
|
||||
template <
|
||||
std::size_t n,
|
||||
typename... Ts,
|
||||
typename std::enable_if<n == sizeof...(Ts)>::type* = nullptr>
|
||||
inline std::vector<std::uint8_t> make_kernarg(
|
||||
std::vector<std::uint8_t> kernarg, const std::tuple<Ts...>&) {
|
||||
return kernarg;
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
inline std::vector<std::uint8_t> make_kernarg(std::vector<std::uint8_t> kernarg, T x, Ts... xs) {
|
||||
return make_kernarg(make_kernarg(std::move(kernarg), std::move(x)), std::move(xs)...);
|
||||
template <
|
||||
std::size_t n,
|
||||
typename... Ts,
|
||||
typename std::enable_if<n != sizeof...(Ts)>::type* = nullptr>
|
||||
inline std::vector<std::uint8_t> make_kernarg(
|
||||
std::vector<std::uint8_t> kernarg, const std::tuple<Ts...>& formals) {
|
||||
using T = typename std::tuple_element<n, std::tuple<Ts...>>::type;
|
||||
|
||||
static_assert(
|
||||
!std::is_reference<T>{},
|
||||
"A __global__ function cannot have a reference as one of its "
|
||||
"arguments.");
|
||||
#if defined(HIP_STRICT)
|
||||
static_assert(
|
||||
std::is_trivially_copyable<T>{},
|
||||
"Only TriviallyCopyable types can be arguments to a __global__ "
|
||||
"function");
|
||||
#endif
|
||||
|
||||
kernarg.resize(round_up_to_next_multiple_nonnegative(
|
||||
kernarg.size(), alignof(T)) + sizeof(T));
|
||||
|
||||
new (kernarg.data() + kernarg.size() - sizeof(T)) T{std::get<n>(formals)};
|
||||
|
||||
return make_kernarg<n + 1>(std::move(kernarg), formals);
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
inline std::vector<std::uint8_t> make_kernarg(Ts... xs) {
|
||||
std::vector<std::uint8_t> kernarg;
|
||||
kernarg.reserve(sizeof(std::tuple<Ts...>));
|
||||
template <typename... Formals, typename... Actuals>
|
||||
inline std::vector<std::uint8_t> make_kernarg(
|
||||
void (*)(Formals...), std::tuple<Actuals...> actuals) {
|
||||
static_assert(sizeof...(Formals) == sizeof...(Actuals),
|
||||
"The count of formal arguments must match the count of actuals.");
|
||||
|
||||
return make_kernarg(std::move(kernarg), std::move(xs)...);
|
||||
std::tuple<Formals...> to_formals{std::move(actuals)};
|
||||
std::vector<std::uint8_t> kernarg;
|
||||
kernarg.reserve(sizeof(to_formals));
|
||||
|
||||
return make_kernarg<0>(std::move(kernarg), to_formals);
|
||||
}
|
||||
|
||||
void hipLaunchKernelGGLImpl(std::uintptr_t function_address, const dim3& numBlocks,
|
||||
@@ -85,7 +108,8 @@ void hipLaunchKernelGGLImpl(std::uintptr_t function_address, const dim3& numBloc
|
||||
template <typename... Args, typename F = void (*)(Args...)>
|
||||
inline void hipLaunchKernelGGL(F kernel, const dim3& numBlocks, const dim3& dimBlocks,
|
||||
std::uint32_t sharedMemBytes, hipStream_t stream, Args... args) {
|
||||
auto kernarg = hip_impl::make_kernarg(std::move(args)...);
|
||||
auto kernarg = hip_impl::make_kernarg(
|
||||
kernel, std::tuple<Args...>{std::move(args)...});
|
||||
std::size_t kernarg_size = kernarg.size();
|
||||
|
||||
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernarg.data(), HIP_LAUNCH_PARAM_BUFFER_SIZE,
|
||||
@@ -100,4 +124,4 @@ inline void hipLaunchKernel(F kernel, const dim3& numBlocks, const dim3& dimBloc
|
||||
std::uint32_t groupMemBytes, hipStream_t stream, Args... args) {
|
||||
hipLaunchKernelGGL(kernel, numBlocks, dimBlocks, groupMemBytes, stream, hipLaunchParm{},
|
||||
std::move(args)...);
|
||||
}
|
||||
}
|
||||
Ссылка в новой задаче
Block a user