This introduces correct support for agent global variables, and implements hipModuleGetGlobal as an actual equivalent for cuModuleGetGlobal.
[ROCm/hip commit: 328c18b886]
Этот коммит содержится в:
@@ -27,6 +27,7 @@ THE SOFTWARE.
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
@@ -613,6 +614,125 @@ hipError_t hipHccModuleLaunchKernel(hipFunction_t f,
|
||||
sharedMemBytes, hStream, kernelParams, extra, startEvent, stopEvent));
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
struct Agent_global {
|
||||
std::string name;
|
||||
hipDeviceptr_t address;
|
||||
std::uint32_t byte_cnt;
|
||||
};
|
||||
|
||||
inline
|
||||
void* address(hsa_executable_symbol_t x)
|
||||
{
|
||||
void* r = nullptr;
|
||||
hsa_executable_symbol_get_info(
|
||||
x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS, &r);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
inline
|
||||
std::string name(hsa_executable_symbol_t x)
|
||||
{
|
||||
uint32_t sz = 0u;
|
||||
hsa_executable_symbol_get_info(
|
||||
x, HSA_EXECUTABLE_SYMBOL_INFO_NAME_LENGTH, &sz);
|
||||
|
||||
std::string r(sz, '\0');
|
||||
hsa_executable_symbol_get_info(
|
||||
x, HSA_EXECUTABLE_SYMBOL_INFO_NAME, &r.front());
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
inline
|
||||
std::uint32_t size(hsa_executable_symbol_t x)
|
||||
{
|
||||
std::uint32_t r = 0;
|
||||
hsa_executable_symbol_get_info(
|
||||
x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SIZE, &r);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
inline
|
||||
void track(const Agent_global& x)
|
||||
{
|
||||
tprintf(
|
||||
DB_MEM,
|
||||
" add variable '%s' with ptr=%p size=%u to tracker\n",
|
||||
x.name.c_str(),
|
||||
x.address,
|
||||
x.byte_cnt);
|
||||
|
||||
auto device = ihipGetTlsDefaultCtx()->getWriteableDevice();
|
||||
|
||||
hc::AmPointerInfo ptr_info(
|
||||
nullptr,
|
||||
x.address,
|
||||
x.address,
|
||||
x.byte_cnt,
|
||||
device->_acc,
|
||||
true,
|
||||
false);
|
||||
hc::am_memtracker_add(x.address, ptr_info);
|
||||
hc::am_memtracker_update(x.address, device->_deviceId, 0u);
|
||||
}
|
||||
|
||||
template<typename Container = std::vector<Agent_global>>
|
||||
inline
|
||||
hsa_status_t copy_agent_global_variables(
|
||||
hsa_executable_t, hsa_agent_t, hsa_executable_symbol_t x, void* out)
|
||||
{
|
||||
assert(out);
|
||||
|
||||
hsa_symbol_kind_t t = {};
|
||||
hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_TYPE, &t);
|
||||
|
||||
if (t == HSA_SYMBOL_KIND_VARIABLE) {
|
||||
static_cast<Container*>(out)->push_back(
|
||||
Agent_global{name(x), address(x), size(x)});
|
||||
|
||||
track(static_cast<Container*>(out)->back());
|
||||
}
|
||||
|
||||
return HSA_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
inline
|
||||
hsa_agent_t this_agent()
|
||||
{
|
||||
auto ctx = ihipGetTlsDefaultCtx();
|
||||
|
||||
if (!ctx) throw std::runtime_error{"No active HIP context."};
|
||||
|
||||
auto device = ctx->getDevice();
|
||||
|
||||
if (!device) throw std::runtime_error{"No device available for HIP."};
|
||||
|
||||
ihipDevice_t *currentDevice = ihipGetDevice(device->_deviceId);
|
||||
|
||||
if (!currentDevice) {
|
||||
throw std::runtime_error{"No active device for HIP"};
|
||||
}
|
||||
|
||||
return currentDevice->_hsaAgent;
|
||||
}
|
||||
|
||||
inline
|
||||
std::vector<Agent_global> read_agent_globals(hipModule_t hmodule)
|
||||
{
|
||||
std::vector<Agent_global> r;
|
||||
|
||||
|
||||
hsa_executable_iterate_agent_symbols(
|
||||
hmodule->executable, this_agent(), copy_agent_global_variables, &r);
|
||||
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
|
||||
hipModule_t hmod, const char* name)
|
||||
{
|
||||
@@ -625,11 +745,37 @@ hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes,
|
||||
return ihipLogStatus(hipErrorNotInitialized);
|
||||
}
|
||||
else{
|
||||
hipFunction_t func;
|
||||
ret = ihipModuleGetSymbol(&func, hmod, name);
|
||||
*bytes = PrintSymbolSizes(hmod->ptr, name) + sizeof(amd_kernel_code_t);
|
||||
*dptr = reinterpret_cast<void*>(func->_object);
|
||||
return ihipLogStatus(ret);
|
||||
static std::unordered_map<
|
||||
hipModule_t, std::vector<Agent_global>> agent_globals;
|
||||
|
||||
// TODO: this is not particularly robust.
|
||||
if (agent_globals.count(hmod) == 0) {
|
||||
static std::mutex mtx;
|
||||
std::lock_guard<std::mutex> lck{mtx};
|
||||
|
||||
if (agent_globals.count(hmod) == 0) {
|
||||
agent_globals.emplace(hmod, read_agent_globals(hmod));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This is unsafe iff some other emplacement triggers rehashing.
|
||||
// It will have to be properly fleshed out in the future.
|
||||
const auto it0 = agent_globals.find(hmod);
|
||||
if (it0 == agent_globals.cend()) {
|
||||
throw std::runtime_error{"agent_globals data structure corrupted."};
|
||||
}
|
||||
|
||||
const auto it1 = std::find_if(
|
||||
it0->second.cbegin(),
|
||||
it0->second.cend(),
|
||||
[=](const Agent_global& x) { return x.name == name; });
|
||||
|
||||
if (it1 == it0->second.cend()) return ihipLogStatus(hipErrorNotFound);
|
||||
|
||||
*dptr = it1->address;
|
||||
*bytes = it1->byte_cnt;
|
||||
|
||||
return ihipLogStatus(hipSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user