SWDEV-375607 - Fix hipGraphExecUpdate behavior
Addresses the below scenarios like parameters mismatch
in memcopy node, difference in the count of nodes, difference
in the dependencies of the nodes.
Change-Id: I31c6516fb27cc1007809f1b50306fdb0c2356ccc
[ROCm/clr commit: f16d336e32]
Этот коммит содержится в:
@@ -2146,9 +2146,12 @@ hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph,
|
||||
reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetNodes();
|
||||
if (newGraphNodes.size() != oldGraphExecNodes.size()) {
|
||||
*updateResult_out = hipGraphExecUpdateErrorTopologyChanged;
|
||||
*hErrorNode_out = nullptr;
|
||||
HIP_RETURN(hipErrorGraphExecUpdateFailure);
|
||||
}
|
||||
|
||||
for (std::vector<hip::GraphNode*>::size_type i = 0; i != newGraphNodes.size(); i++) {
|
||||
// Checks if all the node types are same before updating
|
||||
if (newGraphNodes[i]->GetType() == oldGraphExecNodes[i]->GetType()) {
|
||||
if (newGraphNodes[i]->GetType() != hipGraphNodeTypeHost &&
|
||||
newGraphNodes[i]->GetType() != hipGraphNodeTypeEmpty) {
|
||||
@@ -2159,6 +2162,35 @@ hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph,
|
||||
return hipErrorGraphExecUpdateFailure;
|
||||
}
|
||||
}
|
||||
|
||||
switch(newGraphNodes[i]->GetType()) {
|
||||
case hipGraphNodeTypeMemcpy: {
|
||||
// Checks if the memcpy node's parameters are same
|
||||
const hip::GraphMemcpyNode* newMemcpyNode =
|
||||
static_cast<hip::GraphMemcpyNode const*>(newGraphNodes[i]);
|
||||
const hip::GraphMemcpyNode* oldMemcpyNode =
|
||||
static_cast<hip::GraphMemcpyNode const*>(oldGraphExecNodes[i]);
|
||||
hipMemcpyKind newKind, oldKind;
|
||||
newKind = newMemcpyNode->GetMemcpyKind();
|
||||
oldKind = oldMemcpyNode->GetMemcpyKind();
|
||||
if (newKind != oldKind) {
|
||||
*hErrorNode_out = reinterpret_cast<hipGraphNode_t>(newGraphNodes[i]);
|
||||
*updateResult_out = hipGraphExecUpdateErrorParametersChanged;
|
||||
HIP_RETURN(hipErrorGraphExecUpdateFailure);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Checks if all the node's dependencies are same
|
||||
const std::vector<hip::GraphNode*>& newGraphDependencies =
|
||||
newGraphNodes[i]->GetDependencies();
|
||||
const std::vector<hip::GraphNode*>& oldGraphDependencies =
|
||||
oldGraphExecNodes[i]->GetDependencies();
|
||||
if (newGraphDependencies.size() != oldGraphDependencies.size()) {
|
||||
*hErrorNode_out = reinterpret_cast<hipGraphNode_t>(newGraphNodes[i]);
|
||||
*updateResult_out = hipGraphExecUpdateErrorTopologyChanged;
|
||||
HIP_RETURN(hipErrorGraphExecUpdateFailure);
|
||||
}
|
||||
|
||||
hipError_t status = oldGraphExecNodes[i]->SetParams(newGraphNodes[i]);
|
||||
if (status != hipSuccess) {
|
||||
*hErrorNode_out = reinterpret_cast<hipGraphNode_t>(newGraphNodes[i]);
|
||||
|
||||
@@ -1101,12 +1101,15 @@ class GraphKernelNode : public GraphNode {
|
||||
};
|
||||
|
||||
class GraphMemcpyNode : public GraphNode {
|
||||
protected:
|
||||
hipMemcpy3DParms copyParams_;
|
||||
|
||||
public:
|
||||
GraphMemcpyNode(const hipMemcpy3DParms* pCopyParams)
|
||||
: GraphNode(hipGraphNodeTypeMemcpy, "solid", "trapezium", "MEMCPY") {
|
||||
copyParams_ = *pCopyParams;
|
||||
if (pCopyParams) {
|
||||
copyParams_ = *pCopyParams;
|
||||
}
|
||||
}
|
||||
~GraphMemcpyNode() {}
|
||||
|
||||
@@ -1118,7 +1121,7 @@ class GraphMemcpyNode : public GraphNode {
|
||||
return new GraphMemcpyNode(static_cast<GraphMemcpyNode const&>(*this));
|
||||
}
|
||||
|
||||
hipError_t CreateCommand(hip::Stream* stream) {
|
||||
virtual hipError_t CreateCommand(hip::Stream* stream) {
|
||||
if (IsHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.kind)) {
|
||||
return hipSuccess;
|
||||
}
|
||||
@@ -1133,7 +1136,7 @@ class GraphMemcpyNode : public GraphNode {
|
||||
return status;
|
||||
}
|
||||
|
||||
void EnqueueCommands(hipStream_t stream) override {
|
||||
virtual void EnqueueCommands(hipStream_t stream) override {
|
||||
if (isEnabled_ && IsHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr, copyParams_.kind)) {
|
||||
ihipHtoHMemcpy(copyParams_.dstPtr.ptr, copyParams_.srcPtr.ptr,
|
||||
copyParams_.extent.width * copyParams_.extent.height *
|
||||
@@ -1146,6 +1149,9 @@ class GraphMemcpyNode : public GraphNode {
|
||||
void GetParams(hipMemcpy3DParms* params) {
|
||||
std::memcpy(params, ©Params_, sizeof(hipMemcpy3DParms));
|
||||
}
|
||||
|
||||
virtual hipMemcpyKind GetMemcpyKind() const { return hipMemcpyDefault; };
|
||||
|
||||
hipError_t SetParams(const hipMemcpy3DParms* params) {
|
||||
hipError_t status = ValidateParams(params);
|
||||
if (status != hipSuccess) {
|
||||
@@ -1154,7 +1160,8 @@ class GraphMemcpyNode : public GraphNode {
|
||||
std::memcpy(©Params_, params, sizeof(hipMemcpy3DParms));
|
||||
return hipSuccess;
|
||||
}
|
||||
hipError_t SetParams(GraphNode* node) {
|
||||
|
||||
virtual hipError_t SetParams(GraphNode* node) {
|
||||
const GraphMemcpyNode* memcpyNode = static_cast<GraphMemcpyNode const*>(node);
|
||||
return SetParams(&memcpyNode->copyParams_);
|
||||
}
|
||||
@@ -1245,7 +1252,7 @@ class GraphMemcpyNode : public GraphNode {
|
||||
}
|
||||
};
|
||||
|
||||
class GraphMemcpyNode1D : public GraphNode {
|
||||
class GraphMemcpyNode1D : public GraphMemcpyNode {
|
||||
protected:
|
||||
void* dst_;
|
||||
const void* src_;
|
||||
@@ -1255,7 +1262,7 @@ class GraphMemcpyNode1D : public GraphNode {
|
||||
public:
|
||||
GraphMemcpyNode1D(void* dst, const void* src, size_t count, hipMemcpyKind kind,
|
||||
hipGraphNodeType type = hipGraphNodeTypeMemcpy)
|
||||
: GraphNode(type, "solid", "trapezium", "MEMCPY"),
|
||||
: GraphMemcpyNode(nullptr),
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
count_(count),
|
||||
@@ -1282,7 +1289,7 @@ class GraphMemcpyNode1D : public GraphNode {
|
||||
return status;
|
||||
}
|
||||
|
||||
void EnqueueCommands(hipStream_t stream) {
|
||||
virtual void EnqueueCommands(hipStream_t stream) {
|
||||
bool isH2H = IsHtoHMemcpy(dst_, src_, kind_);
|
||||
if (!isH2H) {
|
||||
if (commands_.empty()) return;
|
||||
@@ -1340,6 +1347,10 @@ class GraphMemcpyNode1D : public GraphNode {
|
||||
}
|
||||
}
|
||||
|
||||
hipMemcpyKind GetMemcpyKind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
hipError_t SetParams(void* dst, const void* src, size_t count, hipMemcpyKind kind) {
|
||||
hipError_t status = ValidateParams(dst, src, count, kind);
|
||||
if (status != hipSuccess) {
|
||||
@@ -1352,7 +1363,7 @@ class GraphMemcpyNode1D : public GraphNode {
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t SetParams(GraphNode* node) {
|
||||
virtual hipError_t SetParams(GraphNode* node) {
|
||||
const GraphMemcpyNode1D* memcpy1DNode = static_cast<GraphMemcpyNode1D const*>(node);
|
||||
return SetParams(memcpy1DNode->dst_, memcpy1DNode->src_, memcpy1DNode->count_,
|
||||
memcpy1DNode->kind_);
|
||||
@@ -1429,7 +1440,7 @@ class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D {
|
||||
static_cast<GraphMemcpyNodeFromSymbol const&>(*this));
|
||||
}
|
||||
|
||||
hipError_t CreateCommand(hip::Stream* stream) {
|
||||
virtual hipError_t CreateCommand(hip::Stream* stream) {
|
||||
hipError_t status = GraphNode::CreateCommand(stream);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
@@ -1496,7 +1507,7 @@ class GraphMemcpyNodeFromSymbol : public GraphMemcpyNode1D {
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t SetParams(GraphNode* node) {
|
||||
virtual hipError_t SetParams(GraphNode* node) {
|
||||
const GraphMemcpyNodeFromSymbol* memcpyNode =
|
||||
static_cast<GraphMemcpyNodeFromSymbol const*>(node);
|
||||
return SetParams(memcpyNode->dst_, memcpyNode->symbol_, memcpyNode->count_, memcpyNode->offset_,
|
||||
@@ -1520,7 +1531,7 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D {
|
||||
return new GraphMemcpyNodeToSymbol(static_cast<GraphMemcpyNodeToSymbol const&>(*this));
|
||||
}
|
||||
|
||||
hipError_t CreateCommand(hip::Stream* stream) {
|
||||
virtual hipError_t CreateCommand(hip::Stream* stream) {
|
||||
hipError_t status = GraphNode::CreateCommand(stream);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
@@ -1585,7 +1596,7 @@ class GraphMemcpyNodeToSymbol : public GraphMemcpyNode1D {
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t SetParams(GraphNode* node) {
|
||||
virtual hipError_t SetParams(GraphNode* node) {
|
||||
const GraphMemcpyNodeToSymbol* memcpyNode =
|
||||
static_cast<GraphMemcpyNodeToSymbol const*>(node);
|
||||
return SetParams(memcpyNode->src_, memcpyNode->symbol_, memcpyNode->count_, memcpyNode->offset_,
|
||||
|
||||
Ссылка в новой задаче
Block a user