diff --git a/src/program_state.cpp b/src/program_state.cpp index dbd7d3ebc4..5e9f9976be 100644 --- a/src/program_state.cpp +++ b/src/program_state.cpp @@ -61,7 +61,7 @@ namespace hip_impl { if (it == impl->get_globals().end()) return nullptr; else - return it->second; + return it->second.first; } hsa_executable_t program_state::load_executable(const char* data, diff --git a/src/program_state.inl b/src/program_state.inl index 993418de96..8861558e04 100644 --- a/src/program_state.inl +++ b/src/program_state.inl @@ -18,6 +18,7 @@ #include #include #include +#include "hc.hpp" #include @@ -193,7 +194,8 @@ public: std::tuple< std::once_flag, std::mutex, - std::unordered_map> globals; + // map from string to pair + std::unordered_map>> globals; using RAII_code_reader = std::unique_ptr& get_globals() { + std::unordered_map>& get_globals() { std::call_once(std::get<0>(globals), [this]() { std::get<2>(globals).reserve(get_symbol_addresses().size()); }); @@ -349,30 +351,52 @@ public: auto& g_mutex = get_globals_mutex(); for (auto&& x : undefined_symbols) { - if (g.find(x) != g.cend()) return; - const auto it1 = get_symbol_addresses().find(x); - if (it1 == get_symbol_addresses().cend()) { hip_throw(std::runtime_error{ "Global symbol: " + x + " is undefined."}); } - std::lock_guard lck{g_mutex}; + hsa_status_t status; + auto check_hsa_global_var_define_error = [&x](hsa_status_t s) { + if (s != HSA_STATUS_SUCCESS) { + const char* es; + hsa_status_string(s, &es); + hip_throw(std::runtime_error{ "Error when defining symbol " + x + " : " + es}); + } + }; - if (g.find(x) != g.cend()) return; + auto retrieve_pinned_address_from_cache = [](decltype(g) g, decltype(x) x) { + const auto& global_addr = g.find(x); + if (global_addr != g.cend()) { + return global_addr->second.second; + } + return (void*)nullptr; + }; - g.emplace(x, (void*)(it1->second.first)); - void* p = nullptr; - hsa_amd_memory_lock( - reinterpret_cast(it1->second.first), - it1->second.second, - nullptr, // All agents. - 0, - &p); - - hsa_executable_agent_global_variable_define( - executable, agent, x.c_str(), p); + void* p = retrieve_pinned_address_from_cache(g, x); + if (p == nullptr) { + std::lock_guard lck{g_mutex}; + p = retrieve_pinned_address_from_cache(g, x); + if (p == nullptr) { + if (x == "_ZN2hc13printf_bufferE") { + // This is the printf buffer, get the pinned address from HCC + p = Kalmar::getContext()->getPrintfBufferPointerVA(); + } + else { + status = hsa_amd_memory_lock(reinterpret_cast(it1->second.first), + it1->second.second, + nullptr, // All agents. + 0, &p); + check_hsa_global_var_define_error(status); + } + // cache the global address and its pinned address + g.emplace(x, std::make_pair(reinterpret_cast(it1->second.first), p)); + } + } + status = hsa_executable_agent_global_variable_define( + executable, agent, x.c_str(), p); + check_hsa_global_var_define_error(status); } } diff --git a/tests/src/kernel/hipPrintfKernel.cpp b/tests/src/kernel/hipPrintfKernel.cpp index 1d4fa5fe30..5675f2e6bd 100644 --- a/tests/src/kernel/hipPrintfKernel.cpp +++ b/tests/src/kernel/hipPrintfKernel.cpp @@ -30,7 +30,12 @@ THE SOFTWARE. __global__ void run_printf() { printf("Hello World\n"); } int main() { - hipLaunchKernelGGL(HIP_KERNEL_NAME(run_printf), dim3(1), dim3(1), 0, 0); - hipDeviceSynchronize(); + int device_count = 0; + hipGetDeviceCount(&device_count); + for (int i = 0; i < device_count; ++i) { + hipSetDevice(i); + hipLaunchKernelGGL(HIP_KERNEL_NAME(run_printf), dim3(1), dim3(1), 0, 0); + hipDeviceSynchronize(); + } passed(); }