From 90a00bc681ab962a50762e8b93a6ddee5ce4fa55 Mon Sep 17 00:00:00 2001 From: Jaydeep Patel Date: Mon, 20 Mar 2023 09:26:16 +0000 Subject: [PATCH] SWDEV-388926 - Original and new memory should be from same device for hipGraphExecMemsetNodeSetParams. Change-Id: I32bd56db0b80d748e3ae0737a38ea975738bdab7 [ROCm/clr commit: ec227d560a67dc24a711192893f30976161104bf] --- projects/clr/hipamd/src/hip_graph.cpp | 2 +- projects/clr/hipamd/src/hip_graph_internal.hpp | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index d788073f59..787ba9c69d 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -1479,7 +1479,7 @@ hipError_t hipGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo if (clonedNode == nullptr) { HIP_RETURN(hipErrorInvalidValue); } - HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(pNodeParams)); + HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(pNodeParams, true)); } hipError_t hipGraphAddDependencies(hipGraph_t graph, const hipGraphNode_t* from, diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index af3a10857b..f033ec01c2 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -1626,7 +1626,7 @@ class hipGraphMemsetNode : public hipGraphNode { std::memcpy(params, pMemsetParams_, sizeof(hipMemsetParams)); } - hipError_t SetParams(const hipMemsetParams* params) { + hipError_t SetParams(const hipMemsetParams* params, bool isExec = false) { hipError_t hip_error = hipSuccess; hipMemsetParams origParams = {}; GetParams(&origParams); @@ -1634,6 +1634,18 @@ class hipGraphMemsetNode : public hipGraphNode { if (hip_error != hipSuccess) { return hip_error; } + if (isExec) { + size_t discardOffset = 0; + amd::Memory *memObj = getMemoryObject(params->dst, discardOffset); + if (memObj != nullptr) { + amd::Memory *memObjOri = getMemoryObject(pMemsetParams_->dst, discardOffset); + if (memObjOri != nullptr) { + if (memObjOri->getUserData().deviceId != memObj->getUserData().deviceId) { + return hipErrorInvalidValue; + } + } + } + } size_t sizeBytes; if (params->height == 1) { sizeBytes = params->width * params->elementSize;