From 30cb2d0e67d12ec349151e8f71e38aedb2b034d3 Mon Sep 17 00:00:00 2001 From: Vladana Stojiljkovic Date: Wed, 13 Nov 2024 14:03:23 +0200 Subject: [PATCH] SWDEV-498061 - Add capture support for hipModuleLaunchCooperativeKernel Change-Id: I5ed188e046c680c2785b3952391f59ed1d0c21b8 --- hipamd/src/hip_graph.cpp | 34 +++++++++++++++++++++++++++++++ hipamd/src/hip_graph_capture.hpp | 6 ++++++ hipamd/src/hip_graph_internal.hpp | 12 ++++++++++- hipamd/src/hip_module.cpp | 3 +++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index ae2da97b81..b517ddbb7f 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -348,6 +348,40 @@ hipError_t capturehipModuleLaunchKernel(hipStream_t& stream, hipFunction_t& f, u return hipSuccess; } +hipError_t capturehipModuleLaunchCooperativeKernel(hipStream_t& stream, hipFunction_t& f, + uint32_t& gridDimX, uint32_t& gridDimY, + uint32_t& gridDimZ, uint32_t& blockDimX, + uint32_t& blockDimY, uint32_t& blockDimZ, + uint32_t& sharedMemBytes, void**& kernelParams) { + ClPrint(amd::LOG_INFO, amd::LOG_API, + "[hipGraph] Current capture node ModuleLaunchCooperativeKernel on stream : %p", stream); + + if (!hip::isValid(stream)) { + return hipErrorContextIsDestroyed; + } + + hip::Stream* s = reinterpret_cast(stream); + hipKernelNodeParams nodeParams; + nodeParams.func = f; + nodeParams.blockDim = {blockDimX, blockDimY, blockDimZ}; + nodeParams.gridDim = {gridDimX, gridDimY, gridDimZ}; + nodeParams.kernelParams = kernelParams; + nodeParams.sharedMemBytes = sharedMemBytes; + nodeParams.extra = nullptr; + + hip::GraphNode* pGraphNode; + hipError_t status = + ihipGraphAddKernelNode(&pGraphNode, s->GetCaptureGraph(), s->GetLastCapturedNodes().data(), + s->GetLastCapturedNodes().size(), &nodeParams, nullptr, true, + amd::NDRangeKernelCommand::CooperativeGroups); + if (status != hipSuccess) { + return status; + } + s->SetLastCapturedNode(pGraphNode); + + return hipSuccess; +} + hipError_t capturehipLaunchByPtr(hipStream_t& stream, hipFunction_t func, dim3 blockDim, dim3 gridDim, unsigned int sharedMemBytes, void** extra) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] Current capture node LaunchByPtr on stream : %p", diff --git a/hipamd/src/hip_graph_capture.hpp b/hipamd/src/hip_graph_capture.hpp index ac62bee750..5e65569846 100644 --- a/hipamd/src/hip_graph_capture.hpp +++ b/hipamd/src/hip_graph_capture.hpp @@ -42,6 +42,12 @@ hipError_t capturehipModuleLaunchKernel(hipStream_t& stream, hipFunction_t& f, u uint32_t& sharedMemBytes, void**& kernelParams, void**& extra); +hipError_t capturehipModuleLaunchCooperativeKernel(hipStream_t& stream, hipFunction_t& f, + uint32_t& gridDimX, uint32_t& gridDimY, + uint32_t& gridDimZ, uint32_t& blockDimX, + uint32_t& blockDimY, uint32_t& blockDimZ, + uint32_t& sharedMemBytes, void**& kernelParams); + hipError_t capturehipLaunchByPtr(hipStream_t& stream, hipFunction_t func, dim3 blockDim, dim3 gridDim, unsigned int sharedMemBytes, void** extra); diff --git a/hipamd/src/hip_graph_internal.hpp b/hipamd/src/hip_graph_internal.hpp index 32b081fc7d..9019b2aa6d 100644 --- a/hipamd/src/hip_graph_internal.hpp +++ b/hipamd/src/hip_graph_internal.hpp @@ -439,7 +439,6 @@ struct GraphNode : public hipGraphNodeDOTAttribute { bool isGraphCapture = false; if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { switch (GetType()) { - case hipGraphNodeTypeKernel: case hipGraphNodeTypeMemset: isGraphCapture = true; break; @@ -1325,6 +1324,17 @@ class GraphKernelNode : public GraphNode { } return hipSuccess; } + + virtual bool GraphCaptureEnabled() override { + bool isGraphCapture = false; + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + // Disable capture for cooperative kernels + if (!coopKernel_) { + isGraphCapture = true; + } + } + return isGraphCapture; + } }; class GraphMemcpyNode : public GraphNode { diff --git a/hipamd/src/hip_module.cpp b/hipamd/src/hip_module.cpp index 5527a761e6..6d219ad761 100644 --- a/hipamd/src/hip_module.cpp +++ b/hipamd/src/hip_module.cpp @@ -572,6 +572,9 @@ hipError_t hipModuleLaunchCooperativeKernel(hipFunction_t f, unsigned int gridDi HIP_RETURN(hipErrorInvalidValue); } + STREAM_CAPTURE(hipModuleLaunchCooperativeKernel, stream, f, gridDimX, gridDimY, gridDimZ, + blockDimX, blockDimY, blockDimZ, sharedMemBytes, kernelParams); + size_t globalWorkSizeX = static_cast(gridDimX) * blockDimX; size_t globalWorkSizeY = static_cast(gridDimY) * blockDimY; size_t globalWorkSizeZ = static_cast(gridDimZ) * blockDimZ;