diff --git a/projects/hip/src/hip_module.cpp b/projects/hip/src/hip_module.cpp index 7799ad86c8..38411f2347 100644 --- a/projects/hip/src/hip_module.cpp +++ b/projects/hip/src/hip_module.cpp @@ -27,6 +27,7 @@ THE SOFTWARE. #include #include #include +#include #include #include @@ -613,6 +614,125 @@ hipError_t hipHccModuleLaunchKernel(hipFunction_t f, sharedMemBytes, hStream, kernelParams, extra, startEvent, stopEvent)); } +namespace +{ + struct Agent_global { + std::string name; + hipDeviceptr_t address; + std::uint32_t byte_cnt; + }; + + inline + void* address(hsa_executable_symbol_t x) + { + void* r = nullptr; + hsa_executable_symbol_get_info( + x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS, &r); + + return r; + } + + inline + std::string name(hsa_executable_symbol_t x) + { + uint32_t sz = 0u; + hsa_executable_symbol_get_info( + x, HSA_EXECUTABLE_SYMBOL_INFO_NAME_LENGTH, &sz); + + std::string r(sz, '\0'); + hsa_executable_symbol_get_info( + x, HSA_EXECUTABLE_SYMBOL_INFO_NAME, &r.front()); + + return r; + } + + inline + std::uint32_t size(hsa_executable_symbol_t x) + { + std::uint32_t r = 0; + hsa_executable_symbol_get_info( + x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SIZE, &r); + + return r; + } + + inline + void track(const Agent_global& x) + { + tprintf( + DB_MEM, + " add variable '%s' with ptr=%p size=%u to tracker\n", + x.name.c_str(), + x.address, + x.byte_cnt); + + auto device = ihipGetTlsDefaultCtx()->getWriteableDevice(); + + hc::AmPointerInfo ptr_info( + nullptr, + x.address, + x.address, + x.byte_cnt, + device->_acc, + true, + false); + hc::am_memtracker_add(x.address, ptr_info); + hc::am_memtracker_update(x.address, device->_deviceId, 0u); + } + + template> + inline + hsa_status_t copy_agent_global_variables( + hsa_executable_t, hsa_agent_t, hsa_executable_symbol_t x, void* out) + { + assert(out); + + hsa_symbol_kind_t t = {}; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_TYPE, &t); + + if (t == HSA_SYMBOL_KIND_VARIABLE) { + static_cast(out)->push_back( + Agent_global{name(x), address(x), size(x)}); + + track(static_cast(out)->back()); + } + + return HSA_STATUS_SUCCESS; + } + + inline + hsa_agent_t this_agent() + { + auto ctx = ihipGetTlsDefaultCtx(); + + if (!ctx) throw std::runtime_error{"No active HIP context."}; + + auto device = ctx->getDevice(); + + if (!device) throw std::runtime_error{"No device available for HIP."}; + + ihipDevice_t *currentDevice = ihipGetDevice(device->_deviceId); + + if (!currentDevice) { + throw std::runtime_error{"No active device for HIP"}; + } + + return currentDevice->_hsaAgent; + } + + inline + std::vector read_agent_globals(hipModule_t hmodule) + { + std::vector r; + + + hsa_executable_iterate_agent_symbols( + hmodule->executable, this_agent(), copy_agent_global_variables, &r); + + return r; + } +} + hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes, hipModule_t hmod, const char* name) { @@ -625,11 +745,37 @@ hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes, return ihipLogStatus(hipErrorNotInitialized); } else{ - hipFunction_t func; - ret = ihipModuleGetSymbol(&func, hmod, name); - *bytes = PrintSymbolSizes(hmod->ptr, name) + sizeof(amd_kernel_code_t); - *dptr = reinterpret_cast(func->_object); - return ihipLogStatus(ret); + static std::unordered_map< + hipModule_t, std::vector> agent_globals; + + // TODO: this is not particularly robust. + if (agent_globals.count(hmod) == 0) { + static std::mutex mtx; + std::lock_guard lck{mtx}; + + if (agent_globals.count(hmod) == 0) { + agent_globals.emplace(hmod, read_agent_globals(hmod)); + } + } + + // TODO: This is unsafe iff some other emplacement triggers rehashing. + // It will have to be properly fleshed out in the future. + const auto it0 = agent_globals.find(hmod); + if (it0 == agent_globals.cend()) { + throw std::runtime_error{"agent_globals data structure corrupted."}; + } + + const auto it1 = std::find_if( + it0->second.cbegin(), + it0->second.cend(), + [=](const Agent_global& x) { return x.name == name; }); + + if (it1 == it0->second.cend()) return ihipLogStatus(hipErrorNotFound); + + *dptr = it1->address; + *bytes = it1->byte_cnt; + + return ihipLogStatus(hipSuccess); } }