SWDEV-508225 - Improve fat binary handling

Change-Id: I78a9951f2f4c4c743c1205b1e40aac215054e27d


[ROCm/clr commit: 08af3eb484]
Этот коммит содержится в:
Saleel Kudchadker
2025-01-08 17:53:02 +00:00
родитель ae379965dd
Коммит 21ae9ef25e
4 изменённых файлов: 104 добавлений и 47 удалений
+98 -40
Просмотреть файл
@@ -1284,6 +1284,7 @@ hipError_t StatCO::digestFatBinary(const void* data, FatBinaryInfo*& programs) {
FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& success) {
amd::ScopedLock lock(sclock_);
module_to_hostModule_.insert(std::make_pair(&modules_[data], data));
if (initialized == false) {
success = true;
@@ -1291,6 +1292,7 @@ FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& s
}
hipError_t err = digestFatBinary(data, modules_[data]);
success = (err == hipSuccess);
return &modules_[data];
}
@@ -1298,56 +1300,76 @@ FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& s
hipError_t StatCO::removeFatBinary(FatBinaryInfo** module) {
amd::ScopedLock lock(sclock_);
auto vit = vars_.begin();
while (vit != vars_.end()) {
if (vit->second->moduleInfo() == module) {
delete vit->second;
vit = vars_.erase(vit);
} else {
++vit;
auto hostVarsIter = module_to_hostVars_.find(module);
if (hostVarsIter != module_to_hostVars_.end()) {
for (auto& hostVar : hostVarsIter->second) {
auto varIter = vars_.find(hostVar);
if (varIter == vars_.end()) {
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostVar 0x%x",
module, hostVar);
} else {
delete varIter->second;
vars_.erase(varIter);
}
}
module_to_hostVars_.erase(hostVarsIter);
} else {
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostVars", module);
}
auto it = managedVars_.begin();
while (it != managedVars_.end()) {
if ((*it)->moduleInfo() == module) {
auto managedVarsIter = managedVars_.find(module);
if (managedVarsIter != managedVars_.end()) {
for (auto& managedVar : managedVarsIter->second) {
hipError_t err;
for (auto dev : g_devices) {
DeviceVar* dvar = nullptr;
IHIP_RETURN_ONFAIL((*it)->getDeviceVarPtr(&dvar, dev->deviceId()));
IHIP_RETURN_ONFAIL(managedVar->getDeviceVarPtr(&dvar, dev->deviceId()));
if (dvar != nullptr) {
// free also deletes the device ptr
err = ihipFree(dvar->device_ptr());
assert(err == hipSuccess);
}
}
err = ihipFree(*(static_cast<void**>((*it)->getManagedVarPtr())));
err = ihipFree(*(static_cast<void**>(managedVar->getManagedVarPtr())));
assert(err == hipSuccess);
delete *it;
it = managedVars_.erase(it);
} else {
++it;
delete managedVar;
}
managedVars_.erase(managedVarsIter);
} else {
LogPrintfError("removeFatBinary: Unable to find module 0x%x managedVars", module);
}
auto fit = functions_.begin();
while (fit != functions_.end()) {
if (fit->second->moduleInfo() == module) {
delete fit->second;
fit = functions_.erase(fit);
} else {
++fit;
auto hostFuncsIter = module_to_hostFunctions_.find(module);
if (hostFuncsIter != module_to_hostFunctions_.end()) {
for (auto& hostFunc : hostFuncsIter->second) {
auto funcIter = functions_.find(hostFunc);
if (funcIter == functions_.end()) {
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostFunc 0x%x",
module, hostFunc);
} else {
delete funcIter->second;
functions_.erase(funcIter);
}
}
module_to_hostFunctions_.erase(hostFuncsIter);
} else {
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostFuncs", module);
}
auto mit = modules_.begin();
while (mit != modules_.end()) {
if (&mit->second == module) {
delete mit->second;
mit = modules_.erase(mit);
auto hostModuleIter = module_to_hostModule_.find(module);
if (hostModuleIter != module_to_hostModule_.end()) {
auto hostModule = hostModuleIter->second;
auto moduleIter = modules_.find(hostModule);
if (moduleIter != modules_.end()) {
delete moduleIter->second;
modules_.erase(moduleIter);
} else {
++mit;
LogPrintfError("removeFatBinary: Unable to find module 0x%x via hostModule 0x%x",
module, hostModule);
}
module_to_hostModule_.erase(hostModuleIter);
} else {
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostModule", module);
}
return hipSuccess;
@@ -1361,6 +1383,7 @@ hipError_t StatCO::registerStatFunction(const void* hostFunction, Function* func
delete func;
} else {
functions_.insert(std::make_pair(hostFunction, func));
module_to_hostFunctions_[func->moduleInfo()].push_back(hostFunction);
}
return hipSuccess;
@@ -1381,6 +1404,17 @@ hipError_t StatCO::getStatFunc(hipFunction_t* hfunc, const void* hostFunction, i
if (it == functions_.end()) {
return hipErrorInvalidSymbol;
}
// Lazy load
FatBinaryInfo **module = it->second->moduleInfo();
if (*(module) == nullptr) {
amd::ScopedLock lock(sclock_);
if (*(module) == nullptr) {
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
assert(err == hipSuccess);
}
}
return it->second->getStatFunc(hfunc, deviceId);
}
@@ -1393,6 +1427,13 @@ hipError_t StatCO::getStatFuncAttr(hipFuncAttributes* func_attr, const void* hos
return hipErrorInvalidSymbol;
}
// Lazy load
FatBinaryInfo **module = it->second->moduleInfo();
if (*(module) == nullptr) {
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
assert(err == hipSuccess);
}
return it->second->getStatFuncAttr(func_attr, deviceId);
}
@@ -1405,6 +1446,7 @@ hipError_t StatCO::registerStatGlobalVar(const void* hostVar, Var* var) {
}
vars_.insert(std::make_pair(hostVar, var));
module_to_hostVars_[var->moduleInfo()].push_back(hostVar);
return hipSuccess;
}
@@ -1417,6 +1459,13 @@ hipError_t StatCO::getStatGlobalVar(const void* hostVar, int deviceId, hipDevice
return hipErrorInvalidSymbol;
}
// Lazy load
FatBinaryInfo **module = it->second->moduleInfo();
if (*(module) == nullptr) {
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
assert(err == hipSuccess);
}
DeviceVar* dvar = nullptr;
IHIP_RETURN_ONFAIL(it->second->getStatDeviceVar(&dvar, deviceId));
@@ -1426,7 +1475,7 @@ hipError_t StatCO::getStatGlobalVar(const void* hostVar, int deviceId, hipDevice
}
hipError_t StatCO::registerStatManagedVar(Var* var) {
managedVars_.emplace_back(var);
managedVars_[var->moduleInfo()].push_back(var);
return hipSuccess;
}
@@ -1435,17 +1484,26 @@ hipError_t StatCO::initStatManagedVarDevicePtr(int deviceId) {
hipError_t err = hipSuccess;
if (managedVarsDevicePtrInitalized_.find(deviceId) == managedVarsDevicePtrInitalized_.end() ||
!managedVarsDevicePtrInitalized_[deviceId]) {
for (auto var : managedVars_) {
DeviceVar* dvar = nullptr;
IHIP_RETURN_ONFAIL(var->getStatDeviceVar(&dvar, deviceId));
for (auto& vecIter : managedVars_) {
for (auto& var : vecIter.second) {
// Lazy load
FatBinaryInfo **module = var->moduleInfo();
if (*(module) == nullptr) {
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
assert(err == hipSuccess);
}
hip::Stream* stream = g_devices.at(deviceId)->NullStream();
if (stream != nullptr) {
err = ihipMemcpy(reinterpret_cast<address>(dvar->device_ptr()), var->getManagedVarPtr(),
dvar->size(), hipMemcpyHostToDevice, *stream);
} else {
ClPrint(amd::LOG_ERROR, amd::LOG_API, "Host Queue is NULL");
return hipErrorInvalidResourceHandle;
DeviceVar* dvar = nullptr;
IHIP_RETURN_ONFAIL(var->getStatDeviceVar(&dvar, deviceId));
hip::Stream* stream = g_devices.at(deviceId)->NullStream();
if (stream != nullptr) {
err = ihipMemcpy(reinterpret_cast<address>(dvar->device_ptr()), var->getManagedVarPtr(),
dvar->size(), hipMemcpyHostToDevice, *stream);
} else {
ClPrint(amd::LOG_ERROR, amd::LOG_API, "Host Queue is NULL");
return hipErrorInvalidResourceHandle;
}
}
}
managedVarsDevicePtrInitalized_[deviceId] = true;
+5 -1
Просмотреть файл
@@ -190,7 +190,11 @@ private:
//Populated during __hipRegisterVars
std::unordered_map<const void*, Var*> vars_;
//Populated during __hipRegisterManagedVar
std::vector<Var*> managedVars_;
std::unordered_map<FatBinaryInfo**, std::vector<Var*> > managedVars_;
//Reverse mapping of modules to speed up removal
std::unordered_map<FatBinaryInfo**, const void*> module_to_hostModule_;
std::unordered_map<FatBinaryInfo**, std::vector<const void*> > module_to_hostFunctions_;
std::unordered_map<FatBinaryInfo**, std::vector<const void*> > module_to_hostVars_;
std::unordered_map<int, bool> managedVarsDevicePtrInitalized_;
};
+1
Просмотреть файл
@@ -569,6 +569,7 @@ hipError_t FatBinaryInfo::ExtractFatBinaryUsingCOMGR(const std::vector<hip::Devi
}
hipError_t FatBinaryInfo::ExtractFatBinary(const std::vector<hip::Device*>& devices) {
amd::ScopedLock lock(FatBinaryLock());
if (!HIP_USE_RUNTIME_UNBUNDLER) {
bool containGenericTarget = false;
hipError_t status = ExtractFatBinaryUsingCOMGR(devices, containGenericTarget);
-6
Просмотреть файл
@@ -734,12 +734,6 @@ void PlatformState::init() {
return;
}
initialized_ = true;
for (auto& it : statCO_.modules_) {
hipError_t err = digestFatBinary(it.first, it.second);
if (err != hipSuccess) {
HIP_ERROR_PRINT(err, "continue parsing remaining modules");
}
}
for (auto& it : statCO_.vars_) {
it.second->resize_dVar(g_devices.size());
}