diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index b4723282f3..14b98b082f 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -1620,7 +1620,8 @@ hipError_t hipGraphAddEventRecordNode(hipGraphNode_t* pGraphNode, hipGraph_t gra hipError_t hipGraphEventRecordNodeGetEvent(hipGraphNode_t node, hipEvent_t* event_out) { HIP_INIT_API(hipGraphEventRecordNodeGetEvent, node, event_out); - if (node == nullptr || event_out == nullptr) { + if (node == nullptr || event_out == nullptr || + node->GetType() != hipGraphNodeTypeEventRecord) { HIP_RETURN(hipErrorInvalidValue); } reinterpret_cast(node)->GetParams(event_out); @@ -1629,7 +1630,8 @@ hipError_t hipGraphEventRecordNodeGetEvent(hipGraphNode_t node, hipEvent_t* even hipError_t hipGraphEventRecordNodeSetEvent(hipGraphNode_t node, hipEvent_t event) { HIP_INIT_API(hipGraphEventRecordNodeSetEvent, node, event); - if (node == nullptr || event == nullptr) { + if (node == nullptr || event == nullptr || + node->GetType() != hipGraphNodeTypeEventRecord ) { HIP_RETURN(hipErrorInvalidValue); } HIP_RETURN(reinterpret_cast(node)->SetParams(event));