diff --git a/hipamd/src/hip_graph_internal.cpp b/hipamd/src/hip_graph_internal.cpp index 20f60a4c3f..4839b18af3 100644 --- a/hipamd/src/hip_graph_internal.cpp +++ b/hipamd/src/hip_graph_internal.cpp @@ -271,10 +271,8 @@ bool Graph::TopologicalOrder(std::vector& TopoOrder) { // ================================================================================================ void Graph::clone(Graph* newGraph, bool cloneNodes) const { newGraph->pOriginalGraph_ = this; - auto curDevId = ihipGetDevice(); for (hip::GraphNode* entry : vertices_) { GraphNode* node = entry->clone(); - node->SetDeviceId(curDevId); node->SetParentGraph(newGraph); newGraph->vertices_.push_back(node); newGraph->clonedNodes_[entry] = node; diff --git a/hipamd/src/hip_graph_internal.hpp b/hipamd/src/hip_graph_internal.hpp index 51b75f59f4..5934a16f5f 100644 --- a/hipamd/src/hip_graph_internal.hpp +++ b/hipamd/src/hip_graph_internal.hpp @@ -1200,6 +1200,8 @@ class GraphKernelNode : public GraphNode { void GetParams(hipKernelNodeParams* params) { *params = kernelParams_; } hipError_t SetParams(const hipKernelNodeParams* params) { + // Update device ID since new params may require validation for the current device. + dev_id_ = ihipGetDevice(); hipFunction_t func = getFunc(kernelParams_, dev_id_); if (!func) { return hipErrorInvalidDeviceFunction; @@ -1226,6 +1228,8 @@ class GraphKernelNode : public GraphNode { hipError_t SetAttrParams(hipKernelNodeAttrID attr, const hipKernelNodeAttrValue* params) { hipDeviceProp_t prop = {0}; + // Update device ID since new params may require validation for the current device. + dev_id_ = ihipGetDevice(); hipError_t status = ihipGetDeviceProperties(&prop, dev_id_); if (hipSuccess != status){ return status;