From 2bb2446d8ff314d26b008cbee36a00c3ced49a1c Mon Sep 17 00:00:00 2001 From: Anusha GodavarthySurya Date: Wed, 13 Dec 2023 07:11:36 +0000 Subject: [PATCH] SWDEV-422207 - Fix graph catch tests with graph optimizations(DEBUG_CLR_GRAPH_PACKET_CAPTURE enabled) Change-Id: I16297e0ddde286bf1798c90f2bf846e69819010d --- hipamd/src/hip_graph.cpp | 16 +++++++++------- hipamd/src/hip_graph_internal.cpp | 6 ++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index ae40a89664..bfeeceb410 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -1211,13 +1211,13 @@ hipError_t ihipGraphInstantiate(hip::GraphExec** pGraphExec, hip::Graph* graph, if (clonedGraph == nullptr) { return hipErrorInvalidValue; } - std::vector> parallelLists; - std::unordered_map> nodeWaitLists; - clonedGraph->GetRunList(parallelLists, nodeWaitLists); std::vector graphNodes; if (false == clonedGraph->TopologicalOrder(graphNodes)) { return hipErrorInvalidValue; } + std::vector> parallelLists; + std::unordered_map> nodeWaitLists; + clonedGraph->GetRunList(parallelLists, nodeWaitLists); *pGraphExec = new hip::GraphExec(graphNodes, parallelLists, nodeWaitLists, clonedGraph, clonedNodes, flags); @@ -1237,10 +1237,12 @@ hipError_t hipGraphInstantiate(hipGraphExec_t* pGraphExec, hipGraph_t graph, } hip::GraphExec* ge; hipError_t status = ihipGraphInstantiate(&ge, reinterpret_cast(graph)); - *pGraphExec = reinterpret_cast(ge); - if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { - // For graph nodes capture AQL packets to dispatch them directly during graph launch. - status = ge->CaptureAQLPackets(); + if (status == hipSuccess) { + *pGraphExec = reinterpret_cast(ge); + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + // For graph nodes capture AQL packets to dispatch them directly during graph launch. + status = ge->CaptureAQLPackets(); + } } HIP_RETURN(status); } diff --git a/hipamd/src/hip_graph_internal.cpp b/hipamd/src/hip_graph_internal.cpp index 9150b281ad..ee70f8b1b9 100644 --- a/hipamd/src/hip_graph_internal.cpp +++ b/hipamd/src/hip_graph_internal.cpp @@ -549,8 +549,10 @@ hipError_t GraphExec::Run(hipStream_t stream) { for (int i = 0; i < topoOrder_.size() - 1; i++) { if (DEBUG_CLR_GRAPH_PACKET_CAPTURE && topoOrder_[i]->GetType() == hipGraphNodeTypeKernel) { - hip_stream->vdev()->dispatchAqlPacket(topoOrder_[i]->GetAqlPacket(), accumulate); - accumulate->addKernelName(topoOrder_[i]->GetKernelName()); + if (topoOrder_[i]->GetEnabled()) { + hip_stream->vdev()->dispatchAqlPacket(topoOrder_[i]->GetAqlPacket(), accumulate); + accumulate->addKernelName(topoOrder_[i]->GetKernelName()); + } } else { topoOrder_[i]->SetStream(hip_stream, this); status = topoOrder_[i]->CreateCommand(topoOrder_[i]->GetQueue());