diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index 5489c72b25..8244bdb018 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -64,7 +64,8 @@ hipError_t ihipGraphAddKernelNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, const hipGraphNode_t* pDependencies, size_t numDependencies, const hipKernelNodeParams* pNodeParams) { if (pGraphNode == nullptr || graph == nullptr || - (numDependencies > 0 && pDependencies == nullptr) || pNodeParams == nullptr) { + (numDependencies > 0 && pDependencies == nullptr) || pNodeParams == nullptr || + pNodeParams->func == nullptr) { return hipErrorInvalidValue; } hipError_t status = ihipValidateKernelParams(pNodeParams); @@ -629,6 +630,7 @@ hipError_t capturehipEventRecord(hipStream_t& stream, hipEvent_t& event) { hip::Event* e = reinterpret_cast(event); e->StartCapture(stream); hip::Stream* s = reinterpret_cast(stream); + s->SetCaptureEvent(event); std::vector lastCapturedNodes = s->GetLastCapturedNodes(); if (!lastCapturedNodes.empty()) { e->SetNodesPrevToRecorded(lastCapturedNodes); @@ -653,6 +655,7 @@ hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, uns s->SetCaptureGraph(reinterpret_cast(e->GetCaptureStream())->GetCaptureGraph()); s->SetCaptureMode(reinterpret_cast(e->GetCaptureStream())->GetCaptureMode()); s->SetParentStream(e->GetCaptureStream()); + s->SetParallelCaptureStream(stream); } s->AddCrossCapturedNode(e->GetNodesPrevToRecorded()); g_captureStreams.push_back(stream); @@ -709,7 +712,7 @@ hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode) hipError_t hipStreamEndCapture(hipStream_t stream, hipGraph_t* pGraph) { HIP_INIT_API(hipStreamEndCapture, stream, pGraph); - if (!hip::isValid(stream)) { + if (pGraph == nullptr || stream == nullptr || !hip::isValid(stream)) { HIP_RETURN(hipErrorInvalidValue); } hip::Stream* s = reinterpret_cast(stream); @@ -972,7 +975,7 @@ hipError_t hipGraphKernelNodeGetParams(hipGraphNode_t node, hipKernelNodeParams* hipError_t hipGraphKernelNodeSetParams(hipGraphNode_t node, const hipKernelNodeParams* pNodeParams) { HIP_INIT_API(hipGraphKernelNodeSetParams, node, pNodeParams); - if (node == nullptr || pNodeParams == nullptr) { + if (node == nullptr || pNodeParams == nullptr || pNodeParams->func == nullptr) { HIP_RETURN(hipErrorInvalidValue); } HIP_RETURN(reinterpret_cast(node)->SetParams(pNodeParams)); @@ -1047,6 +1050,17 @@ hipError_t hipGraphAddDependencies(hipGraph_t graph, const hipGraphNode_t* from, HIP_RETURN(hipErrorInvalidValue); } for (size_t i = 0; i < numDependencies; i++) { + // When the same node is specified for both from and to + if (from[i] == to[i]) { + HIP_RETURN(hipErrorInvalidValue); + } + // When the same edge added from->to return invalid value + const std::vector& edges = from[i]->GetEdges(); + for (auto edge : edges) { + if (edge == to[i]) { + HIP_RETURN(hipErrorInvalidValue); + } + } from[i]->AddEdge(to[i]); } HIP_RETURN(hipSuccess); @@ -1187,7 +1201,8 @@ hipError_t hipGraphGetEdges(hipGraph_t graph, hipGraphNode_t* from, hipGraphNode from[i] = edges[i].first; to[i] = edges[i].second; } - // If numEdges > actual number of edges, the remaining entries in from and to will be set to NULL + // If numEdges > actual number of edges, the remaining entries in from and to will be set to + // NULL for (int i = edges.size(); i < *numEdges; i++) { from[i] = nullptr; to[i] = nullptr; @@ -1244,8 +1259,8 @@ hipError_t hipGraphNodeGetDependentNodes(hipGraphNode_t node, hipGraphNode_t* pD for (int i = 0; i < dependents.size(); i++) { pDependentNodes[i] = dependents[i]; } - // pNumDependentNodes > actual number of dependents, the remaining entries in pDependentNodes will - // be set to NULL + // pNumDependentNodes > actual number of dependents, the remaining entries in pDependentNodes + // will be set to NULL for (int i = dependents.size(); i < *pNumDependentNodes; i++) { pDependentNodes[i] = nullptr; } diff --git a/hipamd/src/hip_hcc.def.in b/hipamd/src/hip_hcc.def.in index c6bda1aa79..1d2ab341de 100644 --- a/hipamd/src/hip_hcc.def.in +++ b/hipamd/src/hip_hcc.def.in @@ -348,3 +348,5 @@ hipGraphExecMemsetNodeSetParams amd_dbgapi_get_build_name amd_dbgapi_get_git_hash amd_dbgapi_get_build_id +hipStreamGetCaptureInfo +hipStreamGetCaptureInfo_v2 diff --git a/hipamd/src/hip_hcc.map.in b/hipamd/src/hip_hcc.map.in index bf43f4386d..71f56c5fbe 100644 --- a/hipamd/src/hip_hcc.map.in +++ b/hipamd/src/hip_hcc.map.in @@ -380,6 +380,8 @@ global: amd_dbgapi_get_build_name; amd_dbgapi_get_git_hash; amd_dbgapi_get_build_id; + hipStreamGetCaptureInfo; + hipStreamGetCaptureInfo_v2; local: *; } hip_4.4; diff --git a/hipamd/src/hip_internal.hpp b/hipamd/src/hip_internal.hpp index f6df46a462..63cbee6d3a 100644 --- a/hipamd/src/hip_internal.hpp +++ b/hipamd/src/hip_internal.hpp @@ -308,6 +308,8 @@ namespace hip { } /// Get Capture ID int GetCaptureID() { return captureID_; } + void SetCaptureEvent(hipEvent_t e) { captureEvents_.push_back(e); } + void SetParallelCaptureStream(hipStream_t s) { parallelCaptureStreams_.push_back(s); } }; /// HIP Device class