SWDEV-325249 - hipGraphAddKernelNode incompatible with hipFunction_t

If params.func isn't a host function, assume it's a hipFunction_t.

Change-Id: I43361ec49a8dd579225f30e31722977ca9a82378


[ROCm/clr commit: 29851c0d34]
Этот коммит содержится в:
Christophe Paquot
2022-03-03 15:00:28 -08:00
коммит произвёл Christophe Paquot
родитель bb53dfb244
Коммит 318ff0000d
2 изменённых файлов: 18 добавлений и 11 удалений
+4 -5
Просмотреть файл
@@ -45,9 +45,8 @@ hipError_t ihipValidateKernelParams(const hipKernelNodeParams* pNodeParams) {
return hipErrorInvalidValue;
}
hipFunction_t func = nullptr;
hipError_t status =
PlatformState::instance().getStatFunc(&func, pNodeParams->func, ihipGetDevice());
if ((status != hipSuccess) || (func == nullptr)) {
hipError_t status = hipGraphKernelNode::getFunc(&func, *pNodeParams, ihipGetDevice());
if (status != hipSuccess) {
return hipErrorInvalidDeviceFunction;
}
size_t globalWorkSizeX = static_cast<size_t>(pNodeParams->gridDim.x) * pNodeParams->blockDim.x;
@@ -85,8 +84,8 @@ hipError_t ihipGraphAddKernelNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
return status;
}
hipFunction_t func = nullptr;
status = PlatformState::instance().getStatFunc(&func, pNodeParams->func, ihipGetDevice());
if ((status != hipSuccess) || (func == nullptr)) {
status = hipGraphKernelNode::getFunc(&func, *pNodeParams, ihipGetDevice());
if (status != hipSuccess) {
return hipErrorInvalidDeviceFunction;
}
*pGraphNode = new hipGraphKernelNode(pNodeParams, func);
+14 -6
Просмотреть файл
@@ -444,6 +444,16 @@ class hipGraphKernelNode : public hipGraphNode {
hipFunction_t func_;
public:
static hipError_t getFunc(hipFunction_t* func, const hipKernelNodeParams& params, unsigned int device) {
hipError_t status = PlatformState::instance().getStatFunc(func, params.func, device);
if (status != hipSuccess) {
*func = reinterpret_cast<hipFunction_t>(params.func);
}
if (*func == nullptr) {
return hipErrorInvalidDeviceFunction;
}
return hipSuccess;
}
hipGraphKernelNode(const hipKernelNodeParams* pNodeParams, const hipFunction_t func)
: hipGraphNode(hipGraphNodeTypeKernel) {
pKernelParams_ = new hipKernelNodeParams(*pNodeParams);
@@ -485,9 +495,8 @@ class hipGraphKernelNode : public hipGraphNode {
}
if (params->func != pKernelParams_->func) {
hipFunction_t func = nullptr;
hipError_t status =
PlatformState::instance().getStatFunc(&func, params->func, ihipGetDevice());
if ((status != hipSuccess) || (func == nullptr)) {
hipError_t status = hipGraphKernelNode::getFunc(&func, *params, ihipGetDevice());
if (status != hipSuccess) {
return hipErrorInvalidDeviceFunction;
}
func_ = func;
@@ -499,9 +508,8 @@ class hipGraphKernelNode : public hipGraphNode {
hipError_t SetCommandParams(const hipKernelNodeParams* params) {
if (params->func != pKernelParams_->func) {
hipFunction_t func = nullptr;
hipError_t status =
PlatformState::instance().getStatFunc(&func, params->func, ihipGetDevice());
if ((status != hipSuccess) || (func == nullptr)) {
hipError_t status = hipGraphKernelNode::getFunc(&func, *params, ihipGetDevice());
if (status != hipSuccess) {
return hipErrorInvalidDeviceFunction;
}
func_ = func;