From 85a372e4ebceab640d75c0a44efa8d7c0eee0e51 Mon Sep 17 00:00:00 2001 From: Rahul Manocha Date: Tue, 9 Apr 2024 16:06:56 -0700 Subject: [PATCH] [SWDEV-454661][SWDEV-454653] - GraphExecMemcpyNodeSetParam to return error on memcpy direction change Change-Id: I2c8f5ea394caeaaa6895003e63cd62a052c491f8 [ROCm/clr commit: 880963346d30ab963024bcc912bc9581f965bb4f] --- projects/clr/hipamd/src/hip_graph.cpp | 19 +++++++++++++++++++ .../clr/hipamd/src/hip_graph_internal.hpp | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 28335c6d51..bf13f3cafd 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -1190,6 +1190,10 @@ hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, hipGraph if (clonedNode == nullptr) { HIP_RETURN(hipErrorInvalidValue); } + hipMemcpyKind oldkind = reinterpret_cast(clonedNode)->GetMemcpyKind(); + if (oldkind != kind) { + HIP_RETURN(hipErrorInvalidValue); + } HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(dst, src, count, kind)); } @@ -1579,6 +1583,12 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo if (clonedNode == nullptr) { HIP_RETURN(hipErrorInvalidValue); } + + hipMemcpyKind oldkind = reinterpret_cast(clonedNode)->GetMemcpyKind(); + hipMemcpyKind newkind = pNodeParams->kind; + if (oldkind != newkind) { + HIP_RETURN(hipErrorInvalidValue); + } HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(pNodeParams)); } @@ -2112,6 +2122,11 @@ hipError_t hipGraphExecMemcpyNodeSetParamsFromSymbol(hipGraphExec_t hGraphExec, if (clonedNode == nullptr) { HIP_RETURN(hipErrorInvalidValue); } + + hipMemcpyKind oldkind = reinterpret_cast(clonedNode)->GetMemcpyKind(); + if (oldkind != kind) { + HIP_RETURN(hipErrorInvalidValue); + } constexpr bool kCheckDeviceIsSame = true; HIP_RETURN(reinterpret_cast(clonedNode) ->SetParams(dst, symbol, count, offset, kind, kCheckDeviceIsSame)); @@ -2182,6 +2197,10 @@ hipError_t hipGraphExecMemcpyNodeSetParamsToSymbol(hipGraphExec_t hGraphExec, hi if (clonedNode == nullptr) { HIP_RETURN(hipErrorInvalidValue); } + hipMemcpyKind oldkind = reinterpret_cast(clonedNode)->GetMemcpyKind(); + if (oldkind != kind) { + HIP_RETURN(hipErrorInvalidValue); + } constexpr bool kCheckDeviceIsSame = true; HIP_RETURN(reinterpret_cast(clonedNode) ->SetParams(symbol, src, count, offset, kind, kCheckDeviceIsSame)); diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index 2e8a90c441..9c4cd083a1 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -1244,7 +1244,7 @@ class GraphMemcpyNode : public GraphNode { std::memcpy(params, ©Params_, sizeof(hipMemcpy3DParms)); } - virtual hipMemcpyKind GetMemcpyKind() const { return hipMemcpyDefault; }; + virtual hipMemcpyKind GetMemcpyKind() const { return copyParams_.kind; }; hipError_t SetParams(const hipMemcpy3DParms* params) { hipError_t status = ValidateParams(params);