diff --git a/projects/hip/api/hip/hip_internal.hpp b/projects/hip/api/hip/hip_internal.hpp index 9606cd47d1..25fe6f9537 100644 --- a/projects/hip/api/hip/hip_internal.hpp +++ b/projects/hip/api/hip/hip_internal.hpp @@ -140,6 +140,7 @@ public: void registerFunction(const void* hostFunction, const DeviceFunction& func); hipFunction_t getFunc(const void* hostFunction, int deviceId); + bool getFuncAttr(const void* hostFunction, hipFuncAttributes* func_attr); bool getGlobalVar(const void* hostVar, int deviceId, hipDeviceptr_t* dev_ptr, size_t* size_ptr); void setupArgument(const void *arg, size_t size, size_t offset); diff --git a/projects/hip/api/hip/hip_module.cpp b/projects/hip/api/hip/hip_module.cpp index b482f39b77..79c5733c74 100644 --- a/projects/hip/api/hip/hip_module.cpp +++ b/projects/hip/api/hip/hip_module.cpp @@ -149,7 +149,11 @@ hipError_t hipFuncGetAttributes(hipFuncAttributes* attr, const void* func) { HIP_INIT_API(attr, func); - HIP_RETURN(hipErrorInvalidDeviceFunction); + if (!PlatformState::instance().getFuncAttr(func, attr)) { + HIP_RETURN(hipErrorUnknown); + } + + HIP_RETURN(hipSuccess); } diff --git a/projects/hip/api/hip/hip_platform.cpp b/projects/hip/api/hip/hip_platform.cpp index 54a8806d54..ec34db8c0a 100644 --- a/projects/hip/api/hip/hip_platform.cpp +++ b/projects/hip/api/hip/hip_platform.cpp @@ -136,6 +136,24 @@ void PlatformState::registerFunction(const void* hostFunction, functions_.insert(std::make_pair(hostFunction, func)); } +bool ihipGetFuncAttributes(const char* func_name, amd::Program* program, hipFuncAttributes* func_attr) { + device::Program* dev_program + = program->getDeviceProgram(*hip::getCurrentContext()->devices()[0]); + + const auto it = dev_program->kernels().find(std::string(func_name)); + if (it == dev_program->kernels().cend()) { + return false; + } + + const device::Kernel::WorkGroupInfo* wginfo = it->second->workGroupInfo(); + func_attr->localSizeBytes = wginfo->localMemSize_; + func_attr->sharedSizeBytes = wginfo->size_; + func_attr->maxThreadsPerBlock = wginfo->wavefrontSize_; + func_attr->numRegs = wginfo->usedVGPRs_; + + return true; +} + hipFunction_t PlatformState::getFunc(const void* hostFunction, int deviceId) { amd::ScopedLock lock(lock_); const auto it = functions_.find(hostFunction); @@ -165,6 +183,36 @@ hipFunction_t PlatformState::getFunc(const void* hostFunction, int deviceId) { return nullptr; } +bool PlatformState::getFuncAttr(const void* hostFunction, + hipFuncAttributes* func_attr) { + + if (func_attr == nullptr) { + return false; + } + + const auto it = functions_.find(hostFunction); + if (it == functions_.cend()) { + return false; + } + + PlatformState::DeviceFunction& devFunc = it->second; + int deviceId = ihipGetDevice(); + + /* If module has not been initialized yet, build the kernel now*/ + if (!(*devFunc.modules)[deviceId].second) { + if (nullptr == PlatformState::instance().getFunc(hostFunction, deviceId)) { + return false; + } + } + + amd::Program* program = as_amd(reinterpret_cast((*devFunc.modules)[deviceId].first)); + if (!ihipGetFuncAttributes(devFunc.deviceName.c_str(), program, func_attr)) { + return false; + } + return true; +} + + bool PlatformState::getGlobalVar(const void* hostVar, int deviceId, hipDeviceptr_t* dev_ptr, size_t* size_ptr) { amd::ScopedLock lock(lock_); @@ -235,9 +283,7 @@ extern "C" void __hipRegisterFunction( int* wSize) { HIP_INIT(); - PlatformState::DeviceFunction func{ std::string{deviceName}, modules, std::vector{ g_devices.size() }}; - PlatformState::instance().registerFunction(hostFunction, func); }