diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 0b578d373c..55f35053c6 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -2146,9 +2146,12 @@ hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph, reinterpret_cast(hGraphExec)->GetNodes(); if (newGraphNodes.size() != oldGraphExecNodes.size()) { *updateResult_out = hipGraphExecUpdateErrorTopologyChanged; + *hErrorNode_out = nullptr; HIP_RETURN(hipErrorGraphExecUpdateFailure); } + for (std::vector::size_type i = 0; i != newGraphNodes.size(); i++) { + // Checks if all the node types are same before updating if (newGraphNodes[i]->GetType() == oldGraphExecNodes[i]->GetType()) { if (newGraphNodes[i]->GetType() != hipGraphNodeTypeHost && newGraphNodes[i]->GetType() != hipGraphNodeTypeEmpty) { @@ -2159,6 +2162,35 @@ hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph, return hipErrorGraphExecUpdateFailure; } } + + switch(newGraphNodes[i]->GetType()) { + case hipGraphNodeTypeMemcpy: { + // Checks if the memcpy node's parameters are same + const hip::GraphMemcpyNode* newMemcpyNode = + static_cast(newGraphNodes[i]); + const hip::GraphMemcpyNode* oldMemcpyNode = + static_cast(oldGraphExecNodes[i]); + hipMemcpyKind newKind, oldKind; + newKind = newMemcpyNode->GetMemcpyKind(); + oldKind = oldMemcpyNode->GetMemcpyKind(); + if (newKind != oldKind) { + *hErrorNode_out = reinterpret_cast(newGraphNodes[i]); + *updateResult_out = hipGraphExecUpdateErrorParametersChanged; + HIP_RETURN(hipErrorGraphExecUpdateFailure); + } + } + } + // Checks if all the node's dependencies are same + const std::vector& newGraphDependencies = + newGraphNodes[i]->GetDependencies(); + const std::vector& oldGraphDependencies = + oldGraphExecNodes[i]->GetDependencies(); + if (newGraphDependencies.size() != oldGraphDependencies.size()) { + *hErrorNode_out = reinterpret_cast(newGraphNodes[i]); + *updateResult_out = hipGraphExecUpdateErrorTopologyChanged; + HIP_RETURN(hipErrorGraphExecUpdateFailure); + } + hipError_t status = oldGraphExecNodes[i]->SetParams(newGraphNodes[i]); if (status != hipSuccess) { *hErrorNode_out = reinterpret_cast(newGraphNodes[i]); diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index 0c079f058a..005e1194be 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -1101,12 +1101,15 @@ class GraphKernelNode : public GraphNode { }; class GraphMemcpyNode : public GraphNode { + protected: hipMemcpy3DParms copyParams_; public: GraphMemcpyNode(const hipMemcpy3DParms* pCopyParams) : GraphNode(hipGraphNodeTypeMemcpy, "solid", "trapezium", "MEMCPY") { - copyParams_ = *pCopyParams; + if (pCopyParams) { + copyParams_ = *pCopyParams; + } } ~GraphMemcpyNode() {} @@ -1118,7 +1121,7 @@ class GraphMemcpyNode : public GraphNode { return new GraphMemcpyNode(static_cast(*this)); } - hipError_t CreateCommand(hip::Stream* stream) { + virtual hipError_t CreateCommand(hip::Stream* stream) { if (IsHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.kind)) { return hipSuccess; } @@ -1133,7 +1136,7 @@ class GraphMemcpyNode : public GraphNode { return status; } - void EnqueueCommands(hipStream_t stream) override { + virtual void EnqueueCommands(hipStream_t stream) override { if (isEnabled_ && IsHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.kind)) { ihipHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.extent.width * copyParams_.extent.height * @@ -1146,6 +1149,9 @@ class GraphMemcpyNode : public GraphNode { void GetParams(hipMemcpy3DParms* params) { std::memcpy(params, ©Params_, sizeof(hipMemcpy3DParms)); } + + virtual hipMemcpyKind GetMemcpyKind() const { return hipMemcpyDefault; }; + hipError_t SetParams(const hipMemcpy3DParms* params) { hipError_t status = ValidateParams(params); if (status != hipSuccess) { @@ -1154,7 +1160,8 @@ class GraphMemcpyNode : public GraphNode { std::memcpy(©Params_, params, sizeof(hipMemcpy3DParms)); return hipSuccess; } - hipError_t SetParams(GraphNode* node) { + + virtual hipError_t SetParams(GraphNode* node) { const GraphMemcpyNode* memcpyNode = static_cast(node); return SetParams(&memcpyNode->copyParams_); } @@ -1245,7 +1252,7 @@ class GraphMemcpyNode : public GraphNode { } }; -class GraphMemcpyNode1D : public GraphNode { +class GraphMemcpyNode1D : public GraphMemcpyNode { protected: void* dst_; const void* src_; @@ -1255,7 +1262,7 @@ class GraphMemcpyNode1D : public GraphNode { public: GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind, hipGraphNodeType type = hipGraphNodeTypeMemcpy) - : GraphNode(type, "solid", "trapezium", "MEMCPY"), + : GraphMemcpyNode(nullptr), dst_(dst), src_(src), count_(count), @@ -1282,7 +1289,7 @@ class GraphMemcpyNode1D : public GraphNode { return status; } - void EnqueueCommands(hipStream_t stream) { + virtual void EnqueueCommands(hipStream_t stream) { bool isH2H = IsHtoHMemcpy(dst_, src_, kind_); if (!isH2H) { if (commands_.empty()) return; @@ -1340,6 +1347,10 @@ class GraphMemcpyNode1D : public GraphNode { } } + hipMemcpyKind GetMemcpyKind() const { + return kind_; + } + hipError_t SetParams(void* dst, const void* src, size_t count, hipMemcpyKind kind) { hipError_t status = ValidateParams(dst, src, count, kind); if (status != hipSuccess) { @@ -1352,7 +1363,7 @@ class GraphMemcpyNode1D : public GraphNode { return hipSuccess; } - hipError_t SetParams(GraphNode* node) { + virtual hipError_t SetParams(GraphNode* node) { const GraphMemcpyNode1D* memcpy1DNode = static_cast(node); return SetParams(memcpy1DNode->dst_, memcpy1DNode->src_, memcpy1DNode->count_, memcpy1DNode->kind_); @@ -1429,7 +1440,7 @@ class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D { static_cast(*this)); } - hipError_t CreateCommand(hip::Stream* stream) { + virtual hipError_t CreateCommand(hip::Stream* stream) { hipError_t status = GraphNode::CreateCommand(stream); if (status != hipSuccess) { return status; @@ -1496,7 +1507,7 @@ class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D { return hipSuccess; } - hipError_t SetParams(GraphNode* node) { + virtual hipError_t SetParams(GraphNode* node) { const GraphMemcpyNodeFromSymbol* memcpyNode = static_cast(node); return SetParams(memcpyNode->dst_, memcpyNode->symbol_, memcpyNode->count_, memcpyNode->offset_, @@ -1520,7 +1531,7 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D { return new GraphMemcpyNodeToSymbol(static_cast(*this)); } - hipError_t CreateCommand(hip::Stream* stream) { + virtual hipError_t CreateCommand(hip::Stream* stream) { hipError_t status = GraphNode::CreateCommand(stream); if (status != hipSuccess) { return status; @@ -1585,7 +1596,7 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D { return hipSuccess; } - hipError_t SetParams(GraphNode* node) { + virtual hipError_t SetParams(GraphNode* node) { const GraphMemcpyNodeToSymbol* memcpyNode = static_cast(node); return SetParams(memcpyNode->src_, memcpyNode->symbol_, memcpyNode->count_, memcpyNode->offset_,