SWDEV-351980 - Consolidate registration tables in the roctracer
Change-Id: I44cd1cc81cf6a529aed89ee8db1377c0aa67f0dc
[ROCm/roctracer commit: 2673bf5e2c]
Этот коммит содержится в:
@@ -41,7 +41,12 @@ inline static std::ostream& operator<<(std::ostream& out, const char& v) {
|
||||
|
||||
#include <roctracer.h>
|
||||
|
||||
enum { HIP_OP_ID_DISPATCH = 0, HIP_OP_ID_COPY = 1, HIP_OP_ID_BARRIER = 2, HIP_OP_ID_NUMBER = 3 };
|
||||
typedef enum {
|
||||
HIP_OP_ID_DISPATCH = 0,
|
||||
HIP_OP_ID_COPY = 1,
|
||||
HIP_OP_ID_BARRIER = 2,
|
||||
HIP_OP_ID_NUMBER = 3
|
||||
} hip_op_id_t;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
||||
@@ -332,7 +332,6 @@ class API_DescrParser:
|
||||
self.cpp_content += "/* Generated by " + os.path.basename(__file__) + " */\n" + license + "\n\n"
|
||||
|
||||
self.cpp_content += '#include <hsa/hsa_api_trace.h>\n'
|
||||
self.cpp_content += '#include \"util/callback_table.h\"\n\n'
|
||||
self.cpp_content += '#include <atomic>\n'
|
||||
self.cpp_content += 'namespace roctracer::hsa_support::detail {\n'
|
||||
|
||||
@@ -409,40 +408,52 @@ class API_DescrParser:
|
||||
def gen_callbacks(self, n, name, call, struct):
|
||||
content = ''
|
||||
if n == -1:
|
||||
content += 'static util::CallbackTable<ACTIVITY_DOMAIN_HSA_API, HSA_API_ID_NUMBER> cb_table;\n'
|
||||
content += '/* section: Static declarations */\n'
|
||||
content += '\n'
|
||||
if call != '-':
|
||||
call_id = self.api_id[call];
|
||||
ret_type = struct['ret']
|
||||
content += 'static ' + ret_type + ' ' + call + '_callback(' + struct['args'] + ') {\n'
|
||||
content += ' hsa_api_data_t api_data{};\n'
|
||||
|
||||
content += ' hsa_trace_data_t trace_data;\n'
|
||||
content += ' bool enabled{false};\n'
|
||||
content += '\n'
|
||||
content += ' if (auto function = report_activity.load(std::memory_order_relaxed); function &&\n'
|
||||
content += ' (enabled =\n'
|
||||
content += ' function(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &trace_data) == 0)) {\n'
|
||||
content += ' if (trace_data.phase_enter != nullptr) {\n'
|
||||
|
||||
for var in struct['alst']:
|
||||
item = struct['astr'][var];
|
||||
if re.search(r'char\* ', item):
|
||||
content += ' api_data.args.' + call + '.' + var + ' = ' + '(' + var + ' != NULL) ? strdup(' + var + ')' + ' : NULL;\n'
|
||||
# FIXME: we should not strdup the char* arguments here, as the callback will not outlive the scope of this function. Instead, we
|
||||
# should generate a helper function to capture the content of the arguments similar to hipApiArgsInit for HIP. We also need a
|
||||
# helper to free the memory that is allocated to capture the content.
|
||||
content += ' trace_data.api_data.args.' + call + '.' + var + ' = ' + '(' + var + ' != NULL) ? strdup(' + var + ')' + ' : NULL;\n'
|
||||
else:
|
||||
content += ' api_data.args.' + call + '.' + var + ' = ' + var + ';\n'
|
||||
content += ' trace_data.api_data.args.' + call + '.' + var + ' = ' + var + ';\n'
|
||||
if call == 'hsa_amd_memory_async_copy_rect' and var == 'range':
|
||||
content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
|
||||
content += ' auto [ api_callback_fun, api_callback_arg ] = cb_table.Get(' + call_id + ');\n'
|
||||
content += ' if (api_callback_fun) {\n'
|
||||
content += ' api_data.phase = ACTIVITY_API_PHASE_ENTER;\n'
|
||||
content += ' api_data.correlation_id = CorrelationIdPush();\n'
|
||||
content += ' api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
|
||||
content += ' trace_data.api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
|
||||
|
||||
content += ' trace_data.phase_enter(' + call_id + ', &trace_data);\n'
|
||||
content += ' }\n'
|
||||
content += ' }\n'
|
||||
content += '\n'
|
||||
|
||||
if ret_type != 'void':
|
||||
# FIXME: we should capture the return value and store it in the api_data
|
||||
content += ' ' + ret_type + ' ret ='
|
||||
content += ' ' + name + '_saved_before_cb.' + call + '_fn(' + ', '.join(struct['alst']) + ');\n'
|
||||
content += ' if (api_callback_fun) {\n'
|
||||
if ret_type != 'void':
|
||||
content += ' api_data.' + ret_type + '_retval = ret;\n'
|
||||
content += ' api_data.phase = ACTIVITY_API_PHASE_EXIT;\n'
|
||||
content += ' api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
|
||||
content += ' CorrelationIdPop();\n'
|
||||
content += ' }\n'
|
||||
|
||||
content += '\n'
|
||||
content += ' if (enabled && trace_data.phase_exit != nullptr)\n'
|
||||
content += ' trace_data.phase_exit(' + call_id + ', &trace_data);\n'
|
||||
|
||||
if ret_type != 'void':
|
||||
content += '\n'
|
||||
content += ' return ret;\n'
|
||||
content += '}\n'
|
||||
|
||||
return content
|
||||
|
||||
# generate API intercepting code
|
||||
|
||||
@@ -26,7 +26,6 @@
|
||||
#include "memory_pool.h"
|
||||
#include "roctracer.h"
|
||||
#include "roctracer_hsa.h"
|
||||
#include "callback_table.h"
|
||||
|
||||
#include <hsa/hsa.h>
|
||||
#include <hsa/amd_hsa_signal.h>
|
||||
@@ -35,23 +34,32 @@
|
||||
#include <optional>
|
||||
#include <mutex>
|
||||
|
||||
namespace {
|
||||
|
||||
std::atomic<int (*)(activity_domain_t domain, uint32_t operation_id, void* data)> report_activity;
|
||||
|
||||
bool IsEnabled(activity_domain_t domain, uint32_t operation_id) {
|
||||
auto report = report_activity.load(std::memory_order_relaxed);
|
||||
return report && report(domain, operation_id, nullptr) == 0;
|
||||
}
|
||||
|
||||
void ReportActivity(activity_domain_t domain, uint32_t operation_id, void* data) {
|
||||
if (auto report = report_activity.load(std::memory_order_relaxed))
|
||||
report(domain, operation_id, data);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#include "hsa_prof_str.inline.h"
|
||||
|
||||
namespace roctracer::hsa_support {
|
||||
|
||||
namespace {
|
||||
|
||||
util::CallbackTable<ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_NUMBER> hsa_evt_cb_table;
|
||||
|
||||
CoreApiTable saved_core_api{};
|
||||
AmdExtTable saved_amd_ext_api{};
|
||||
hsa_ven_amd_loader_1_01_pfn_t hsa_loader_api{};
|
||||
|
||||
// async copy activity callback
|
||||
std::mutex init_mutex;
|
||||
bool async_copy_callback_enabled = false;
|
||||
MemoryPool* async_copy_callback_memory_pool = nullptr;
|
||||
|
||||
struct AgentInfo {
|
||||
int index;
|
||||
hsa_device_type_t type;
|
||||
@@ -81,7 +89,6 @@ class Tracker {
|
||||
hsa_signal_t orig;
|
||||
hsa_signal_t signal;
|
||||
void (*handler)(const entry_t*);
|
||||
MemoryPool* pool;
|
||||
union {
|
||||
struct {
|
||||
} copy;
|
||||
@@ -182,7 +189,7 @@ hsa_status_t HSA_API MemoryAllocateIntercept(hsa_region_t region, size_t size, v
|
||||
hsa_status_t status = saved_core_api.hsa_memory_allocate_fn(region, size, ptr);
|
||||
if (status != HSA_STATUS_SUCCESS) return status;
|
||||
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) {
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) {
|
||||
hsa_evt_data_t data{};
|
||||
data.allocate.ptr = *ptr;
|
||||
data.allocate.size = size;
|
||||
@@ -192,7 +199,7 @@ hsa_status_t HSA_API MemoryAllocateIntercept(hsa_region_t region, size_t size, v
|
||||
&data.allocate.global_flag) != HSA_STATUS_SUCCESS)
|
||||
fatal("hsa_region_get_info failed");
|
||||
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data);
|
||||
}
|
||||
|
||||
return HSA_STATUS_SUCCESS;
|
||||
@@ -203,14 +210,14 @@ hsa_status_t MemoryAssignAgentIntercept(void* ptr, hsa_agent_t agent,
|
||||
hsa_status_t status = saved_core_api.hsa_memory_assign_agent_fn(ptr, agent, access);
|
||||
if (status != HSA_STATUS_SUCCESS) return status;
|
||||
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); callback_fun) {
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) {
|
||||
hsa_evt_data_t data{};
|
||||
data.device.ptr = ptr;
|
||||
if (saved_core_api.hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_DEVICE, &data.device.type) !=
|
||||
HSA_STATUS_SUCCESS)
|
||||
fatal("hsa_agent_get_info failed");
|
||||
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data);
|
||||
}
|
||||
|
||||
return HSA_STATUS_SUCCESS;
|
||||
@@ -220,13 +227,13 @@ hsa_status_t MemoryCopyIntercept(void* dst, const void* src, size_t size) {
|
||||
hsa_status_t status = saved_core_api.hsa_memory_copy_fn(dst, src, size);
|
||||
if (status != HSA_STATUS_SUCCESS) return status;
|
||||
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_MEMCOPY); callback_fun) {
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY)) {
|
||||
hsa_evt_data_t data{};
|
||||
data.memcopy.dst = dst;
|
||||
data.memcopy.src = src;
|
||||
data.memcopy.size = size;
|
||||
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY, &data);
|
||||
}
|
||||
|
||||
return HSA_STATUS_SUCCESS;
|
||||
@@ -237,7 +244,7 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
|
||||
hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_pool_allocate_fn(pool, size, flags, ptr);
|
||||
if (size == 0 || status != HSA_STATUS_SUCCESS) return status;
|
||||
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) {
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) {
|
||||
hsa_evt_data_t data{};
|
||||
data.allocate.ptr = *ptr;
|
||||
data.allocate.size = size;
|
||||
@@ -249,17 +256,13 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
|
||||
HSA_STATUS_SUCCESS)
|
||||
fatal("hsa_region_get_info failed");
|
||||
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data);
|
||||
}
|
||||
|
||||
if (std::tie(callback_fun, callback_arg) = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE);
|
||||
!callback_fun)
|
||||
return HSA_STATUS_SUCCESS;
|
||||
|
||||
// FIXME: Why is this only reported if HSA_EVT_ID_ALLOCATE is also set?
|
||||
auto callback_data = std::make_tuple(callback_fun, callback_arg, pool, ptr);
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) {
|
||||
auto callback_data = std::make_pair(pool, ptr);
|
||||
auto agent_callback = [](hsa_agent_t agent, void* iterate_agent_callback_data) {
|
||||
auto [callback_fun, callback_arg, pool, ptr] =
|
||||
*reinterpret_cast<decltype(callback_data)*>(iterate_agent_callback_data);
|
||||
auto [pool, ptr] = *reinterpret_cast<decltype(callback_data)*>(iterate_agent_callback_data);
|
||||
|
||||
if (hsa_amd_memory_pool_access_t value;
|
||||
saved_amd_ext_api.hsa_amd_agent_memory_pool_get_info_fn(
|
||||
@@ -276,7 +279,7 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
|
||||
data.device.agent = agent;
|
||||
data.device.ptr = ptr;
|
||||
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data);
|
||||
return HSA_STATUS_SUCCESS;
|
||||
};
|
||||
saved_core_api.hsa_iterate_agents_fn(agent_callback, &callback_data);
|
||||
@@ -286,11 +289,11 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
|
||||
}
|
||||
|
||||
hsa_status_t MemoryPoolFreeIntercept(void* ptr) {
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) {
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) {
|
||||
hsa_evt_data_t data{};
|
||||
data.allocate.ptr = ptr;
|
||||
data.allocate.size = 0;
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data);
|
||||
}
|
||||
|
||||
return saved_amd_ext_api.hsa_amd_memory_pool_free_fn(ptr);
|
||||
@@ -303,7 +306,7 @@ hsa_status_t AgentsAllowAccessIntercept(uint32_t num_agents, const hsa_agent_t*
|
||||
saved_amd_ext_api.hsa_amd_agents_allow_access_fn(num_agents, agents, flags, ptr);
|
||||
if (status != HSA_STATUS_SUCCESS) return status;
|
||||
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); callback_fun) {
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) {
|
||||
while (num_agents--) {
|
||||
hsa_agent_t agent = *agents++;
|
||||
auto it = agent_info_map.find(agent.handle);
|
||||
@@ -315,7 +318,7 @@ hsa_status_t AgentsAllowAccessIntercept(uint32_t num_agents, const hsa_agent_t*
|
||||
data.device.agent = agent;
|
||||
data.device.ptr = ptr;
|
||||
|
||||
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data);
|
||||
}
|
||||
}
|
||||
return HSA_STATUS_SUCCESS;
|
||||
@@ -329,7 +332,6 @@ struct CodeObjectCallbackArg {
|
||||
|
||||
hsa_status_t CodeObjectCallback(hsa_executable_t executable,
|
||||
hsa_loaded_code_object_t loaded_code_object, void* arg) {
|
||||
auto* code_object_callback_arg = static_cast<CodeObjectCallbackArg*>(arg);
|
||||
hsa_evt_data_t data{};
|
||||
|
||||
if (hsa_loader_api.hsa_ven_amd_loader_loaded_code_object_get_info(
|
||||
@@ -384,9 +386,8 @@ hsa_status_t CodeObjectCallback(hsa_executable_t executable,
|
||||
fatal("hsa_ven_amd_loader_loaded_code_object_get_info failed");
|
||||
|
||||
data.codeobj.uri = uri_str.c_str();
|
||||
data.codeobj.unload = code_object_callback_arg->unload ? 1 : 0;
|
||||
code_object_callback_arg->callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ, &data,
|
||||
code_object_callback_arg->callback_arg);
|
||||
data.codeobj.unload = *static_cast<bool*>(arg) ? 1 : 0;
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ, &data);
|
||||
|
||||
return HSA_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -395,20 +396,20 @@ hsa_status_t ExecutableFreezeIntercept(hsa_executable_t executable, const char*
|
||||
hsa_status_t status = saved_core_api.hsa_executable_freeze_fn(executable, options);
|
||||
if (status != HSA_STATUS_SUCCESS) return status;
|
||||
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_CODEOBJ); callback_fun) {
|
||||
CodeObjectCallbackArg arg = {callback_fun, callback_arg, false};
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ)) {
|
||||
bool unload = false;
|
||||
hsa_loader_api.hsa_ven_amd_loader_executable_iterate_loaded_code_objects(
|
||||
executable, CodeObjectCallback, &arg);
|
||||
executable, CodeObjectCallback, &unload);
|
||||
}
|
||||
|
||||
return HSA_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
hsa_status_t ExecutableDestroyIntercept(hsa_executable_t executable) {
|
||||
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_CODEOBJ); callback_fun) {
|
||||
CodeObjectCallbackArg arg = {callback_fun, callback_arg, true};
|
||||
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ)) {
|
||||
bool unload = true;
|
||||
hsa_loader_api.hsa_ven_amd_loader_executable_iterate_loaded_code_objects(
|
||||
executable, CodeObjectCallback, &arg);
|
||||
executable, CodeObjectCallback, &unload);
|
||||
}
|
||||
|
||||
return saved_core_api.hsa_executable_destroy_fn(executable);
|
||||
@@ -422,25 +423,31 @@ void MemoryASyncCopyHandler(const Tracker::entry_t* entry) {
|
||||
record.end_ns = entry->end;
|
||||
record.device_id = 0;
|
||||
record.correlation_id = entry->correlation_id;
|
||||
entry->pool->Write(record);
|
||||
ReportActivity(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY, &record);
|
||||
}
|
||||
|
||||
hsa_status_t MemoryASyncCopyIntercept(void* dst, hsa_agent_t dst_agent, const void* src,
|
||||
hsa_agent_t src_agent, size_t size, uint32_t num_dep_signals,
|
||||
const hsa_signal_t* dep_signals,
|
||||
hsa_signal_t completion_signal) {
|
||||
if (!async_copy_callback_enabled) {
|
||||
bool is_enabled = IsEnabled(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY);
|
||||
|
||||
// FIXME: what happens if the state changes before returning?
|
||||
[[maybe_unused]] hsa_status_t status =
|
||||
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(is_enabled);
|
||||
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
|
||||
|
||||
if (!is_enabled) {
|
||||
return saved_amd_ext_api.hsa_amd_memory_async_copy_fn(
|
||||
dst, dst_agent, src, src_agent, size, num_dep_signals, dep_signals, completion_signal);
|
||||
}
|
||||
|
||||
Tracker::entry_t* entry = new Tracker::entry_t();
|
||||
entry->handler = MemoryASyncCopyHandler;
|
||||
entry->pool = async_copy_callback_memory_pool;
|
||||
entry->correlation_id = CorrelationId();
|
||||
Tracker::Enable(Tracker::COPY_ENTRY_TYPE, hsa_agent_t{}, completion_signal, entry);
|
||||
|
||||
hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_async_copy_fn(
|
||||
status = saved_amd_ext_api.hsa_amd_memory_async_copy_fn(
|
||||
dst, dst_agent, src, src_agent, size, num_dep_signals, dep_signals, entry->signal);
|
||||
if (status != HSA_STATUS_SUCCESS) Tracker::Disable(entry);
|
||||
|
||||
@@ -454,7 +461,14 @@ hsa_status_t MemoryASyncCopyRectIntercept(const hsa_pitched_ptr_t* dst,
|
||||
hsa_agent_t copy_agent, hsa_amd_copy_direction_t dir,
|
||||
uint32_t num_dep_signals, const hsa_signal_t* dep_signals,
|
||||
hsa_signal_t completion_signal) {
|
||||
if (!async_copy_callback_enabled) {
|
||||
bool is_enabled = IsEnabled(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY);
|
||||
|
||||
// FIXME: what happens if the state changes before returning?
|
||||
[[maybe_unused]] hsa_status_t status =
|
||||
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(is_enabled);
|
||||
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
|
||||
|
||||
if (!is_enabled) {
|
||||
return saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn(
|
||||
dst, dst_offset, src, src_offset, range, copy_agent, dir, num_dep_signals, dep_signals,
|
||||
completion_signal);
|
||||
@@ -462,11 +476,10 @@ hsa_status_t MemoryASyncCopyRectIntercept(const hsa_pitched_ptr_t* dst,
|
||||
|
||||
Tracker::entry_t* entry = new Tracker::entry_t();
|
||||
entry->handler = MemoryASyncCopyHandler;
|
||||
entry->pool = async_copy_callback_memory_pool;
|
||||
entry->correlation_id = CorrelationId();
|
||||
Tracker::Enable(Tracker::COPY_ENTRY_TYPE, hsa_agent_t{}, completion_signal, entry);
|
||||
|
||||
hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn(
|
||||
status = saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn(
|
||||
dst, dst_offset, src, src_offset, range, copy_agent, dir, num_dep_signals, dep_signals,
|
||||
entry->signal);
|
||||
if (status != HSA_STATUS_SUCCESS) Tracker::Disable(entry);
|
||||
@@ -502,8 +515,6 @@ roctracer_timestamp_t timestamp_ns() {
|
||||
}
|
||||
|
||||
void Initialize(HsaApiTable* table) {
|
||||
std::scoped_lock lock(init_mutex);
|
||||
|
||||
// Save the HSA core api and amd_ext api.
|
||||
saved_core_api = *table->core_;
|
||||
saved_amd_ext_api = *table->amd_ext_;
|
||||
@@ -558,20 +569,12 @@ void Initialize(HsaApiTable* table) {
|
||||
detail::InstallCoreApiWrappers(table->core_);
|
||||
detail::InstallAmdExtWrappers(table->amd_ext_);
|
||||
detail::InstallImageExtWrappers(table->image_ext_);
|
||||
|
||||
if (async_copy_callback_enabled) {
|
||||
[[maybe_unused]] hsa_status_t status =
|
||||
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(true);
|
||||
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
|
||||
}
|
||||
}
|
||||
|
||||
void Finalize() {
|
||||
if (hsa_support::async_copy_callback_enabled) {
|
||||
[[maybe_unused]] hsa_status_t status =
|
||||
hsa_support::saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false);
|
||||
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
|
||||
}
|
||||
if (hsa_status_t status = saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false);
|
||||
status != HSA_STATUS_SUCCESS)
|
||||
assert(!"hsa_amd_profiling_async_copy_enable failed");
|
||||
}
|
||||
|
||||
const char* GetApiName(uint32_t id) { return detail::GetApiName(id); }
|
||||
@@ -612,111 +615,9 @@ const char* GetOpsName(uint32_t id) {
|
||||
|
||||
uint32_t GetApiCode(const char* str) { return detail::GetApiCode(str); }
|
||||
|
||||
void EnableActivity(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
if (op == HSA_OP_ID_COPY) {
|
||||
std::scoped_lock lock(init_mutex);
|
||||
|
||||
if (saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn != nullptr) {
|
||||
[[maybe_unused]] hsa_status_t status =
|
||||
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(true);
|
||||
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
|
||||
}
|
||||
async_copy_callback_enabled = true;
|
||||
async_copy_callback_memory_pool = reinterpret_cast<MemoryPool*>(pool);
|
||||
} else if (op == HSA_OP_ID_RESERVED1) {
|
||||
/* Place holder for PC sampling. */
|
||||
} else {
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED,
|
||||
"HSA OPS operation ID(" << op << ") is not currently implemented");
|
||||
}
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
// FIXME: Add HSA api activities.
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void EnableCallback(roctracer_domain_t domain, uint32_t cid, roctracer_rtapi_callback_t callback,
|
||||
void* user_data) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
if (cid >= HSA_API_ID_NUMBER)
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
|
||||
"invalid HSA API operation ID(" << cid << ")");
|
||||
|
||||
detail::cb_table.Set(cid, callback, user_data);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
if (cid >= HSA_EVT_ID_NUMBER)
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
|
||||
"invalid HSA API operation ID(" << cid << ")");
|
||||
|
||||
hsa_evt_cb_table.Set(cid, callback, user_data);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DisableActivity(roctracer_domain_t domain, uint32_t op) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
if (op == HSA_OP_ID_COPY) {
|
||||
std::scoped_lock lock(init_mutex);
|
||||
|
||||
async_copy_callback_enabled = false;
|
||||
async_copy_callback_memory_pool = nullptr;
|
||||
|
||||
if (saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn != nullptr) {
|
||||
[[maybe_unused]] hsa_status_t status =
|
||||
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false);
|
||||
assert(status == HSA_STATUS_SUCCESS || status == HSA_STATUS_ERROR_NOT_INITIALIZED ||
|
||||
!"hsa_amd_profiling_async_copy_enable failed");
|
||||
}
|
||||
} else if (op == HSA_OP_ID_RESERVED1) {
|
||||
/* Place holder for PC sampling. */
|
||||
} else {
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED,
|
||||
"HSA OPS operation ID(" << op << ") is not currently implemented");
|
||||
}
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
// FIXME: Add HSA api activities.
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DisableCallback(roctracer_domain_t domain, uint32_t cid) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
if (cid >= HSA_API_ID_NUMBER)
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
|
||||
"invalid HSA API operation ID(" << cid << ")");
|
||||
detail::cb_table.Set(cid, nullptr, nullptr);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
if (cid >= HSA_EVT_ID_NUMBER)
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
|
||||
"invalid HSA EVT operation ID(" << cid << ")");
|
||||
hsa_evt_cb_table.Set(cid, nullptr, nullptr);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
void RegisterTracerCallback(int (*function)(activity_domain_t domain, uint32_t operation_id,
|
||||
void* data)) {
|
||||
report_activity.store(function, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
} // namespace roctracer::hsa_support
|
||||
|
||||
@@ -28,6 +28,15 @@
|
||||
|
||||
namespace roctracer::hsa_support {
|
||||
|
||||
struct hsa_trace_data_t {
|
||||
hsa_api_data_t api_data;
|
||||
uint64_t phase_enter_timestamp;
|
||||
uint64_t phase_data;
|
||||
|
||||
void (*phase_enter)(hsa_api_id_t operation_id, hsa_trace_data_t* data);
|
||||
void (*phase_exit)(hsa_api_id_t operation_id, hsa_trace_data_t* data);
|
||||
};
|
||||
|
||||
void Initialize(HsaApiTable* table);
|
||||
void Finalize();
|
||||
|
||||
@@ -36,13 +45,8 @@ const char* GetEvtName(uint32_t id);
|
||||
const char* GetOpsName(uint32_t id);
|
||||
uint32_t GetApiCode(const char* str);
|
||||
|
||||
void EnableActivity(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool);
|
||||
void EnableCallback(roctracer_domain_t domain, uint32_t cid, roctracer_rtapi_callback_t callback,
|
||||
void* user_data);
|
||||
|
||||
void DisableCallback(roctracer_domain_t domain, uint32_t cid);
|
||||
void DisableActivity(roctracer_domain_t domain, uint32_t op);
|
||||
|
||||
void RegisterTracerCallback(int (*function)(activity_domain_t domain, uint32_t operation_id,
|
||||
void* data));
|
||||
uint64_t timestamp_ns();
|
||||
|
||||
} // namespace roctracer::hsa_support
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE. */
|
||||
|
||||
#ifndef SRC_CORE_JOURNAL_H_
|
||||
#define SRC_CORE_JOURNAL_H_
|
||||
|
||||
#include "ext/prof_protocol.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace roctracer {
|
||||
|
||||
template <typename Data> class Journal {
|
||||
public:
|
||||
/* Insert { domain, op } into the journal. Return false if the insertion
|
||||
updated an existing entry. */
|
||||
template <typename T = Data, std::enable_if_t<std::is_constructible_v<Data, T>, int> = 0>
|
||||
bool Insert(roctracer_domain_t domain, uint32_t op, T&& data) {
|
||||
std::lock_guard lock(mutex_);
|
||||
auto result = map_[domain].try_emplace(op, std::forward<T>(data));
|
||||
if (!result.second) result.first->second = std::forward<T>(data);
|
||||
return result.second;
|
||||
}
|
||||
|
||||
/* Remove { domain, op } from the journal. Return false if the entry did not exist. */
|
||||
bool Remove(roctracer_domain_t domain, uint32_t op) {
|
||||
std::lock_guard lock(mutex_);
|
||||
return map_[domain].erase(op) == 1;
|
||||
}
|
||||
|
||||
template <typename Functor> void ForEach(Functor&& func) {
|
||||
std::lock_guard lock(mutex_);
|
||||
for (auto&& domain : map_)
|
||||
for (auto&& operation : domain.second)
|
||||
if (!func(domain.first /* domain */, operation.first /* op */, operation.second /* data */))
|
||||
break; /* FIXME: what are we breaking out of? */
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::unordered_map<roctracer_domain_t, std::unordered_map<uint32_t, Data>> map_;
|
||||
};
|
||||
|
||||
} // namespace roctracer
|
||||
|
||||
#endif // SRC_CORE_JOURNAL_H_
|
||||
@@ -122,13 +122,6 @@ __attribute__((weak)) const char* hipKernelNameRefByPtr(const void* hostFunction
|
||||
__attribute__((weak)) int hipGetStreamDeviceId(hipStream_t stream) { return 0; }
|
||||
__attribute__((weak)) const char* hipApiName(uint32_t id) { return NULL; }
|
||||
|
||||
__attribute__((weak)) hipError_t hipRegisterAsyncActivityCallback_t(uint32_t op, void* fun,
|
||||
void* arg) {
|
||||
return hipErrorUnknown;
|
||||
}
|
||||
__attribute__((weak)) hipError_t hipRemoveAsyncActivityCallback_t(uint32_t op) {
|
||||
return hipErrorUnknown;
|
||||
}
|
||||
__attribute__((weak)) const char* hipGetCmdName(unsigned op) { return NULL; }
|
||||
|
||||
class HipLoaderStatic {
|
||||
@@ -158,12 +151,8 @@ class HipLoaderStatic {
|
||||
GetStreamDeviceId_t* GetStreamDeviceId;
|
||||
ApiName_t* ApiName;
|
||||
|
||||
typedef hipError_t(hipRegisterAsyncActivityCallback_t)(uint32_t op, void* fun, void* arg);
|
||||
typedef hipError_t(hipRemoveAsyncActivityCallback_t)(uint32_t op);
|
||||
typedef const char*(hipGetOpName_t)(unsigned op);
|
||||
|
||||
hipRegisterAsyncActivityCallback_t* RegisterActivityCallback;
|
||||
hipRemoveAsyncActivityCallback_t* RemoveActivityCallback;
|
||||
hipGetOpName_t* GetOpName;
|
||||
|
||||
static inline loader_t& Instance() {
|
||||
@@ -191,8 +180,6 @@ class HipLoaderStatic {
|
||||
GetStreamDeviceId = hipGetStreamDeviceId;
|
||||
ApiName = hipApiName;
|
||||
|
||||
RegisterAsyncActivityCallback = hipRegisterAsyncActivityCallback;
|
||||
RemoveAsyncActivityCallback = hipRemoveAsyncActivityCallback;
|
||||
GetOpName = hipGetCmdName;
|
||||
}
|
||||
|
||||
@@ -204,53 +191,36 @@ class HipApi {
|
||||
public:
|
||||
typedef BaseLoader<HipApi> Loader;
|
||||
|
||||
typedef decltype(hipRegisterApiCallback) RegisterApiCallback_t;
|
||||
typedef decltype(hipRemoveApiCallback) RemoveApiCallback_t;
|
||||
typedef decltype(hipRegisterActivityCallback) RegisterActivityCallback_t;
|
||||
typedef decltype(hipRemoveActivityCallback) RemoveActivityCallback_t;
|
||||
typedef decltype(hipKernelNameRef) KernelNameRef_t;
|
||||
typedef decltype(hipKernelNameRefByPtr) KernelNameRefByPtr_t;
|
||||
typedef decltype(hipGetStreamDeviceId) GetStreamDeviceId_t;
|
||||
typedef decltype(hipApiName) ApiName_t;
|
||||
typedef int(hipGetStreamDeviceId_t)(hipStream_t stream);
|
||||
typedef const char*(hipKernelNameRef_t)(const hipFunction_t function);
|
||||
typedef const char*(hipKernelNameRefByPtr_t)(const void* host_function, hipStream_t stream);
|
||||
typedef const char*(hipApiName_t)(uint32_t id);
|
||||
typedef const char*(hipGetCmdName_t)(uint32_t op);
|
||||
typedef void(hipRegisterTracerCallback_t)(int (*function)(activity_domain_t domain,
|
||||
uint32_t operation_id, void* data));
|
||||
|
||||
RegisterApiCallback_t* RegisterApiCallback;
|
||||
RemoveApiCallback_t* RemoveApiCallback;
|
||||
RegisterActivityCallback_t* RegisterActivityCallback;
|
||||
RemoveActivityCallback_t* RemoveActivityCallback;
|
||||
KernelNameRef_t* KernelNameRef;
|
||||
KernelNameRefByPtr_t* KernelNameRefByPtr_;
|
||||
hipKernelNameRef_t* KernelNameRef;
|
||||
const char* KernelNameRefByPtr(const void* function, hipStream_t stream = nullptr) const {
|
||||
return KernelNameRefByPtr_(function, stream);
|
||||
}
|
||||
GetStreamDeviceId_t* GetStreamDeviceId;
|
||||
ApiName_t* ApiName;
|
||||
|
||||
typedef hipError_t(hipRegisterAsyncActivityCallback_t)(uint32_t op, void* fun, void* arg);
|
||||
typedef hipError_t(hipRemoveAsyncActivityCallback_t)(uint32_t op);
|
||||
typedef const char*(hipGetOpName_t)(unsigned op);
|
||||
|
||||
hipRegisterAsyncActivityCallback_t* RegisterAsyncActivityCallback;
|
||||
hipRemoveAsyncActivityCallback_t* RemoveAsyncActivityCallback;
|
||||
hipGetOpName_t* GetOpName;
|
||||
hipGetStreamDeviceId_t* GetStreamDeviceId;
|
||||
hipGetCmdName_t* GetOpName;
|
||||
hipApiName_t* ApiName;
|
||||
hipRegisterTracerCallback_t* RegisterTracerCallback;
|
||||
|
||||
protected:
|
||||
void init(Loader* loader) {
|
||||
RegisterApiCallback = loader->GetFun<RegisterApiCallback_t>("hipRegisterApiCallback");
|
||||
RemoveApiCallback = loader->GetFun<RemoveApiCallback_t>("hipRemoveApiCallback");
|
||||
RegisterActivityCallback =
|
||||
loader->GetFun<RegisterActivityCallback_t>("hipRegisterActivityCallback");
|
||||
RemoveActivityCallback = loader->GetFun<RemoveActivityCallback_t>("hipRemoveActivityCallback");
|
||||
KernelNameRef = loader->GetFun<KernelNameRef_t>("hipKernelNameRef");
|
||||
KernelNameRefByPtr_ = loader->GetFun<KernelNameRefByPtr_t>("hipKernelNameRefByPtr");
|
||||
GetStreamDeviceId = loader->GetFun<GetStreamDeviceId_t>("hipGetStreamDeviceId");
|
||||
ApiName = loader->GetFun<ApiName_t>("hipApiName");
|
||||
|
||||
RegisterAsyncActivityCallback =
|
||||
loader->GetFun<hipRegisterAsyncActivityCallback_t>("hipRegisterAsyncActivityCallback");
|
||||
RemoveAsyncActivityCallback =
|
||||
loader->GetFun<hipRemoveAsyncActivityCallback_t>("hipRemoveAsyncActivityCallback");
|
||||
GetOpName = loader->GetFun<hipGetOpName_t>("hipGetCmdName");
|
||||
GetStreamDeviceId = loader->GetFun<hipGetStreamDeviceId_t>("hipGetStreamDeviceId");
|
||||
KernelNameRef = loader->GetFun<hipKernelNameRef_t>("hipKernelNameRef");
|
||||
KernelNameRefByPtr_ = loader->GetFun<hipKernelNameRefByPtr_t>("hipKernelNameRefByPtr");
|
||||
GetOpName = loader->GetFun<hipGetCmdName_t>("hipGetCmdName");
|
||||
ApiName = loader->GetFun<hipApiName_t>("hipApiName");
|
||||
RegisterTracerCallback =
|
||||
loader->GetFun<hipRegisterTracerCallback_t>("hipRegisterTracerCallback");
|
||||
}
|
||||
|
||||
private:
|
||||
hipKernelNameRefByPtr_t* KernelNameRefByPtr_;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -260,16 +230,14 @@ class RocTxApi {
|
||||
public:
|
||||
typedef BaseLoader<RocTxApi> Loader;
|
||||
|
||||
typedef bool(RegisterApiCallback_t)(uint32_t op, void* callback, void* arg);
|
||||
typedef bool(RemoveApiCallback_t)(uint32_t op);
|
||||
|
||||
RegisterApiCallback_t* RegisterApiCallback;
|
||||
RemoveApiCallback_t* RemoveApiCallback;
|
||||
typedef void(roctxRegisterTracerCallback_t)(int (*function)(activity_domain_t domain,
|
||||
uint32_t operation_id, void* data));
|
||||
roctxRegisterTracerCallback_t* RegisterTracerCallback;
|
||||
|
||||
protected:
|
||||
void init(Loader* loader) {
|
||||
RegisterApiCallback = loader->GetFun<RegisterApiCallback_t>("RegisterApiCallback");
|
||||
RemoveApiCallback = loader->GetFun<RemoveApiCallback_t>("RemoveApiCallback");
|
||||
RegisterTracerCallback =
|
||||
loader->GetFun<roctxRegisterTracerCallback_t>("roctxRegisterTracerCallback");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -278,8 +246,7 @@ typedef BaseLoader<RocTxApi> RocTxLoader;
|
||||
#if STATIC_BUILD
|
||||
typedef HipLoaderStatic HipLoader;
|
||||
#else
|
||||
typedef BaseLoader<HipApi> HipLoaderShared;
|
||||
typedef HipLoaderShared HipLoader;
|
||||
using HipLoader = BaseLoader<HipApi>;
|
||||
#endif
|
||||
|
||||
} // namespace roctracer
|
||||
@@ -299,7 +266,7 @@ typedef HipLoaderShared HipLoader;
|
||||
roctracer::HipLoaderStatic::instance_t roctracer::HipLoaderStatic::instance_{};
|
||||
#else
|
||||
#define LOADER_INSTANTIATE_HIP() \
|
||||
template <> const char* roctracer::HipLoaderShared::lib_name_ = "libamdhip64.so";
|
||||
template <> const char* roctracer::HipLoader::lib_name_ = "libamdhip64.so";
|
||||
#endif
|
||||
|
||||
#define LOADER_INSTANTIATE() \
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE. */
|
||||
|
||||
#ifndef UTIL_CALLBACK_TABLE_H_
|
||||
#define UTIL_CALLBACK_TABLE_H_
|
||||
|
||||
#include "ext/prof_protocol.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <shared_mutex>
|
||||
#include <utility>
|
||||
|
||||
namespace roctracer::util {
|
||||
|
||||
namespace detail {
|
||||
struct False {
|
||||
constexpr bool operator()() { return false; }
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// Generic callbacks table
|
||||
template <typename T, uint32_t N, typename IsStopped = detail::False> class RegistrationTable {
|
||||
public:
|
||||
template <typename... Args> void Register(uint32_t operation_id, Args... args) {
|
||||
assert(operation_id < N && "operation_id is out of range");
|
||||
auto& entry = table_[operation_id];
|
||||
std::unique_lock lock(entry.mutex);
|
||||
if (!entry.enabled.exchange(true, std::memory_order_relaxed))
|
||||
registered_count_.fetch_add(1, std::memory_order_relaxed);
|
||||
entry.data = T{std::forward<Args>(args)...};
|
||||
}
|
||||
|
||||
void Unregister(uint32_t operation_id) {
|
||||
assert(operation_id < N && "id is out of range");
|
||||
auto& entry = table_[operation_id];
|
||||
std::unique_lock lock(entry.mutex);
|
||||
if (entry.enabled.exchange(false, std::memory_order_relaxed))
|
||||
registered_count_.fetch_sub(1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
std::optional<T> Get(uint32_t operation_id) const {
|
||||
assert(operation_id < N && "id is out of range");
|
||||
auto& entry = table_[operation_id];
|
||||
if (!entry.enabled.load(std::memory_order_relaxed) || IsStopped{}()) return std::nullopt;
|
||||
std::shared_lock lock(entry.mutex);
|
||||
return entry.enabled.load(std::memory_order_relaxed) ? std::make_optional(entry.data)
|
||||
: std::nullopt;
|
||||
}
|
||||
|
||||
bool IsEmpty() const { return registered_count_.load(std::memory_order_relaxed) == 0; }
|
||||
|
||||
private:
|
||||
std::atomic<size_t> registered_count_{0};
|
||||
struct {
|
||||
std::atomic<bool> enabled{false};
|
||||
mutable std::shared_mutex mutex;
|
||||
T data;
|
||||
} table_[N]{};
|
||||
};
|
||||
|
||||
|
||||
} // namespace roctracer::util
|
||||
|
||||
#endif // UTIL_CALLBACK_TABLE_H_
|
||||
@@ -34,17 +34,18 @@
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <stack>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "correlation_id.h"
|
||||
#include "debug.h"
|
||||
#include "journal.h"
|
||||
#include "loader.h"
|
||||
#include "hsa_support.h"
|
||||
#include "memory_pool.h"
|
||||
#include "exception.h"
|
||||
#include "hsa_support.h"
|
||||
#include "loader.h"
|
||||
#include "logger.h"
|
||||
#include "memory_pool.h"
|
||||
#include "registration_table.h"
|
||||
|
||||
#define API_METHOD_PREFIX \
|
||||
roctracer_status_t err = ROCTRACER_STATUS_SUCCESS; \
|
||||
@@ -74,126 +75,38 @@ static inline uint32_t GetTid() {
|
||||
return tid;
|
||||
}
|
||||
|
||||
using namespace roctracer;
|
||||
|
||||
namespace {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Internal library methods
|
||||
//
|
||||
namespace roctracer {
|
||||
|
||||
namespace ext_support {
|
||||
roctracer_start_cb_t roctracer_start_cb = nullptr;
|
||||
roctracer_stop_cb_t roctracer_stop_cb = nullptr;
|
||||
} // namespace ext_support
|
||||
|
||||
struct CallbackJournalData {
|
||||
roctracer_rtapi_callback_t callback;
|
||||
void* user_data;
|
||||
};
|
||||
static Journal<CallbackJournalData> cb_journal;
|
||||
|
||||
struct ActivityJournalData {
|
||||
roctracer_pool_t* pool;
|
||||
};
|
||||
static Journal<ActivityJournalData> act_journal;
|
||||
|
||||
roctracer_status_t GetExcStatus(const std::exception& e) {
|
||||
const ApiError* roctracer_exc_ptr = dynamic_cast<const ApiError*>(&e);
|
||||
return (roctracer_exc_ptr) ? roctracer_exc_ptr->status() : ROCTRACER_STATUS_ERROR;
|
||||
}
|
||||
|
||||
std::mutex hip_activity_mutex;
|
||||
std::mutex registration_mutex;
|
||||
|
||||
enum { API_CB_MASK = 0x1, API_ACT_MASK = 0x2 };
|
||||
// Memory pool routines and primitives
|
||||
std::recursive_mutex memory_pool_mutex;
|
||||
MemoryPool* default_memory_pool = nullptr;
|
||||
|
||||
class HIPActivityCallbackTracker {
|
||||
public:
|
||||
uint32_t enable_check(uint32_t op, uint32_t mask) { return data_[op] |= mask; }
|
||||
uint32_t disable_check(uint32_t op, uint32_t mask) { return data_[op] &= ~mask; }
|
||||
} // namespace
|
||||
|
||||
private:
|
||||
std::unordered_map<uint32_t, uint32_t> data_;
|
||||
};
|
||||
|
||||
static HIPActivityCallbackTracker hip_act_cb_tracker;
|
||||
|
||||
void HIP_ApiCallback(uint32_t op_id, roctracer_record_t* record, void* callback_data, void* arg) {
|
||||
hip_api_data_t* data = static_cast<hip_api_data_t*>(callback_data);
|
||||
MemoryPool* pool = static_cast<MemoryPool*>(arg);
|
||||
|
||||
if (data->phase == ACTIVITY_API_PHASE_ENTER) {
|
||||
// Generate a new correlation ID.
|
||||
uint64_t correlation_id = CorrelationIdPush();
|
||||
data->correlation_id = correlation_id;
|
||||
|
||||
if (pool != nullptr) {
|
||||
// Filing record info
|
||||
record->domain = ACTIVITY_DOMAIN_HIP_API;
|
||||
record->kind = 0;
|
||||
record->op = op_id;
|
||||
record->process_id = GetPid();
|
||||
record->thread_id = GetTid();
|
||||
record->begin_ns = hsa_support::timestamp_ns();
|
||||
record->correlation_id = correlation_id;
|
||||
}
|
||||
} else {
|
||||
if (pool != nullptr) {
|
||||
record->end_ns = hsa_support::timestamp_ns();
|
||||
|
||||
if (auto external_id = ExternalCorrelationId()) {
|
||||
roctracer_record_t ext_record{};
|
||||
ext_record.domain = ACTIVITY_DOMAIN_EXT_API;
|
||||
ext_record.op = ACTIVITY_EXT_OP_EXTERN_ID;
|
||||
ext_record.correlation_id = record->correlation_id;
|
||||
ext_record.external_id = *external_id;
|
||||
// Write the external correlation id record directly followed by the activity record.
|
||||
pool->Write(std::array<roctracer_record_t, 2>{ext_record, *record});
|
||||
} else {
|
||||
// Write record to the buffer.
|
||||
pool->Write(*record);
|
||||
}
|
||||
}
|
||||
CorrelationIdPop();
|
||||
}
|
||||
}
|
||||
|
||||
void HIP_AsyncActivityCallback(uint32_t op_id, void* record_ptr, void* arg) {
|
||||
MemoryPool* pool = reinterpret_cast<MemoryPool*>(arg);
|
||||
roctracer_record_t& record = *reinterpret_cast<roctracer_record_t*>(record_ptr);
|
||||
record.domain = ACTIVITY_DOMAIN_HIP_OPS;
|
||||
|
||||
// If the record is for a kernel dispatch, write the kernel name in the pool's data,
|
||||
// and make the record point to it. Older HIP runtimes do not provide a kernel
|
||||
// name, so record.kernel_name might be null.
|
||||
if (record.op == HIP_OP_ID_DISPATCH && record.kernel_name != nullptr)
|
||||
pool->Write(record, record.kernel_name, strlen(record.kernel_name) + 1,
|
||||
[](auto& record, const void* data) {
|
||||
record.kernel_name = static_cast<const char*>(data);
|
||||
});
|
||||
else
|
||||
pool->Write(record);
|
||||
}
|
||||
namespace roctracer {
|
||||
|
||||
// Logger routines and primitives
|
||||
util::Logger::mutex_t util::Logger::mutex_;
|
||||
std::atomic<util::Logger*> util::Logger::instance_{};
|
||||
|
||||
// Memory pool routines and primitives
|
||||
MemoryPool* default_memory_pool = nullptr;
|
||||
std::recursive_mutex memory_pool_mutex;
|
||||
|
||||
// Stop status routines and primitives
|
||||
unsigned stop_status_value = 0;
|
||||
std::mutex stop_status_mutex;
|
||||
unsigned set_stopped(unsigned val) {
|
||||
std::lock_guard lock(stop_status_mutex);
|
||||
const unsigned ret = (stop_status_value ^ val);
|
||||
stop_status_value = val;
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace roctracer
|
||||
|
||||
using namespace roctracer;
|
||||
|
||||
LOADER_INSTANTIATE();
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -227,7 +140,7 @@ ROCTRACER_API const char* roctracer_op_string(uint32_t domain, uint32_t op, uint
|
||||
case ACTIVITY_DOMAIN_EXT_API:
|
||||
return "EXT_API";
|
||||
default:
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
|
||||
throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID");
|
||||
}
|
||||
API_METHOD_CATCH(nullptr)
|
||||
}
|
||||
@@ -261,95 +174,350 @@ ROCTRACER_API roctracer_status_t roctracer_op_code(uint32_t domain, const char*
|
||||
API_METHOD_SUFFIX
|
||||
}
|
||||
|
||||
static inline uint32_t get_op_begin(uint32_t domain) {
|
||||
namespace {
|
||||
|
||||
template <activity_domain_t> struct DomainTraits;
|
||||
|
||||
template <> struct DomainTraits<ACTIVITY_DOMAIN_HIP_API> {
|
||||
using ApiData = hip_api_data_t;
|
||||
using OperationId = hip_api_id_t;
|
||||
static constexpr size_t kOpIdBegin = HIP_API_ID_FIRST;
|
||||
static constexpr size_t kOpIdEnd = HIP_API_ID_LAST + 1;
|
||||
};
|
||||
|
||||
template <> struct DomainTraits<ACTIVITY_DOMAIN_HSA_API> {
|
||||
using ApiData = hsa_api_data_t;
|
||||
using OperationId = hsa_api_id_t;
|
||||
static constexpr size_t kOpIdBegin = 0;
|
||||
static constexpr size_t kOpIdEnd = HSA_API_ID_NUMBER;
|
||||
};
|
||||
|
||||
template <> struct DomainTraits<ACTIVITY_DOMAIN_ROCTX> {
|
||||
using ApiData = roctx_api_data_t;
|
||||
using OperationId = roctx_api_id_t;
|
||||
static constexpr size_t kOpIdBegin = 0;
|
||||
static constexpr size_t kOpIdEnd = ROCTX_API_ID_NUMBER;
|
||||
};
|
||||
|
||||
template <> struct DomainTraits<ACTIVITY_DOMAIN_HIP_OPS> {
|
||||
using OperationId = hip_op_id_t;
|
||||
static constexpr size_t kOpIdBegin = 0;
|
||||
static constexpr size_t kOpIdEnd = HIP_OP_ID_NUMBER;
|
||||
};
|
||||
|
||||
template <> struct DomainTraits<ACTIVITY_DOMAIN_HSA_OPS> {
|
||||
using OperationId = hsa_op_id_t;
|
||||
static constexpr size_t kOpIdBegin = 0;
|
||||
static constexpr size_t kOpIdEnd = HSA_OP_ID_NUMBER;
|
||||
};
|
||||
|
||||
template <> struct DomainTraits<ACTIVITY_DOMAIN_HSA_EVT> {
|
||||
using ApiData = hsa_evt_data_t;
|
||||
using OperationId = hsa_evt_id_t;
|
||||
static constexpr size_t kOpIdBegin = 0;
|
||||
static constexpr size_t kOpIdEnd = HSA_EVT_ID_NUMBER;
|
||||
};
|
||||
|
||||
constexpr uint32_t get_op_begin(activity_domain_t domain) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
return 0;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HSA_OPS>::kOpIdBegin;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
return 0;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HSA_API>::kOpIdBegin;
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
return 0;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HSA_EVT>::kOpIdBegin;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
return 0;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HIP_OPS>::kOpIdBegin;
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
return HIP_API_ID_FIRST;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HIP_API>::kOpIdBegin;
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
return DomainTraits<ACTIVITY_DOMAIN_ROCTX>::kOpIdBegin;
|
||||
case ACTIVITY_DOMAIN_EXT_API:
|
||||
return 0;
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
return 0;
|
||||
default:
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
|
||||
throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static inline uint32_t get_op_end(uint32_t domain) {
|
||||
constexpr uint32_t get_op_end(activity_domain_t domain) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
return HSA_OP_ID_NUMBER;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HSA_OPS>::kOpIdEnd;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
return HSA_API_ID_NUMBER;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HSA_API>::kOpIdEnd;
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
return HSA_EVT_ID_NUMBER;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HSA_EVT>::kOpIdEnd;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
return HIP_OP_ID_NUMBER;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HIP_OPS>::kOpIdEnd;
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
return HIP_API_ID_LAST + 1;
|
||||
case ACTIVITY_DOMAIN_EXT_API:
|
||||
return 0;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_HIP_API>::kOpIdEnd;
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
return ROCTX_API_ID_NUMBER;
|
||||
return DomainTraits<ACTIVITY_DOMAIN_ROCTX>::kOpIdEnd;
|
||||
case ACTIVITY_DOMAIN_EXT_API:
|
||||
return get_op_begin(ACTIVITY_DOMAIN_EXT_API);
|
||||
default:
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
|
||||
throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Enable runtime API callbacks
|
||||
static void roctracer_enable_callback_fun(roctracer_domain_t domain, uint32_t op,
|
||||
roctracer_rtapi_callback_t callback, void* user_data) {
|
||||
std::atomic<bool> stopped_status{false};
|
||||
|
||||
struct IsStopped {
|
||||
bool operator()() const { return stopped_status.load(std::memory_order_relaxed); }
|
||||
};
|
||||
|
||||
struct NeverStopped {
|
||||
constexpr bool operator()() { return false; }
|
||||
};
|
||||
|
||||
using UserCallback = std::pair<activity_rtapi_callback_t, void*>;
|
||||
|
||||
template <activity_domain_t domain, typename IsStopped>
|
||||
using CallbackRegistrationTable =
|
||||
util::RegistrationTable<UserCallback, DomainTraits<domain>::kOpIdEnd, IsStopped>;
|
||||
|
||||
template <activity_domain_t domain, typename IsStopped>
|
||||
using ActivityRegistrationTable =
|
||||
util::RegistrationTable<MemoryPool*, DomainTraits<domain>::kOpIdEnd, IsStopped>;
|
||||
|
||||
template <activity_domain_t domain> struct ApiTracer {
|
||||
using ApiData = typename DomainTraits<domain>::ApiData;
|
||||
using OperationId = typename DomainTraits<domain>::OperationId;
|
||||
|
||||
struct TraceData {
|
||||
ApiData api_data; // API specific data (for example, function arguments).
|
||||
uint64_t phase_enter_timestamp; // timestamp when phase_enter was executed.
|
||||
uint64_t phase_data; // data that can be shared between phase_enter and phase_exit.
|
||||
|
||||
void (*phase_enter)(OperationId operation_id, TraceData* data);
|
||||
void (*phase_exit)(OperationId operation_id, TraceData* data);
|
||||
};
|
||||
|
||||
static void Exit(OperationId operation_id, TraceData* trace_data) {
|
||||
if (auto pool = activity_table.Get(operation_id)) {
|
||||
assert(trace_data != nullptr);
|
||||
activity_record_t record{};
|
||||
|
||||
record.domain = domain;
|
||||
record.op = operation_id;
|
||||
record.correlation_id = trace_data->api_data.correlation_id;
|
||||
record.begin_ns = trace_data->phase_enter_timestamp;
|
||||
record.end_ns = hsa_support::timestamp_ns();
|
||||
record.process_id = GetPid();
|
||||
record.thread_id = GetTid();
|
||||
|
||||
if (auto external_id = ExternalCorrelationId()) {
|
||||
roctracer_record_t ext_record{};
|
||||
ext_record.domain = ACTIVITY_DOMAIN_EXT_API;
|
||||
ext_record.op = ACTIVITY_EXT_OP_EXTERN_ID;
|
||||
ext_record.correlation_id = record.correlation_id;
|
||||
ext_record.external_id = *external_id;
|
||||
// Write the external correlation id record directly followed by the activity record.
|
||||
(*pool)->Write(std::array<roctracer_record_t, 2>{ext_record, record});
|
||||
} else {
|
||||
// Write record to the buffer.
|
||||
(*pool)->Write(record);
|
||||
}
|
||||
}
|
||||
CorrelationIdPop();
|
||||
}
|
||||
|
||||
static void Exit_UserCallback(OperationId operation_id, TraceData* trace_data) {
|
||||
if (auto user_callback = callback_table.Get(operation_id)) {
|
||||
assert(trace_data != nullptr);
|
||||
trace_data->api_data.phase = ACTIVITY_API_PHASE_EXIT;
|
||||
user_callback->first(domain, operation_id, &trace_data->api_data, user_callback->second);
|
||||
}
|
||||
Exit(operation_id, trace_data);
|
||||
}
|
||||
|
||||
static void Enter_UserCallback(OperationId operation_id, TraceData* trace_data) {
|
||||
if (auto user_callback = callback_table.Get(operation_id)) {
|
||||
assert(trace_data != nullptr);
|
||||
trace_data->api_data.phase = ACTIVITY_API_PHASE_ENTER;
|
||||
user_callback->first(domain, operation_id, &trace_data->api_data, user_callback->second);
|
||||
trace_data->phase_exit = Exit_UserCallback;
|
||||
} else {
|
||||
trace_data->phase_exit = Exit;
|
||||
}
|
||||
}
|
||||
|
||||
static int Enter(OperationId operation_id, TraceData* trace_data) {
|
||||
bool callback_enabled = callback_table.Get(operation_id).has_value(),
|
||||
activity_enabled = activity_table.Get(operation_id).has_value();
|
||||
if (!callback_enabled && !activity_enabled) return -1;
|
||||
|
||||
if (trace_data != nullptr) {
|
||||
// Generate a new correlation ID.
|
||||
trace_data->api_data.correlation_id = CorrelationIdPush();
|
||||
|
||||
if (activity_enabled) {
|
||||
trace_data->phase_enter_timestamp = hsa_support::timestamp_ns();
|
||||
trace_data->phase_enter = nullptr;
|
||||
trace_data->phase_exit = Exit;
|
||||
}
|
||||
if (callback_enabled) {
|
||||
trace_data->phase_enter = Enter_UserCallback;
|
||||
trace_data->phase_exit = [](OperationId, TraceData*) { fatal("should not reach here"); };
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static CallbackRegistrationTable<domain, IsStopped> callback_table;
|
||||
static ActivityRegistrationTable<domain, IsStopped> activity_table;
|
||||
};
|
||||
|
||||
template <activity_domain_t domain>
|
||||
CallbackRegistrationTable<domain, IsStopped> ApiTracer<domain>::callback_table;
|
||||
|
||||
template <activity_domain_t domain>
|
||||
ActivityRegistrationTable<domain, IsStopped> ApiTracer<domain>::activity_table;
|
||||
|
||||
using HIP_ApiTracer = ApiTracer<ACTIVITY_DOMAIN_HIP_API>;
|
||||
using HSA_ApiTracer = ApiTracer<ACTIVITY_DOMAIN_HSA_API>;
|
||||
|
||||
CallbackRegistrationTable<ACTIVITY_DOMAIN_ROCTX, NeverStopped> roctx_api_callback_table;
|
||||
ActivityRegistrationTable<ACTIVITY_DOMAIN_HIP_OPS, IsStopped> hip_ops_activity_table;
|
||||
ActivityRegistrationTable<ACTIVITY_DOMAIN_HSA_OPS, IsStopped> hsa_ops_activity_table;
|
||||
CallbackRegistrationTable<ACTIVITY_DOMAIN_HSA_EVT, IsStopped> hsa_evt_callback_table;
|
||||
|
||||
int TracerCallback(activity_domain_t domain, uint32_t operation_id, void* data) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
hsa_support::EnableCallback(domain, op, callback, user_data);
|
||||
break;
|
||||
return HSA_ApiTracer::Enter(static_cast<HSA_ApiTracer::OperationId>(operation_id),
|
||||
static_cast<HSA_ApiTracer::TraceData*>(data));
|
||||
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
return HIP_ApiTracer::Enter(static_cast<HIP_ApiTracer::OperationId>(operation_id),
|
||||
static_cast<HIP_ApiTracer::TraceData*>(data));
|
||||
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_API: {
|
||||
if (!HipLoader::Instance().Enabled()) break;
|
||||
std::lock_guard lock(hip_activity_mutex);
|
||||
|
||||
if (hipError_t err =
|
||||
HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data);
|
||||
err != hipSuccess)
|
||||
fatal("HIP::RegisterApiCallback(%d) failed (err=%d)", op, err);
|
||||
|
||||
if ((hip_act_cb_tracker.enable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
|
||||
if (hipError_t err =
|
||||
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
|
||||
err != hipSuccess)
|
||||
fatal("HIP::RegisterActivityCallback(%d) failed (err=%d)", op, err);
|
||||
if (auto pool = hip_ops_activity_table.Get(operation_id)) {
|
||||
if (auto record = static_cast<activity_record_t*>(data)) {
|
||||
// If the record is for a kernel dispatch, write the kernel name in the pool's data,
|
||||
// and make the record point to it. Older HIP runtimes do not provide a kernel
|
||||
// name, so record.kernel_name might be null.
|
||||
if (operation_id == HIP_OP_ID_DISPATCH && record->kernel_name != nullptr)
|
||||
(*pool)->Write(*record, record->kernel_name, strlen(record->kernel_name) + 1,
|
||||
[](auto& record, const void* data) {
|
||||
record.kernel_name = static_cast<const char*>(data);
|
||||
});
|
||||
else
|
||||
(*pool)->Write(*record);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ACTIVITY_DOMAIN_ROCTX: {
|
||||
if (RocTxLoader::Instance().Enabled() &&
|
||||
!RocTxLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data))
|
||||
fatal("ROCTX::RegisterApiCallback(%d) failed", op);
|
||||
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
if (auto user_callback = roctx_api_callback_table.Get(operation_id)) {
|
||||
if (auto api_data = static_cast<DomainTraits<ACTIVITY_DOMAIN_ROCTX>::ApiData*>(data))
|
||||
user_callback->first(ACTIVITY_DOMAIN_ROCTX, operation_id, api_data,
|
||||
user_callback->second);
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
if (auto pool = hsa_ops_activity_table.Get(operation_id)) {
|
||||
if (auto record = static_cast<activity_record_t*>(data)) (*pool)->Write(*record);
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
if (auto user_callback = hsa_evt_callback_table.Get(operation_id)) {
|
||||
if (auto api_data = static_cast<DomainTraits<ACTIVITY_DOMAIN_HSA_EVT>::ApiData*>(data))
|
||||
user_callback->first(ACTIVITY_DOMAIN_HSA_EVT, operation_id, api_data,
|
||||
user_callback->second);
|
||||
return 0;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <typename... Tables> struct RegistrationTableGroup {
|
||||
private:
|
||||
bool AllEmpty() const {
|
||||
return std::apply([](auto&&... tables) { return (tables.IsEmpty() && ...); }, tables_);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename Functor1, typename Functor2>
|
||||
RegistrationTableGroup(Functor1&& engage_tracer, Functor2&& disengage_tracer, Tables&... tables)
|
||||
: engage_tracer_(std::forward<Functor1>(engage_tracer)),
|
||||
disengage_tracer_(std::forward<Functor2>(disengage_tracer)),
|
||||
tables_(tables...) {}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void Register(T& table, uint32_t operation_id, Args... args) const {
|
||||
if (AllEmpty()) engage_tracer_();
|
||||
table.Register(operation_id, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename T> void Unregister(T& table, uint32_t operation_id) const {
|
||||
table.Unregister(operation_id);
|
||||
if (AllEmpty()) disengage_tracer_();
|
||||
}
|
||||
|
||||
private:
|
||||
const std::function<void()> engage_tracer_, disengage_tracer_;
|
||||
const std::tuple<const Tables&...> tables_;
|
||||
};
|
||||
|
||||
RegistrationTableGroup HSA_registration_group(
|
||||
[]() { hsa_support::RegisterTracerCallback(TracerCallback); },
|
||||
[]() { hsa_support::RegisterTracerCallback(nullptr); }, HSA_ApiTracer::callback_table,
|
||||
HSA_ApiTracer::activity_table, hsa_ops_activity_table, hsa_evt_callback_table);
|
||||
|
||||
RegistrationTableGroup HIP_registration_group(
|
||||
[]() { HipLoader::Instance().RegisterTracerCallback(TracerCallback); },
|
||||
[]() { HipLoader::Instance().RegisterTracerCallback(nullptr); }, HIP_ApiTracer::callback_table,
|
||||
HIP_ApiTracer::activity_table, hip_ops_activity_table);
|
||||
|
||||
RegistrationTableGroup ROCTX_registration_group(
|
||||
[]() { RocTxLoader::Instance().RegisterTracerCallback(TracerCallback); },
|
||||
[]() { RocTxLoader::Instance().RegisterTracerCallback(nullptr); }, roctx_api_callback_table);
|
||||
|
||||
} // namespace
|
||||
|
||||
// Enable runtime API callbacks
|
||||
static void roctracer_enable_callback_impl(roctracer_domain_t domain, uint32_t operation_id,
|
||||
roctracer_rtapi_callback_t callback, void* user_data) {
|
||||
std::lock_guard lock(registration_mutex);
|
||||
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
HSA_registration_group.Register(hsa_evt_callback_table, operation_id, callback, user_data);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
HSA_registration_group.Register(HSA_ApiTracer::callback_table, operation_id, callback,
|
||||
user_data);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
if (HipLoader::Instance().Enabled())
|
||||
HIP_registration_group.Register(HIP_ApiTracer::callback_table, operation_id, callback,
|
||||
user_data);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
if (RocTxLoader::Instance().Enabled())
|
||||
ROCTX_registration_group.Register(roctx_api_callback_table, operation_id, callback,
|
||||
user_data);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
|
||||
}
|
||||
}
|
||||
|
||||
static void roctracer_enable_callback_impl(roctracer_domain_t domain, uint32_t op,
|
||||
roctracer_rtapi_callback_t callback, void* user_data) {
|
||||
cb_journal.Insert(domain, op, {callback, user_data});
|
||||
roctracer_enable_callback_fun(domain, op, callback, user_data);
|
||||
}
|
||||
|
||||
ROCTRACER_API roctracer_status_t roctracer_enable_op_callback(roctracer_domain_t domain,
|
||||
uint32_t op,
|
||||
roctracer_rtapi_callback_t callback,
|
||||
@@ -369,43 +537,33 @@ ROCTRACER_API roctracer_status_t roctracer_enable_domain_callback(
|
||||
}
|
||||
|
||||
// Disable runtime API callbacks
|
||||
static void roctracer_disable_callback_fun(roctracer_domain_t domain, uint32_t op) {
|
||||
static void roctracer_disable_callback_impl(roctracer_domain_t domain, uint32_t operation_id) {
|
||||
std::lock_guard lock(registration_mutex);
|
||||
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
hsa_support::DisableCallback(domain, op);
|
||||
HSA_registration_group.Unregister(hsa_evt_callback_table, operation_id);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
HSA_registration_group.Unregister(HSA_ApiTracer::callback_table, operation_id);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
if (HipLoader::Instance().Enabled())
|
||||
HIP_registration_group.Unregister(HIP_ApiTracer::callback_table, operation_id);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_API: {
|
||||
if (!HipLoader::Instance().Enabled()) break;
|
||||
std::lock_guard lock(hip_activity_mutex);
|
||||
|
||||
if (hipError_t err = HipLoader::Instance().RemoveApiCallback(op); err != hipSuccess)
|
||||
fatal("HIP::RemoveApiCallback(%d) failed (err=%d)", op, err);
|
||||
|
||||
if ((hip_act_cb_tracker.disable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
|
||||
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
|
||||
fatal("HIP::RemoveActivityCallback(%d) failed (err=%d)", op, err);
|
||||
}
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
if (RocTxLoader::Instance().Enabled())
|
||||
ROCTX_registration_group.Unregister(roctx_api_callback_table, operation_id);
|
||||
break;
|
||||
}
|
||||
case ACTIVITY_DOMAIN_ROCTX: {
|
||||
if (RocTxLoader::Instance().Enabled() && !RocTxLoader::Instance().RemoveApiCallback(op))
|
||||
fatal("ROCTX::RemoveApiCallback(%d) failed", op);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
|
||||
}
|
||||
}
|
||||
|
||||
static void roctracer_disable_callback_impl(roctracer_domain_t domain, uint32_t op) {
|
||||
cb_journal.Remove(domain, op);
|
||||
roctracer_disable_callback_fun(domain, op);
|
||||
}
|
||||
|
||||
ROCTRACER_API roctracer_status_t roctracer_disable_op_callback(roctracer_domain_t domain,
|
||||
uint32_t op) {
|
||||
API_METHOD_PREFIX
|
||||
@@ -470,33 +628,32 @@ ROCTRACER_API roctracer_status_t roctracer_next_record(const activity_record_t*
|
||||
}
|
||||
|
||||
// Enable activity records logging
|
||||
static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op,
|
||||
roctracer_pool_t* pool) {
|
||||
assert(pool != nullptr);
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
hsa_support::EnableActivity(domain, op, pool);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS: {
|
||||
if (HipLoader::Instance().Enabled() &&
|
||||
HipLoader::Instance().RegisterAsyncActivityCallback(op, (void*)HIP_AsyncActivityCallback,
|
||||
pool) != hipSuccess)
|
||||
fatal("HIP::EnableActivityCallback error");
|
||||
break;
|
||||
}
|
||||
case ACTIVITY_DOMAIN_HIP_API: {
|
||||
if (!HipLoader::Instance().Enabled()) break;
|
||||
std::lock_guard lock(hip_activity_mutex);
|
||||
static void roctracer_enable_activity_impl(roctracer_domain_t domain, uint32_t op,
|
||||
roctracer_pool_t* pool) {
|
||||
std::lock_guard lock(registration_mutex);
|
||||
|
||||
hip_act_cb_tracker.enable_check(op, API_ACT_MASK);
|
||||
if (hipError_t err =
|
||||
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool);
|
||||
err != hipSuccess)
|
||||
fatal("HIP::RegisterActivityCallback(%d) (err=%d)", op, err);
|
||||
MemoryPool* memory_pool = reinterpret_cast<MemoryPool*>(pool);
|
||||
if (memory_pool == nullptr) memory_pool = default_memory_pool;
|
||||
if (memory_pool == nullptr)
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED, "no default pool");
|
||||
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
HSA_registration_group.Register(HSA_ApiTracer::activity_table, op, memory_pool);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
HSA_registration_group.Register(hsa_ops_activity_table, op, memory_pool);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
if (HipLoader::Instance().Enabled())
|
||||
HIP_registration_group.Register(HIP_ApiTracer::activity_table, op, memory_pool);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
if (HipLoader::Instance().Enabled())
|
||||
HIP_registration_group.Register(hip_ops_activity_table, op, memory_pool);
|
||||
break;
|
||||
}
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
break;
|
||||
default:
|
||||
@@ -504,15 +661,6 @@ static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op
|
||||
}
|
||||
}
|
||||
|
||||
static void roctracer_enable_activity_impl(roctracer_domain_t domain, uint32_t op,
|
||||
roctracer_pool_t* pool) {
|
||||
if (pool == nullptr) pool = default_memory_pool;
|
||||
if (pool == nullptr)
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED, "no default pool");
|
||||
act_journal.Insert(domain, op, {pool});
|
||||
roctracer_enable_activity_fun(domain, op, pool);
|
||||
}
|
||||
|
||||
ROCTRACER_API roctracer_status_t roctracer_enable_op_activity_expl(roctracer_domain_t domain,
|
||||
uint32_t op,
|
||||
roctracer_pool_t* pool) {
|
||||
@@ -552,34 +700,26 @@ ROCTRACER_API roctracer_status_t roctracer_enable_domain_activity(activity_domai
|
||||
}
|
||||
|
||||
// Disable activity records logging
|
||||
static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t op) {
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
hsa_support::DisableActivity(domain, op);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS: {
|
||||
if (HipLoader::Instance().Enabled() &&
|
||||
HipLoader::Instance().RemoveAsyncActivityCallback(op) != hipSuccess)
|
||||
fatal("HIP::EnableActivityCallback(%d)", op);
|
||||
break;
|
||||
}
|
||||
case ACTIVITY_DOMAIN_HIP_API: {
|
||||
if (!HipLoader::Instance().Enabled()) break;
|
||||
std::lock_guard lock(hip_activity_mutex);
|
||||
static void roctracer_disable_activity_impl(roctracer_domain_t domain, uint32_t op) {
|
||||
std::lock_guard lock(registration_mutex);
|
||||
|
||||
if ((hip_act_cb_tracker.disable_check(op, API_ACT_MASK) & API_CB_MASK) == 0) {
|
||||
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
|
||||
fatal("HIP::RemoveActivityCallback(%d) failed (err=%d)", op, err);
|
||||
} else {
|
||||
if (hipError_t err =
|
||||
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
|
||||
err != hipSuccess)
|
||||
fatal("HIP::RegisterActivityCallback(%d) failed (err=%d)", op, err);
|
||||
}
|
||||
switch (domain) {
|
||||
case ACTIVITY_DOMAIN_HSA_EVT:
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_API:
|
||||
HSA_registration_group.Unregister(HSA_ApiTracer::activity_table, op);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HSA_OPS:
|
||||
HSA_registration_group.Unregister(hsa_ops_activity_table, op);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_API:
|
||||
if (HipLoader::Instance().Enabled())
|
||||
HIP_registration_group.Unregister(HIP_ApiTracer::activity_table, op);
|
||||
break;
|
||||
case ACTIVITY_DOMAIN_HIP_OPS:
|
||||
if (HipLoader::Instance().Enabled())
|
||||
HIP_registration_group.Unregister(hip_ops_activity_table, op);
|
||||
break;
|
||||
}
|
||||
case ACTIVITY_DOMAIN_ROCTX:
|
||||
break;
|
||||
default:
|
||||
@@ -587,11 +727,6 @@ static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t o
|
||||
}
|
||||
}
|
||||
|
||||
static void roctracer_disable_activity_impl(roctracer_domain_t domain, uint32_t op) {
|
||||
act_journal.Remove(domain, op);
|
||||
roctracer_disable_activity_fun(domain, op);
|
||||
}
|
||||
|
||||
ROCTRACER_API roctracer_status_t roctracer_disable_op_activity(roctracer_domain_t domain,
|
||||
uint32_t op) {
|
||||
API_METHOD_PREFIX
|
||||
@@ -622,6 +757,7 @@ static void roctracer_close_pool_impl(roctracer_pool_t* pool) {
|
||||
MemoryPool* p = reinterpret_cast<MemoryPool*>(pool);
|
||||
if (p == default_memory_pool) default_memory_pool = nullptr;
|
||||
|
||||
#if 0
|
||||
// Disable any activities that specify the pool being deleted.
|
||||
std::vector<std::pair<roctracer_domain_t, uint32_t>> ops;
|
||||
act_journal.ForEach(
|
||||
@@ -630,6 +766,7 @@ static void roctracer_close_pool_impl(roctracer_pool_t* pool) {
|
||||
return true;
|
||||
});
|
||||
for (auto&& [domain, op] : ops) roctracer_disable_activity_impl(domain, op);
|
||||
#endif
|
||||
|
||||
delete (p);
|
||||
}
|
||||
@@ -693,35 +830,14 @@ roctracer_activity_pop_external_correlation_id(activity_correlation_id_t* last_i
|
||||
|
||||
// Start API
|
||||
ROCTRACER_API void roctracer_start() {
|
||||
if (set_stopped(0)) {
|
||||
if (ext_support::roctracer_start_cb) ext_support::roctracer_start_cb();
|
||||
cb_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const CallbackJournalData& data) {
|
||||
roctracer_enable_callback_fun(domain, op, data.callback, data.user_data);
|
||||
return true;
|
||||
});
|
||||
act_journal.ForEach(
|
||||
[](roctracer_domain_t domain, uint32_t op, const ActivityJournalData& data) {
|
||||
roctracer_enable_activity_fun(domain, op, data.pool);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
if (stopped_status.exchange(false, std::memory_order_relaxed) && roctracer_start_cb)
|
||||
roctracer_start_cb();
|
||||
}
|
||||
|
||||
// Stop API
|
||||
ROCTRACER_API void roctracer_stop() {
|
||||
if (set_stopped(1)) {
|
||||
// Must disable the activity first as the spawner checks for the activity being NULL
|
||||
// to indicate that there is no callback.
|
||||
act_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const ActivityJournalData&) {
|
||||
roctracer_disable_activity_fun(domain, op);
|
||||
return true;
|
||||
});
|
||||
cb_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const CallbackJournalData&) {
|
||||
roctracer_disable_callback_fun(domain, op);
|
||||
return true;
|
||||
});
|
||||
if (ext_support::roctracer_stop_cb) ext_support::roctracer_stop_cb();
|
||||
}
|
||||
if (!stopped_status.exchange(true, std::memory_order_relaxed) && roctracer_stop_cb)
|
||||
roctracer_stop_cb();
|
||||
}
|
||||
|
||||
ROCTRACER_API roctracer_status_t roctracer_get_timestamp(roctracer_timestamp_t* timestamp) {
|
||||
@@ -745,8 +861,8 @@ ROCTRACER_API roctracer_status_t roctracer_set_properties(roctracer_domain_t dom
|
||||
case ACTIVITY_DOMAIN_EXT_API: {
|
||||
roctracer_ext_properties_t* ops_properties =
|
||||
reinterpret_cast<roctracer_ext_properties_t*>(properties);
|
||||
ext_support::roctracer_start_cb = ops_properties->start_cb;
|
||||
ext_support::roctracer_stop_cb = ops_properties->stop_cb;
|
||||
roctracer_start_cb = ops_properties->start_cb;
|
||||
roctracer_stop_cb = ops_properties->stop_cb;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
ROCTX_4.1 {
|
||||
global: RegisterApiCallback;
|
||||
RemoveApiCallback;
|
||||
roctxMarkA;
|
||||
global: roctxMarkA;
|
||||
roctxRangePop;
|
||||
roctxRangePushA;
|
||||
roctxRangeStartA;
|
||||
roctxRangeStop;
|
||||
roctxRegisterTracerCallback;
|
||||
roctx_version_major;
|
||||
roctx_version_minor;
|
||||
local: *;
|
||||
|
||||
@@ -22,66 +22,72 @@
|
||||
#include "roctracer_roctx.h"
|
||||
#include "ext/prof_protocol.h"
|
||||
|
||||
#include "util/callback_table.h"
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
|
||||
namespace {
|
||||
|
||||
roctracer::util::CallbackTable<ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_NUMBER> callbacks;
|
||||
thread_local int nested_range_level(0);
|
||||
std::atomic<int (*)(activity_domain_t domain, uint32_t operation_id, void* data)> report_activity;
|
||||
thread_local int nested_range_level{0};
|
||||
|
||||
void ReportActivity(roctx_api_id_t operation_id, const char* message = nullptr,
|
||||
roctx_range_id_t id = {}) {
|
||||
auto function = report_activity.load(std::memory_order_relaxed);
|
||||
if (!function) return;
|
||||
|
||||
roctx_api_data_t api_data{};
|
||||
switch (operation_id) {
|
||||
case ROCTX_API_ID_roctxMarkA:
|
||||
api_data.args.roctxMarkA.message = message;
|
||||
break;
|
||||
case ROCTX_API_ID_roctxRangePushA:
|
||||
api_data.args.roctxRangePushA.message = message;
|
||||
break;
|
||||
case ROCTX_API_ID_roctxRangePop:
|
||||
break;
|
||||
case ROCTX_API_ID_roctxRangeStartA:
|
||||
api_data.args.roctxRangeStartA.message = message;
|
||||
api_data.args.roctxRangeStartA.id = id;
|
||||
break;
|
||||
case ROCTX_API_ID_roctxRangeStop:
|
||||
api_data.args.roctxRangeStop.id = id;
|
||||
break;
|
||||
default:
|
||||
assert(!"should not reach here");
|
||||
}
|
||||
function(ACTIVITY_DOMAIN_ROCTX, operation_id, &api_data);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ROCTX_API uint32_t roctx_version_major() { return ROCTX_VERSION_MAJOR; }
|
||||
ROCTX_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; }
|
||||
|
||||
ROCTX_API void roctxMarkA(const char* message) {
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxMarkA.message = message;
|
||||
callbacks.Invoke(ROCTX_API_ID_roctxMarkA, &api_data);
|
||||
}
|
||||
ROCTX_API void roctxMarkA(const char* message) { ReportActivity(ROCTX_API_ID_roctxMarkA, message); }
|
||||
|
||||
ROCTX_API int roctxRangePushA(const char* message) {
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxRangePushA.message = message;
|
||||
callbacks.Invoke(ROCTX_API_ID_roctxRangePushA, &api_data);
|
||||
|
||||
ReportActivity(ROCTX_API_ID_roctxRangePushA, message);
|
||||
return nested_range_level++;
|
||||
}
|
||||
|
||||
ROCTX_API int roctxRangePop() {
|
||||
roctx_api_data_t api_data{};
|
||||
callbacks.Invoke(ROCTX_API_ID_roctxRangePop, &api_data);
|
||||
|
||||
ReportActivity(ROCTX_API_ID_roctxRangePop);
|
||||
if (nested_range_level == 0) return -1;
|
||||
return --nested_range_level;
|
||||
}
|
||||
|
||||
ROCTX_API roctx_range_id_t roctxRangeStartA(const char* message) {
|
||||
static std::atomic<roctx_range_id_t> start_stop_range_id(1);
|
||||
auto id = start_stop_range_id++;
|
||||
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxRangeStartA.message = message;
|
||||
api_data.args.roctxRangeStartA.id = id;
|
||||
callbacks.Invoke(ROCTX_API_ID_roctxRangeStartA, &api_data);
|
||||
|
||||
return id;
|
||||
auto range_id = start_stop_range_id++;
|
||||
ReportActivity(ROCTX_API_ID_roctxRangeStartA, message, range_id);
|
||||
return range_id;
|
||||
}
|
||||
|
||||
ROCTX_API void roctxRangeStop(roctx_range_id_t rangeId) {
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxRangeStop.id = rangeId;
|
||||
callbacks.Invoke(ROCTX_API_ID_roctxRangeStop, &api_data);
|
||||
ROCTX_API void roctxRangeStop(roctx_range_id_t range_id) {
|
||||
ReportActivity(ROCTX_API_ID_roctxRangeStop, nullptr, range_id);
|
||||
}
|
||||
|
||||
extern "C" ROCTX_EXPORT bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
|
||||
if (op >= ROCTX_API_ID_NUMBER) return false;
|
||||
callbacks.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
|
||||
return true;
|
||||
extern "C" ROCTX_EXPORT void roctxRegisterTracerCallback(const void* function) {
|
||||
report_activity.store(reinterpret_cast<decltype(report_activity.load())>(function),
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
extern "C" ROCTX_EXPORT bool RemoveApiCallback(uint32_t op) {
|
||||
if (op >= ROCTX_API_ID_NUMBER) return false;
|
||||
callbacks.Set(op, nullptr, nullptr);
|
||||
return true;
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE. */
|
||||
|
||||
#ifndef UTIL_CALLBACK_TABLE_H_
|
||||
#define UTIL_CALLBACK_TABLE_H_
|
||||
|
||||
#include "ext/prof_protocol.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
namespace roctracer::util {
|
||||
|
||||
// Generic callbacks table
|
||||
template <activity_domain_t Domain, uint32_t N> class CallbackTable {
|
||||
public:
|
||||
CallbackTable()
|
||||
// Zero initialize the callbacks array as the function pointer is used to determine if the
|
||||
// callback is enabled.
|
||||
: callbacks_() {}
|
||||
|
||||
void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) {
|
||||
assert(callback_id < N && "callback_id is out of range");
|
||||
std::lock_guard lock(mutex_);
|
||||
auto& callback = callbacks_[callback_id];
|
||||
callback.first.store(callback_function, std::memory_order_relaxed);
|
||||
callback.second = user_arg;
|
||||
}
|
||||
|
||||
auto Get(uint32_t callback_id) const {
|
||||
assert(callback_id < N && "id is out of range");
|
||||
std::lock_guard lock(mutex_);
|
||||
auto& callback = callbacks_[callback_id];
|
||||
return std::make_pair(callback.first.load(std::memory_order_relaxed), callback.second);
|
||||
}
|
||||
|
||||
template <typename... Args> void Invoke(uint32_t callback_id, Args... args) {
|
||||
if (callbacks_[callback_id].first.load(std::memory_order_relaxed) == nullptr) return;
|
||||
if (auto [callback_function, user_arg] = Get(callback_id); callback_function != nullptr)
|
||||
callback_function(Domain, callback_id, std::forward<Args>(args)..., user_arg);
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<std::pair<std::atomic<activity_rtapi_callback_t>, void*>, N> callbacks_;
|
||||
mutable std::mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace roctracer::util
|
||||
|
||||
#endif // UTIL_CALLBACK_TABLE_H_
|
||||
Ссылка в новой задаче
Block a user