From f4cde23be36552da5d19e85f3c83601a4fffbb05 Mon Sep 17 00:00:00 2001 From: Laurent Morichetti Date: Tue, 10 Apr 2018 08:57:34 -0700 Subject: [PATCH] Create one hipModule_t per device --- hipamd/src/hip_clang.cpp | 61 ++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/hipamd/src/hip_clang.cpp b/hipamd/src/hip_clang.cpp index 5c4da0b39d..80b6111fc2 100644 --- a/hipamd/src/hip_clang.cpp +++ b/hipamd/src/hip_clang.cpp @@ -53,7 +53,7 @@ struct __CudaFatBinaryWrapper { }; -extern "C" std::unordered_map* +extern "C" std::vector* __hipRegisterFatBinary(const void* data) { HIP_INIT(); @@ -69,7 +69,7 @@ __hipRegisterFatBinary(const void* data) return nullptr; } - auto modules = new std::unordered_map{}; + auto modules = new std::vector{g_deviceCnt}; if (!modules) { return nullptr; } @@ -83,12 +83,33 @@ __hipRegisterFatBinary(const void* data) if (triple.compare(AMDGCN_AMDHSA_TRIPLE)) continue; - hipModule_t module; - if (hipSuccess == hipModuleLoadData(&module, reinterpret_cast( - reinterpret_cast(header) + desc->offset))) { - modules->emplace(std::string{&desc->triple[sizeof(AMDGCN_AMDHSA_TRIPLE)], - desc->tripleSize - sizeof(AMDGCN_AMDHSA_TRIPLE)}, - module); + std::string target{&desc->triple[sizeof(AMDGCN_AMDHSA_TRIPLE)], + desc->tripleSize - sizeof(AMDGCN_AMDHSA_TRIPLE)}; + + for (int deviceId = 0; deviceId < g_deviceCnt; ++deviceId) { + hsa_agent_t agent = g_allAgents[deviceId + 1]; + + char name[64] = {}; + hsa_agent_get_info(agent, HSA_AGENT_INFO_NAME, name); + if (target.compare(name)) { + continue; + } + + ihipModule_t* module = new ihipModule_t; + if (!module) { + continue; + } + + hsa_executable_create_alt(HSA_PROFILE_FULL, HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, nullptr, + &module->executable); + + std::string image{reinterpret_cast( + reinterpret_cast(header) + desc->offset), desc->size}; + module->executable = hip_impl::load_executable(image, module->executable, agent); + + if (module->executable.handle) { + modules->at(deviceId) = module; + } } } @@ -98,7 +119,7 @@ __hipRegisterFatBinary(const void* data) std::map> g_functions; extern "C" void __hipRegisterFunction( - std::unordered_map* modules, + std::vector* modules, const void* hostFunction, char* deviceFunction, const char* deviceName, @@ -111,18 +132,10 @@ extern "C" void __hipRegisterFunction( { std::vector functions{g_deviceCnt}; - for (auto&& it : *modules) { + for (int deviceId = 0; deviceId < g_deviceCnt; ++deviceId) { hipFunction_t function; - if (hipSuccess != hipModuleGetFunction(&function, it.second, deviceName)) { - continue; - } - - for (int deviceId = 0; deviceId < g_deviceCnt; ++deviceId) { - char name[64] = {}; - hsa_agent_get_info(g_allAgents[deviceId + 1], HSA_AGENT_INFO_NAME, name); - if (!it.first.compare(name)) { - functions[deviceId] = function; - } + if (hipSuccess == hipModuleGetFunction(&function, modules->at(deviceId), deviceName)) { + functions[deviceId] = function; } } @@ -130,7 +143,7 @@ extern "C" void __hipRegisterFunction( } extern "C" void __hipRegisterVar( - hipModule_t module, + std::vector* modules, char* hostVar, char* deviceVar, const char* deviceName, @@ -141,11 +154,9 @@ extern "C" void __hipRegisterVar( { } -extern "C" void __hipUnregisterFatBinary(std::unordered_map* modules) +extern "C" void __hipUnregisterFatBinary(std::vector* modules) { - for (auto&& it : *modules) { - delete it.second; - } + std::for_each(modules->begin(), modules->end(), [](hipModule_t module){ delete module; }); delete modules; }