diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 8db0fd0e19..edd05ea13d 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -1200,7 +1200,9 @@ hipError_t hipGraphChildGraphNodeGetGraph(hipGraphNode_t node, hipGraph_t* pGrap HIP_RETURN(hipErrorInvalidValue); } *pGraph = reinterpret_cast(node)->GetChildGraph(); - if (pGraph == nullptr) { + + //if the node count is larger than 0, the current node is a parent + if (*pGraph == nullptr || reinterpret_cast(pGraph)->GetNodeCount() > 0) { HIP_RETURN(hipErrorInvalidValue); } HIP_RETURN(hipSuccess); diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index 5b6a3e719d..eeb5c05cf5 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -224,6 +224,7 @@ struct hipGraphNode { } } ihipGraph* GetParentGraph() { return parentGraph_; } + virtual ihipGraph* GetChildGraph() { return nullptr; } void SetParentGraph(ihipGraph* graph) { parentGraph_ = graph; } virtual hipError_t SetParams(hipGraphNode* node) { return hipSuccess; } };