SWDEV-375607 - Fix hipGraphExecUpdate behavior

Addresses the below scenarios like parameters mismatch
in memcopy node, difference in the count of nodes, difference
in the dependencies of the nodes.

Change-Id: I31c6516fb27cc1007809f1b50306fdb0c2356ccc


[ROCm/clr commit: f16d336e32]
This commit is contained in:
Satyanvesh Dittakavi
2023-08-28 12:02:06 +00:00
parent 78a3dc739d
commit be8dbcf736
2 changed files with 55 additions and 12 deletions
+32
View File
@@ -2146,9 +2146,12 @@ hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph,
reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetNodes();
if (newGraphNodes.size() != oldGraphExecNodes.size()) {
*updateResult_out = hipGraphExecUpdateErrorTopologyChanged;
*hErrorNode_out = nullptr;
HIP_RETURN(hipErrorGraphExecUpdateFailure);
}
for (std::vector<hip::GraphNode*>::size_type i = 0; i != newGraphNodes.size(); i++) {
// Checks if all the node types are same before updating
if (newGraphNodes[i]->GetType() == oldGraphExecNodes[i]->GetType()) {
if (newGraphNodes[i]->GetType() != hipGraphNodeTypeHost &&
newGraphNodes[i]->GetType() != hipGraphNodeTypeEmpty) {
@@ -2159,6 +2162,35 @@ hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph,
return hipErrorGraphExecUpdateFailure;
}
}
switch(newGraphNodes[i]->GetType()) {
case hipGraphNodeTypeMemcpy: {
// Checks if the memcpy node's parameters are same
const hip::GraphMemcpyNode* newMemcpyNode =
static_cast<hip::GraphMemcpyNode const*>(newGraphNodes[i]);
const hip::GraphMemcpyNode* oldMemcpyNode =
static_cast<hip::GraphMemcpyNode const*>(oldGraphExecNodes[i]);
hipMemcpyKind newKind, oldKind;
newKind = newMemcpyNode->GetMemcpyKind();
oldKind = oldMemcpyNode->GetMemcpyKind();
if (newKind != oldKind) {
*hErrorNode_out = reinterpret_cast<hipGraphNode_t>(newGraphNodes[i]);
*updateResult_out = hipGraphExecUpdateErrorParametersChanged;
HIP_RETURN(hipErrorGraphExecUpdateFailure);
}
}
}
// Checks if all the node's dependencies are same
const std::vector<hip::GraphNode*>& newGraphDependencies =
newGraphNodes[i]->GetDependencies();
const std::vector<hip::GraphNode*>& oldGraphDependencies =
oldGraphExecNodes[i]->GetDependencies();
if (newGraphDependencies.size() != oldGraphDependencies.size()) {
*hErrorNode_out = reinterpret_cast<hipGraphNode_t>(newGraphNodes[i]);
*updateResult_out = hipGraphExecUpdateErrorTopologyChanged;
HIP_RETURN(hipErrorGraphExecUpdateFailure);
}
hipError_t status = oldGraphExecNodes[i]->SetParams(newGraphNodes[i]);
if (status != hipSuccess) {
*hErrorNode_out = reinterpret_cast<hipGraphNode_t>(newGraphNodes[i]);
+23 -12
View File
@@ -1101,12 +1101,15 @@ class GraphKernelNode : public GraphNode {
};
class GraphMemcpyNode : public GraphNode {
protected:
hipMemcpy3DParms copyParams_;
public:
GraphMemcpyNode(const hipMemcpy3DParms* pCopyParams)
: GraphNode(hipGraphNodeTypeMemcpy, "solid", "trapezium", "MEMCPY") {
copyParams_ = *pCopyParams;
if (pCopyParams) {
copyParams_ = *pCopyParams;
}
}
~GraphMemcpyNode() {}
@@ -1118,7 +1121,7 @@ class GraphMemcpyNode : public GraphNode {
return new GraphMemcpyNode(static_cast<GraphMemcpyNode const&>(*this));
}
hipError_t CreateCommand(hip::Stream* stream) {
virtual hipError_t CreateCommand(hip::Stream* stream) {
if (IsHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.kind)) {
return hipSuccess;
}
@@ -1133,7 +1136,7 @@ class GraphMemcpyNode : public GraphNode {
return status;
}
void EnqueueCommands(hipStream_t stream) override {
virtual void EnqueueCommands(hipStream_t stream) override {
if (isEnabled_ && IsHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.kind)) {
ihipHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr,
copyParams_.extent.width * copyParams_.extent.height *
@@ -1146,6 +1149,9 @@ class GraphMemcpyNode : public GraphNode {
void GetParams(hipMemcpy3DParms* params) {
std::memcpy(params, &copyParams_, sizeof(hipMemcpy3DParms));
}
virtual hipMemcpyKind GetMemcpyKind() const { return hipMemcpyDefault; };
hipError_t SetParams(const hipMemcpy3DParms* params) {
hipError_t status = ValidateParams(params);
if (status != hipSuccess) {
@@ -1154,7 +1160,8 @@ class GraphMemcpyNode : public GraphNode {
std::memcpy(&copyParams_, params, sizeof(hipMemcpy3DParms));
return hipSuccess;
}
hipError_t SetParams(GraphNode* node) {
virtual hipError_t SetParams(GraphNode* node) {
const GraphMemcpyNode* memcpyNode = static_cast<GraphMemcpyNode const*>(node);
return SetParams(&memcpyNode->copyParams_);
}
@@ -1245,7 +1252,7 @@ class GraphMemcpyNode : public GraphNode {
}
};
class GraphMemcpyNode1D : public GraphNode {
class GraphMemcpyNode1D : public GraphMemcpyNode {
protected:
void* dst_;
const void* src_;
@@ -1255,7 +1262,7 @@ class GraphMemcpyNode1D : public GraphNode {
public:
GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind,
hipGraphNodeType type = hipGraphNodeTypeMemcpy)
: GraphNode(type, "solid", "trapezium", "MEMCPY"),
: GraphMemcpyNode(nullptr),
dst_(dst),
src_(src),
count_(count),
@@ -1282,7 +1289,7 @@ class GraphMemcpyNode1D : public GraphNode {
return status;
}
void EnqueueCommands(hipStream_t stream) {
virtual void EnqueueCommands(hipStream_t stream) {
bool isH2H = IsHtoHMemcpy(dst_, src_, kind_);
if (!isH2H) {
if (commands_.empty()) return;
@@ -1340,6 +1347,10 @@ class GraphMemcpyNode1D : public GraphNode {
}
}
hipMemcpyKind GetMemcpyKind() const {
return kind_;
}
hipError_t SetParams(void* dst, const void* src, size_t count, hipMemcpyKind kind) {
hipError_t status = ValidateParams(dst, src, count, kind);
if (status != hipSuccess) {
@@ -1352,7 +1363,7 @@ class GraphMemcpyNode1D : public GraphNode {
return hipSuccess;
}
hipError_t SetParams(GraphNode* node) {
virtual hipError_t SetParams(GraphNode* node) {
const GraphMemcpyNode1D* memcpy1DNode = static_cast<GraphMemcpyNode1D const*>(node);
return SetParams(memcpy1DNode->dst_, memcpy1DNode->src_, memcpy1DNode->count_,
memcpy1DNode->kind_);
@@ -1429,7 +1440,7 @@ class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D {
static_cast<GraphMemcpyNodeFromSymbol const&>(*this));
}
hipError_t CreateCommand(hip::Stream* stream) {
virtual hipError_t CreateCommand(hip::Stream* stream) {
hipError_t status = GraphNode::CreateCommand(stream);
if (status != hipSuccess) {
return status;
@@ -1496,7 +1507,7 @@ class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D {
return hipSuccess;
}
hipError_t SetParams(GraphNode* node) {
virtual hipError_t SetParams(GraphNode* node) {
const GraphMemcpyNodeFromSymbol* memcpyNode =
static_cast<GraphMemcpyNodeFromSymbol const*>(node);
return SetParams(memcpyNode->dst_, memcpyNode->symbol_, memcpyNode->count_, memcpyNode->offset_,
@@ -1520,7 +1531,7 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D {
return new GraphMemcpyNodeToSymbol(static_cast<GraphMemcpyNodeToSymbol const&>(*this));
}
hipError_t CreateCommand(hip::Stream* stream) {
virtual hipError_t CreateCommand(hip::Stream* stream) {
hipError_t status = GraphNode::CreateCommand(stream);
if (status != hipSuccess) {
return status;
@@ -1585,7 +1596,7 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D {
return hipSuccess;
}
hipError_t SetParams(GraphNode* node) {
virtual hipError_t SetParams(GraphNode* node) {
const GraphMemcpyNodeToSymbol* memcpyNode =
static_cast<GraphMemcpyNodeToSymbol const*>(node);
return SetParams(memcpyNode->src_, memcpyNode->symbol_, memcpyNode->count_, memcpyNode->offset_,