diff --git a/hipamd/src/hip_code_object.cpp b/hipamd/src/hip_code_object.cpp index b2635c6b7e..8ad95b7aa5 100644 --- a/hipamd/src/hip_code_object.cpp +++ b/hipamd/src/hip_code_object.cpp @@ -1011,6 +1011,12 @@ hipError_t DynCO::getDynFunc(hipFunction_t* hfunc, std::string func_name) { return it->second->getDynFunc(hfunc, module()); } +bool DynCO::isValidDynFunc(const void* hfunc) { + amd::ScopedLock lock(dclock_); + return std::any_of(functions_.begin(), functions_.end(), + [&](auto& it) { return it.second->isValidDynFunc(hfunc); }); +} + hipError_t DynCO::initDynManagedVars(const std::string& managedVar) { amd::ScopedLock lock(dclock_); DeviceVar* dvar; diff --git a/hipamd/src/hip_code_object.hpp b/hipamd/src/hip_code_object.hpp index 795bee004d..846280d1b9 100644 --- a/hipamd/src/hip_code_object.hpp +++ b/hipamd/src/hip_code_object.hpp @@ -116,6 +116,7 @@ public: //Gets GlobalVar/Functions from a dynamically loaded code object hipError_t getDynFunc(hipFunction_t* hfunc, std::string func_name); + bool isValidDynFunc(const void* hfunc); hipError_t getDeviceVar(DeviceVar** dvar, std::string var_name); hipError_t getManagedVarPointer(std::string name, void** pointer, size_t* size_ptr) const { diff --git a/hipamd/src/hip_global.cpp b/hipamd/src/hip_global.cpp index 0e5abed174..c63831b498 100644 --- a/hipamd/src/hip_global.cpp +++ b/hipamd/src/hip_global.cpp @@ -150,6 +150,10 @@ hipError_t Function::getDynFunc(hipFunction_t* hfunc, hipModule_t hmod) { return hipSuccess; } +bool Function::isValidDynFunc(const void* hfunc) { + return (hfunc == dFunc_[ihipGetDevice()]->asHipFunction()); +} + hipError_t Function::getStatFunc(hipFunction_t* hfunc, int deviceId) { guarantee(modules_ != nullptr, "Module not initialized"); diff --git a/hipamd/src/hip_global.hpp b/hipamd/src/hip_global.hpp index a3f6d29a01..5d28da21be 100644 --- a/hipamd/src/hip_global.hpp +++ b/hipamd/src/hip_global.hpp @@ -84,7 +84,7 @@ public: //Return DeviceFunc for this this dynamically loaded module hipError_t getDynFunc(hipFunction_t* hfunc, hipModule_t hmod); - + bool isValidDynFunc(const void* hfunc); //Return Device Func & attr . Generate/build if not already done so. hipError_t getStatFunc(hipFunction_t *hfunc, int deviceId); hipError_t getStatFuncAttr(hipFuncAttributes* func_attr, int deviceId); diff --git a/hipamd/src/hip_module.cpp b/hipamd/src/hip_module.cpp index 4c1b63172d..187e08bb86 100644 --- a/hipamd/src/hip_module.cpp +++ b/hipamd/src/hip_module.cpp @@ -186,14 +186,21 @@ hipError_t hipFuncSetAttribute(const void* func, hipFuncAttribute attr, int valu HIP_RETURN(hipErrorInvalidValue); } - hipFunction_t h_func; - HIP_RETURN_ONFAIL(PlatformState::instance().getStatFunc(&h_func, func, ihipGetDevice())); + hipFunction_t h_func = nullptr; + const hip::DeviceFunc* function = nullptr; - hip::DeviceFunc* function = hip::DeviceFunc::asFunction(h_func); - if (function == nullptr) { - HIP_RETURN(hipErrorInvalidHandle); + hipError_t err = PlatformState::instance().getStatFunc(&h_func, func, ihipGetDevice()); + if (h_func == nullptr) { + if (PlatformState::instance().isValidDynFunc((func))) { + function = reinterpret_cast(func); + } else { + HIP_RETURN(hipErrorInvalidDeviceFunction); + } + } else { + function = reinterpret_cast(h_func); } - amd::Kernel* kernel = reinterpret_cast(function)->kernel(); + + amd::Kernel* kernel = function->kernel(); if (kernel == nullptr) { HIP_RETURN(hipErrorInvalidDeviceFunction); diff --git a/hipamd/src/hip_platform.cpp b/hipamd/src/hip_platform.cpp index 9c0c99d6ab..465cd54405 100644 --- a/hipamd/src/hip_platform.cpp +++ b/hipamd/src/hip_platform.cpp @@ -805,6 +805,12 @@ hipError_t PlatformState::getDynFunc(hipFunction_t* hfunc, hipModule_t hmod, return it->second->getDynFunc(hfunc, func_name); } +bool PlatformState::isValidDynFunc(const void* hfunc) { + amd::ScopedLock lock(lock_); + return std::any_of(dynCO_map_.begin(), dynCO_map_.end(), + [&](auto& it) { return it.second->isValidDynFunc(hfunc); }); +} + hipError_t PlatformState::getDynGlobalVar(const char* hostVar, hipModule_t hmod, hipDeviceptr_t* dev_ptr, size_t* size_ptr) { amd::ScopedLock lock(lock_); diff --git a/hipamd/src/hip_platform.hpp b/hipamd/src/hip_platform.hpp index 5cc30832a7..46e14143fc 100644 --- a/hipamd/src/hip_platform.hpp +++ b/hipamd/src/hip_platform.hpp @@ -61,7 +61,7 @@ class PlatformState { // Dynamic Code Objects functions hipError_t loadModule(hipModule_t* module, const char* fname, const void* image = nullptr); hipError_t unloadModule(hipModule_t hmod); - + bool isValidDynFunc(const void* hfunc); hipError_t getDynFunc(hipFunction_t* hfunc, hipModule_t hmod, const char* func_name); hipError_t getDynGlobalVar(const char* hostVar, hipModule_t hmod, hipDeviceptr_t* dev_ptr, size_t* size_ptr);