From e0c11624e572e39b3afe7e8d73b3cda8a39be944 Mon Sep 17 00:00:00 2001 From: Rahul Manocha Date: Thu, 21 Nov 2024 15:50:47 -0800 Subject: [PATCH] SWDEV-497288 - Enable hipGraphExecSetParams for Ext SemWait and SemSignal Nodes Change-Id: I7184a3a04ac17d3d841222ae1559db66d73a429c --- hipamd/src/hip_graph.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index bc0c5a9908..76b9a10a7d 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -1231,7 +1231,7 @@ hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, hipGraph HIP_INIT_API(hipGraphExecMemcpyNodeSetParams1D, hGraphExec, node, dst, src, count, kind); hip::GraphNode* n = reinterpret_cast(node); if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || dst == nullptr || - src == nullptr || count == 0 || src == dst) { + src == nullptr || count == 0 || src == dst || n->GetType() != hipGraphNodeTypeMemcpy) { HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* clonedNode = reinterpret_cast( @@ -1636,7 +1636,8 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo HIP_INIT_API(hipGraphExecMemcpyNodeSetParams, hGraphExec, node, pNodeParams); hip::GraphNode* n = reinterpret_cast(node); if (hGraphExec == nullptr || - !hip::GraphNode::isNodeValid(reinterpret_cast(n))) { + !hip::GraphNode::isNodeValid(reinterpret_cast(n)) || + n->GetType() != hipGraphNodeTypeMemcpy) { HIP_RETURN(hipErrorInvalidValue); } if (ihipMemcpy3D_validate(pNodeParams) != hipSuccess) { @@ -1697,7 +1698,7 @@ hipError_t hipGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo hip::GraphNode* n = reinterpret_cast(node); if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || pNodeParams == nullptr || - pNodeParams->dst == nullptr) { + pNodeParams->dst == nullptr || n->GetType() != hipGraphNodeTypeMemset) { HIP_RETURN(hipErrorInvalidValue); } if (ihipGraphMemsetParams_validate(pNodeParams) != hipSuccess) { @@ -1762,7 +1763,7 @@ hipError_t hipGraphExecKernelNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo hip::GraphNode* n = reinterpret_cast(node); if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || - pNodeParams == nullptr || pNodeParams->func == nullptr) { + pNodeParams == nullptr || pNodeParams->func == nullptr || n->GetType() != hipGraphNodeTypeKernel) { HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n); @@ -1799,7 +1800,7 @@ hipError_t hipGraphExecChildGraphNodeSetParams(hipGraphExec_t hGraphExec, hipGra hip::GraphNode* n = reinterpret_cast(node); hip::Graph* cg = reinterpret_cast(childGraph); if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || childGraph == nullptr || - !hip::Graph::isGraphValid(cg)) { + !hip::Graph::isGraphValid(cg) || n->GetType() != hipGraphNodeTypeGraph) { HIP_RETURN(hipErrorInvalidValue); } @@ -2428,7 +2429,7 @@ hipError_t hipGraphExecEventWaitNodeSetEvent(hipGraphExec_t hGraphExec, hipGraph hip::GraphNode* n = reinterpret_cast(hNode); if (hGraphExec == nullptr || hNode == nullptr || event == nullptr || - (n->GetType() != hipGraphNodeTypeWaitEvent)) { + (n->GetType() != hipGraphNodeTypeWaitEvent) || n->GetType() != hipGraphNodeTypeWaitEvent) { HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n); @@ -2479,7 +2480,7 @@ hipError_t hipGraphExecHostNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode HIP_INIT_API(hipGraphExecHostNodeSetParams, hGraphExec, node, pNodeParams); hip::GraphNode* n = reinterpret_cast(node); if (hGraphExec == nullptr || pNodeParams == nullptr || pNodeParams->fn == nullptr || - !hip::GraphNode::isNodeValid(n)) { + !hip::GraphNode::isNodeValid(n) || n->GetType() != hipGraphNodeTypeHost) { HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n); @@ -3165,7 +3166,8 @@ hipError_t hipGraphExecExternalSemaphoresSignalNodeSetParams(hipGraphExec_t hGra hip::GraphNode* n = reinterpret_cast(hNode); hip::GraphExec* graphExec = reinterpret_cast(hGraphExec); if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) || - !hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) { + !hip::GraphNode::isNodeValid(n) || nodeParams == nullptr || + n->GetType() != hipGraphNodeTypeExtSemaphoreSignal) { HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* clonedNode = graphExec->GetClonedNode(n); @@ -3183,7 +3185,8 @@ hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams(hipGraphExec_t hGraph hip::GraphNode* n = reinterpret_cast(hNode); hip::GraphExec* graphExec = reinterpret_cast(hGraphExec); if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) || - !hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) { + !hip::GraphNode::isNodeValid(n) || nodeParams == nullptr || + n->GetType() != hipGraphNodeTypeExtSemaphoreWait) { HIP_RETURN(hipErrorInvalidValue); } hip::GraphNode* clonedNode = graphExec->GetClonedNode(n); @@ -3326,12 +3329,12 @@ hipError_t ihipGraphNodeSetParams(hip::GraphNode* n, hipGraphNodeParams *nodePar nodeParams->eventRecord.event); break; case hipGraphNodeTypeExtSemaphoreSignal: - status = hipErrorNotSupported; - // to be added. + status = reinterpret_cast(n)->SetParams( + &nodeParams->extSemSignal); break; case hipGraphNodeTypeExtSemaphoreWait: - status = hipErrorNotSupported; - // to be added. + status = reinterpret_cast(n)->SetParams( + &nodeParams->extSemWait); break; case hipGraphNodeTypeMemAlloc: status = hipErrorNotSupported;