SWDEV-508225 - Improve fat binary handling
Change-Id: I78a9951f2f4c4c743c1205b1e40aac215054e27d
Este cometimento está contido em:
@@ -1284,6 +1284,7 @@ hipError_t StatCO::digestFatBinary(const void* data, FatBinaryInfo*& programs) {
|
||||
|
||||
FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& success) {
|
||||
amd::ScopedLock lock(sclock_);
|
||||
module_to_hostModule_.insert(std::make_pair(&modules_[data], data));
|
||||
|
||||
if (initialized == false) {
|
||||
success = true;
|
||||
@@ -1291,6 +1292,7 @@ FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& s
|
||||
}
|
||||
|
||||
hipError_t err = digestFatBinary(data, modules_[data]);
|
||||
|
||||
success = (err == hipSuccess);
|
||||
return &modules_[data];
|
||||
}
|
||||
@@ -1298,56 +1300,76 @@ FatBinaryInfo** StatCO::addFatBinary(const void* data, bool initialized, bool& s
|
||||
hipError_t StatCO::removeFatBinary(FatBinaryInfo** module) {
|
||||
amd::ScopedLock lock(sclock_);
|
||||
|
||||
auto vit = vars_.begin();
|
||||
while (vit != vars_.end()) {
|
||||
if (vit->second->moduleInfo() == module) {
|
||||
delete vit->second;
|
||||
vit = vars_.erase(vit);
|
||||
} else {
|
||||
++vit;
|
||||
auto hostVarsIter = module_to_hostVars_.find(module);
|
||||
if (hostVarsIter != module_to_hostVars_.end()) {
|
||||
for (auto& hostVar : hostVarsIter->second) {
|
||||
auto varIter = vars_.find(hostVar);
|
||||
if (varIter == vars_.end()) {
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostVar 0x%x",
|
||||
module, hostVar);
|
||||
} else {
|
||||
delete varIter->second;
|
||||
vars_.erase(varIter);
|
||||
}
|
||||
}
|
||||
module_to_hostVars_.erase(hostVarsIter);
|
||||
} else {
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostVars", module);
|
||||
}
|
||||
|
||||
auto it = managedVars_.begin();
|
||||
while (it != managedVars_.end()) {
|
||||
if ((*it)->moduleInfo() == module) {
|
||||
auto managedVarsIter = managedVars_.find(module);
|
||||
if (managedVarsIter != managedVars_.end()) {
|
||||
for (auto& managedVar : managedVarsIter->second) {
|
||||
hipError_t err;
|
||||
for (auto dev : g_devices) {
|
||||
DeviceVar* dvar = nullptr;
|
||||
IHIP_RETURN_ONFAIL((*it)->getDeviceVarPtr(&dvar, dev->deviceId()));
|
||||
IHIP_RETURN_ONFAIL(managedVar->getDeviceVarPtr(&dvar, dev->deviceId()));
|
||||
if (dvar != nullptr) {
|
||||
// free also deletes the device ptr
|
||||
err = ihipFree(dvar->device_ptr());
|
||||
assert(err == hipSuccess);
|
||||
}
|
||||
}
|
||||
err = ihipFree(*(static_cast<void**>((*it)->getManagedVarPtr())));
|
||||
err = ihipFree(*(static_cast<void**>(managedVar->getManagedVarPtr())));
|
||||
assert(err == hipSuccess);
|
||||
delete *it;
|
||||
it = managedVars_.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
delete managedVar;
|
||||
}
|
||||
managedVars_.erase(managedVarsIter);
|
||||
} else {
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x managedVars", module);
|
||||
}
|
||||
|
||||
auto fit = functions_.begin();
|
||||
while (fit != functions_.end()) {
|
||||
if (fit->second->moduleInfo() == module) {
|
||||
delete fit->second;
|
||||
fit = functions_.erase(fit);
|
||||
} else {
|
||||
++fit;
|
||||
auto hostFuncsIter = module_to_hostFunctions_.find(module);
|
||||
if (hostFuncsIter != module_to_hostFunctions_.end()) {
|
||||
for (auto& hostFunc : hostFuncsIter->second) {
|
||||
auto funcIter = functions_.find(hostFunc);
|
||||
if (funcIter == functions_.end()) {
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostFunc 0x%x",
|
||||
module, hostFunc);
|
||||
} else {
|
||||
delete funcIter->second;
|
||||
functions_.erase(funcIter);
|
||||
}
|
||||
}
|
||||
module_to_hostFunctions_.erase(hostFuncsIter);
|
||||
} else {
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostFuncs", module);
|
||||
}
|
||||
|
||||
auto mit = modules_.begin();
|
||||
while (mit != modules_.end()) {
|
||||
if (&mit->second == module) {
|
||||
delete mit->second;
|
||||
mit = modules_.erase(mit);
|
||||
auto hostModuleIter = module_to_hostModule_.find(module);
|
||||
if (hostModuleIter != module_to_hostModule_.end()) {
|
||||
auto hostModule = hostModuleIter->second;
|
||||
auto moduleIter = modules_.find(hostModule);
|
||||
if (moduleIter != modules_.end()) {
|
||||
delete moduleIter->second;
|
||||
modules_.erase(moduleIter);
|
||||
} else {
|
||||
++mit;
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x via hostModule 0x%x",
|
||||
module, hostModule);
|
||||
}
|
||||
module_to_hostModule_.erase(hostModuleIter);
|
||||
} else {
|
||||
LogPrintfError("removeFatBinary: Unable to find module 0x%x hostModule", module);
|
||||
}
|
||||
|
||||
return hipSuccess;
|
||||
@@ -1361,6 +1383,7 @@ hipError_t StatCO::registerStatFunction(const void* hostFunction, Function* func
|
||||
delete func;
|
||||
} else {
|
||||
functions_.insert(std::make_pair(hostFunction, func));
|
||||
module_to_hostFunctions_[func->moduleInfo()].push_back(hostFunction);
|
||||
}
|
||||
|
||||
return hipSuccess;
|
||||
@@ -1381,6 +1404,17 @@ hipError_t StatCO::getStatFunc(hipFunction_t* hfunc, const void* hostFunction, i
|
||||
if (it == functions_.end()) {
|
||||
return hipErrorInvalidSymbol;
|
||||
}
|
||||
|
||||
// Lazy load
|
||||
FatBinaryInfo **module = it->second->moduleInfo();
|
||||
if (*(module) == nullptr) {
|
||||
amd::ScopedLock lock(sclock_);
|
||||
if (*(module) == nullptr) {
|
||||
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
|
||||
assert(err == hipSuccess);
|
||||
}
|
||||
}
|
||||
|
||||
return it->second->getStatFunc(hfunc, deviceId);
|
||||
}
|
||||
|
||||
@@ -1393,6 +1427,13 @@ hipError_t StatCO::getStatFuncAttr(hipFuncAttributes* func_attr, const void* hos
|
||||
return hipErrorInvalidSymbol;
|
||||
}
|
||||
|
||||
// Lazy load
|
||||
FatBinaryInfo **module = it->second->moduleInfo();
|
||||
if (*(module) == nullptr) {
|
||||
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
|
||||
assert(err == hipSuccess);
|
||||
}
|
||||
|
||||
return it->second->getStatFuncAttr(func_attr, deviceId);
|
||||
}
|
||||
|
||||
@@ -1405,6 +1446,7 @@ hipError_t StatCO::registerStatGlobalVar(const void* hostVar, Var* var) {
|
||||
}
|
||||
|
||||
vars_.insert(std::make_pair(hostVar, var));
|
||||
module_to_hostVars_[var->moduleInfo()].push_back(hostVar);
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
@@ -1417,6 +1459,13 @@ hipError_t StatCO::getStatGlobalVar(const void* hostVar, int deviceId, hipDevice
|
||||
return hipErrorInvalidSymbol;
|
||||
}
|
||||
|
||||
// Lazy load
|
||||
FatBinaryInfo **module = it->second->moduleInfo();
|
||||
if (*(module) == nullptr) {
|
||||
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
|
||||
assert(err == hipSuccess);
|
||||
}
|
||||
|
||||
DeviceVar* dvar = nullptr;
|
||||
IHIP_RETURN_ONFAIL(it->second->getStatDeviceVar(&dvar, deviceId));
|
||||
|
||||
@@ -1426,7 +1475,7 @@ hipError_t StatCO::getStatGlobalVar(const void* hostVar, int deviceId, hipDevice
|
||||
}
|
||||
|
||||
hipError_t StatCO::registerStatManagedVar(Var* var) {
|
||||
managedVars_.emplace_back(var);
|
||||
managedVars_[var->moduleInfo()].push_back(var);
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
@@ -1435,17 +1484,26 @@ hipError_t StatCO::initStatManagedVarDevicePtr(int deviceId) {
|
||||
hipError_t err = hipSuccess;
|
||||
if (managedVarsDevicePtrInitalized_.find(deviceId) == managedVarsDevicePtrInitalized_.end() ||
|
||||
!managedVarsDevicePtrInitalized_[deviceId]) {
|
||||
for (auto var : managedVars_) {
|
||||
DeviceVar* dvar = nullptr;
|
||||
IHIP_RETURN_ONFAIL(var->getStatDeviceVar(&dvar, deviceId));
|
||||
for (auto& vecIter : managedVars_) {
|
||||
for (auto& var : vecIter.second) {
|
||||
// Lazy load
|
||||
FatBinaryInfo **module = var->moduleInfo();
|
||||
if (*(module) == nullptr) {
|
||||
hipError_t err = digestFatBinary(module_to_hostModule_[module], *module);
|
||||
assert(err == hipSuccess);
|
||||
}
|
||||
|
||||
hip::Stream* stream = g_devices.at(deviceId)->NullStream();
|
||||
if (stream != nullptr) {
|
||||
err = ihipMemcpy(reinterpret_cast<address>(dvar->device_ptr()), var->getManagedVarPtr(),
|
||||
dvar->size(), hipMemcpyHostToDevice, *stream);
|
||||
} else {
|
||||
ClPrint(amd::LOG_ERROR, amd::LOG_API, "Host Queue is NULL");
|
||||
return hipErrorInvalidResourceHandle;
|
||||
DeviceVar* dvar = nullptr;
|
||||
IHIP_RETURN_ONFAIL(var->getStatDeviceVar(&dvar, deviceId));
|
||||
|
||||
hip::Stream* stream = g_devices.at(deviceId)->NullStream();
|
||||
if (stream != nullptr) {
|
||||
err = ihipMemcpy(reinterpret_cast<address>(dvar->device_ptr()), var->getManagedVarPtr(),
|
||||
dvar->size(), hipMemcpyHostToDevice, *stream);
|
||||
} else {
|
||||
ClPrint(amd::LOG_ERROR, amd::LOG_API, "Host Queue is NULL");
|
||||
return hipErrorInvalidResourceHandle;
|
||||
}
|
||||
}
|
||||
}
|
||||
managedVarsDevicePtrInitalized_[deviceId] = true;
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador