From ca3c2ac185cb2182cd6d8434740ca9c16bd220ff Mon Sep 17 00:00:00 2001 From: Jaydeep Patel Date: Mon, 22 Apr 2024 19:26:50 +0000 Subject: [PATCH] SWDEV-457316 - Some validations related to Graph Node. Free node should be added in same graph and once. Graph clone containing mem alloc/mem free node not supported. Destroy mem alloc/mem free node is not supported if already added in graph. Change-Id: I40459e66d7dd84f3b5298617990313b41458c804 --- hipamd/src/hip_graph.cpp | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index 6975119a7f..6c063b9609 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -2048,6 +2048,11 @@ hipError_t hipGraphDestroyNode(hipGraphNode_t node) { if (!hip::GraphNode::isNodeValid(reinterpret_cast(n))) { HIP_RETURN(hipErrorInvalidValue); } + + if (n->GetType() == hipGraphNodeTypeMemAlloc || + n->GetType() == hipGraphNodeTypeMemFree) { + HIP_RETURN(hipErrorNotSupported); + } // First remove all the edges both incoming and outgoing from node. for (auto& edge : n->GetEdges()) { n->RemoveUpdateEdge(edge); @@ -2072,6 +2077,12 @@ hipError_t hipGraphClone(hipGraph_t* pGraphClone, hipGraph_t originalGraph) { if (!hip::Graph::isGraphValid(g)) { HIP_RETURN(hipErrorInvalidValue); } + for (auto n : g->vertices_) { + if (n->GetType() == hipGraphNodeTypeMemAlloc || + n->GetType() == hipGraphNodeTypeMemFree) { + HIP_RETURN(hipErrorNotSupported); + } + } *pGraphClone = reinterpret_cast(g->clone()); HIP_RETURN(hipSuccess); } @@ -2569,6 +2580,26 @@ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* pNode; + bool AllocNodeFound = false; + hip::Graph* g = reinterpret_cast(graph); + for (auto n : g->vertices_) { + if (n->GetType() == hipGraphNodeTypeMemAlloc) { + hipMemAllocNodeParams param = {}; + reinterpret_cast(n)->GetParams(¶m); + if (param.dptr == dev_ptr) { + AllocNodeFound = true; + } + } else if (n->GetType() == hipGraphNodeTypeMemFree) { + void* param; + reinterpret_cast(n)->GetParams(¶m); + if (param == dev_ptr) { + HIP_RETURN(hipErrorInvalidValue); + } + } + } + if (!AllocNodeFound) { + HIP_RETURN(hipErrorInvalidValue); + } auto status = ihipGraphAddMemFreeNode(&pNode, reinterpret_cast(graph),