[vdi] Revise the symbol management.
- As different modules may have symbols with the same name, each symbol needs identifying with a pair of the module handle and the symbol name. Change-Id: I85650a787d9a424545154cc40ebd59e706fa358f
This commit is contained in:
@@ -265,8 +265,11 @@ private:
|
||||
|
||||
std::unordered_map<const void*, DeviceFunction > functions_;
|
||||
std::unordered_multimap<std::string, DeviceVar > vars_;
|
||||
// Map from the host shadow symbol to its device name.
|
||||
std::unordered_map<const void*, std::string> symbols_;
|
||||
// Map from the host shadow symbol to its device name. As different modules
|
||||
// may have the same name, each symbol is uniquely identified by a pair of
|
||||
// module handle and its name.
|
||||
std::unordered_map<const void*,
|
||||
std::pair<hipModule_t, std::string>> symbols_;
|
||||
|
||||
static PlatformState* platform_;
|
||||
|
||||
@@ -286,9 +289,9 @@ public:
|
||||
std::vector< std::pair<hipModule_t, bool> >* unregisterVar(hipModule_t hmod);
|
||||
|
||||
|
||||
bool findSymbol(const void *hostVar, std::string &devName);
|
||||
bool findSymbol(const void *hostVar, hipModule_t &hmod, std::string &devName);
|
||||
PlatformState::DeviceVar* findVar(std::string hostVar, int deviceId, hipModule_t hmod);
|
||||
void registerVarSym(const void *hostVar, const char *symbolName);
|
||||
void registerVarSym(const void *hostVar, hipModule_t hmod, const char *symbolName);
|
||||
void registerVar(const char* symbolName, const DeviceVar& var);
|
||||
void registerFunction(const void* hostFunction, const DeviceFunction& func);
|
||||
|
||||
|
||||
+12
-8
@@ -716,13 +716,14 @@ hipError_t hipMemcpyToSymbol(const void* symbol, const void* src, size_t count,
|
||||
size_t sym_size = 0;
|
||||
hipDeviceptr_t device_ptr = nullptr;
|
||||
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
|
||||
DevLogPrintfError("cannot find symbol 0x%x \n", symbolName.c_str());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
}
|
||||
/* Get address and size for the global symbol */
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
|
||||
&device_ptr, &sym_size)) {
|
||||
DevLogPrintfError("Cannot get global var: %s at device: %d \n", symbolName.c_str(), ihipGetDevice());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
@@ -748,13 +749,14 @@ hipError_t hipMemcpyFromSymbol(void* dst, const void* symbol, size_t count,
|
||||
size_t sym_size = 0;
|
||||
hipDeviceptr_t device_ptr = nullptr;
|
||||
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
|
||||
DevLogPrintfError("cannot find symbol: 0x%x \n", symbol);
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
}
|
||||
/* Get address and size for the global symbol */
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
|
||||
&device_ptr, &sym_size)) {
|
||||
DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
@@ -780,13 +782,14 @@ hipError_t hipMemcpyToSymbolAsync(const void* symbol, const void* src, size_t co
|
||||
size_t sym_size = 0;
|
||||
hipDeviceptr_t device_ptr = nullptr;
|
||||
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
|
||||
DevLogPrintfError("cannot find symbol: 0x%x \n", symbol);
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
}
|
||||
/* Get address and size for the global symbol */
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
|
||||
&device_ptr, &sym_size)) {
|
||||
DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
@@ -812,13 +815,14 @@ hipError_t hipMemcpyFromSymbolAsync(void* dst, const void* symbol, size_t count,
|
||||
size_t sym_size = 0;
|
||||
hipDeviceptr_t device_ptr = nullptr;
|
||||
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
|
||||
DevLogPrintfError("cannot find symbol: 0x%x \n", symbol);
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
}
|
||||
/* Get address and size for the global symbol */
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
|
||||
&device_ptr, &sym_size)) {
|
||||
DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
|
||||
+19
-14
@@ -268,19 +268,21 @@ PlatformState::DeviceVar* PlatformState::findVar(std::string hostVar, int device
|
||||
return dvar;
|
||||
}
|
||||
|
||||
bool PlatformState::findSymbol(const void *hostVar, std::string &symbolName) {
|
||||
bool PlatformState::findSymbol(const void *hostVar,
|
||||
hipModule_t &hmod, std::string &symbolName) {
|
||||
auto it = symbols_.find(hostVar);
|
||||
if (it != symbols_.end()) {
|
||||
symbolName = it->second;
|
||||
hmod = it->second.first;
|
||||
symbolName = it->second.second;
|
||||
return true;
|
||||
}
|
||||
DevLogPrintfError("Could not find the Symbol: %s \n", symbolName.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
void PlatformState::registerVarSym(const void *hostVar, const char *symbolName) {
|
||||
void PlatformState::registerVarSym(const void* hostVar, hipModule_t hmod, const char* symbolName) {
|
||||
amd::ScopedLock lock(lock_);
|
||||
symbols_.insert(std::make_pair(hostVar, std::string(symbolName)));
|
||||
symbols_.insert(std::make_pair(hostVar, std::make_pair(hmod, std::string(symbolName))));
|
||||
}
|
||||
|
||||
void PlatformState::registerVar(const char* hostvar,
|
||||
@@ -494,7 +496,7 @@ bool PlatformState::getTexRef(const char* hostVar, hipModule_t hmod, textureRefe
|
||||
dvar->shadowAllocated = true;
|
||||
}
|
||||
*texRef = reinterpret_cast<textureReference *>(dvar->shadowVptr);
|
||||
registerVarSym(dvar->shadowVptr, hostVar);
|
||||
registerVarSym(dvar->shadowVptr, hmod, hostVar);
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -541,12 +543,13 @@ bool PlatformState::getGlobalVar(const char* hostVar, int deviceId, hipModule_t
|
||||
bool PlatformState::getGlobalVarFromSymbol(const void* hostVar, int deviceId,
|
||||
hipDeviceptr_t* dev_ptr,
|
||||
size_t* size_ptr) {
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(hostVar, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(hostVar, hmod, symbolName)) {
|
||||
return false;
|
||||
}
|
||||
return PlatformState::instance().getGlobalVar(symbolName.c_str(),
|
||||
ihipGetDevice(), nullptr,
|
||||
ihipGetDevice(), hmod,
|
||||
dev_ptr, size_ptr);
|
||||
}
|
||||
|
||||
@@ -615,7 +618,7 @@ extern "C" void __hipRegisterVar(
|
||||
/*norm*/ 0};
|
||||
|
||||
PlatformState::instance().registerVar(hostVar, dvar);
|
||||
PlatformState::instance().registerVarSym(var, deviceVar);
|
||||
PlatformState::instance().registerVarSym(var, nullptr, deviceVar);
|
||||
}
|
||||
|
||||
extern "C" void __hipRegisterSurface(std::vector<std::pair<hipModule_t, bool>>*
|
||||
@@ -634,7 +637,7 @@ extern "C" void __hipRegisterSurface(std::vector<std::pair<hipModule_t, bool>>*
|
||||
type,
|
||||
/*norm*/ 0};
|
||||
PlatformState::instance().registerVar(hostVar, dvar);
|
||||
PlatformState::instance().registerVarSym(var, deviceVar);
|
||||
PlatformState::instance().registerVarSym(var, nullptr, deviceVar);
|
||||
}
|
||||
|
||||
extern "C" void __hipRegisterTexture(std::vector<std::pair<hipModule_t, bool>>*
|
||||
@@ -653,7 +656,7 @@ extern "C" void __hipRegisterTexture(std::vector<std::pair<hipModule_t, bool>>*
|
||||
type,
|
||||
norm};
|
||||
PlatformState::instance().registerVar(hostVar, dvar);
|
||||
PlatformState::instance().registerVarSym(var, deviceVar);
|
||||
PlatformState::instance().registerVarSym(var, nullptr, deviceVar);
|
||||
}
|
||||
|
||||
extern "C" void __hipUnregisterFatBinary(std::vector< std::pair<hipModule_t, bool> >* modules)
|
||||
@@ -760,13 +763,14 @@ extern "C" hipError_t hipLaunchByPtr(const void *hostFunction)
|
||||
hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol) {
|
||||
HIP_INIT_API(hipGetSymbolAddress, devPtr, symbol);
|
||||
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
|
||||
DevLogPrintfError("Cannot find symbol: %s \n", symbolName.c_str());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
}
|
||||
size_t size = 0;
|
||||
if(!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
|
||||
if(!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
|
||||
devPtr, &size)) {
|
||||
DevLogPrintfError("Cannot find global variable device ptr for symbol: %s at device: %d \n",
|
||||
symbolName.c_str(), ihipGetDevice());
|
||||
@@ -778,13 +782,14 @@ hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol) {
|
||||
hipError_t hipGetSymbolSize(size_t* sizePtr, const void* symbol) {
|
||||
HIP_INIT_API(hipGetSymbolSize, sizePtr, symbol);
|
||||
|
||||
hipModule_t hmod;
|
||||
std::string symbolName;
|
||||
if (!PlatformState::instance().findSymbol(symbol, symbolName)) {
|
||||
if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) {
|
||||
DevLogPrintfError("Cannot find symbol: %s \n", symbolName.c_str());
|
||||
HIP_RETURN(hipErrorInvalidSymbol);
|
||||
}
|
||||
hipDeviceptr_t devPtr = nullptr;
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr,
|
||||
if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod,
|
||||
&devPtr, sizePtr)) {
|
||||
DevLogPrintfError("Cannot find global variable device ptr for symbol: %s at device: %d \n",
|
||||
symbolName.c_str(), ihipGetDevice());
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user