diff --git a/projects/clr/hipamd/src/hip_code_object.cpp b/projects/clr/hipamd/src/hip_code_object.cpp index 15f5722deb..b345323e34 100644 --- a/projects/clr/hipamd/src/hip_code_object.cpp +++ b/projects/clr/hipamd/src/hip_code_object.cpp @@ -1284,6 +1284,7 @@ hipError_t StatCO::digestFatBinary(const void* data, FatBinaryInfo*& programs) { FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& success) { amd::ScopedLock lock(sclock_); + module_to_hostModule_.insert(std::make_pair(&modules_[data], data)); if (initialized == false) { success = true; @@ -1291,6 +1292,7 @@ FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& s } hipError_t err = digestFatBinary(data, modules_[data]); + success = (err == hipSuccess); return &modules_[data]; } @@ -1298,56 +1300,76 @@ FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& s hipError_t StatCO::removeFatBinary(FatBinaryInfo** module) { amd::ScopedLock lock(sclock_); - auto vit = vars_.begin(); - while (vit != vars_.end()) { - if (vit->second->moduleInfo() == module) { - delete vit->second; - vit = vars_.erase(vit); - } else { - ++vit; + auto hostVarsIter = module_to_hostVars_.find(module); + if (hostVarsIter != module_to_hostVars_.end()) { + for (auto& hostVar : hostVarsIter->second) { + auto varIter = vars_.find(hostVar); + if (varIter == vars_.end()) { + LogPrintfError("removeFatBinary: Unable to find module 0x%x hostVar 0x%x", + module, hostVar); + } else { + delete varIter->second; + vars_.erase(varIter); + } } + module_to_hostVars_.erase(hostVarsIter); + } else { + LogPrintfError("removeFatBinary: Unable to find module 0x%x hostVars", module); } - auto it = managedVars_.begin(); - while (it != managedVars_.end()) { - if ((*it)->moduleInfo() == module) { + auto managedVarsIter = managedVars_.find(module); + if (managedVarsIter != managedVars_.end()) { + for (auto& managedVar : managedVarsIter->second) { hipError_t err; for (auto dev : g_devices) { DeviceVar* dvar = nullptr; - IHIP_RETURN_ONFAIL((*it)->getDeviceVarPtr(&dvar, dev->deviceId())); + IHIP_RETURN_ONFAIL(managedVar->getDeviceVarPtr(&dvar, dev->deviceId())); if (dvar != nullptr) { // free also deletes the device ptr err = ihipFree(dvar->device_ptr()); assert(err == hipSuccess); } } - err = ihipFree(*(static_cast((*it)->getManagedVarPtr()))); + err = ihipFree(*(static_cast(managedVar->getManagedVarPtr()))); assert(err == hipSuccess); - delete *it; - it = managedVars_.erase(it); - } else { - ++it; + delete managedVar; } + managedVars_.erase(managedVarsIter); + } else { + LogPrintfError("removeFatBinary: Unable to find module 0x%x managedVars", module); } - auto fit = functions_.begin(); - while (fit != functions_.end()) { - if (fit->second->moduleInfo() == module) { - delete fit->second; - fit = functions_.erase(fit); - } else { - ++fit; + auto hostFuncsIter = module_to_hostFunctions_.find(module); + if (hostFuncsIter != module_to_hostFunctions_.end()) { + for (auto& hostFunc : hostFuncsIter->second) { + auto funcIter = functions_.find(hostFunc); + if (funcIter == functions_.end()) { + LogPrintfError("removeFatBinary: Unable to find module 0x%x hostFunc 0x%x", + module, hostFunc); + } else { + delete funcIter->second; + functions_.erase(funcIter); + } } + module_to_hostFunctions_.erase(hostFuncsIter); + } else { + LogPrintfError("removeFatBinary: Unable to find module 0x%x hostFuncs", module); } - auto mit = modules_.begin(); - while (mit != modules_.end()) { - if (&mit->second == module) { - delete mit->second; - mit = modules_.erase(mit); + auto hostModuleIter = module_to_hostModule_.find(module); + if (hostModuleIter != module_to_hostModule_.end()) { + auto hostModule = hostModuleIter->second; + auto moduleIter = modules_.find(hostModule); + if (moduleIter != modules_.end()) { + delete moduleIter->second; + modules_.erase(moduleIter); } else { - ++mit; + LogPrintfError("removeFatBinary: Unable to find module 0x%x via hostModule 0x%x", + module, hostModule); } + module_to_hostModule_.erase(hostModuleIter); + } else { + LogPrintfError("removeFatBinary: Unable to find module 0x%x hostModule", module); } return hipSuccess; @@ -1361,6 +1383,7 @@ hipError_t StatCO::registerStatFunction(const void* hostFunction, Function* func delete func; } else { functions_.insert(std::make_pair(hostFunction, func)); + module_to_hostFunctions_[func->moduleInfo()].push_back(hostFunction); } return hipSuccess; @@ -1381,6 +1404,17 @@ hipError_t StatCO::getStatFunc(hipFunction_t* hfunc, const void* hostFunction, i if (it == functions_.end()) { return hipErrorInvalidSymbol; } + + // Lazy load + FatBinaryInfo **module = it->second->moduleInfo(); + if (*(module) == nullptr) { + amd::ScopedLock lock(sclock_); + if (*(module) == nullptr) { + hipError_t err = digestFatBinary(module_to_hostModule_[module], *module); + assert(err == hipSuccess); + } + } + return it->second->getStatFunc(hfunc, deviceId); } @@ -1393,6 +1427,13 @@ hipError_t StatCO::getStatFuncAttr(hipFuncAttributes* func_attr, const void* hos return hipErrorInvalidSymbol; } + // Lazy load + FatBinaryInfo **module = it->second->moduleInfo(); + if (*(module) == nullptr) { + hipError_t err = digestFatBinary(module_to_hostModule_[module], *module); + assert(err == hipSuccess); + } + return it->second->getStatFuncAttr(func_attr, deviceId); } @@ -1405,6 +1446,7 @@ hipError_t StatCO::registerStatGlobalVar(const void* hostVar, Var* var) { } vars_.insert(std::make_pair(hostVar, var)); + module_to_hostVars_[var->moduleInfo()].push_back(hostVar); return hipSuccess; } @@ -1417,6 +1459,13 @@ hipError_t StatCO::getStatGlobalVar(const void* hostVar, int deviceId, hipDevice return hipErrorInvalidSymbol; } + // Lazy load + FatBinaryInfo **module = it->second->moduleInfo(); + if (*(module) == nullptr) { + hipError_t err = digestFatBinary(module_to_hostModule_[module], *module); + assert(err == hipSuccess); + } + DeviceVar* dvar = nullptr; IHIP_RETURN_ONFAIL(it->second->getStatDeviceVar(&dvar, deviceId)); @@ -1426,7 +1475,7 @@ hipError_t StatCO::getStatGlobalVar(const void* hostVar, int deviceId, hipDevice } hipError_t StatCO::registerStatManagedVar(Var* var) { - managedVars_.emplace_back(var); + managedVars_[var->moduleInfo()].push_back(var); return hipSuccess; } @@ -1435,17 +1484,26 @@ hipError_t StatCO::initStatManagedVarDevicePtr(int deviceId) { hipError_t err = hipSuccess; if (managedVarsDevicePtrInitalized_.find(deviceId) == managedVarsDevicePtrInitalized_.end() || !managedVarsDevicePtrInitalized_[deviceId]) { - for (auto var : managedVars_) { - DeviceVar* dvar = nullptr; - IHIP_RETURN_ONFAIL(var->getStatDeviceVar(&dvar, deviceId)); + for (auto& vecIter : managedVars_) { + for (auto& var : vecIter.second) { + // Lazy load + FatBinaryInfo **module = var->moduleInfo(); + if (*(module) == nullptr) { + hipError_t err = digestFatBinary(module_to_hostModule_[module], *module); + assert(err == hipSuccess); + } - hip::Stream* stream = g_devices.at(deviceId)->NullStream(); - if (stream != nullptr) { - err = ihipMemcpy(reinterpret_cast
(dvar->device_ptr()), var->getManagedVarPtr(), - dvar->size(), hipMemcpyHostToDevice, *stream); - } else { - ClPrint(amd::LOG_ERROR, amd::LOG_API, "Host Queue is NULL"); - return hipErrorInvalidResourceHandle; + DeviceVar* dvar = nullptr; + IHIP_RETURN_ONFAIL(var->getStatDeviceVar(&dvar, deviceId)); + + hip::Stream* stream = g_devices.at(deviceId)->NullStream(); + if (stream != nullptr) { + err = ihipMemcpy(reinterpret_cast
(dvar->device_ptr()), var->getManagedVarPtr(), + dvar->size(), hipMemcpyHostToDevice, *stream); + } else { + ClPrint(amd::LOG_ERROR, amd::LOG_API, "Host Queue is NULL"); + return hipErrorInvalidResourceHandle; + } } } managedVarsDevicePtrInitalized_[deviceId] = true; diff --git a/projects/clr/hipamd/src/hip_code_object.hpp b/projects/clr/hipamd/src/hip_code_object.hpp index d682ee4f54..7a4e6e74e7 100644 --- a/projects/clr/hipamd/src/hip_code_object.hpp +++ b/projects/clr/hipamd/src/hip_code_object.hpp @@ -190,7 +190,11 @@ private: //Populated during __hipRegisterVars std::unordered_map vars_; //Populated during __hipRegisterManagedVar - std::vector managedVars_; + std::unordered_map > managedVars_; + //Reverse mapping of modules to speed up removal + std::unordered_map module_to_hostModule_; + std::unordered_map > module_to_hostFunctions_; + std::unordered_map > module_to_hostVars_; std::unordered_map managedVarsDevicePtrInitalized_; }; diff --git a/projects/clr/hipamd/src/hip_fatbin.cpp b/projects/clr/hipamd/src/hip_fatbin.cpp index 2f25c2c6a3..b9d0434673 100644 --- a/projects/clr/hipamd/src/hip_fatbin.cpp +++ b/projects/clr/hipamd/src/hip_fatbin.cpp @@ -569,6 +569,7 @@ hipError_t FatBinaryInfo::ExtractFatBinaryUsingCOMGR(const std::vector& devices) { + amd::ScopedLock lock(FatBinaryLock()); if (!HIP_USE_RUNTIME_UNBUNDLER) { bool containGenericTarget = false; hipError_t status = ExtractFatBinaryUsingCOMGR(devices, containGenericTarget); diff --git a/projects/clr/hipamd/src/hip_platform.cpp b/projects/clr/hipamd/src/hip_platform.cpp index 1a68955834..137702afb2 100644 --- a/projects/clr/hipamd/src/hip_platform.cpp +++ b/projects/clr/hipamd/src/hip_platform.cpp @@ -734,12 +734,6 @@ void PlatformState::init() { return; } initialized_ = true; - for (auto& it : statCO_.modules_) { - hipError_t err = digestFatBinary(it.first, it.second); - if (err != hipSuccess) { - HIP_ERROR_PRINT(err, "continue parsing remaining modules"); - } - } for (auto& it : statCO_.vars_) { it.second->resize_dVar(g_devices.size()); }