diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index 17890e3dda..a4b64f9e1b 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -1102,7 +1102,7 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo // Check if pNodeParams passed is a empty struct if (((pNodeParams->srcArray == 0) && (pNodeParams->srcPtr.ptr == nullptr)) || ((pNodeParams->dstArray == 0) && (pNodeParams->dstPtr.ptr == nullptr))) { - return hipErrorInvalidValue; + HIP_RETURN(hipErrorInvalidValue); } hipGraphNode_t clonedNode = hGraphExec->GetClonedNode(node); if (clonedNode == nullptr) { @@ -1202,9 +1202,23 @@ hipError_t hipGraphChildGraphNodeGetGraph(hipGraphNode_t node, hipGraph_t* pGrap hipError_t hipGraphExecChildGraphNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t node, hipGraph_t childGraph) { HIP_INIT_API(hipGraphExecChildGraphNodeSetParams, hGraphExec, node, childGraph); - if (hGraphExec == nullptr || node == nullptr || childGraph == nullptr) { + if (hGraphExec == nullptr || node == nullptr || childGraph == nullptr || + !ihipGraph::isGraphValid(childGraph)) { HIP_RETURN(hipErrorInvalidValue); } + + if (childGraph == node->GetParentGraph()) { + HIP_RETURN(hipErrorUnknown); + } + + hipGraphNode_t hipErrorNode_out; + hipGraphExecUpdateResult updateResult_out; + // Check if this instantiated graph is updatable. All restrictions in hipGraphExecUpdate() apply. + if (hipGraphExecUpdate(hGraphExec, childGraph, &hipErrorNode_out, &updateResult_out) == + hipErrorGraphExecUpdateFailure) { + HIP_RETURN(hipErrorUnknown); + } + hipGraphNode_t clonedNode = hGraphExec->GetClonedNode(node); if (clonedNode == nullptr) { HIP_RETURN(hipErrorInvalidValue);