From fc014587f8ccc4d83473cc1d19f4bf4e1cbc5ce7 Mon Sep 17 00:00:00 2001 From: Anusha GodavarthySurya Date: Wed, 7 Aug 2024 05:55:50 +0000 Subject: [PATCH] SWDEV-477324 - Graph Capture memcpy D2D Change-Id: Ifaa4d78854c03b3150233142df187c9bbf731cab [ROCm/clr commit: e98179d92446ce352898ac86c04accdc8cc8cd54] --- projects/clr/hipamd/src/hip_graph.cpp | 46 ++++++++++++++++--- .../clr/hipamd/src/hip_graph_internal.hpp | 32 +++++++++++-- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 3fe822ff1d..075544c2d4 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -1202,8 +1202,16 @@ hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, hipGraph if (oldkind != kind) { HIP_RETURN(hipErrorInvalidValue); } - HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(dst, src, - count, kind)); + hipError_t status = reinterpret_cast(clonedNode)->SetParams(dst, src, + count, kind); + if (status != hipSuccess) { + HIP_RETURN(status); + } + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + status = reinterpret_cast(hGraphExec) + ->UpdateAQLPacket(reinterpret_cast(clonedNode)); + } + HIP_RETURN(status); } hipError_t hipGraphAddMemsetNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, @@ -1612,7 +1620,15 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo if (oldkind != newkind) { HIP_RETURN(hipErrorInvalidValue); } - HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(pNodeParams)); + hipError_t status = reinterpret_cast(clonedNode)->SetParams(pNodeParams); + if (status != hipSuccess) { + HIP_RETURN(status); + } + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + status = reinterpret_cast(hGraphExec) + ->UpdateAQLPacket(reinterpret_cast(clonedNode)); + } + HIP_RETURN(status); } hipError_t hipGraphMemsetNodeGetParams(hipGraphNode_t node, hipMemsetParams* pNodeParams) { @@ -2181,8 +2197,16 @@ hipError_t hipGraphExecMemcpyNodeSetParamsFromSymbol(hipGraphExec_t hGraphExec, HIP_RETURN(hipErrorInvalidValue); } constexpr bool kCheckDeviceIsSame = true; - HIP_RETURN(reinterpret_cast(clonedNode) - ->SetParams(dst, symbol, count, offset, kind, kCheckDeviceIsSame)); + hipError_t status = reinterpret_cast(clonedNode) + ->SetParams(dst, symbol, count, offset, kind, kCheckDeviceIsSame); + if (status != hipSuccess) { + HIP_RETURN(status); + } + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + status = reinterpret_cast(hGraphExec) + ->UpdateAQLPacket(reinterpret_cast(clonedNode)); + } + HIP_RETURN(status); } hipError_t hipGraphAddMemcpyNodeToSymbol(hipGraphNode_t* pGraphNode, hipGraph_t graph, @@ -2255,8 +2279,16 @@ hipError_t hipGraphExecMemcpyNodeSetParamsToSymbol(hipGraphExec_t hGraphExec, hi HIP_RETURN(hipErrorInvalidValue); } constexpr bool kCheckDeviceIsSame = true; - HIP_RETURN(reinterpret_cast(clonedNode) - ->SetParams(symbol, src, count, offset, kind, kCheckDeviceIsSame)); + hipError_t status = reinterpret_cast(clonedNode) + ->SetParams(symbol, src, count, offset, kind, kCheckDeviceIsSame); + if (status != hipSuccess) { + HIP_RETURN(status); + } + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + status = reinterpret_cast(hGraphExec) + ->UpdateAQLPacket(reinterpret_cast(clonedNode)); + } + HIP_RETURN(status); } hipError_t hipGraphAddEventRecordNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index c491088def..87af33000c 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -212,7 +212,7 @@ struct GraphNode : public hipGraphNodeDOTAttribute { std::vector gpuPackets_; //!< GPU Packet to enqueue during graph launch std::string capturedKernelName_; size_t alignedKernArgSize_ = 256; //!< Aligned size required for kernel args - size_t kernargSegmentByteSize_ = 256; //!< Kernel arg segment byte size + size_t kernargSegmentByteSize_ = 512; //!< Kernel arg segment byte size size_t kernargSegmentAlignment_ = 256; //!< Kernel arg segment alignment public: @@ -434,13 +434,11 @@ struct GraphNode : public hipGraphNodeDOTAttribute { unsigned int GetEnabled() const { return isEnabled_; } void SetEnabled(unsigned int isEnabled) { isEnabled_ = isEnabled; } // Returns true if capture is enabled for the current node. - bool GraphCaptureEnabled() { + virtual bool GraphCaptureEnabled() { bool isGraphCapture = false; if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { switch (GetType()) { case hipGraphNodeTypeKernel: - isGraphCapture = true; - break; case hipGraphNodeTypeMemset: isGraphCapture = true; break; @@ -1525,6 +1523,19 @@ class GraphMemcpyNode : public GraphNode { return shape_; } } + virtual bool GraphCaptureEnabled() override { + bool isGraphCapture = false; + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + switch (copyParams_.kind) { + case hipMemcpyDeviceToDevice: + isGraphCapture = true; + break; + default: + break; + } + } + return isGraphCapture; + } }; class GraphMemcpyNode1D : public GraphMemcpyNode { @@ -1704,6 +1715,19 @@ class GraphMemcpyNode1D : public GraphMemcpyNode { return shape_; } } + virtual bool GraphCaptureEnabled() override { + bool isGraphCapture = false; + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + switch (kind_) { + case hipMemcpyDeviceToDevice: + isGraphCapture = true; + break; + default: + break; + } + } + return isGraphCapture; + } }; class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D {