diff --git a/vdi/hip_internal.hpp b/vdi/hip_internal.hpp index 18521c4890..529ca7abdb 100644 --- a/vdi/hip_internal.hpp +++ b/vdi/hip_internal.hpp @@ -265,8 +265,11 @@ private: std::unordered_map functions_; std::unordered_multimap vars_; - // Map from the host shadow symbol to its device name. - std::unordered_map symbols_; + // Map from the host shadow symbol to its device name. As different modules + // may have the same name, each symbol is uniquely identified by a pair of + // module handle and its name. + std::unordered_map> symbols_; static PlatformState* platform_; @@ -286,9 +289,9 @@ public: std::vector< std::pair >* unregisterVar(hipModule_t hmod); - bool findSymbol(const void *hostVar, std::string &devName); + bool findSymbol(const void *hostVar, hipModule_t &hmod, std::string &devName); PlatformState::DeviceVar* findVar(std::string hostVar, int deviceId, hipModule_t hmod); - void registerVarSym(const void *hostVar, const char *symbolName); + void registerVarSym(const void *hostVar, hipModule_t hmod, const char *symbolName); void registerVar(const char* symbolName, const DeviceVar& var); void registerFunction(const void* hostFunction, const DeviceFunction& func); diff --git a/vdi/hip_memory.cpp b/vdi/hip_memory.cpp index 2b3daab894..967952f03f 100755 --- a/vdi/hip_memory.cpp +++ b/vdi/hip_memory.cpp @@ -716,13 +716,14 @@ hipError_t hipMemcpyToSymbol(const void* symbol, const void* src, size_t count, size_t sym_size = 0; hipDeviceptr_t device_ptr = nullptr; + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(symbol, symbolName)) { + if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) { DevLogPrintfError("cannot find symbol 0x%x \n", symbolName.c_str()); HIP_RETURN(hipErrorInvalidSymbol); } /* Get address and size for the global symbol */ - if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr, + if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod, &device_ptr, &sym_size)) { DevLogPrintfError("Cannot get global var: %s at device: %d \n", symbolName.c_str(), ihipGetDevice()); HIP_RETURN(hipErrorInvalidSymbol); @@ -748,13 +749,14 @@ hipError_t hipMemcpyFromSymbol(void* dst, const void* symbol, size_t count, size_t sym_size = 0; hipDeviceptr_t device_ptr = nullptr; + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(symbol, symbolName)) { + if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) { DevLogPrintfError("cannot find symbol: 0x%x \n", symbol); HIP_RETURN(hipErrorInvalidSymbol); } /* Get address and size for the global symbol */ - if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr, + if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod, &device_ptr, &sym_size)) { DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str()); HIP_RETURN(hipErrorInvalidSymbol); @@ -780,13 +782,14 @@ hipError_t hipMemcpyToSymbolAsync(const void* symbol, const void* src, size_t co size_t sym_size = 0; hipDeviceptr_t device_ptr = nullptr; + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(symbol, symbolName)) { + if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) { DevLogPrintfError("cannot find symbol: 0x%x \n", symbol); HIP_RETURN(hipErrorInvalidSymbol); } /* Get address and size for the global symbol */ - if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr, + if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod, &device_ptr, &sym_size)) { DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str()); HIP_RETURN(hipErrorInvalidSymbol); @@ -812,13 +815,14 @@ hipError_t hipMemcpyFromSymbolAsync(void* dst, const void* symbol, size_t count, size_t sym_size = 0; hipDeviceptr_t device_ptr = nullptr; + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(symbol, symbolName)) { + if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) { DevLogPrintfError("cannot find symbol: 0x%x \n", symbol); HIP_RETURN(hipErrorInvalidSymbol); } /* Get address and size for the global symbol */ - if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr, + if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod, &device_ptr, &sym_size)) { DevLogPrintfError("Cannot find symbol Name: %s \n", symbolName.c_str()); HIP_RETURN(hipErrorInvalidSymbol); diff --git a/vdi/hip_platform.cpp b/vdi/hip_platform.cpp index c5a9099bf5..217d152688 100755 --- a/vdi/hip_platform.cpp +++ b/vdi/hip_platform.cpp @@ -268,19 +268,21 @@ PlatformState::DeviceVar* PlatformState::findVar(std::string hostVar, int device return dvar; } -bool PlatformState::findSymbol(const void *hostVar, std::string &symbolName) { +bool PlatformState::findSymbol(const void *hostVar, + hipModule_t &hmod, std::string &symbolName) { auto it = symbols_.find(hostVar); if (it != symbols_.end()) { - symbolName = it->second; + hmod = it->second.first; + symbolName = it->second.second; return true; } DevLogPrintfError("Could not find the Symbol: %s \n", symbolName.c_str()); return false; } -void PlatformState::registerVarSym(const void *hostVar, const char *symbolName) { +void PlatformState::registerVarSym(const void* hostVar, hipModule_t hmod, const char* symbolName) { amd::ScopedLock lock(lock_); - symbols_.insert(std::make_pair(hostVar, std::string(symbolName))); + symbols_.insert(std::make_pair(hostVar, std::make_pair(hmod, std::string(symbolName)))); } void PlatformState::registerVar(const char* hostvar, @@ -494,7 +496,7 @@ bool PlatformState::getTexRef(const char* hostVar, hipModule_t hmod, textureRefe dvar->shadowAllocated = true; } *texRef = reinterpret_cast(dvar->shadowVptr); - registerVarSym(dvar->shadowVptr, hostVar); + registerVarSym(dvar->shadowVptr, hmod, hostVar); return true; } @@ -541,12 +543,13 @@ bool PlatformState::getGlobalVar(const char* hostVar, int deviceId, hipModule_t bool PlatformState::getGlobalVarFromSymbol(const void* hostVar, int deviceId, hipDeviceptr_t* dev_ptr, size_t* size_ptr) { + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(hostVar, symbolName)) { + if (!PlatformState::instance().findSymbol(hostVar, hmod, symbolName)) { return false; } return PlatformState::instance().getGlobalVar(symbolName.c_str(), - ihipGetDevice(), nullptr, + ihipGetDevice(), hmod, dev_ptr, size_ptr); } @@ -615,7 +618,7 @@ extern "C" void __hipRegisterVar( /*norm*/ 0}; PlatformState::instance().registerVar(hostVar, dvar); - PlatformState::instance().registerVarSym(var, deviceVar); + PlatformState::instance().registerVarSym(var, nullptr, deviceVar); } extern "C" void __hipRegisterSurface(std::vector>* @@ -634,7 +637,7 @@ extern "C" void __hipRegisterSurface(std::vector>* type, /*norm*/ 0}; PlatformState::instance().registerVar(hostVar, dvar); - PlatformState::instance().registerVarSym(var, deviceVar); + PlatformState::instance().registerVarSym(var, nullptr, deviceVar); } extern "C" void __hipRegisterTexture(std::vector>* @@ -653,7 +656,7 @@ extern "C" void __hipRegisterTexture(std::vector>* type, norm}; PlatformState::instance().registerVar(hostVar, dvar); - PlatformState::instance().registerVarSym(var, deviceVar); + PlatformState::instance().registerVarSym(var, nullptr, deviceVar); } extern "C" void __hipUnregisterFatBinary(std::vector< std::pair >* modules) @@ -760,13 +763,14 @@ extern "C" hipError_t hipLaunchByPtr(const void *hostFunction) hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol) { HIP_INIT_API(hipGetSymbolAddress, devPtr, symbol); + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(symbol, symbolName)) { + if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) { DevLogPrintfError("Cannot find symbol: %s \n", symbolName.c_str()); HIP_RETURN(hipErrorInvalidSymbol); } size_t size = 0; - if(!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr, + if(!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod, devPtr, &size)) { DevLogPrintfError("Cannot find global variable device ptr for symbol: %s at device: %d \n", symbolName.c_str(), ihipGetDevice()); @@ -778,13 +782,14 @@ hipError_t hipGetSymbolAddress(void** devPtr, const void* symbol) { hipError_t hipGetSymbolSize(size_t* sizePtr, const void* symbol) { HIP_INIT_API(hipGetSymbolSize, sizePtr, symbol); + hipModule_t hmod; std::string symbolName; - if (!PlatformState::instance().findSymbol(symbol, symbolName)) { + if (!PlatformState::instance().findSymbol(symbol, hmod, symbolName)) { DevLogPrintfError("Cannot find symbol: %s \n", symbolName.c_str()); HIP_RETURN(hipErrorInvalidSymbol); } hipDeviceptr_t devPtr = nullptr; - if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), nullptr, + if (!PlatformState::instance().getGlobalVar(symbolName.c_str(), ihipGetDevice(), hmod, &devPtr, sizePtr)) { DevLogPrintfError("Cannot find global variable device ptr for symbol: %s at device: %d \n", symbolName.c_str(), ihipGetDevice());