diff --git a/projects/hip/include/hip/hcc_detail/functional_grid_launch.hpp b/projects/hip/include/hip/hcc_detail/functional_grid_launch.hpp index de943f310d..66e5873f3a 100644 --- a/projects/hip/include/hip/hcc_detail/functional_grid_launch.hpp +++ b/projects/hip/include/hip/hcc_detail/functional_grid_launch.hpp @@ -51,30 +51,53 @@ inline T round_up_to_next_multiple_nonnegative(T x, T y) { return tmp - tmp % y; } -inline std::vector make_kernarg() { return {}; } - -inline std::vector make_kernarg(std::vector kernarg) { return kernarg; } - -template -inline std::vector make_kernarg(std::vector kernarg, T x) { - kernarg.resize(round_up_to_next_multiple_nonnegative(kernarg.size(), alignof(T)) + sizeof(T)); - - new (kernarg.data() + kernarg.size() - sizeof(T)) T{std::move(x)}; - +template < + std::size_t n, + typename... Ts, + typename std::enable_if::type* = nullptr> +inline std::vector make_kernarg( + std::vector kernarg, const std::tuple&) { return kernarg; } -template -inline std::vector make_kernarg(std::vector kernarg, T x, Ts... xs) { - return make_kernarg(make_kernarg(std::move(kernarg), std::move(x)), std::move(xs)...); +template < + std::size_t n, + typename... Ts, + typename std::enable_if::type* = nullptr> +inline std::vector make_kernarg( + std::vector kernarg, const std::tuple& formals) { + using T = typename std::tuple_element>::type; + + static_assert( + !std::is_reference{}, + "A __global__ function cannot have a reference as one of its " + "arguments."); + #if defined(HIP_STRICT) + static_assert( + std::is_trivially_copyable{}, + "Only TriviallyCopyable types can be arguments to a __global__ " + "function"); + #endif + + kernarg.resize(round_up_to_next_multiple_nonnegative( + kernarg.size(), alignof(T)) + sizeof(T)); + + new (kernarg.data() + kernarg.size() - sizeof(T)) T{std::get(formals)}; + + return make_kernarg(std::move(kernarg), formals); } -template -inline std::vector make_kernarg(Ts... xs) { - std::vector kernarg; - kernarg.reserve(sizeof(std::tuple)); +template +inline std::vector make_kernarg( + void (*)(Formals...), std::tuple actuals) { + static_assert(sizeof...(Formals) == sizeof...(Actuals), + "The count of formal arguments must match the count of actuals."); - return make_kernarg(std::move(kernarg), std::move(xs)...); + std::tuple to_formals{std::move(actuals)}; + std::vector kernarg; + kernarg.reserve(sizeof(to_formals)); + + return make_kernarg<0>(std::move(kernarg), to_formals); } void hipLaunchKernelGGLImpl(std::uintptr_t function_address, const dim3& numBlocks, @@ -85,7 +108,8 @@ void hipLaunchKernelGGLImpl(std::uintptr_t function_address, const dim3& numBloc template inline void hipLaunchKernelGGL(F kernel, const dim3& numBlocks, const dim3& dimBlocks, std::uint32_t sharedMemBytes, hipStream_t stream, Args... args) { - auto kernarg = hip_impl::make_kernarg(std::move(args)...); + auto kernarg = hip_impl::make_kernarg( + kernel, std::tuple{std::move(args)...}); std::size_t kernarg_size = kernarg.size(); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernarg.data(), HIP_LAUNCH_PARAM_BUFFER_SIZE, @@ -100,4 +124,4 @@ inline void hipLaunchKernel(F kernel, const dim3& numBlocks, const dim3& dimBloc std::uint32_t groupMemBytes, hipStream_t stream, Args... args) { hipLaunchKernelGGL(kernel, numBlocks, dimBlocks, groupMemBytes, stream, hipLaunchParm{}, std::move(args)...); -} +} \ No newline at end of file