From 2673bf5e2c569387050bf54a67bb0f93e8894264 Mon Sep 17 00:00:00 2001 From: Laurent Morichetti Date: Fri, 2 Sep 2022 12:40:15 -0700 Subject: [PATCH] SWDEV-351980 - Consolidate registration tables in the roctracer Change-Id: I44cd1cc81cf6a529aed89ee8db1377c0aa67f0dc --- inc/roctracer_hip.h | 7 +- script/hsaap.py | 47 +- src/roctracer/hsa_support.cpp | 231 +++------- src/roctracer/hsa_support.h | 18 +- src/roctracer/journal.h | 65 --- src/roctracer/loader.h | 91 ++-- src/roctracer/registration_table.h | 84 ++++ src/roctracer/roctracer.cpp | 668 +++++++++++++++++------------ src/roctx/exportmap | 5 +- src/roctx/roctx.cpp | 80 ++-- src/util/callback_table.h | 70 --- 11 files changed, 662 insertions(+), 704 deletions(-) delete mode 100644 src/roctracer/journal.h create mode 100644 src/roctracer/registration_table.h delete mode 100644 src/util/callback_table.h diff --git a/inc/roctracer_hip.h b/inc/roctracer_hip.h index 5ac2d1a7cf..8b0434944c 100644 --- a/inc/roctracer_hip.h +++ b/inc/roctracer_hip.h @@ -41,7 +41,12 @@ inline static std::ostream& operator<<(std::ostream& out, const char& v) { #include -enum { HIP_OP_ID_DISPATCH = 0, HIP_OP_ID_COPY = 1, HIP_OP_ID_BARRIER = 2, HIP_OP_ID_NUMBER = 3 }; +typedef enum { + HIP_OP_ID_DISPATCH = 0, + HIP_OP_ID_COPY = 1, + HIP_OP_ID_BARRIER = 2, + HIP_OP_ID_NUMBER = 3 +} hip_op_id_t; #ifdef __cplusplus extern "C" { diff --git a/script/hsaap.py b/script/hsaap.py index 866f1b569b..2aa0742c8f 100755 --- a/script/hsaap.py +++ b/script/hsaap.py @@ -332,7 +332,6 @@ class API_DescrParser: self.cpp_content += "/* Generated by " + os.path.basename(__file__) + " */\n" + license + "\n\n" self.cpp_content += '#include \n' - self.cpp_content += '#include \"util/callback_table.h\"\n\n' self.cpp_content += '#include \n' self.cpp_content += 'namespace roctracer::hsa_support::detail {\n' @@ -409,40 +408,52 @@ class API_DescrParser: def gen_callbacks(self, n, name, call, struct): content = '' if n == -1: - content += 'static util::CallbackTable cb_table;\n' + content += '/* section: Static declarations */\n' content += '\n' if call != '-': call_id = self.api_id[call]; ret_type = struct['ret'] content += 'static ' + ret_type + ' ' + call + '_callback(' + struct['args'] + ') {\n' - content += ' hsa_api_data_t api_data{};\n' + + content += ' hsa_trace_data_t trace_data;\n' + content += ' bool enabled{false};\n' + content += '\n' + content += ' if (auto function = report_activity.load(std::memory_order_relaxed); function &&\n' + content += ' (enabled =\n' + content += ' function(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &trace_data) == 0)) {\n' + content += ' if (trace_data.phase_enter != nullptr) {\n' + for var in struct['alst']: item = struct['astr'][var]; if re.search(r'char\* ', item): - content += ' api_data.args.' + call + '.' + var + ' = ' + '(' + var + ' != NULL) ? strdup(' + var + ')' + ' : NULL;\n' + # FIXME: we should not strdup the char* arguments here, as the callback will not outlive the scope of this function. Instead, we + # should generate a helper function to capture the content of the arguments similar to hipApiArgsInit for HIP. We also need a + # helper to free the memory that is allocated to capture the content. + content += ' trace_data.api_data.args.' + call + '.' + var + ' = ' + '(' + var + ' != NULL) ? strdup(' + var + ')' + ' : NULL;\n' else: - content += ' api_data.args.' + call + '.' + var + ' = ' + var + ';\n' + content += ' trace_data.api_data.args.' + call + '.' + var + ' = ' + var + ';\n' if call == 'hsa_amd_memory_async_copy_rect' and var == 'range': - content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n' - content += ' auto [ api_callback_fun, api_callback_arg ] = cb_table.Get(' + call_id + ');\n' - content += ' if (api_callback_fun) {\n' - content += ' api_data.phase = ACTIVITY_API_PHASE_ENTER;\n' - content += ' api_data.correlation_id = CorrelationIdPush();\n' - content += ' api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n' + content += ' trace_data.api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n' + + content += ' trace_data.phase_enter(' + call_id + ', &trace_data);\n' + content += ' }\n' content += ' }\n' + content += '\n' + if ret_type != 'void': + # FIXME: we should capture the return value and store it in the api_data content += ' ' + ret_type + ' ret =' content += ' ' + name + '_saved_before_cb.' + call + '_fn(' + ', '.join(struct['alst']) + ');\n' - content += ' if (api_callback_fun) {\n' - if ret_type != 'void': - content += ' api_data.' + ret_type + '_retval = ret;\n' - content += ' api_data.phase = ACTIVITY_API_PHASE_EXIT;\n' - content += ' api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n' - content += ' CorrelationIdPop();\n' - content += ' }\n' + + content += '\n' + content += ' if (enabled && trace_data.phase_exit != nullptr)\n' + content += ' trace_data.phase_exit(' + call_id + ', &trace_data);\n' + if ret_type != 'void': + content += '\n' content += ' return ret;\n' content += '}\n' + return content # generate API intercepting code diff --git a/src/roctracer/hsa_support.cpp b/src/roctracer/hsa_support.cpp index b4c6d39c72..c92e273516 100644 --- a/src/roctracer/hsa_support.cpp +++ b/src/roctracer/hsa_support.cpp @@ -26,7 +26,6 @@ #include "memory_pool.h" #include "roctracer.h" #include "roctracer_hsa.h" -#include "callback_table.h" #include #include @@ -35,23 +34,32 @@ #include #include +namespace { + +std::atomic report_activity; + +bool IsEnabled(activity_domain_t domain, uint32_t operation_id) { + auto report = report_activity.load(std::memory_order_relaxed); + return report && report(domain, operation_id, nullptr) == 0; +} + +void ReportActivity(activity_domain_t domain, uint32_t operation_id, void* data) { + if (auto report = report_activity.load(std::memory_order_relaxed)) + report(domain, operation_id, data); +} + +} // namespace + #include "hsa_prof_str.inline.h" namespace roctracer::hsa_support { namespace { -util::CallbackTable hsa_evt_cb_table; - CoreApiTable saved_core_api{}; AmdExtTable saved_amd_ext_api{}; hsa_ven_amd_loader_1_01_pfn_t hsa_loader_api{}; -// async copy activity callback -std::mutex init_mutex; -bool async_copy_callback_enabled = false; -MemoryPool* async_copy_callback_memory_pool = nullptr; - struct AgentInfo { int index; hsa_device_type_t type; @@ -81,7 +89,6 @@ class Tracker { hsa_signal_t orig; hsa_signal_t signal; void (*handler)(const entry_t*); - MemoryPool* pool; union { struct { } copy; @@ -182,7 +189,7 @@ hsa_status_t HSA_API MemoryAllocateIntercept(hsa_region_t region, size_t size, v hsa_status_t status = saved_core_api.hsa_memory_allocate_fn(region, size, ptr); if (status != HSA_STATUS_SUCCESS) return status; - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) { + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) { hsa_evt_data_t data{}; data.allocate.ptr = *ptr; data.allocate.size = size; @@ -192,7 +199,7 @@ hsa_status_t HSA_API MemoryAllocateIntercept(hsa_region_t region, size_t size, v &data.allocate.global_flag) != HSA_STATUS_SUCCESS) fatal("hsa_region_get_info failed"); - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data); } return HSA_STATUS_SUCCESS; @@ -203,14 +210,14 @@ hsa_status_t MemoryAssignAgentIntercept(void* ptr, hsa_agent_t agent, hsa_status_t status = saved_core_api.hsa_memory_assign_agent_fn(ptr, agent, access); if (status != HSA_STATUS_SUCCESS) return status; - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); callback_fun) { + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) { hsa_evt_data_t data{}; data.device.ptr = ptr; if (saved_core_api.hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_DEVICE, &data.device.type) != HSA_STATUS_SUCCESS) fatal("hsa_agent_get_info failed"); - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data); } return HSA_STATUS_SUCCESS; @@ -220,13 +227,13 @@ hsa_status_t MemoryCopyIntercept(void* dst, const void* src, size_t size) { hsa_status_t status = saved_core_api.hsa_memory_copy_fn(dst, src, size); if (status != HSA_STATUS_SUCCESS) return status; - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_MEMCOPY); callback_fun) { + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY)) { hsa_evt_data_t data{}; data.memcopy.dst = dst; data.memcopy.src = src; data.memcopy.size = size; - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY, &data); } return HSA_STATUS_SUCCESS; @@ -237,7 +244,7 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_pool_allocate_fn(pool, size, flags, ptr); if (size == 0 || status != HSA_STATUS_SUCCESS) return status; - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) { + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) { hsa_evt_data_t data{}; data.allocate.ptr = *ptr; data.allocate.size = size; @@ -249,17 +256,13 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size HSA_STATUS_SUCCESS) fatal("hsa_region_get_info failed"); - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data); + } - if (std::tie(callback_fun, callback_arg) = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); - !callback_fun) - return HSA_STATUS_SUCCESS; - - // FIXME: Why is this only reported if HSA_EVT_ID_ALLOCATE is also set? - auto callback_data = std::make_tuple(callback_fun, callback_arg, pool, ptr); + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) { + auto callback_data = std::make_pair(pool, ptr); auto agent_callback = [](hsa_agent_t agent, void* iterate_agent_callback_data) { - auto [callback_fun, callback_arg, pool, ptr] = - *reinterpret_cast(iterate_agent_callback_data); + auto [pool, ptr] = *reinterpret_cast(iterate_agent_callback_data); if (hsa_amd_memory_pool_access_t value; saved_amd_ext_api.hsa_amd_agent_memory_pool_get_info_fn( @@ -276,7 +279,7 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size data.device.agent = agent; data.device.ptr = ptr; - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data); return HSA_STATUS_SUCCESS; }; saved_core_api.hsa_iterate_agents_fn(agent_callback, &callback_data); @@ -286,11 +289,11 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size } hsa_status_t MemoryPoolFreeIntercept(void* ptr) { - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) { + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) { hsa_evt_data_t data{}; data.allocate.ptr = ptr; data.allocate.size = 0; - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data); } return saved_amd_ext_api.hsa_amd_memory_pool_free_fn(ptr); @@ -303,7 +306,7 @@ hsa_status_t AgentsAllowAccessIntercept(uint32_t num_agents, const hsa_agent_t* saved_amd_ext_api.hsa_amd_agents_allow_access_fn(num_agents, agents, flags, ptr); if (status != HSA_STATUS_SUCCESS) return status; - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); callback_fun) { + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) { while (num_agents--) { hsa_agent_t agent = *agents++; auto it = agent_info_map.find(agent.handle); @@ -315,7 +318,7 @@ hsa_status_t AgentsAllowAccessIntercept(uint32_t num_agents, const hsa_agent_t* data.device.agent = agent; data.device.ptr = ptr; - callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg); + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data); } } return HSA_STATUS_SUCCESS; @@ -329,7 +332,6 @@ struct CodeObjectCallbackArg { hsa_status_t CodeObjectCallback(hsa_executable_t executable, hsa_loaded_code_object_t loaded_code_object, void* arg) { - auto* code_object_callback_arg = static_cast(arg); hsa_evt_data_t data{}; if (hsa_loader_api.hsa_ven_amd_loader_loaded_code_object_get_info( @@ -384,9 +386,8 @@ hsa_status_t CodeObjectCallback(hsa_executable_t executable, fatal("hsa_ven_amd_loader_loaded_code_object_get_info failed"); data.codeobj.uri = uri_str.c_str(); - data.codeobj.unload = code_object_callback_arg->unload ? 1 : 0; - code_object_callback_arg->callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ, &data, - code_object_callback_arg->callback_arg); + data.codeobj.unload = *static_cast(arg) ? 1 : 0; + ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ, &data); return HSA_STATUS_SUCCESS; } @@ -395,20 +396,20 @@ hsa_status_t ExecutableFreezeIntercept(hsa_executable_t executable, const char* hsa_status_t status = saved_core_api.hsa_executable_freeze_fn(executable, options); if (status != HSA_STATUS_SUCCESS) return status; - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_CODEOBJ); callback_fun) { - CodeObjectCallbackArg arg = {callback_fun, callback_arg, false}; + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ)) { + bool unload = false; hsa_loader_api.hsa_ven_amd_loader_executable_iterate_loaded_code_objects( - executable, CodeObjectCallback, &arg); + executable, CodeObjectCallback, &unload); } return HSA_STATUS_SUCCESS; } hsa_status_t ExecutableDestroyIntercept(hsa_executable_t executable) { - if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_CODEOBJ); callback_fun) { - CodeObjectCallbackArg arg = {callback_fun, callback_arg, true}; + if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ)) { + bool unload = true; hsa_loader_api.hsa_ven_amd_loader_executable_iterate_loaded_code_objects( - executable, CodeObjectCallback, &arg); + executable, CodeObjectCallback, &unload); } return saved_core_api.hsa_executable_destroy_fn(executable); @@ -422,25 +423,31 @@ void MemoryASyncCopyHandler(const Tracker::entry_t* entry) { record.end_ns = entry->end; record.device_id = 0; record.correlation_id = entry->correlation_id; - entry->pool->Write(record); + ReportActivity(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY, &record); } hsa_status_t MemoryASyncCopyIntercept(void* dst, hsa_agent_t dst_agent, const void* src, hsa_agent_t src_agent, size_t size, uint32_t num_dep_signals, const hsa_signal_t* dep_signals, hsa_signal_t completion_signal) { - if (!async_copy_callback_enabled) { + bool is_enabled = IsEnabled(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY); + + // FIXME: what happens if the state changes before returning? + [[maybe_unused]] hsa_status_t status = + saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(is_enabled); + assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed"); + + if (!is_enabled) { return saved_amd_ext_api.hsa_amd_memory_async_copy_fn( dst, dst_agent, src, src_agent, size, num_dep_signals, dep_signals, completion_signal); } Tracker::entry_t* entry = new Tracker::entry_t(); entry->handler = MemoryASyncCopyHandler; - entry->pool = async_copy_callback_memory_pool; entry->correlation_id = CorrelationId(); Tracker::Enable(Tracker::COPY_ENTRY_TYPE, hsa_agent_t{}, completion_signal, entry); - hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_async_copy_fn( + status = saved_amd_ext_api.hsa_amd_memory_async_copy_fn( dst, dst_agent, src, src_agent, size, num_dep_signals, dep_signals, entry->signal); if (status != HSA_STATUS_SUCCESS) Tracker::Disable(entry); @@ -454,7 +461,14 @@ hsa_status_t MemoryASyncCopyRectIntercept(const hsa_pitched_ptr_t* dst, hsa_agent_t copy_agent, hsa_amd_copy_direction_t dir, uint32_t num_dep_signals, const hsa_signal_t* dep_signals, hsa_signal_t completion_signal) { - if (!async_copy_callback_enabled) { + bool is_enabled = IsEnabled(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY); + + // FIXME: what happens if the state changes before returning? + [[maybe_unused]] hsa_status_t status = + saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(is_enabled); + assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed"); + + if (!is_enabled) { return saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn( dst, dst_offset, src, src_offset, range, copy_agent, dir, num_dep_signals, dep_signals, completion_signal); @@ -462,11 +476,10 @@ hsa_status_t MemoryASyncCopyRectIntercept(const hsa_pitched_ptr_t* dst, Tracker::entry_t* entry = new Tracker::entry_t(); entry->handler = MemoryASyncCopyHandler; - entry->pool = async_copy_callback_memory_pool; entry->correlation_id = CorrelationId(); Tracker::Enable(Tracker::COPY_ENTRY_TYPE, hsa_agent_t{}, completion_signal, entry); - hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn( + status = saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn( dst, dst_offset, src, src_offset, range, copy_agent, dir, num_dep_signals, dep_signals, entry->signal); if (status != HSA_STATUS_SUCCESS) Tracker::Disable(entry); @@ -502,8 +515,6 @@ roctracer_timestamp_t timestamp_ns() { } void Initialize(HsaApiTable* table) { - std::scoped_lock lock(init_mutex); - // Save the HSA core api and amd_ext api. saved_core_api = *table->core_; saved_amd_ext_api = *table->amd_ext_; @@ -558,20 +569,12 @@ void Initialize(HsaApiTable* table) { detail::InstallCoreApiWrappers(table->core_); detail::InstallAmdExtWrappers(table->amd_ext_); detail::InstallImageExtWrappers(table->image_ext_); - - if (async_copy_callback_enabled) { - [[maybe_unused]] hsa_status_t status = - saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(true); - assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed"); - } } void Finalize() { - if (hsa_support::async_copy_callback_enabled) { - [[maybe_unused]] hsa_status_t status = - hsa_support::saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false); - assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed"); - } + if (hsa_status_t status = saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false); + status != HSA_STATUS_SUCCESS) + assert(!"hsa_amd_profiling_async_copy_enable failed"); } const char* GetApiName(uint32_t id) { return detail::GetApiName(id); } @@ -612,111 +615,9 @@ const char* GetOpsName(uint32_t id) { uint32_t GetApiCode(const char* str) { return detail::GetApiCode(str); } -void EnableActivity(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool) { - switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - if (op == HSA_OP_ID_COPY) { - std::scoped_lock lock(init_mutex); - - if (saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn != nullptr) { - [[maybe_unused]] hsa_status_t status = - saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(true); - assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed"); - } - async_copy_callback_enabled = true; - async_copy_callback_memory_pool = reinterpret_cast(pool); - } else if (op == HSA_OP_ID_RESERVED1) { - /* Place holder for PC sampling. */ - } else { - EXC_RAISING(ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED, - "HSA OPS operation ID(" << op << ") is not currently implemented"); - } - break; - case ACTIVITY_DOMAIN_HSA_API: - // FIXME: Add HSA api activities. - break; - case ACTIVITY_DOMAIN_HSA_EVT: - break; - default: - break; - } -} - -void EnableCallback(roctracer_domain_t domain, uint32_t cid, roctracer_rtapi_callback_t callback, - void* user_data) { - switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - break; - case ACTIVITY_DOMAIN_HSA_API: - if (cid >= HSA_API_ID_NUMBER) - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT, - "invalid HSA API operation ID(" << cid << ")"); - - detail::cb_table.Set(cid, callback, user_data); - break; - case ACTIVITY_DOMAIN_HSA_EVT: - if (cid >= HSA_EVT_ID_NUMBER) - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT, - "invalid HSA API operation ID(" << cid << ")"); - - hsa_evt_cb_table.Set(cid, callback, user_data); - break; - default: - break; - } -} - -void DisableActivity(roctracer_domain_t domain, uint32_t op) { - switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - if (op == HSA_OP_ID_COPY) { - std::scoped_lock lock(init_mutex); - - async_copy_callback_enabled = false; - async_copy_callback_memory_pool = nullptr; - - if (saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn != nullptr) { - [[maybe_unused]] hsa_status_t status = - saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false); - assert(status == HSA_STATUS_SUCCESS || status == HSA_STATUS_ERROR_NOT_INITIALIZED || - !"hsa_amd_profiling_async_copy_enable failed"); - } - } else if (op == HSA_OP_ID_RESERVED1) { - /* Place holder for PC sampling. */ - } else { - EXC_RAISING(ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED, - "HSA OPS operation ID(" << op << ") is not currently implemented"); - } - break; - case ACTIVITY_DOMAIN_HSA_API: - // FIXME: Add HSA api activities. - break; - case ACTIVITY_DOMAIN_HSA_EVT: - break; - default: - break; - } -} - -void DisableCallback(roctracer_domain_t domain, uint32_t cid) { - switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - break; - case ACTIVITY_DOMAIN_HSA_API: - if (cid >= HSA_API_ID_NUMBER) - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT, - "invalid HSA API operation ID(" << cid << ")"); - detail::cb_table.Set(cid, nullptr, nullptr); - break; - case ACTIVITY_DOMAIN_HSA_EVT: - if (cid >= HSA_EVT_ID_NUMBER) - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT, - "invalid HSA EVT operation ID(" << cid << ")"); - hsa_evt_cb_table.Set(cid, nullptr, nullptr); - break; - default: - break; - } +void RegisterTracerCallback(int (*function)(activity_domain_t domain, uint32_t operation_id, + void* data)) { + report_activity.store(function, std::memory_order_relaxed); } } // namespace roctracer::hsa_support diff --git a/src/roctracer/hsa_support.h b/src/roctracer/hsa_support.h index 563df51712..3e9922eaca 100644 --- a/src/roctracer/hsa_support.h +++ b/src/roctracer/hsa_support.h @@ -28,6 +28,15 @@ namespace roctracer::hsa_support { +struct hsa_trace_data_t { + hsa_api_data_t api_data; + uint64_t phase_enter_timestamp; + uint64_t phase_data; + + void (*phase_enter)(hsa_api_id_t operation_id, hsa_trace_data_t* data); + void (*phase_exit)(hsa_api_id_t operation_id, hsa_trace_data_t* data); +}; + void Initialize(HsaApiTable* table); void Finalize(); @@ -36,13 +45,8 @@ const char* GetEvtName(uint32_t id); const char* GetOpsName(uint32_t id); uint32_t GetApiCode(const char* str); -void EnableActivity(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool); -void EnableCallback(roctracer_domain_t domain, uint32_t cid, roctracer_rtapi_callback_t callback, - void* user_data); - -void DisableCallback(roctracer_domain_t domain, uint32_t cid); -void DisableActivity(roctracer_domain_t domain, uint32_t op); - +void RegisterTracerCallback(int (*function)(activity_domain_t domain, uint32_t operation_id, + void* data)); uint64_t timestamp_ns(); } // namespace roctracer::hsa_support diff --git a/src/roctracer/journal.h b/src/roctracer/journal.h deleted file mode 100644 index 0bb844e76e..0000000000 --- a/src/roctracer/journal.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - THE SOFTWARE. */ - -#ifndef SRC_CORE_JOURNAL_H_ -#define SRC_CORE_JOURNAL_H_ - -#include "ext/prof_protocol.h" - -#include -#include -#include - -namespace roctracer { - -template class Journal { - public: - /* Insert { domain, op } into the journal. Return false if the insertion - updated an existing entry. */ - template , int> = 0> - bool Insert(roctracer_domain_t domain, uint32_t op, T&& data) { - std::lock_guard lock(mutex_); - auto result = map_[domain].try_emplace(op, std::forward(data)); - if (!result.second) result.first->second = std::forward(data); - return result.second; - } - - /* Remove { domain, op } from the journal. Return false if the entry did not exist. */ - bool Remove(roctracer_domain_t domain, uint32_t op) { - std::lock_guard lock(mutex_); - return map_[domain].erase(op) == 1; - } - - template void ForEach(Functor&& func) { - std::lock_guard lock(mutex_); - for (auto&& domain : map_) - for (auto&& operation : domain.second) - if (!func(domain.first /* domain */, operation.first /* op */, operation.second /* data */)) - break; /* FIXME: what are we breaking out of? */ - } - - private: - std::mutex mutex_; - std::unordered_map> map_; -}; - -} // namespace roctracer - -#endif // SRC_CORE_JOURNAL_H_ diff --git a/src/roctracer/loader.h b/src/roctracer/loader.h index bdf9f52675..2ef7d24643 100644 --- a/src/roctracer/loader.h +++ b/src/roctracer/loader.h @@ -122,13 +122,6 @@ __attribute__((weak)) const char* hipKernelNameRefByPtr(const void* hostFunction __attribute__((weak)) int hipGetStreamDeviceId(hipStream_t stream) { return 0; } __attribute__((weak)) const char* hipApiName(uint32_t id) { return NULL; } -__attribute__((weak)) hipError_t hipRegisterAsyncActivityCallback_t(uint32_t op, void* fun, - void* arg) { - return hipErrorUnknown; -} -__attribute__((weak)) hipError_t hipRemoveAsyncActivityCallback_t(uint32_t op) { - return hipErrorUnknown; -} __attribute__((weak)) const char* hipGetCmdName(unsigned op) { return NULL; } class HipLoaderStatic { @@ -158,12 +151,8 @@ class HipLoaderStatic { GetStreamDeviceId_t* GetStreamDeviceId; ApiName_t* ApiName; - typedef hipError_t(hipRegisterAsyncActivityCallback_t)(uint32_t op, void* fun, void* arg); - typedef hipError_t(hipRemoveAsyncActivityCallback_t)(uint32_t op); typedef const char*(hipGetOpName_t)(unsigned op); - hipRegisterAsyncActivityCallback_t* RegisterActivityCallback; - hipRemoveAsyncActivityCallback_t* RemoveActivityCallback; hipGetOpName_t* GetOpName; static inline loader_t& Instance() { @@ -191,8 +180,6 @@ class HipLoaderStatic { GetStreamDeviceId = hipGetStreamDeviceId; ApiName = hipApiName; - RegisterAsyncActivityCallback = hipRegisterAsyncActivityCallback; - RemoveAsyncActivityCallback = hipRemoveAsyncActivityCallback; GetOpName = hipGetCmdName; } @@ -204,53 +191,36 @@ class HipApi { public: typedef BaseLoader Loader; - typedef decltype(hipRegisterApiCallback) RegisterApiCallback_t; - typedef decltype(hipRemoveApiCallback) RemoveApiCallback_t; - typedef decltype(hipRegisterActivityCallback) RegisterActivityCallback_t; - typedef decltype(hipRemoveActivityCallback) RemoveActivityCallback_t; - typedef decltype(hipKernelNameRef) KernelNameRef_t; - typedef decltype(hipKernelNameRefByPtr) KernelNameRefByPtr_t; - typedef decltype(hipGetStreamDeviceId) GetStreamDeviceId_t; - typedef decltype(hipApiName) ApiName_t; + typedef int(hipGetStreamDeviceId_t)(hipStream_t stream); + typedef const char*(hipKernelNameRef_t)(const hipFunction_t function); + typedef const char*(hipKernelNameRefByPtr_t)(const void* host_function, hipStream_t stream); + typedef const char*(hipApiName_t)(uint32_t id); + typedef const char*(hipGetCmdName_t)(uint32_t op); + typedef void(hipRegisterTracerCallback_t)(int (*function)(activity_domain_t domain, + uint32_t operation_id, void* data)); - RegisterApiCallback_t* RegisterApiCallback; - RemoveApiCallback_t* RemoveApiCallback; - RegisterActivityCallback_t* RegisterActivityCallback; - RemoveActivityCallback_t* RemoveActivityCallback; - KernelNameRef_t* KernelNameRef; - KernelNameRefByPtr_t* KernelNameRefByPtr_; + hipKernelNameRef_t* KernelNameRef; const char* KernelNameRefByPtr(const void* function, hipStream_t stream = nullptr) const { return KernelNameRefByPtr_(function, stream); } - GetStreamDeviceId_t* GetStreamDeviceId; - ApiName_t* ApiName; - - typedef hipError_t(hipRegisterAsyncActivityCallback_t)(uint32_t op, void* fun, void* arg); - typedef hipError_t(hipRemoveAsyncActivityCallback_t)(uint32_t op); - typedef const char*(hipGetOpName_t)(unsigned op); - - hipRegisterAsyncActivityCallback_t* RegisterAsyncActivityCallback; - hipRemoveAsyncActivityCallback_t* RemoveAsyncActivityCallback; - hipGetOpName_t* GetOpName; + hipGetStreamDeviceId_t* GetStreamDeviceId; + hipGetCmdName_t* GetOpName; + hipApiName_t* ApiName; + hipRegisterTracerCallback_t* RegisterTracerCallback; protected: void init(Loader* loader) { - RegisterApiCallback = loader->GetFun("hipRegisterApiCallback"); - RemoveApiCallback = loader->GetFun("hipRemoveApiCallback"); - RegisterActivityCallback = - loader->GetFun("hipRegisterActivityCallback"); - RemoveActivityCallback = loader->GetFun("hipRemoveActivityCallback"); - KernelNameRef = loader->GetFun("hipKernelNameRef"); - KernelNameRefByPtr_ = loader->GetFun("hipKernelNameRefByPtr"); - GetStreamDeviceId = loader->GetFun("hipGetStreamDeviceId"); - ApiName = loader->GetFun("hipApiName"); - - RegisterAsyncActivityCallback = - loader->GetFun("hipRegisterAsyncActivityCallback"); - RemoveAsyncActivityCallback = - loader->GetFun("hipRemoveAsyncActivityCallback"); - GetOpName = loader->GetFun("hipGetCmdName"); + GetStreamDeviceId = loader->GetFun("hipGetStreamDeviceId"); + KernelNameRef = loader->GetFun("hipKernelNameRef"); + KernelNameRefByPtr_ = loader->GetFun("hipKernelNameRefByPtr"); + GetOpName = loader->GetFun("hipGetCmdName"); + ApiName = loader->GetFun("hipApiName"); + RegisterTracerCallback = + loader->GetFun("hipRegisterTracerCallback"); } + + private: + hipKernelNameRefByPtr_t* KernelNameRefByPtr_; }; #endif @@ -260,16 +230,14 @@ class RocTxApi { public: typedef BaseLoader Loader; - typedef bool(RegisterApiCallback_t)(uint32_t op, void* callback, void* arg); - typedef bool(RemoveApiCallback_t)(uint32_t op); - - RegisterApiCallback_t* RegisterApiCallback; - RemoveApiCallback_t* RemoveApiCallback; + typedef void(roctxRegisterTracerCallback_t)(int (*function)(activity_domain_t domain, + uint32_t operation_id, void* data)); + roctxRegisterTracerCallback_t* RegisterTracerCallback; protected: void init(Loader* loader) { - RegisterApiCallback = loader->GetFun("RegisterApiCallback"); - RemoveApiCallback = loader->GetFun("RemoveApiCallback"); + RegisterTracerCallback = + loader->GetFun("roctxRegisterTracerCallback"); } }; @@ -278,8 +246,7 @@ typedef BaseLoader RocTxLoader; #if STATIC_BUILD typedef HipLoaderStatic HipLoader; #else -typedef BaseLoader HipLoaderShared; -typedef HipLoaderShared HipLoader; +using HipLoader = BaseLoader; #endif } // namespace roctracer @@ -299,7 +266,7 @@ typedef HipLoaderShared HipLoader; roctracer::HipLoaderStatic::instance_t roctracer::HipLoaderStatic::instance_{}; #else #define LOADER_INSTANTIATE_HIP() \ - template <> const char* roctracer::HipLoaderShared::lib_name_ = "libamdhip64.so"; + template <> const char* roctracer::HipLoader::lib_name_ = "libamdhip64.so"; #endif #define LOADER_INSTANTIATE() \ diff --git a/src/roctracer/registration_table.h b/src/roctracer/registration_table.h new file mode 100644 index 0000000000..aa25779bd6 --- /dev/null +++ b/src/roctracer/registration_table.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +#ifndef UTIL_CALLBACK_TABLE_H_ +#define UTIL_CALLBACK_TABLE_H_ + +#include "ext/prof_protocol.h" + +#include +#include +#include +#include +#include +#include + +namespace roctracer::util { + +namespace detail { +struct False { + constexpr bool operator()() { return false; } +}; +} // namespace detail + +// Generic callbacks table +template class RegistrationTable { + public: + template void Register(uint32_t operation_id, Args... args) { + assert(operation_id < N && "operation_id is out of range"); + auto& entry = table_[operation_id]; + std::unique_lock lock(entry.mutex); + if (!entry.enabled.exchange(true, std::memory_order_relaxed)) + registered_count_.fetch_add(1, std::memory_order_relaxed); + entry.data = T{std::forward(args)...}; + } + + void Unregister(uint32_t operation_id) { + assert(operation_id < N && "id is out of range"); + auto& entry = table_[operation_id]; + std::unique_lock lock(entry.mutex); + if (entry.enabled.exchange(false, std::memory_order_relaxed)) + registered_count_.fetch_sub(1, std::memory_order_relaxed); + } + + std::optional Get(uint32_t operation_id) const { + assert(operation_id < N && "id is out of range"); + auto& entry = table_[operation_id]; + if (!entry.enabled.load(std::memory_order_relaxed) || IsStopped{}()) return std::nullopt; + std::shared_lock lock(entry.mutex); + return entry.enabled.load(std::memory_order_relaxed) ? std::make_optional(entry.data) + : std::nullopt; + } + + bool IsEmpty() const { return registered_count_.load(std::memory_order_relaxed) == 0; } + + private: + std::atomic registered_count_{0}; + struct { + std::atomic enabled{false}; + mutable std::shared_mutex mutex; + T data; + } table_[N]{}; +}; + + +} // namespace roctracer::util + +#endif // UTIL_CALLBACK_TABLE_H_ diff --git a/src/roctracer/roctracer.cpp b/src/roctracer/roctracer.cpp index 2924fcac86..1799298d5c 100644 --- a/src/roctracer/roctracer.cpp +++ b/src/roctracer/roctracer.cpp @@ -34,17 +34,18 @@ #include #include #include +#include #include #include #include "correlation_id.h" #include "debug.h" -#include "journal.h" -#include "loader.h" -#include "hsa_support.h" -#include "memory_pool.h" #include "exception.h" +#include "hsa_support.h" +#include "loader.h" #include "logger.h" +#include "memory_pool.h" +#include "registration_table.h" #define API_METHOD_PREFIX \ roctracer_status_t err = ROCTRACER_STATUS_SUCCESS; \ @@ -74,126 +75,38 @@ static inline uint32_t GetTid() { return tid; } +using namespace roctracer; + +namespace { + /////////////////////////////////////////////////////////////////////////////////////////////////// // Internal library methods // -namespace roctracer { -namespace ext_support { roctracer_start_cb_t roctracer_start_cb = nullptr; roctracer_stop_cb_t roctracer_stop_cb = nullptr; -} // namespace ext_support - -struct CallbackJournalData { - roctracer_rtapi_callback_t callback; - void* user_data; -}; -static Journal cb_journal; - -struct ActivityJournalData { - roctracer_pool_t* pool; -}; -static Journal act_journal; roctracer_status_t GetExcStatus(const std::exception& e) { const ApiError* roctracer_exc_ptr = dynamic_cast(&e); return (roctracer_exc_ptr) ? roctracer_exc_ptr->status() : ROCTRACER_STATUS_ERROR; } -std::mutex hip_activity_mutex; +std::mutex registration_mutex; -enum { API_CB_MASK = 0x1, API_ACT_MASK = 0x2 }; +// Memory pool routines and primitives +std::recursive_mutex memory_pool_mutex; +MemoryPool* default_memory_pool = nullptr; -class HIPActivityCallbackTracker { - public: - uint32_t enable_check(uint32_t op, uint32_t mask) { return data_[op] |= mask; } - uint32_t disable_check(uint32_t op, uint32_t mask) { return data_[op] &= ~mask; } +} // namespace - private: - std::unordered_map data_; -}; - -static HIPActivityCallbackTracker hip_act_cb_tracker; - -void HIP_ApiCallback(uint32_t op_id, roctracer_record_t* record, void* callback_data, void* arg) { - hip_api_data_t* data = static_cast(callback_data); - MemoryPool* pool = static_cast(arg); - - if (data->phase == ACTIVITY_API_PHASE_ENTER) { - // Generate a new correlation ID. - uint64_t correlation_id = CorrelationIdPush(); - data->correlation_id = correlation_id; - - if (pool != nullptr) { - // Filing record info - record->domain = ACTIVITY_DOMAIN_HIP_API; - record->kind = 0; - record->op = op_id; - record->process_id = GetPid(); - record->thread_id = GetTid(); - record->begin_ns = hsa_support::timestamp_ns(); - record->correlation_id = correlation_id; - } - } else { - if (pool != nullptr) { - record->end_ns = hsa_support::timestamp_ns(); - - if (auto external_id = ExternalCorrelationId()) { - roctracer_record_t ext_record{}; - ext_record.domain = ACTIVITY_DOMAIN_EXT_API; - ext_record.op = ACTIVITY_EXT_OP_EXTERN_ID; - ext_record.correlation_id = record->correlation_id; - ext_record.external_id = *external_id; - // Write the external correlation id record directly followed by the activity record. - pool->Write(std::array{ext_record, *record}); - } else { - // Write record to the buffer. - pool->Write(*record); - } - } - CorrelationIdPop(); - } -} - -void HIP_AsyncActivityCallback(uint32_t op_id, void* record_ptr, void* arg) { - MemoryPool* pool = reinterpret_cast(arg); - roctracer_record_t& record = *reinterpret_cast(record_ptr); - record.domain = ACTIVITY_DOMAIN_HIP_OPS; - - // If the record is for a kernel dispatch, write the kernel name in the pool's data, - // and make the record point to it. Older HIP runtimes do not provide a kernel - // name, so record.kernel_name might be null. - if (record.op == HIP_OP_ID_DISPATCH && record.kernel_name != nullptr) - pool->Write(record, record.kernel_name, strlen(record.kernel_name) + 1, - [](auto& record, const void* data) { - record.kernel_name = static_cast(data); - }); - else - pool->Write(record); -} +namespace roctracer { // Logger routines and primitives util::Logger::mutex_t util::Logger::mutex_; std::atomic util::Logger::instance_{}; -// Memory pool routines and primitives -MemoryPool* default_memory_pool = nullptr; -std::recursive_mutex memory_pool_mutex; - -// Stop status routines and primitives -unsigned stop_status_value = 0; -std::mutex stop_status_mutex; -unsigned set_stopped(unsigned val) { - std::lock_guard lock(stop_status_mutex); - const unsigned ret = (stop_status_value ^ val); - stop_status_value = val; - return ret; -} - } // namespace roctracer -using namespace roctracer; - LOADER_INSTANTIATE(); /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -227,7 +140,7 @@ ROCTRACER_API const char* roctracer_op_string(uint32_t domain, uint32_t op, uint case ACTIVITY_DOMAIN_EXT_API: return "EXT_API"; default: - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")"); + throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID"); } API_METHOD_CATCH(nullptr) } @@ -261,95 +174,350 @@ ROCTRACER_API roctracer_status_t roctracer_op_code(uint32_t domain, const char* API_METHOD_SUFFIX } -static inline uint32_t get_op_begin(uint32_t domain) { +namespace { + +template struct DomainTraits; + +template <> struct DomainTraits { + using ApiData = hip_api_data_t; + using OperationId = hip_api_id_t; + static constexpr size_t kOpIdBegin = HIP_API_ID_FIRST; + static constexpr size_t kOpIdEnd = HIP_API_ID_LAST + 1; +}; + +template <> struct DomainTraits { + using ApiData = hsa_api_data_t; + using OperationId = hsa_api_id_t; + static constexpr size_t kOpIdBegin = 0; + static constexpr size_t kOpIdEnd = HSA_API_ID_NUMBER; +}; + +template <> struct DomainTraits { + using ApiData = roctx_api_data_t; + using OperationId = roctx_api_id_t; + static constexpr size_t kOpIdBegin = 0; + static constexpr size_t kOpIdEnd = ROCTX_API_ID_NUMBER; +}; + +template <> struct DomainTraits { + using OperationId = hip_op_id_t; + static constexpr size_t kOpIdBegin = 0; + static constexpr size_t kOpIdEnd = HIP_OP_ID_NUMBER; +}; + +template <> struct DomainTraits { + using OperationId = hsa_op_id_t; + static constexpr size_t kOpIdBegin = 0; + static constexpr size_t kOpIdEnd = HSA_OP_ID_NUMBER; +}; + +template <> struct DomainTraits { + using ApiData = hsa_evt_data_t; + using OperationId = hsa_evt_id_t; + static constexpr size_t kOpIdBegin = 0; + static constexpr size_t kOpIdEnd = HSA_EVT_ID_NUMBER; +}; + +constexpr uint32_t get_op_begin(activity_domain_t domain) { switch (domain) { case ACTIVITY_DOMAIN_HSA_OPS: - return 0; + return DomainTraits::kOpIdBegin; case ACTIVITY_DOMAIN_HSA_API: - return 0; + return DomainTraits::kOpIdBegin; case ACTIVITY_DOMAIN_HSA_EVT: - return 0; + return DomainTraits::kOpIdBegin; case ACTIVITY_DOMAIN_HIP_OPS: - return 0; + return DomainTraits::kOpIdBegin; case ACTIVITY_DOMAIN_HIP_API: - return HIP_API_ID_FIRST; + return DomainTraits::kOpIdBegin; + case ACTIVITY_DOMAIN_ROCTX: + return DomainTraits::kOpIdBegin; case ACTIVITY_DOMAIN_EXT_API: return 0; - case ACTIVITY_DOMAIN_ROCTX: - return 0; default: - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")"); + throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID"); } - return 0; } -static inline uint32_t get_op_end(uint32_t domain) { +constexpr uint32_t get_op_end(activity_domain_t domain) { switch (domain) { case ACTIVITY_DOMAIN_HSA_OPS: - return HSA_OP_ID_NUMBER; + return DomainTraits::kOpIdEnd; case ACTIVITY_DOMAIN_HSA_API: - return HSA_API_ID_NUMBER; + return DomainTraits::kOpIdEnd; case ACTIVITY_DOMAIN_HSA_EVT: - return HSA_EVT_ID_NUMBER; + return DomainTraits::kOpIdEnd; case ACTIVITY_DOMAIN_HIP_OPS: - return HIP_OP_ID_NUMBER; + return DomainTraits::kOpIdEnd; case ACTIVITY_DOMAIN_HIP_API: - return HIP_API_ID_LAST + 1; - case ACTIVITY_DOMAIN_EXT_API: - return 0; + return DomainTraits::kOpIdEnd; case ACTIVITY_DOMAIN_ROCTX: - return ROCTX_API_ID_NUMBER; + return DomainTraits::kOpIdEnd; + case ACTIVITY_DOMAIN_EXT_API: + return get_op_begin(ACTIVITY_DOMAIN_EXT_API); default: - EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")"); + throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID"); } - return 0; } -// Enable runtime API callbacks -static void roctracer_enable_callback_fun(roctracer_domain_t domain, uint32_t op, - roctracer_rtapi_callback_t callback, void* user_data) { +std::atomic stopped_status{false}; + +struct IsStopped { + bool operator()() const { return stopped_status.load(std::memory_order_relaxed); } +}; + +struct NeverStopped { + constexpr bool operator()() { return false; } +}; + +using UserCallback = std::pair; + +template +using CallbackRegistrationTable = + util::RegistrationTable::kOpIdEnd, IsStopped>; + +template +using ActivityRegistrationTable = + util::RegistrationTable::kOpIdEnd, IsStopped>; + +template struct ApiTracer { + using ApiData = typename DomainTraits::ApiData; + using OperationId = typename DomainTraits::OperationId; + + struct TraceData { + ApiData api_data; // API specific data (for example, function arguments). + uint64_t phase_enter_timestamp; // timestamp when phase_enter was executed. + uint64_t phase_data; // data that can be shared between phase_enter and phase_exit. + + void (*phase_enter)(OperationId operation_id, TraceData* data); + void (*phase_exit)(OperationId operation_id, TraceData* data); + }; + + static void Exit(OperationId operation_id, TraceData* trace_data) { + if (auto pool = activity_table.Get(operation_id)) { + assert(trace_data != nullptr); + activity_record_t record{}; + + record.domain = domain; + record.op = operation_id; + record.correlation_id = trace_data->api_data.correlation_id; + record.begin_ns = trace_data->phase_enter_timestamp; + record.end_ns = hsa_support::timestamp_ns(); + record.process_id = GetPid(); + record.thread_id = GetTid(); + + if (auto external_id = ExternalCorrelationId()) { + roctracer_record_t ext_record{}; + ext_record.domain = ACTIVITY_DOMAIN_EXT_API; + ext_record.op = ACTIVITY_EXT_OP_EXTERN_ID; + ext_record.correlation_id = record.correlation_id; + ext_record.external_id = *external_id; + // Write the external correlation id record directly followed by the activity record. + (*pool)->Write(std::array{ext_record, record}); + } else { + // Write record to the buffer. + (*pool)->Write(record); + } + } + CorrelationIdPop(); + } + + static void Exit_UserCallback(OperationId operation_id, TraceData* trace_data) { + if (auto user_callback = callback_table.Get(operation_id)) { + assert(trace_data != nullptr); + trace_data->api_data.phase = ACTIVITY_API_PHASE_EXIT; + user_callback->first(domain, operation_id, &trace_data->api_data, user_callback->second); + } + Exit(operation_id, trace_data); + } + + static void Enter_UserCallback(OperationId operation_id, TraceData* trace_data) { + if (auto user_callback = callback_table.Get(operation_id)) { + assert(trace_data != nullptr); + trace_data->api_data.phase = ACTIVITY_API_PHASE_ENTER; + user_callback->first(domain, operation_id, &trace_data->api_data, user_callback->second); + trace_data->phase_exit = Exit_UserCallback; + } else { + trace_data->phase_exit = Exit; + } + } + + static int Enter(OperationId operation_id, TraceData* trace_data) { + bool callback_enabled = callback_table.Get(operation_id).has_value(), + activity_enabled = activity_table.Get(operation_id).has_value(); + if (!callback_enabled && !activity_enabled) return -1; + + if (trace_data != nullptr) { + // Generate a new correlation ID. + trace_data->api_data.correlation_id = CorrelationIdPush(); + + if (activity_enabled) { + trace_data->phase_enter_timestamp = hsa_support::timestamp_ns(); + trace_data->phase_enter = nullptr; + trace_data->phase_exit = Exit; + } + if (callback_enabled) { + trace_data->phase_enter = Enter_UserCallback; + trace_data->phase_exit = [](OperationId, TraceData*) { fatal("should not reach here"); }; + } + } + return 0; + } + + static CallbackRegistrationTable callback_table; + static ActivityRegistrationTable activity_table; +}; + +template +CallbackRegistrationTable ApiTracer::callback_table; + +template +ActivityRegistrationTable ApiTracer::activity_table; + +using HIP_ApiTracer = ApiTracer; +using HSA_ApiTracer = ApiTracer; + +CallbackRegistrationTable roctx_api_callback_table; +ActivityRegistrationTable hip_ops_activity_table; +ActivityRegistrationTable hsa_ops_activity_table; +CallbackRegistrationTable hsa_evt_callback_table; + +int TracerCallback(activity_domain_t domain, uint32_t operation_id, void* data) { switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: case ACTIVITY_DOMAIN_HSA_API: - case ACTIVITY_DOMAIN_HSA_EVT: - hsa_support::EnableCallback(domain, op, callback, user_data); - break; + return HSA_ApiTracer::Enter(static_cast(operation_id), + static_cast(data)); + + case ACTIVITY_DOMAIN_HIP_API: + return HIP_ApiTracer::Enter(static_cast(operation_id), + static_cast(data)); + case ACTIVITY_DOMAIN_HIP_OPS: - break; - case ACTIVITY_DOMAIN_HIP_API: { - if (!HipLoader::Instance().Enabled()) break; - std::lock_guard lock(hip_activity_mutex); - - if (hipError_t err = - HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data); - err != hipSuccess) - fatal("HIP::RegisterApiCallback(%d) failed (err=%d)", op, err); - - if ((hip_act_cb_tracker.enable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) { - if (hipError_t err = - HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr); - err != hipSuccess) - fatal("HIP::RegisterActivityCallback(%d) failed (err=%d)", op, err); + if (auto pool = hip_ops_activity_table.Get(operation_id)) { + if (auto record = static_cast(data)) { + // If the record is for a kernel dispatch, write the kernel name in the pool's data, + // and make the record point to it. Older HIP runtimes do not provide a kernel + // name, so record.kernel_name might be null. + if (operation_id == HIP_OP_ID_DISPATCH && record->kernel_name != nullptr) + (*pool)->Write(*record, record->kernel_name, strlen(record->kernel_name) + 1, + [](auto& record, const void* data) { + record.kernel_name = static_cast(data); + }); + else + (*pool)->Write(*record); + } + return 0; } break; - } - case ACTIVITY_DOMAIN_ROCTX: { - if (RocTxLoader::Instance().Enabled() && - !RocTxLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data)) - fatal("ROCTX::RegisterApiCallback(%d) failed", op); + + case ACTIVITY_DOMAIN_ROCTX: + if (auto user_callback = roctx_api_callback_table.Get(operation_id)) { + if (auto api_data = static_cast::ApiData*>(data)) + user_callback->first(ACTIVITY_DOMAIN_ROCTX, operation_id, api_data, + user_callback->second); + return 0; + } + break; + + case ACTIVITY_DOMAIN_HSA_OPS: + if (auto pool = hsa_ops_activity_table.Get(operation_id)) { + if (auto record = static_cast(data)) (*pool)->Write(*record); + return 0; + } + break; + + case ACTIVITY_DOMAIN_HSA_EVT: + if (auto user_callback = hsa_evt_callback_table.Get(operation_id)) { + if (auto api_data = static_cast::ApiData*>(data)) + user_callback->first(ACTIVITY_DOMAIN_HSA_EVT, operation_id, api_data, + user_callback->second); + return 0; + } + break; + + default: + break; + } + return -1; +} + +template struct RegistrationTableGroup { + private: + bool AllEmpty() const { + return std::apply([](auto&&... tables) { return (tables.IsEmpty() && ...); }, tables_); + } + + public: + template + RegistrationTableGroup(Functor1&& engage_tracer, Functor2&& disengage_tracer, Tables&... tables) + : engage_tracer_(std::forward(engage_tracer)), + disengage_tracer_(std::forward(disengage_tracer)), + tables_(tables...) {} + + template + void Register(T& table, uint32_t operation_id, Args... args) const { + if (AllEmpty()) engage_tracer_(); + table.Register(operation_id, std::forward(args)...); + } + + template void Unregister(T& table, uint32_t operation_id) const { + table.Unregister(operation_id); + if (AllEmpty()) disengage_tracer_(); + } + + private: + const std::function engage_tracer_, disengage_tracer_; + const std::tuple tables_; +}; + +RegistrationTableGroup HSA_registration_group( + []() { hsa_support::RegisterTracerCallback(TracerCallback); }, + []() { hsa_support::RegisterTracerCallback(nullptr); }, HSA_ApiTracer::callback_table, + HSA_ApiTracer::activity_table, hsa_ops_activity_table, hsa_evt_callback_table); + +RegistrationTableGroup HIP_registration_group( + []() { HipLoader::Instance().RegisterTracerCallback(TracerCallback); }, + []() { HipLoader::Instance().RegisterTracerCallback(nullptr); }, HIP_ApiTracer::callback_table, + HIP_ApiTracer::activity_table, hip_ops_activity_table); + +RegistrationTableGroup ROCTX_registration_group( + []() { RocTxLoader::Instance().RegisterTracerCallback(TracerCallback); }, + []() { RocTxLoader::Instance().RegisterTracerCallback(nullptr); }, roctx_api_callback_table); + +} // namespace + +// Enable runtime API callbacks +static void roctracer_enable_callback_impl(roctracer_domain_t domain, uint32_t operation_id, + roctracer_rtapi_callback_t callback, void* user_data) { + std::lock_guard lock(registration_mutex); + + switch (domain) { + case ACTIVITY_DOMAIN_HSA_EVT: + HSA_registration_group.Register(hsa_evt_callback_table, operation_id, callback, user_data); + break; + case ACTIVITY_DOMAIN_HSA_API: + HSA_registration_group.Register(HSA_ApiTracer::callback_table, operation_id, callback, + user_data); + break; + case ACTIVITY_DOMAIN_HSA_OPS: + break; + case ACTIVITY_DOMAIN_HIP_API: + if (HipLoader::Instance().Enabled()) + HIP_registration_group.Register(HIP_ApiTracer::callback_table, operation_id, callback, + user_data); + break; + case ACTIVITY_DOMAIN_HIP_OPS: + break; + case ACTIVITY_DOMAIN_ROCTX: + if (RocTxLoader::Instance().Enabled()) + ROCTX_registration_group.Register(roctx_api_callback_table, operation_id, callback, + user_data); break; - } default: EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")"); } } -static void roctracer_enable_callback_impl(roctracer_domain_t domain, uint32_t op, - roctracer_rtapi_callback_t callback, void* user_data) { - cb_journal.Insert(domain, op, {callback, user_data}); - roctracer_enable_callback_fun(domain, op, callback, user_data); -} - ROCTRACER_API roctracer_status_t roctracer_enable_op_callback(roctracer_domain_t domain, uint32_t op, roctracer_rtapi_callback_t callback, @@ -369,43 +537,33 @@ ROCTRACER_API roctracer_status_t roctracer_enable_domain_callback( } // Disable runtime API callbacks -static void roctracer_disable_callback_fun(roctracer_domain_t domain, uint32_t op) { +static void roctracer_disable_callback_impl(roctracer_domain_t domain, uint32_t operation_id) { + std::lock_guard lock(registration_mutex); + switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - case ACTIVITY_DOMAIN_HSA_API: case ACTIVITY_DOMAIN_HSA_EVT: - hsa_support::DisableCallback(domain, op); + HSA_registration_group.Unregister(hsa_evt_callback_table, operation_id); + break; + case ACTIVITY_DOMAIN_HSA_API: + HSA_registration_group.Unregister(HSA_ApiTracer::callback_table, operation_id); + break; + case ACTIVITY_DOMAIN_HSA_OPS: + break; + case ACTIVITY_DOMAIN_HIP_API: + if (HipLoader::Instance().Enabled()) + HIP_registration_group.Unregister(HIP_ApiTracer::callback_table, operation_id); break; case ACTIVITY_DOMAIN_HIP_OPS: break; - case ACTIVITY_DOMAIN_HIP_API: { - if (!HipLoader::Instance().Enabled()) break; - std::lock_guard lock(hip_activity_mutex); - - if (hipError_t err = HipLoader::Instance().RemoveApiCallback(op); err != hipSuccess) - fatal("HIP::RemoveApiCallback(%d) failed (err=%d)", op, err); - - if ((hip_act_cb_tracker.disable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) { - if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess) - fatal("HIP::RemoveActivityCallback(%d) failed (err=%d)", op, err); - } + case ACTIVITY_DOMAIN_ROCTX: + if (RocTxLoader::Instance().Enabled()) + ROCTX_registration_group.Unregister(roctx_api_callback_table, operation_id); break; - } - case ACTIVITY_DOMAIN_ROCTX: { - if (RocTxLoader::Instance().Enabled() && !RocTxLoader::Instance().RemoveApiCallback(op)) - fatal("ROCTX::RemoveApiCallback(%d) failed", op); - break; - } default: EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")"); } } -static void roctracer_disable_callback_impl(roctracer_domain_t domain, uint32_t op) { - cb_journal.Remove(domain, op); - roctracer_disable_callback_fun(domain, op); -} - ROCTRACER_API roctracer_status_t roctracer_disable_op_callback(roctracer_domain_t domain, uint32_t op) { API_METHOD_PREFIX @@ -470,33 +628,32 @@ ROCTRACER_API roctracer_status_t roctracer_next_record(const activity_record_t* } // Enable activity records logging -static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op, - roctracer_pool_t* pool) { - assert(pool != nullptr); - switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - case ACTIVITY_DOMAIN_HSA_API: - case ACTIVITY_DOMAIN_HSA_EVT: - hsa_support::EnableActivity(domain, op, pool); - break; - case ACTIVITY_DOMAIN_HIP_OPS: { - if (HipLoader::Instance().Enabled() && - HipLoader::Instance().RegisterAsyncActivityCallback(op, (void*)HIP_AsyncActivityCallback, - pool) != hipSuccess) - fatal("HIP::EnableActivityCallback error"); - break; - } - case ACTIVITY_DOMAIN_HIP_API: { - if (!HipLoader::Instance().Enabled()) break; - std::lock_guard lock(hip_activity_mutex); +static void roctracer_enable_activity_impl(roctracer_domain_t domain, uint32_t op, + roctracer_pool_t* pool) { + std::lock_guard lock(registration_mutex); - hip_act_cb_tracker.enable_check(op, API_ACT_MASK); - if (hipError_t err = - HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool); - err != hipSuccess) - fatal("HIP::RegisterActivityCallback(%d) (err=%d)", op, err); + MemoryPool* memory_pool = reinterpret_cast(pool); + if (memory_pool == nullptr) memory_pool = default_memory_pool; + if (memory_pool == nullptr) + EXC_RAISING(ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED, "no default pool"); + + switch (domain) { + case ACTIVITY_DOMAIN_HSA_EVT: + break; + case ACTIVITY_DOMAIN_HSA_API: + HSA_registration_group.Register(HSA_ApiTracer::activity_table, op, memory_pool); + break; + case ACTIVITY_DOMAIN_HSA_OPS: + HSA_registration_group.Register(hsa_ops_activity_table, op, memory_pool); + break; + case ACTIVITY_DOMAIN_HIP_API: + if (HipLoader::Instance().Enabled()) + HIP_registration_group.Register(HIP_ApiTracer::activity_table, op, memory_pool); + break; + case ACTIVITY_DOMAIN_HIP_OPS: + if (HipLoader::Instance().Enabled()) + HIP_registration_group.Register(hip_ops_activity_table, op, memory_pool); break; - } case ACTIVITY_DOMAIN_ROCTX: break; default: @@ -504,15 +661,6 @@ static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op } } -static void roctracer_enable_activity_impl(roctracer_domain_t domain, uint32_t op, - roctracer_pool_t* pool) { - if (pool == nullptr) pool = default_memory_pool; - if (pool == nullptr) - EXC_RAISING(ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED, "no default pool"); - act_journal.Insert(domain, op, {pool}); - roctracer_enable_activity_fun(domain, op, pool); -} - ROCTRACER_API roctracer_status_t roctracer_enable_op_activity_expl(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool) { @@ -552,34 +700,26 @@ ROCTRACER_API roctracer_status_t roctracer_enable_domain_activity(activity_domai } // Disable activity records logging -static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t op) { - switch (domain) { - case ACTIVITY_DOMAIN_HSA_OPS: - case ACTIVITY_DOMAIN_HSA_API: - case ACTIVITY_DOMAIN_HSA_EVT: - hsa_support::DisableActivity(domain, op); - break; - case ACTIVITY_DOMAIN_HIP_OPS: { - if (HipLoader::Instance().Enabled() && - HipLoader::Instance().RemoveAsyncActivityCallback(op) != hipSuccess) - fatal("HIP::EnableActivityCallback(%d)", op); - break; - } - case ACTIVITY_DOMAIN_HIP_API: { - if (!HipLoader::Instance().Enabled()) break; - std::lock_guard lock(hip_activity_mutex); +static void roctracer_disable_activity_impl(roctracer_domain_t domain, uint32_t op) { + std::lock_guard lock(registration_mutex); - if ((hip_act_cb_tracker.disable_check(op, API_ACT_MASK) & API_CB_MASK) == 0) { - if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess) - fatal("HIP::RemoveActivityCallback(%d) failed (err=%d)", op, err); - } else { - if (hipError_t err = - HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr); - err != hipSuccess) - fatal("HIP::RegisterActivityCallback(%d) failed (err=%d)", op, err); - } + switch (domain) { + case ACTIVITY_DOMAIN_HSA_EVT: + break; + case ACTIVITY_DOMAIN_HSA_API: + HSA_registration_group.Unregister(HSA_ApiTracer::activity_table, op); + break; + case ACTIVITY_DOMAIN_HSA_OPS: + HSA_registration_group.Unregister(hsa_ops_activity_table, op); + break; + case ACTIVITY_DOMAIN_HIP_API: + if (HipLoader::Instance().Enabled()) + HIP_registration_group.Unregister(HIP_ApiTracer::activity_table, op); + break; + case ACTIVITY_DOMAIN_HIP_OPS: + if (HipLoader::Instance().Enabled()) + HIP_registration_group.Unregister(hip_ops_activity_table, op); break; - } case ACTIVITY_DOMAIN_ROCTX: break; default: @@ -587,11 +727,6 @@ static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t o } } -static void roctracer_disable_activity_impl(roctracer_domain_t domain, uint32_t op) { - act_journal.Remove(domain, op); - roctracer_disable_activity_fun(domain, op); -} - ROCTRACER_API roctracer_status_t roctracer_disable_op_activity(roctracer_domain_t domain, uint32_t op) { API_METHOD_PREFIX @@ -622,6 +757,7 @@ static void roctracer_close_pool_impl(roctracer_pool_t* pool) { MemoryPool* p = reinterpret_cast(pool); if (p == default_memory_pool) default_memory_pool = nullptr; +#if 0 // Disable any activities that specify the pool being deleted. std::vector> ops; act_journal.ForEach( @@ -630,6 +766,7 @@ static void roctracer_close_pool_impl(roctracer_pool_t* pool) { return true; }); for (auto&& [domain, op] : ops) roctracer_disable_activity_impl(domain, op); +#endif delete (p); } @@ -693,35 +830,14 @@ roctracer_activity_pop_external_correlation_id(activity_correlation_id_t* last_i // Start API ROCTRACER_API void roctracer_start() { - if (set_stopped(0)) { - if (ext_support::roctracer_start_cb) ext_support::roctracer_start_cb(); - cb_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const CallbackJournalData& data) { - roctracer_enable_callback_fun(domain, op, data.callback, data.user_data); - return true; - }); - act_journal.ForEach( - [](roctracer_domain_t domain, uint32_t op, const ActivityJournalData& data) { - roctracer_enable_activity_fun(domain, op, data.pool); - return true; - }); - } + if (stopped_status.exchange(false, std::memory_order_relaxed) && roctracer_start_cb) + roctracer_start_cb(); } // Stop API ROCTRACER_API void roctracer_stop() { - if (set_stopped(1)) { - // Must disable the activity first as the spawner checks for the activity being NULL - // to indicate that there is no callback. - act_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const ActivityJournalData&) { - roctracer_disable_activity_fun(domain, op); - return true; - }); - cb_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const CallbackJournalData&) { - roctracer_disable_callback_fun(domain, op); - return true; - }); - if (ext_support::roctracer_stop_cb) ext_support::roctracer_stop_cb(); - } + if (!stopped_status.exchange(true, std::memory_order_relaxed) && roctracer_stop_cb) + roctracer_stop_cb(); } ROCTRACER_API roctracer_status_t roctracer_get_timestamp(roctracer_timestamp_t* timestamp) { @@ -745,8 +861,8 @@ ROCTRACER_API roctracer_status_t roctracer_set_properties(roctracer_domain_t dom case ACTIVITY_DOMAIN_EXT_API: { roctracer_ext_properties_t* ops_properties = reinterpret_cast(properties); - ext_support::roctracer_start_cb = ops_properties->start_cb; - ext_support::roctracer_stop_cb = ops_properties->stop_cb; + roctracer_start_cb = ops_properties->start_cb; + roctracer_stop_cb = ops_properties->stop_cb; break; } default: diff --git a/src/roctx/exportmap b/src/roctx/exportmap index de57516313..9018c824bf 100644 --- a/src/roctx/exportmap +++ b/src/roctx/exportmap @@ -1,11 +1,10 @@ ROCTX_4.1 { -global: RegisterApiCallback; - RemoveApiCallback; - roctxMarkA; +global: roctxMarkA; roctxRangePop; roctxRangePushA; roctxRangeStartA; roctxRangeStop; + roctxRegisterTracerCallback; roctx_version_major; roctx_version_minor; local: *; diff --git a/src/roctx/roctx.cpp b/src/roctx/roctx.cpp index 9da9195862..c7baf57a32 100644 --- a/src/roctx/roctx.cpp +++ b/src/roctx/roctx.cpp @@ -22,66 +22,72 @@ #include "roctracer_roctx.h" #include "ext/prof_protocol.h" -#include "util/callback_table.h" +#include +#include namespace { -roctracer::util::CallbackTable callbacks; -thread_local int nested_range_level(0); +std::atomic report_activity; +thread_local int nested_range_level{0}; + +void ReportActivity(roctx_api_id_t operation_id, const char* message = nullptr, + roctx_range_id_t id = {}) { + auto function = report_activity.load(std::memory_order_relaxed); + if (!function) return; + + roctx_api_data_t api_data{}; + switch (operation_id) { + case ROCTX_API_ID_roctxMarkA: + api_data.args.roctxMarkA.message = message; + break; + case ROCTX_API_ID_roctxRangePushA: + api_data.args.roctxRangePushA.message = message; + break; + case ROCTX_API_ID_roctxRangePop: + break; + case ROCTX_API_ID_roctxRangeStartA: + api_data.args.roctxRangeStartA.message = message; + api_data.args.roctxRangeStartA.id = id; + break; + case ROCTX_API_ID_roctxRangeStop: + api_data.args.roctxRangeStop.id = id; + break; + default: + assert(!"should not reach here"); + } + function(ACTIVITY_DOMAIN_ROCTX, operation_id, &api_data); +} } // namespace ROCTX_API uint32_t roctx_version_major() { return ROCTX_VERSION_MAJOR; } ROCTX_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; } -ROCTX_API void roctxMarkA(const char* message) { - roctx_api_data_t api_data{}; - api_data.args.roctxMarkA.message = message; - callbacks.Invoke(ROCTX_API_ID_roctxMarkA, &api_data); -} +ROCTX_API void roctxMarkA(const char* message) { ReportActivity(ROCTX_API_ID_roctxMarkA, message); } ROCTX_API int roctxRangePushA(const char* message) { - roctx_api_data_t api_data{}; - api_data.args.roctxRangePushA.message = message; - callbacks.Invoke(ROCTX_API_ID_roctxRangePushA, &api_data); - + ReportActivity(ROCTX_API_ID_roctxRangePushA, message); return nested_range_level++; } ROCTX_API int roctxRangePop() { - roctx_api_data_t api_data{}; - callbacks.Invoke(ROCTX_API_ID_roctxRangePop, &api_data); - + ReportActivity(ROCTX_API_ID_roctxRangePop); if (nested_range_level == 0) return -1; return --nested_range_level; } ROCTX_API roctx_range_id_t roctxRangeStartA(const char* message) { static std::atomic start_stop_range_id(1); - auto id = start_stop_range_id++; - - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStartA.message = message; - api_data.args.roctxRangeStartA.id = id; - callbacks.Invoke(ROCTX_API_ID_roctxRangeStartA, &api_data); - - return id; + auto range_id = start_stop_range_id++; + ReportActivity(ROCTX_API_ID_roctxRangeStartA, message, range_id); + return range_id; } -ROCTX_API void roctxRangeStop(roctx_range_id_t rangeId) { - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStop.id = rangeId; - callbacks.Invoke(ROCTX_API_ID_roctxRangeStop, &api_data); +ROCTX_API void roctxRangeStop(roctx_range_id_t range_id) { + ReportActivity(ROCTX_API_ID_roctxRangeStop, nullptr, range_id); } -extern "C" ROCTX_EXPORT bool RegisterApiCallback(uint32_t op, void* callback, void* arg) { - if (op >= ROCTX_API_ID_NUMBER) return false; - callbacks.Set(op, reinterpret_cast(callback), arg); - return true; +extern "C" ROCTX_EXPORT void roctxRegisterTracerCallback(const void* function) { + report_activity.store(reinterpret_cast(function), + std::memory_order_relaxed); } - -extern "C" ROCTX_EXPORT bool RemoveApiCallback(uint32_t op) { - if (op >= ROCTX_API_ID_NUMBER) return false; - callbacks.Set(op, nullptr, nullptr); - return true; -} \ No newline at end of file diff --git a/src/util/callback_table.h b/src/util/callback_table.h deleted file mode 100644 index 5165ef549f..0000000000 --- a/src/util/callback_table.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - THE SOFTWARE. */ - -#ifndef UTIL_CALLBACK_TABLE_H_ -#define UTIL_CALLBACK_TABLE_H_ - -#include "ext/prof_protocol.h" - -#include -#include -#include -#include -#include - -namespace roctracer::util { - -// Generic callbacks table -template class CallbackTable { - public: - CallbackTable() - // Zero initialize the callbacks array as the function pointer is used to determine if the - // callback is enabled. - : callbacks_() {} - - void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) { - assert(callback_id < N && "callback_id is out of range"); - std::lock_guard lock(mutex_); - auto& callback = callbacks_[callback_id]; - callback.first.store(callback_function, std::memory_order_relaxed); - callback.second = user_arg; - } - - auto Get(uint32_t callback_id) const { - assert(callback_id < N && "id is out of range"); - std::lock_guard lock(mutex_); - auto& callback = callbacks_[callback_id]; - return std::make_pair(callback.first.load(std::memory_order_relaxed), callback.second); - } - - template void Invoke(uint32_t callback_id, Args... args) { - if (callbacks_[callback_id].first.load(std::memory_order_relaxed) == nullptr) return; - if (auto [callback_function, user_arg] = Get(callback_id); callback_function != nullptr) - callback_function(Domain, callback_id, std::forward(args)..., user_arg); - } - - private: - std::array, void*>, N> callbacks_; - mutable std::mutex mutex_; -}; - -} // namespace roctracer::util - -#endif // UTIL_CALLBACK_TABLE_H_