This introduces correct support for agent global variables, and implements hipModuleGetGlobal as an actual equivalent for cuModuleGetGlobal.

[ROCm/hip commit: 328c18b886]
Этот коммит содержится в:
Alex Voicu
2017-11-03 01:44:48 +00:00
родитель da749d453f
Коммит 28eb8e2c3e
+151 -5
Просмотреть файл
@@ -27,6 +27,7 @@ THE SOFTWARE.
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include <map>
@@ -613,6 +614,125 @@ hipError_t hipHccModuleLaunchKernel(hipFunction_t f,
sharedMemBytes, hStream, kernelParams, extra, startEvent, stopEvent));
}
namespace
{
struct Agent_global {
std::string name;
hipDeviceptr_t address;
std::uint32_t byte_cnt;
};
inline
void* address(hsa_executable_symbol_t x)
{
void* r = nullptr;
hsa_executable_symbol_get_info(
x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS, &r);
return r;
}
inline
std::string name(hsa_executable_symbol_t x)
{
uint32_t sz = 0u;
hsa_executable_symbol_get_info(
x, HSA_EXECUTABLE_SYMBOL_INFO_NAME_LENGTH, &sz);
std::string r(sz, '\0');
hsa_executable_symbol_get_info(
x, HSA_EXECUTABLE_SYMBOL_INFO_NAME, &r.front());
return r;
}
inline
std::uint32_t size(hsa_executable_symbol_t x)
{
std::uint32_t r = 0;
hsa_executable_symbol_get_info(
x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SIZE, &r);
return r;
}
inline
void track(const Agent_global& x)
{
tprintf(
DB_MEM,
" add variable '%s' with ptr=%p size=%u to tracker\n",
x.name.c_str(),
x.address,
x.byte_cnt);
auto device = ihipGetTlsDefaultCtx()->getWriteableDevice();
hc::AmPointerInfo ptr_info(
nullptr,
x.address,
x.address,
x.byte_cnt,
device->_acc,
true,
false);
hc::am_memtracker_add(x.address, ptr_info);
hc::am_memtracker_update(x.address, device->_deviceId, 0u);
}
template<typename Container = std::vector<Agent_global>>
inline
hsa_status_t copy_agent_global_variables(
hsa_executable_t, hsa_agent_t, hsa_executable_symbol_t x, void* out)
{
assert(out);
hsa_symbol_kind_t t = {};
hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_TYPE, &t);
if (t == HSA_SYMBOL_KIND_VARIABLE) {
static_cast<Container*>(out)->push_back(
Agent_global{name(x), address(x), size(x)});
track(static_cast<Container*>(out)->back());
}
return HSA_STATUS_SUCCESS;
}
inline
hsa_agent_t this_agent()
{
auto ctx = ihipGetTlsDefaultCtx();
if (!ctx) throw std::runtime_error{"No active HIP context."};
auto device = ctx->getDevice();
if (!device) throw std::runtime_error{"No device available for HIP."};
ihipDevice_t *currentDevice = ihipGetDevice(device->_deviceId);
if (!currentDevice) {
throw std::runtime_error{"No active device for HIP"};
}
return currentDevice->_hsaAgent;
}
inline
std::vector<Agent_global> read_agent_globals(hipModule_t hmodule)
{
std::vector<Agent_global> r;
hsa_executable_iterate_agent_symbols(
hmodule->executable, this_agent(), copy_agent_global_variables, &r);
return r;
}
}
hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
hipModule_t hmod, const char* name)
{
@@ -625,11 +745,37 @@ hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
return ihipLogStatus(hipErrorNotInitialized);
}
else{
hipFunction_t func;
ret = ihipModuleGetSymbol(&func, hmod, name);
*bytes = PrintSymbolSizes(hmod->ptr, name) + sizeof(amd_kernel_code_t);
*dptr = reinterpret_cast<void*>(func->_object);
return ihipLogStatus(ret);
static std::unordered_map<
hipModule_t, std::vector<Agent_global>> agent_globals;
// TODO: this is not particularly robust.
if (agent_globals.count(hmod) == 0) {
static std::mutex mtx;
std::lock_guard<std::mutex> lck{mtx};
if (agent_globals.count(hmod) == 0) {
agent_globals.emplace(hmod, read_agent_globals(hmod));
}
}
// TODO: This is unsafe iff some other emplacement triggers rehashing.
// It will have to be properly fleshed out in the future.
const auto it0 = agent_globals.find(hmod);
if (it0 == agent_globals.cend()) {
throw std::runtime_error{"agent_globals data structure corrupted."};
}
const auto it1 = std::find_if(
it0->second.cbegin(),
it0->second.cend(),
[=](const Agent_global& x) { return x.name == name; });
if (it1 == it0->second.cend()) return ihipLogStatus(hipErrorNotFound);
*dptr = it1->address;
*bytes = it1->byte_cnt;
return ihipLogStatus(hipSuccess);
}
}