From 318ff0000dc777382b8ea4752feedf8233cccb62 Mon Sep 17 00:00:00 2001 From: Christophe Paquot Date: Thu, 3 Mar 2022 15:00:28 -0800 Subject: [PATCH] SWDEV-325249 - hipGraphAddKernelNode incompatible with hipFunction_t If params.func isn't a host function, assume it's a hipFunction_t. Change-Id: I43361ec49a8dd579225f30e31722977ca9a82378 [ROCm/clr commit: 29851c0d340c3d28681ff542e73f88979ef21000] --- projects/clr/hipamd/src/hip_graph.cpp | 9 ++++----- .../clr/hipamd/src/hip_graph_internal.hpp | 20 +++++++++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index faf8b5c2ba..fc07149498 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -45,9 +45,8 @@ hipError_t ihipValidateKernelParams(const hipKernelNodeParams* pNodeParams) { return hipErrorInvalidValue; } hipFunction_t func = nullptr; - hipError_t status = - PlatformState::instance().getStatFunc(&func, pNodeParams->func, ihipGetDevice()); - if ((status != hipSuccess) || (func == nullptr)) { + hipError_t status = hipGraphKernelNode::getFunc(&func, *pNodeParams, ihipGetDevice()); + if (status != hipSuccess) { return hipErrorInvalidDeviceFunction; } size_t globalWorkSizeX = static_cast(pNodeParams->gridDim.x) * pNodeParams->blockDim.x; @@ -85,8 +84,8 @@ hipError_t ihipGraphAddKernelNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, return status; } hipFunction_t func = nullptr; - status = PlatformState::instance().getStatFunc(&func, pNodeParams->func, ihipGetDevice()); - if ((status != hipSuccess) || (func == nullptr)) { + status = hipGraphKernelNode::getFunc(&func, *pNodeParams, ihipGetDevice()); + if (status != hipSuccess) { return hipErrorInvalidDeviceFunction; } *pGraphNode = new hipGraphKernelNode(pNodeParams, func); diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index 3981d5ffd2..2dd1cab8e7 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -444,6 +444,16 @@ class hipGraphKernelNode : public hipGraphNode { hipFunction_t func_; public: + static hipError_t getFunc(hipFunction_t* func, const hipKernelNodeParams& params, unsigned int device) { + hipError_t status = PlatformState::instance().getStatFunc(func, params.func, device); + if (status != hipSuccess) { + *func = reinterpret_cast(params.func); + } + if (*func == nullptr) { + return hipErrorInvalidDeviceFunction; + } + return hipSuccess; + } hipGraphKernelNode(const hipKernelNodeParams* pNodeParams, const hipFunction_t func) : hipGraphNode(hipGraphNodeTypeKernel) { pKernelParams_ = new hipKernelNodeParams(*pNodeParams); @@ -485,9 +495,8 @@ class hipGraphKernelNode : public hipGraphNode { } if (params->func != pKernelParams_->func) { hipFunction_t func = nullptr; - hipError_t status = - PlatformState::instance().getStatFunc(&func, params->func, ihipGetDevice()); - if ((status != hipSuccess) || (func == nullptr)) { + hipError_t status = hipGraphKernelNode::getFunc(&func, *params, ihipGetDevice()); + if (status != hipSuccess) { return hipErrorInvalidDeviceFunction; } func_ = func; @@ -499,9 +508,8 @@ class hipGraphKernelNode : public hipGraphNode { hipError_t SetCommandParams(const hipKernelNodeParams* params) { if (params->func != pKernelParams_->func) { hipFunction_t func = nullptr; - hipError_t status = - PlatformState::instance().getStatFunc(&func, params->func, ihipGetDevice()); - if ((status != hipSuccess) || (func == nullptr)) { + hipError_t status = hipGraphKernelNode::getFunc(&func, *params, ihipGetDevice()); + if (status != hipSuccess) { return hipErrorInvalidDeviceFunction; } func_ = func;