[vdi] Revise the symbol management.

- As different modules may have symbols with the same name, each symbol
  needs identifying with a pair of the module handle and the symbol
  name.

Change-Id: I85650a787d9a424545154cc40ebd59e706fa358f
This commit is contained in:
Michael LIAO
2020-04-24 22:18:23 -04:00
parent f7f7337bae
commit a0acf7bdaa
3 ha cambiato i file con 38 aggiunte e 26 eliminazioni
+7 -4
Vedi File
@@ -265,8 +265,11 @@ private:
std::unordered_map<const void*, DeviceFunction > functions_;
std::unordered_multimap<std::string, DeviceVar > vars_;
// Map from the host shadow symbol to its device name.
std::unordered_map<const void*, std::string> symbols_;
// Map from the host shadow symbol to its device name. As different modules
// may have the same name, each symbol is uniquely identified by a pair of
// module handle and its name.
std::unordered_map<const void*,
std::pair<hipModule_t, std::string>> symbols_;
static PlatformState* platform_;
@@ -286,9 +289,9 @@ public:
std::vector< std::pair<hipModule_t, bool> >* unregisterVar(hipModule_t hmod);
bool findSymbol(const void *hostVar, std::string &devName);
bool findSymbol(const void *hostVar, hipModule_t &hmod, std::string &devName);
PlatformState::DeviceVar* findVar(std::string hostVar, int deviceId, hipModule_t hmod);
void registerVarSym(const void *hostVar, const char *symbolName);
void registerVarSym(const void *hostVar, hipModule_t hmod, const char *symbolName);
void registerVar(const char* symbolName, const DeviceVar& var);
void registerFunction(const void* hostFunction, const DeviceFunction& func);
+12 -8
Vedi File
@@ -716,13 +716,14 @@ hipError_t hipMemcpyToSymbol(const void* symbol, const void* src, size_t count,
size_t sym_size = 0;
hipDeviceptr_t device_ptr = nullptr;
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
DevLogPrintfError("cannot find symbol 0x%x \n", symbolName.c_str());
HIP_RETURN(hipErrorInvalidSymbol);
}
/* Get address and size for the global symbol */
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
&device_ptr, &sym_size)) {
DevLogPrintfError("Cannot get global var: %s at device: %d \n", symbolName.c_str(), ihipGetDevice());
HIP_RETURN(hipErrorInvalidSymbol);
@@ -748,13 +749,14 @@ hipError_t hipMemcpyFromSymbol(void* dst, const void* symbol, size_t count,
size_t sym_size = 0;
hipDeviceptr_t device_ptr = nullptr;
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
DevLogPrintfError("cannot find symbol: 0x%x \n", symbol);
HIP_RETURN(hipErrorInvalidSymbol);
}
/* Get address and size for the global symbol */
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
&device_ptr, &sym_size)) {
DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str());
HIP_RETURN(hipErrorInvalidSymbol);
@@ -780,13 +782,14 @@ hipError_t hipMemcpyToSymbolAsync(const void* symbol, const void* src, size_t co
size_t sym_size = 0;
hipDeviceptr_t device_ptr = nullptr;
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
DevLogPrintfError("cannot find symbol: 0x%x \n", symbol);
HIP_RETURN(hipErrorInvalidSymbol);
}
/* Get address and size for the global symbol */
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
&device_ptr, &sym_size)) {
DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str());
HIP_RETURN(hipErrorInvalidSymbol);
@@ -812,13 +815,14 @@ hipError_t hipMemcpyFromSymbolAsync(void* dst, const void* symbol, size_t count,
size_t sym_size = 0;
hipDeviceptr_t device_ptr = nullptr;
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
DevLogPrintfError("cannot find symbol: 0x%x \n", symbol);
HIP_RETURN(hipErrorInvalidSymbol);
}
/* Get address and size for the global symbol */
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
&device_ptr, &sym_size)) {
DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str());
HIP_RETURN(hipErrorInvalidSymbol);
+19 -14
Vedi File
@@ -268,19 +268,21 @@ PlatformState::DeviceVar* PlatformState::findVar(std::string hostVar, int device
return dvar;
}
bool PlatformState::findSymbol(const void *hostVar, std::string &symbolName) {
bool PlatformState::findSymbol(const void *hostVar,
hipModule_t &hmod, std::string &symbolName) {
auto it = symbols_.find(hostVar);
if (it != symbols_.end()) {
symbolName = it->second;
hmod = it->second.first;
symbolName = it->second.second;
return true;
}
DevLogPrintfError("Could not find the Symbol: %s \n", symbolName.c_str());
return false;
}
void PlatformState::registerVarSym(const void *hostVar, const char *symbolName) {
void PlatformState::registerVarSym(const void* hostVar, hipModule_t hmod, const char* symbolName) {
amd::ScopedLock lock(lock_);
symbols_.insert(std::make_pair(hostVar, std::string(symbolName)));
symbols_.insert(std::make_pair(hostVar, std::make_pair(hmod, std::string(symbolName))));
}
void PlatformState::registerVar(const char* hostvar,
@@ -494,7 +496,7 @@ bool PlatformState::getTexRef(const char* hostVar, hipModule_t hmod, textureRefe
dvar->shadowAllocated = true;
}
*texRef = reinterpret_cast<textureReference *>(dvar->shadowVptr);
registerVarSym(dvar->shadowVptr, hostVar);
registerVarSym(dvar->shadowVptr, hmod, hostVar);
return true;
}
@@ -541,12 +543,13 @@ bool PlatformState::getGlobalVar(const char* hostVar, int deviceId, hipModule_t
bool PlatformState::getGlobalVarFromSymbol(const void* hostVar, int deviceId,
hipDeviceptr_t* dev_ptr,
size_t* size_ptr) {
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(hostVar, symbolName)) {
if (!PlatformState::instance().findSymbol(hostVar, hmod, symbolName)) {
return false;
}
return PlatformState::instance().getGlobalVar(symbolName.c_str(),
ihipGetDevice(), nullptr,
ihipGetDevice(), hmod,
dev_ptr, size_ptr);
}
@@ -615,7 +618,7 @@ extern "C" void __hipRegisterVar(
/*norm*/ 0};
PlatformState::instance().registerVar(hostVar, dvar);
PlatformState::instance().registerVarSym(var, deviceVar);
PlatformState::instance().registerVarSym(var, nullptr, deviceVar);
}
extern "C" void __hipRegisterSurface(std::vector<std::pair<hipModule_t, bool>>*
@@ -634,7 +637,7 @@ extern "C" void __hipRegisterSurface(std::vector<std::pair<hipModule_t, bool>>*
type,
/*norm*/ 0};
PlatformState::instance().registerVar(hostVar, dvar);
PlatformState::instance().registerVarSym(var, deviceVar);
PlatformState::instance().registerVarSym(var, nullptr, deviceVar);
}
extern "C" void __hipRegisterTexture(std::vector<std::pair<hipModule_t, bool>>*
@@ -653,7 +656,7 @@ extern "C" void __hipRegisterTexture(std::vector<std::pair<hipModule_t, bool>>*
type,
norm};
PlatformState::instance().registerVar(hostVar, dvar);
PlatformState::instance().registerVarSym(var, deviceVar);
PlatformState::instance().registerVarSym(var, nullptr, deviceVar);
}
extern "C" void __hipUnregisterFatBinary(std::vector< std::pair<hipModule_t, bool> >* modules)
@@ -760,13 +763,14 @@ extern "C" hipError_t hipLaunchByPtr(const void *hostFunction)
hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol) {
HIP_INIT_API(hipGetSymbolAddress, devPtr, symbol);
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
DevLogPrintfError("Cannot find symbol: %s \n", symbolName.c_str());
HIP_RETURN(hipErrorInvalidSymbol);
}
size_t size = 0;
if(!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
if(!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
devPtr, &size)) {
DevLogPrintfError("Cannot find global variable device ptr for symbol: %s at device: %d \n",
symbolName.c_str(), ihipGetDevice());
@@ -778,13 +782,14 @@ hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol) {
hipError_t hipGetSymbolSize(size_t* sizePtr, const void* symbol) {
HIP_INIT_API(hipGetSymbolSize, sizePtr, symbol);
hipModule_t hmod;
std::string symbolName;
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
DevLogPrintfError("Cannot find symbol: %s \n", symbolName.c_str());
HIP_RETURN(hipErrorInvalidSymbol);
}
hipDeviceptr_t devPtr = nullptr;
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
&devPtr, sizePtr)) {
DevLogPrintfError("Cannot find global variable device ptr for symbol: %s at device: %d \n",
symbolName.c_str(), ihipGetDevice());