Merge "Fix failure to get global variables" into amd-master-next
Tento commit je obsažen v:
@@ -157,16 +157,215 @@ extern "C" void __hipRegisterFunction(
|
||||
g_functions.insert(std::make_pair(hostFunction, std::move(functions)));
|
||||
}
|
||||
|
||||
static inline const char* hsa_strerror(hsa_status_t status) {
|
||||
const char* str = nullptr;
|
||||
if (hsa_status_string(status, &str) == HSA_STATUS_SUCCESS) {
|
||||
return str;
|
||||
}
|
||||
return "Unknown error";
|
||||
}
|
||||
|
||||
struct RegisteredVar {
|
||||
public:
|
||||
RegisteredVar(): size_(0), devicePtr_(nullptr) {}
|
||||
~RegisteredVar() {}
|
||||
|
||||
static inline const char* hsa_strerror(hsa_status_t status) {
|
||||
const char* str = nullptr;
|
||||
if (hsa_status_string(status, &str) == HSA_STATUS_SUCCESS) {
|
||||
return str;
|
||||
}
|
||||
return "Unknown error";
|
||||
}
|
||||
|
||||
hipDeviceptr_t getdeviceptr() const { return devicePtr_; };
|
||||
size_t getvarsize() const { return size_; };
|
||||
|
||||
size_t size_; // Size of the variable
|
||||
hipDeviceptr_t devicePtr_; //Device Memory Address of the variable.
|
||||
};
|
||||
|
||||
struct DeviceVar {
|
||||
void* shadowVptr;
|
||||
std::string hostVar;
|
||||
size_t size;
|
||||
std::vector<hipModule_t>* modules;
|
||||
std::vector<RegisteredVar> rvars;
|
||||
bool dyn_undef;
|
||||
};
|
||||
|
||||
std::unordered_multimap<std::string, DeviceVar > g_vars;
|
||||
|
||||
//The logic follows PlatformState::getGlobalVar in VDI RT
|
||||
static DeviceVar* findVar(std::string hostVar, int deviceId, hipModule_t hmod) {
|
||||
DeviceVar* dvar = nullptr;
|
||||
if (hmod != nullptr) {
|
||||
// If module is provided, then get the var only from that module
|
||||
auto var_range = g_vars.equal_range(hostVar);
|
||||
for (auto it = var_range.first; it != var_range.second; ++it) {
|
||||
if ((*it->second.modules)[deviceId] == hmod) {
|
||||
dvar = &(it->second);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If var count is < 2, return the var
|
||||
if (g_vars.count(hostVar) < 2) {
|
||||
auto it = g_vars.find(hostVar);
|
||||
dvar = ((it == g_vars.end()) ? nullptr : &(it->second));
|
||||
} else {
|
||||
// If var count is > 2, return the original var,
|
||||
// if original var count != 1, return g_vars.end()/Invalid
|
||||
size_t orig_global_count = 0;
|
||||
auto var_range = g_vars.equal_range(hostVar);
|
||||
for (auto it = var_range.first; it != var_range.second; ++it) {
|
||||
// when dyn_undef is set, it is a shadow var
|
||||
if (it->second.dyn_undef == false) {
|
||||
++orig_global_count;
|
||||
dvar = &(it->second);
|
||||
}
|
||||
}
|
||||
dvar = ((orig_global_count == 1) ? dvar : nullptr);
|
||||
}
|
||||
}
|
||||
return dvar;
|
||||
}
|
||||
|
||||
hipError_t ihipGetGlobalVar(hipDeviceptr_t* dev_ptr, size_t* size_ptr,
|
||||
const char* hostVar, hipModule_t hmod) {
|
||||
GET_TLS();
|
||||
auto ctx = ihipGetTlsDefaultCtx();
|
||||
|
||||
if (!ctx) return hipErrorInvalidValue;
|
||||
|
||||
auto device = ctx->getDevice();
|
||||
|
||||
if (!device) return hipErrorInvalidValue;
|
||||
|
||||
ihipDevice_t* currentDevice = ihipGetDevice(device->_deviceId);
|
||||
|
||||
if (!currentDevice) return hipErrorInvalidValue;
|
||||
|
||||
int deviceId = device->_deviceId;
|
||||
|
||||
DeviceVar* dvar = findVar(std::string(hostVar), deviceId, hmod);
|
||||
if (dvar == nullptr) return hipErrorInvalidValue;
|
||||
|
||||
if (dvar->rvars[deviceId].getdeviceptr() == nullptr) return hipErrorInvalidValue;
|
||||
|
||||
*size_ptr = dvar->rvars[deviceId].getvarsize();
|
||||
*dev_ptr = dvar->rvars[deviceId].getdeviceptr();
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
static bool createGlobalVarObj(const hsa_executable_t& hsaExecutable, const hsa_agent_t& hasAgent,
|
||||
const char* global_name, void** device_pptr, size_t* bytes) {
|
||||
hsa_status_t status = HSA_STATUS_SUCCESS;
|
||||
hsa_symbol_kind_t sym_type;
|
||||
hsa_executable_symbol_t global_symbol;
|
||||
std::string buildLog;
|
||||
|
||||
/* Find HSA Symbol by name */
|
||||
status = hsa_executable_get_symbol_by_name(hsaExecutable, global_name, &hasAgent,
|
||||
&global_symbol);
|
||||
if (status != HSA_STATUS_SUCCESS) {
|
||||
buildLog += "Error: Failed to find the Symbol by Name: ";
|
||||
buildLog += hsa_strerror(status);
|
||||
tprintf(DB_FB, "createGlobalVarObj: %s\n", buildLog.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Find HSA Symbol Type */
|
||||
status = hsa_executable_symbol_get_info(global_symbol, HSA_EXECUTABLE_SYMBOL_INFO_TYPE,
|
||||
&sym_type);
|
||||
if (status != HSA_STATUS_SUCCESS) {
|
||||
buildLog += "Error: Failed to find the Symbol Type : ";
|
||||
buildLog += hsa_strerror(status);
|
||||
tprintf(DB_FB, "createGlobalVarObj: %s\n", buildLog.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Make sure symbol type is VARIABLE */
|
||||
if (sym_type != HSA_SYMBOL_KIND_VARIABLE) {
|
||||
buildLog += "Error: Symbol is not of type VARIABLE : ";
|
||||
buildLog += hsa_strerror(status);
|
||||
tprintf(DB_FB, "createGlobalVarObj: %s\n", buildLog.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Retrieve the size of the variable */
|
||||
status = hsa_executable_symbol_get_info(global_symbol, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SIZE, bytes);
|
||||
|
||||
if (status != HSA_STATUS_SUCCESS) {
|
||||
buildLog += "Error: Failed to retrieve the Symbol Size : ";
|
||||
buildLog += hsa_strerror(status);
|
||||
tprintf(DB_FB, "createGlobalVarObj: %s\n", buildLog.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Find HSA Symbol Address */
|
||||
status = hsa_executable_symbol_get_info(global_symbol,
|
||||
HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS, device_pptr);
|
||||
if (status != HSA_STATUS_SUCCESS) {
|
||||
buildLog += "Error: Failed to find the Symbol Address : ";
|
||||
buildLog += hsa_strerror(status);
|
||||
tprintf(DB_FB, "createGlobalVarObj: %s\n", buildLog.c_str());
|
||||
return false;
|
||||
} else {
|
||||
tprintf(DB_FB, "createGlobalVarObj: var %s : device=%p, size=%zu\n", global_name, *device_pptr, *bytes);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Registers a device-side global variable.
|
||||
// For each global variable in device code, there is a corresponding shadow
|
||||
// global variable in host code. The shadow host variable is used to keep
|
||||
// track of the value of the device side global variable between kernel
|
||||
// executions.
|
||||
// The basic logic is taken from VDI RT, but there is much difference.
|
||||
extern "C" void __hipRegisterVar(
|
||||
std::vector<hipModule_t>* modules,
|
||||
char* hostVar,
|
||||
char* deviceVar,
|
||||
const char* deviceName,
|
||||
int ext,
|
||||
int size,
|
||||
int constant,
|
||||
int global)
|
||||
std::vector<hipModule_t>* modules, // The device modules containing code object
|
||||
char* var, // The shadow variable in host code
|
||||
char* hostVar, // Variable name in host code
|
||||
const char* deviceVar, // Variable name in device code
|
||||
int ext, // Whether this variable is external
|
||||
int size, // Size of the variable
|
||||
int constant, // Whether this variable is constant
|
||||
int global) // Unknown, always 0
|
||||
{
|
||||
HIP_INIT_API(__hipRegisterVar, modules, var, hostVar, deviceVar, ext, size, constant, global);
|
||||
|
||||
DeviceVar dvar{var, std::string{ hostVar }, static_cast<size_t>(size), modules,
|
||||
std::vector<RegisteredVar>{ g_deviceCnt }, false };
|
||||
|
||||
for (int deviceId = 0; deviceId < g_deviceCnt; deviceId++) {
|
||||
auto device = ihipGetDevice(deviceId);
|
||||
if(!device) {
|
||||
continue;
|
||||
}
|
||||
hsa_executable_t& executable = (*modules)[deviceId]->executable;
|
||||
hsa_agent_t& agent = g_allAgents[deviceId + 1];
|
||||
size_t bytes = 0;
|
||||
hipDeviceptr_t devicePtr = nullptr;
|
||||
|
||||
bool success = createGlobalVarObj(executable, agent, hostVar, &devicePtr, &bytes);
|
||||
if(!success) {
|
||||
return;
|
||||
}
|
||||
dvar.rvars[deviceId].devicePtr_ = devicePtr;
|
||||
dvar.rvars[deviceId].size_ = bytes;
|
||||
|
||||
hc::AmPointerInfo ptrInfo(nullptr, devicePtr, devicePtr, bytes, device->_acc, true, false);
|
||||
hc::am_memtracker_add(devicePtr, ptrInfo);
|
||||
|
||||
#if USE_APP_PTR_FOR_CTX
|
||||
hc::am_memtracker_update(devicePtr, device->_deviceId, 0u, ihipGetTlsDefaultCtx());
|
||||
#else
|
||||
hc::am_memtracker_update(devicePtr, device->_deviceId, 0u);
|
||||
#endif
|
||||
}
|
||||
g_vars.insert(std::make_pair(std::string(hostVar), dvar));
|
||||
}
|
||||
|
||||
extern "C" void __hipUnregisterFatBinary(std::vector<hipModule_t>* modules)
|
||||
|
||||
@@ -982,6 +982,18 @@ hipStream_t ihipSyncAndResolveStream(hipStream_t, bool lockAcquired = 0);
|
||||
hipError_t ihipStreamSynchronize(TlsData *tls, hipStream_t stream);
|
||||
void ihipStreamCallbackHandler(ihipStreamCallback_t* cb);
|
||||
|
||||
/**
|
||||
* @brief Copies the memory address and size of symbol @p symbolName
|
||||
*
|
||||
* @param[in] symbolName - Symbol on device
|
||||
* @param[out] devPtr - Pointer to a pointer to the memory referred to by the symbol
|
||||
* @param[out] size - Pointer to the size of the symbol
|
||||
* @return #hipSuccess, #hipErrorNotInitialized, #hipErrorNotFound, #hipErrorInvalidValue
|
||||
*
|
||||
*/
|
||||
hipError_t ihipGetGlobalVar(hipDeviceptr_t* dev_ptr, size_t* size_ptr, const char* hostVar,
|
||||
hipModule_t hmod = nullptr);
|
||||
|
||||
// Stream printf functions:
|
||||
inline std::ostream& operator<<(std::ostream& os, const ihipStream_t& s) {
|
||||
os << "stream:";
|
||||
|
||||
@@ -879,7 +879,12 @@ namespace hip_impl {
|
||||
|
||||
hipError_t agent_globals::read_agent_global_from_process(hipDeviceptr_t* dptr, size_t* bytes,
|
||||
const char* name) {
|
||||
return impl->read_agent_global_from_process(dptr, bytes, name);
|
||||
hipError_t result = impl->read_agent_global_from_process(dptr, bytes, name);
|
||||
if(result != hipSuccess) {
|
||||
// For Clang Compiler + Hcc Rt
|
||||
result = ihipGetGlobalVar(dptr, bytes, name);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // Namespace hip_impl.
|
||||
|
||||
Odkázat v novém úkolu
Zablokovat Uživatele