SWDEV-325249 - hipGraphAddKernelNode incompatible with hipFunction_t
If params.func isn't a host function, assume it's a hipFunction_t.
Change-Id: I43361ec49a8dd579225f30e31722977ca9a82378
[ROCm/clr commit: 29851c0d34]
Этот коммит содержится в:
коммит произвёл
Christophe Paquot
родитель
bb53dfb244
Коммит
318ff0000d
@@ -45,9 +45,8 @@ hipError_t ihipValidateKernelParams(const hipKernelNodeParams* pNodeParams) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
hipFunction_t func = nullptr;
|
||||
hipError_t status =
|
||||
PlatformState::instance().getStatFunc(&func, pNodeParams->func, ihipGetDevice());
|
||||
if ((status != hipSuccess) || (func == nullptr)) {
|
||||
hipError_t status = hipGraphKernelNode::getFunc(&func, *pNodeParams, ihipGetDevice());
|
||||
if (status != hipSuccess) {
|
||||
return hipErrorInvalidDeviceFunction;
|
||||
}
|
||||
size_t globalWorkSizeX = static_cast<size_t>(pNodeParams->gridDim.x) * pNodeParams->blockDim.x;
|
||||
@@ -85,8 +84,8 @@ hipError_t ihipGraphAddKernelNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
|
||||
return status;
|
||||
}
|
||||
hipFunction_t func = nullptr;
|
||||
status = PlatformState::instance().getStatFunc(&func, pNodeParams->func, ihipGetDevice());
|
||||
if ((status != hipSuccess) || (func == nullptr)) {
|
||||
status = hipGraphKernelNode::getFunc(&func, *pNodeParams, ihipGetDevice());
|
||||
if (status != hipSuccess) {
|
||||
return hipErrorInvalidDeviceFunction;
|
||||
}
|
||||
*pGraphNode = new hipGraphKernelNode(pNodeParams, func);
|
||||
|
||||
@@ -444,6 +444,16 @@ class hipGraphKernelNode : public hipGraphNode {
|
||||
hipFunction_t func_;
|
||||
|
||||
public:
|
||||
static hipError_t getFunc(hipFunction_t* func, const hipKernelNodeParams& params, unsigned int device) {
|
||||
hipError_t status = PlatformState::instance().getStatFunc(func, params.func, device);
|
||||
if (status != hipSuccess) {
|
||||
*func = reinterpret_cast<hipFunction_t>(params.func);
|
||||
}
|
||||
if (*func == nullptr) {
|
||||
return hipErrorInvalidDeviceFunction;
|
||||
}
|
||||
return hipSuccess;
|
||||
}
|
||||
hipGraphKernelNode(const hipKernelNodeParams* pNodeParams, const hipFunction_t func)
|
||||
: hipGraphNode(hipGraphNodeTypeKernel) {
|
||||
pKernelParams_ = new hipKernelNodeParams(*pNodeParams);
|
||||
@@ -485,9 +495,8 @@ class hipGraphKernelNode : public hipGraphNode {
|
||||
}
|
||||
if (params->func != pKernelParams_->func) {
|
||||
hipFunction_t func = nullptr;
|
||||
hipError_t status =
|
||||
PlatformState::instance().getStatFunc(&func, params->func, ihipGetDevice());
|
||||
if ((status != hipSuccess) || (func == nullptr)) {
|
||||
hipError_t status = hipGraphKernelNode::getFunc(&func, *params, ihipGetDevice());
|
||||
if (status != hipSuccess) {
|
||||
return hipErrorInvalidDeviceFunction;
|
||||
}
|
||||
func_ = func;
|
||||
@@ -499,9 +508,8 @@ class hipGraphKernelNode : public hipGraphNode {
|
||||
hipError_t SetCommandParams(const hipKernelNodeParams* params) {
|
||||
if (params->func != pKernelParams_->func) {
|
||||
hipFunction_t func = nullptr;
|
||||
hipError_t status =
|
||||
PlatformState::instance().getStatFunc(&func, params->func, ihipGetDevice());
|
||||
if ((status != hipSuccess) || (func == nullptr)) {
|
||||
hipError_t status = hipGraphKernelNode::getFunc(&func, *params, ihipGetDevice());
|
||||
if (status != hipSuccess) {
|
||||
return hipErrorInvalidDeviceFunction;
|
||||
}
|
||||
func_ = func;
|
||||
|
||||
Ссылка в новой задаче
Block a user