Create one hipModule_t per device

Этот коммит содержится в:
Laurent Morichetti
2018-04-10 08:57:34 -07:00
родитель 7f03ff1012
Коммит f4cde23be3
+36 -25
Просмотреть файл
@@ -53,7 +53,7 @@ struct __CudaFatBinaryWrapper {
};
extern "C" std::unordered_map<std::string, hipModule_t>*
extern "C" std::vector<hipModule_t>*
__hipRegisterFatBinary(const void* data)
{
HIP_INIT();
@@ -69,7 +69,7 @@ __hipRegisterFatBinary(const void* data)
return nullptr;
}
auto modules = new std::unordered_map<std::string, hipModule_t>{};
auto modules = new std::vector<hipModule_t>{g_deviceCnt};
if (!modules) {
return nullptr;
}
@@ -83,12 +83,33 @@ __hipRegisterFatBinary(const void* data)
if (triple.compare(AMDGCN_AMDHSA_TRIPLE))
continue;
hipModule_t module;
if (hipSuccess == hipModuleLoadData(&module, reinterpret_cast<const void*>(
reinterpret_cast<uintptr_t>(header) + desc->offset))) {
modules->emplace(std::string{&desc->triple[sizeof(AMDGCN_AMDHSA_TRIPLE)],
desc->tripleSize - sizeof(AMDGCN_AMDHSA_TRIPLE)},
module);
std::string target{&desc->triple[sizeof(AMDGCN_AMDHSA_TRIPLE)],
desc->tripleSize - sizeof(AMDGCN_AMDHSA_TRIPLE)};
for (int deviceId = 0; deviceId < g_deviceCnt; ++deviceId) {
hsa_agent_t agent = g_allAgents[deviceId + 1];
char name[64] = {};
hsa_agent_get_info(agent, HSA_AGENT_INFO_NAME, name);
if (target.compare(name)) {
continue;
}
ihipModule_t* module = new ihipModule_t;
if (!module) {
continue;
}
hsa_executable_create_alt(HSA_PROFILE_FULL, HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, nullptr,
&module->executable);
std::string image{reinterpret_cast<const char*>(
reinterpret_cast<uintptr_t>(header) + desc->offset), desc->size};
module->executable = hip_impl::load_executable(image, module->executable, agent);
if (module->executable.handle) {
modules->at(deviceId) = module;
}
}
}
@@ -98,7 +119,7 @@ __hipRegisterFatBinary(const void* data)
std::map<const void*, std::vector<hipFunction_t>> g_functions;
extern "C" void __hipRegisterFunction(
std::unordered_map<std::string, hipModule_t>* modules,
std::vector<hipModule_t>* modules,
const void* hostFunction,
char* deviceFunction,
const char* deviceName,
@@ -111,18 +132,10 @@ extern "C" void __hipRegisterFunction(
{
std::vector<hipFunction_t> functions{g_deviceCnt};
for (auto&& it : *modules) {
for (int deviceId = 0; deviceId < g_deviceCnt; ++deviceId) {
hipFunction_t function;
if (hipSuccess != hipModuleGetFunction(&function, it.second, deviceName)) {
continue;
}
for (int deviceId = 0; deviceId < g_deviceCnt; ++deviceId) {
char name[64] = {};
hsa_agent_get_info(g_allAgents[deviceId + 1], HSA_AGENT_INFO_NAME, name);
if (!it.first.compare(name)) {
functions[deviceId] = function;
}
if (hipSuccess == hipModuleGetFunction(&function, modules->at(deviceId), deviceName)) {
functions[deviceId] = function;
}
}
@@ -130,7 +143,7 @@ extern "C" void __hipRegisterFunction(
}
extern "C" void __hipRegisterVar(
hipModule_t module,
std::vector<hipModule_t>* modules,
char* hostVar,
char* deviceVar,
const char* deviceName,
@@ -141,11 +154,9 @@ extern "C" void __hipRegisterVar(
{
}
extern "C" void __hipUnregisterFatBinary(std::unordered_map<std::string, hipModule_t>* modules)
extern "C" void __hipUnregisterFatBinary(std::vector<hipModule_t>* modules)
{
for (auto&& it : *modules) {
delete it.second;
}
std::for_each(modules->begin(), modules->end(), [](hipModule_t module){ delete module; });
delete modules;
}