SWDEV-351980 - Consolidate registration tables in the roctracer

Change-Id: I44cd1cc81cf6a529aed89ee8db1377c0aa67f0dc


[ROCm/roctracer commit: 2673bf5e2c]
Этот коммит содержится в:
Laurent Morichetti
2022-09-02 12:40:15 -07:00
родитель a7700afbf6
Коммит 3d46d2d5cb
11 изменённых файлов: 662 добавлений и 704 удалений
+6 -1
Просмотреть файл
@@ -41,7 +41,12 @@ inline static std::ostream& operator<<(std::ostream& out, const char& v) {
#include <roctracer.h>
enum { HIP_OP_ID_DISPATCH = 0, HIP_OP_ID_COPY = 1, HIP_OP_ID_BARRIER = 2, HIP_OP_ID_NUMBER = 3 };
typedef enum {
HIP_OP_ID_DISPATCH = 0,
HIP_OP_ID_COPY = 1,
HIP_OP_ID_BARRIER = 2,
HIP_OP_ID_NUMBER = 3
} hip_op_id_t;
#ifdef __cplusplus
extern "C" {
+29 -18
Просмотреть файл
@@ -332,7 +332,6 @@ class API_DescrParser:
self.cpp_content += "/* Generated by " + os.path.basename(__file__) + " */\n" + license + "\n\n"
self.cpp_content += '#include <hsa/hsa_api_trace.h>\n'
self.cpp_content += '#include \"util/callback_table.h\"\n\n'
self.cpp_content += '#include <atomic>\n'
self.cpp_content += 'namespace roctracer::hsa_support::detail {\n'
@@ -409,40 +408,52 @@ class API_DescrParser:
def gen_callbacks(self, n, name, call, struct):
content = ''
if n == -1:
content += 'static util::CallbackTable<ACTIVITY_DOMAIN_HSA_API, HSA_API_ID_NUMBER> cb_table;\n'
content += '/* section: Static declarations */\n'
content += '\n'
if call != '-':
call_id = self.api_id[call];
ret_type = struct['ret']
content += 'static ' + ret_type + ' ' + call + '_callback(' + struct['args'] + ') {\n'
content += ' hsa_api_data_t api_data{};\n'
content += ' hsa_trace_data_t trace_data;\n'
content += ' bool enabled{false};\n'
content += '\n'
content += ' if (auto function = report_activity.load(std::memory_order_relaxed); function &&\n'
content += ' (enabled =\n'
content += ' function(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &trace_data) == 0)) {\n'
content += ' if (trace_data.phase_enter != nullptr) {\n'
for var in struct['alst']:
item = struct['astr'][var];
if re.search(r'char\* ', item):
content += ' api_data.args.' + call + '.' + var + ' = ' + '(' + var + ' != NULL) ? strdup(' + var + ')' + ' : NULL;\n'
# FIXME: we should not strdup the char* arguments here, as the callback will not outlive the scope of this function. Instead, we
# should generate a helper function to capture the content of the arguments similar to hipApiArgsInit for HIP. We also need a
# helper to free the memory that is allocated to capture the content.
content += ' trace_data.api_data.args.' + call + '.' + var + ' = ' + '(' + var + ' != NULL) ? strdup(' + var + ')' + ' : NULL;\n'
else:
content += ' api_data.args.' + call + '.' + var + ' = ' + var + ';\n'
content += ' trace_data.api_data.args.' + call + '.' + var + ' = ' + var + ';\n'
if call == 'hsa_amd_memory_async_copy_rect' and var == 'range':
content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
content += ' auto [ api_callback_fun, api_callback_arg ] = cb_table.Get(' + call_id + ');\n'
content += ' if (api_callback_fun) {\n'
content += ' api_data.phase = ACTIVITY_API_PHASE_ENTER;\n'
content += ' api_data.correlation_id = CorrelationIdPush();\n'
content += ' api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
content += ' trace_data.api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
content += ' trace_data.phase_enter(' + call_id + ', &trace_data);\n'
content += ' }\n'
content += ' }\n'
content += '\n'
if ret_type != 'void':
# FIXME: we should capture the return value and store it in the api_data
content += ' ' + ret_type + ' ret ='
content += ' ' + name + '_saved_before_cb.' + call + '_fn(' + ', '.join(struct['alst']) + ');\n'
content += ' if (api_callback_fun) {\n'
if ret_type != 'void':
content += ' api_data.' + ret_type + '_retval = ret;\n'
content += ' api_data.phase = ACTIVITY_API_PHASE_EXIT;\n'
content += ' api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
content += ' CorrelationIdPop();\n'
content += ' }\n'
content += '\n'
content += ' if (enabled && trace_data.phase_exit != nullptr)\n'
content += ' trace_data.phase_exit(' + call_id + ', &trace_data);\n'
if ret_type != 'void':
content += '\n'
content += ' return ret;\n'
content += '}\n'
return content
# generate API intercepting code
+66 -165
Просмотреть файл
@@ -26,7 +26,6 @@
#include "memory_pool.h"
#include "roctracer.h"
#include "roctracer_hsa.h"
#include "callback_table.h"
#include <hsa/hsa.h>
#include <hsa/amd_hsa_signal.h>
@@ -35,23 +34,32 @@
#include <optional>
#include <mutex>
namespace {
std::atomic<int (*)(activity_domain_t domain, uint32_t operation_id, void* data)> report_activity;
bool IsEnabled(activity_domain_t domain, uint32_t operation_id) {
auto report = report_activity.load(std::memory_order_relaxed);
return report && report(domain, operation_id, nullptr) == 0;
}
void ReportActivity(activity_domain_t domain, uint32_t operation_id, void* data) {
if (auto report = report_activity.load(std::memory_order_relaxed))
report(domain, operation_id, data);
}
} // namespace
#include "hsa_prof_str.inline.h"
namespace roctracer::hsa_support {
namespace {
util::CallbackTable<ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_NUMBER> hsa_evt_cb_table;
CoreApiTable saved_core_api{};
AmdExtTable saved_amd_ext_api{};
hsa_ven_amd_loader_1_01_pfn_t hsa_loader_api{};
// async copy activity callback
std::mutex init_mutex;
bool async_copy_callback_enabled = false;
MemoryPool* async_copy_callback_memory_pool = nullptr;
struct AgentInfo {
int index;
hsa_device_type_t type;
@@ -81,7 +89,6 @@ class Tracker {
hsa_signal_t orig;
hsa_signal_t signal;
void (*handler)(const entry_t*);
MemoryPool* pool;
union {
struct {
} copy;
@@ -182,7 +189,7 @@ hsa_status_t HSA_API MemoryAllocateIntercept(hsa_region_t region, size_t size, v
hsa_status_t status = saved_core_api.hsa_memory_allocate_fn(region, size, ptr);
if (status != HSA_STATUS_SUCCESS) return status;
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) {
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) {
hsa_evt_data_t data{};
data.allocate.ptr = *ptr;
data.allocate.size = size;
@@ -192,7 +199,7 @@ hsa_status_t HSA_API MemoryAllocateIntercept(hsa_region_t region, size_t size, v
&data.allocate.global_flag) != HSA_STATUS_SUCCESS)
fatal("hsa_region_get_info failed");
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data);
}
return HSA_STATUS_SUCCESS;
@@ -203,14 +210,14 @@ hsa_status_t MemoryAssignAgentIntercept(void* ptr, hsa_agent_t agent,
hsa_status_t status = saved_core_api.hsa_memory_assign_agent_fn(ptr, agent, access);
if (status != HSA_STATUS_SUCCESS) return status;
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); callback_fun) {
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) {
hsa_evt_data_t data{};
data.device.ptr = ptr;
if (saved_core_api.hsa_agent_get_info_fn(agent, HSA_AGENT_INFO_DEVICE, &data.device.type) !=
HSA_STATUS_SUCCESS)
fatal("hsa_agent_get_info failed");
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data);
}
return HSA_STATUS_SUCCESS;
@@ -220,13 +227,13 @@ hsa_status_t MemoryCopyIntercept(void* dst, const void* src, size_t size) {
hsa_status_t status = saved_core_api.hsa_memory_copy_fn(dst, src, size);
if (status != HSA_STATUS_SUCCESS) return status;
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_MEMCOPY); callback_fun) {
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY)) {
hsa_evt_data_t data{};
data.memcopy.dst = dst;
data.memcopy.src = src;
data.memcopy.size = size;
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_MEMCOPY, &data);
}
return HSA_STATUS_SUCCESS;
@@ -237,7 +244,7 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_pool_allocate_fn(pool, size, flags, ptr);
if (size == 0 || status != HSA_STATUS_SUCCESS) return status;
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) {
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) {
hsa_evt_data_t data{};
data.allocate.ptr = *ptr;
data.allocate.size = size;
@@ -249,17 +256,13 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
HSA_STATUS_SUCCESS)
fatal("hsa_region_get_info failed");
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data);
}
if (std::tie(callback_fun, callback_arg) = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE);
!callback_fun)
return HSA_STATUS_SUCCESS;
// FIXME: Why is this only reported if HSA_EVT_ID_ALLOCATE is also set?
auto callback_data = std::make_tuple(callback_fun, callback_arg, pool, ptr);
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) {
auto callback_data = std::make_pair(pool, ptr);
auto agent_callback = [](hsa_agent_t agent, void* iterate_agent_callback_data) {
auto [callback_fun, callback_arg, pool, ptr] =
*reinterpret_cast<decltype(callback_data)*>(iterate_agent_callback_data);
auto [pool, ptr] = *reinterpret_cast<decltype(callback_data)*>(iterate_agent_callback_data);
if (hsa_amd_memory_pool_access_t value;
saved_amd_ext_api.hsa_amd_agent_memory_pool_get_info_fn(
@@ -276,7 +279,7 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
data.device.agent = agent;
data.device.ptr = ptr;
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data);
return HSA_STATUS_SUCCESS;
};
saved_core_api.hsa_iterate_agents_fn(agent_callback, &callback_data);
@@ -286,11 +289,11 @@ hsa_status_t MemoryPoolAllocateIntercept(hsa_amd_memory_pool_t pool, size_t size
}
hsa_status_t MemoryPoolFreeIntercept(void* ptr) {
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_ALLOCATE); callback_fun) {
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE)) {
hsa_evt_data_t data{};
data.allocate.ptr = ptr;
data.allocate.size = 0;
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_ALLOCATE, &data);
}
return saved_amd_ext_api.hsa_amd_memory_pool_free_fn(ptr);
@@ -303,7 +306,7 @@ hsa_status_t AgentsAllowAccessIntercept(uint32_t num_agents, const hsa_agent_t*
saved_amd_ext_api.hsa_amd_agents_allow_access_fn(num_agents, agents, flags, ptr);
if (status != HSA_STATUS_SUCCESS) return status;
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_DEVICE); callback_fun) {
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE)) {
while (num_agents--) {
hsa_agent_t agent = *agents++;
auto it = agent_info_map.find(agent.handle);
@@ -315,7 +318,7 @@ hsa_status_t AgentsAllowAccessIntercept(uint32_t num_agents, const hsa_agent_t*
data.device.agent = agent;
data.device.ptr = ptr;
callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data, callback_arg);
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_DEVICE, &data);
}
}
return HSA_STATUS_SUCCESS;
@@ -329,7 +332,6 @@ struct CodeObjectCallbackArg {
hsa_status_t CodeObjectCallback(hsa_executable_t executable,
hsa_loaded_code_object_t loaded_code_object, void* arg) {
auto* code_object_callback_arg = static_cast<CodeObjectCallbackArg*>(arg);
hsa_evt_data_t data{};
if (hsa_loader_api.hsa_ven_amd_loader_loaded_code_object_get_info(
@@ -384,9 +386,8 @@ hsa_status_t CodeObjectCallback(hsa_executable_t executable,
fatal("hsa_ven_amd_loader_loaded_code_object_get_info failed");
data.codeobj.uri = uri_str.c_str();
data.codeobj.unload = code_object_callback_arg->unload ? 1 : 0;
code_object_callback_arg->callback_fun(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ, &data,
code_object_callback_arg->callback_arg);
data.codeobj.unload = *static_cast<bool*>(arg) ? 1 : 0;
ReportActivity(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ, &data);
return HSA_STATUS_SUCCESS;
}
@@ -395,20 +396,20 @@ hsa_status_t ExecutableFreezeIntercept(hsa_executable_t executable, const char*
hsa_status_t status = saved_core_api.hsa_executable_freeze_fn(executable, options);
if (status != HSA_STATUS_SUCCESS) return status;
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_CODEOBJ); callback_fun) {
CodeObjectCallbackArg arg = {callback_fun, callback_arg, false};
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ)) {
bool unload = false;
hsa_loader_api.hsa_ven_amd_loader_executable_iterate_loaded_code_objects(
executable, CodeObjectCallback, &arg);
executable, CodeObjectCallback, &unload);
}
return HSA_STATUS_SUCCESS;
}
hsa_status_t ExecutableDestroyIntercept(hsa_executable_t executable) {
if (auto [callback_fun, callback_arg] = hsa_evt_cb_table.Get(HSA_EVT_ID_CODEOBJ); callback_fun) {
CodeObjectCallbackArg arg = {callback_fun, callback_arg, true};
if (IsEnabled(ACTIVITY_DOMAIN_HSA_EVT, HSA_EVT_ID_CODEOBJ)) {
bool unload = true;
hsa_loader_api.hsa_ven_amd_loader_executable_iterate_loaded_code_objects(
executable, CodeObjectCallback, &arg);
executable, CodeObjectCallback, &unload);
}
return saved_core_api.hsa_executable_destroy_fn(executable);
@@ -422,25 +423,31 @@ void MemoryASyncCopyHandler(const Tracker::entry_t* entry) {
record.end_ns = entry->end;
record.device_id = 0;
record.correlation_id = entry->correlation_id;
entry->pool->Write(record);
ReportActivity(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY, &record);
}
hsa_status_t MemoryASyncCopyIntercept(void* dst, hsa_agent_t dst_agent, const void* src,
hsa_agent_t src_agent, size_t size, uint32_t num_dep_signals,
const hsa_signal_t* dep_signals,
hsa_signal_t completion_signal) {
if (!async_copy_callback_enabled) {
bool is_enabled = IsEnabled(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY);
// FIXME: what happens if the state changes before returning?
[[maybe_unused]] hsa_status_t status =
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(is_enabled);
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
if (!is_enabled) {
return saved_amd_ext_api.hsa_amd_memory_async_copy_fn(
dst, dst_agent, src, src_agent, size, num_dep_signals, dep_signals, completion_signal);
}
Tracker::entry_t* entry = new Tracker::entry_t();
entry->handler = MemoryASyncCopyHandler;
entry->pool = async_copy_callback_memory_pool;
entry->correlation_id = CorrelationId();
Tracker::Enable(Tracker::COPY_ENTRY_TYPE, hsa_agent_t{}, completion_signal, entry);
hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_async_copy_fn(
status = saved_amd_ext_api.hsa_amd_memory_async_copy_fn(
dst, dst_agent, src, src_agent, size, num_dep_signals, dep_signals, entry->signal);
if (status != HSA_STATUS_SUCCESS) Tracker::Disable(entry);
@@ -454,7 +461,14 @@ hsa_status_t MemoryASyncCopyRectIntercept(const hsa_pitched_ptr_t* dst,
hsa_agent_t copy_agent, hsa_amd_copy_direction_t dir,
uint32_t num_dep_signals, const hsa_signal_t* dep_signals,
hsa_signal_t completion_signal) {
if (!async_copy_callback_enabled) {
bool is_enabled = IsEnabled(ACTIVITY_DOMAIN_HSA_OPS, HSA_OP_ID_COPY);
// FIXME: what happens if the state changes before returning?
[[maybe_unused]] hsa_status_t status =
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(is_enabled);
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
if (!is_enabled) {
return saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn(
dst, dst_offset, src, src_offset, range, copy_agent, dir, num_dep_signals, dep_signals,
completion_signal);
@@ -462,11 +476,10 @@ hsa_status_t MemoryASyncCopyRectIntercept(const hsa_pitched_ptr_t* dst,
Tracker::entry_t* entry = new Tracker::entry_t();
entry->handler = MemoryASyncCopyHandler;
entry->pool = async_copy_callback_memory_pool;
entry->correlation_id = CorrelationId();
Tracker::Enable(Tracker::COPY_ENTRY_TYPE, hsa_agent_t{}, completion_signal, entry);
hsa_status_t status = saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn(
status = saved_amd_ext_api.hsa_amd_memory_async_copy_rect_fn(
dst, dst_offset, src, src_offset, range, copy_agent, dir, num_dep_signals, dep_signals,
entry->signal);
if (status != HSA_STATUS_SUCCESS) Tracker::Disable(entry);
@@ -502,8 +515,6 @@ roctracer_timestamp_t timestamp_ns() {
}
void Initialize(HsaApiTable* table) {
std::scoped_lock lock(init_mutex);
// Save the HSA core api and amd_ext api.
saved_core_api = *table->core_;
saved_amd_ext_api = *table->amd_ext_;
@@ -558,20 +569,12 @@ void Initialize(HsaApiTable* table) {
detail::InstallCoreApiWrappers(table->core_);
detail::InstallAmdExtWrappers(table->amd_ext_);
detail::InstallImageExtWrappers(table->image_ext_);
if (async_copy_callback_enabled) {
[[maybe_unused]] hsa_status_t status =
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(true);
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
}
}
void Finalize() {
if (hsa_support::async_copy_callback_enabled) {
[[maybe_unused]] hsa_status_t status =
hsa_support::saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false);
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
}
if (hsa_status_t status = saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false);
status != HSA_STATUS_SUCCESS)
assert(!"hsa_amd_profiling_async_copy_enable failed");
}
const char* GetApiName(uint32_t id) { return detail::GetApiName(id); }
@@ -612,111 +615,9 @@ const char* GetOpsName(uint32_t id) {
uint32_t GetApiCode(const char* str) { return detail::GetApiCode(str); }
void EnableActivity(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
if (op == HSA_OP_ID_COPY) {
std::scoped_lock lock(init_mutex);
if (saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn != nullptr) {
[[maybe_unused]] hsa_status_t status =
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(true);
assert(status == HSA_STATUS_SUCCESS && "hsa_amd_profiling_async_copy_enable failed");
}
async_copy_callback_enabled = true;
async_copy_callback_memory_pool = reinterpret_cast<MemoryPool*>(pool);
} else if (op == HSA_OP_ID_RESERVED1) {
/* Place holder for PC sampling. */
} else {
EXC_RAISING(ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED,
"HSA OPS operation ID(" << op << ") is not currently implemented");
}
break;
case ACTIVITY_DOMAIN_HSA_API:
// FIXME: Add HSA api activities.
break;
case ACTIVITY_DOMAIN_HSA_EVT:
break;
default:
break;
}
}
void EnableCallback(roctracer_domain_t domain, uint32_t cid, roctracer_rtapi_callback_t callback,
void* user_data) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
break;
case ACTIVITY_DOMAIN_HSA_API:
if (cid >= HSA_API_ID_NUMBER)
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
"invalid HSA API operation ID(" << cid << ")");
detail::cb_table.Set(cid, callback, user_data);
break;
case ACTIVITY_DOMAIN_HSA_EVT:
if (cid >= HSA_EVT_ID_NUMBER)
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
"invalid HSA API operation ID(" << cid << ")");
hsa_evt_cb_table.Set(cid, callback, user_data);
break;
default:
break;
}
}
void DisableActivity(roctracer_domain_t domain, uint32_t op) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
if (op == HSA_OP_ID_COPY) {
std::scoped_lock lock(init_mutex);
async_copy_callback_enabled = false;
async_copy_callback_memory_pool = nullptr;
if (saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn != nullptr) {
[[maybe_unused]] hsa_status_t status =
saved_amd_ext_api.hsa_amd_profiling_async_copy_enable_fn(false);
assert(status == HSA_STATUS_SUCCESS || status == HSA_STATUS_ERROR_NOT_INITIALIZED ||
!"hsa_amd_profiling_async_copy_enable failed");
}
} else if (op == HSA_OP_ID_RESERVED1) {
/* Place holder for PC sampling. */
} else {
EXC_RAISING(ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED,
"HSA OPS operation ID(" << op << ") is not currently implemented");
}
break;
case ACTIVITY_DOMAIN_HSA_API:
// FIXME: Add HSA api activities.
break;
case ACTIVITY_DOMAIN_HSA_EVT:
break;
default:
break;
}
}
void DisableCallback(roctracer_domain_t domain, uint32_t cid) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
break;
case ACTIVITY_DOMAIN_HSA_API:
if (cid >= HSA_API_ID_NUMBER)
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
"invalid HSA API operation ID(" << cid << ")");
detail::cb_table.Set(cid, nullptr, nullptr);
break;
case ACTIVITY_DOMAIN_HSA_EVT:
if (cid >= HSA_EVT_ID_NUMBER)
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT,
"invalid HSA EVT operation ID(" << cid << ")");
hsa_evt_cb_table.Set(cid, nullptr, nullptr);
break;
default:
break;
}
void RegisterTracerCallback(int (*function)(activity_domain_t domain, uint32_t operation_id,
void* data)) {
report_activity.store(function, std::memory_order_relaxed);
}
} // namespace roctracer::hsa_support
+11 -7
Просмотреть файл
@@ -28,6 +28,15 @@
namespace roctracer::hsa_support {
struct hsa_trace_data_t {
hsa_api_data_t api_data;
uint64_t phase_enter_timestamp;
uint64_t phase_data;
void (*phase_enter)(hsa_api_id_t operation_id, hsa_trace_data_t* data);
void (*phase_exit)(hsa_api_id_t operation_id, hsa_trace_data_t* data);
};
void Initialize(HsaApiTable* table);
void Finalize();
@@ -36,13 +45,8 @@ const char* GetEvtName(uint32_t id);
const char* GetOpsName(uint32_t id);
uint32_t GetApiCode(const char* str);
void EnableActivity(roctracer_domain_t domain, uint32_t op, roctracer_pool_t* pool);
void EnableCallback(roctracer_domain_t domain, uint32_t cid, roctracer_rtapi_callback_t callback,
void* user_data);
void DisableCallback(roctracer_domain_t domain, uint32_t cid);
void DisableActivity(roctracer_domain_t domain, uint32_t op);
void RegisterTracerCallback(int (*function)(activity_domain_t domain, uint32_t operation_id,
void* data));
uint64_t timestamp_ns();
} // namespace roctracer::hsa_support
-65
Просмотреть файл
@@ -1,65 +0,0 @@
/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. */
#ifndef SRC_CORE_JOURNAL_H_
#define SRC_CORE_JOURNAL_H_
#include "ext/prof_protocol.h"
#include <mutex>
#include <type_traits>
#include <unordered_map>
namespace roctracer {
template <typename Data> class Journal {
public:
/* Insert { domain, op } into the journal. Return false if the insertion
updated an existing entry. */
template <typename T = Data, std::enable_if_t<std::is_constructible_v<Data, T>, int> = 0>
bool Insert(roctracer_domain_t domain, uint32_t op, T&& data) {
std::lock_guard lock(mutex_);
auto result = map_[domain].try_emplace(op, std::forward<T>(data));
if (!result.second) result.first->second = std::forward<T>(data);
return result.second;
}
/* Remove { domain, op } from the journal. Return false if the entry did not exist. */
bool Remove(roctracer_domain_t domain, uint32_t op) {
std::lock_guard lock(mutex_);
return map_[domain].erase(op) == 1;
}
template <typename Functor> void ForEach(Functor&& func) {
std::lock_guard lock(mutex_);
for (auto&& domain : map_)
for (auto&& operation : domain.second)
if (!func(domain.first /* domain */, operation.first /* op */, operation.second /* data */))
break; /* FIXME: what are we breaking out of? */
}
private:
std::mutex mutex_;
std::unordered_map<roctracer_domain_t, std::unordered_map<uint32_t, Data>> map_;
};
} // namespace roctracer
#endif // SRC_CORE_JOURNAL_H_
+29 -62
Просмотреть файл
@@ -122,13 +122,6 @@ __attribute__((weak)) const char* hipKernelNameRefByPtr(const void* hostFunction
__attribute__((weak)) int hipGetStreamDeviceId(hipStream_t stream) { return 0; }
__attribute__((weak)) const char* hipApiName(uint32_t id) { return NULL; }
__attribute__((weak)) hipError_t hipRegisterAsyncActivityCallback_t(uint32_t op, void* fun,
void* arg) {
return hipErrorUnknown;
}
__attribute__((weak)) hipError_t hipRemoveAsyncActivityCallback_t(uint32_t op) {
return hipErrorUnknown;
}
__attribute__((weak)) const char* hipGetCmdName(unsigned op) { return NULL; }
class HipLoaderStatic {
@@ -158,12 +151,8 @@ class HipLoaderStatic {
GetStreamDeviceId_t* GetStreamDeviceId;
ApiName_t* ApiName;
typedef hipError_t(hipRegisterAsyncActivityCallback_t)(uint32_t op, void* fun, void* arg);
typedef hipError_t(hipRemoveAsyncActivityCallback_t)(uint32_t op);
typedef const char*(hipGetOpName_t)(unsigned op);
hipRegisterAsyncActivityCallback_t* RegisterActivityCallback;
hipRemoveAsyncActivityCallback_t* RemoveActivityCallback;
hipGetOpName_t* GetOpName;
static inline loader_t& Instance() {
@@ -191,8 +180,6 @@ class HipLoaderStatic {
GetStreamDeviceId = hipGetStreamDeviceId;
ApiName = hipApiName;
RegisterAsyncActivityCallback = hipRegisterAsyncActivityCallback;
RemoveAsyncActivityCallback = hipRemoveAsyncActivityCallback;
GetOpName = hipGetCmdName;
}
@@ -204,53 +191,36 @@ class HipApi {
public:
typedef BaseLoader<HipApi> Loader;
typedef decltype(hipRegisterApiCallback) RegisterApiCallback_t;
typedef decltype(hipRemoveApiCallback) RemoveApiCallback_t;
typedef decltype(hipRegisterActivityCallback) RegisterActivityCallback_t;
typedef decltype(hipRemoveActivityCallback) RemoveActivityCallback_t;
typedef decltype(hipKernelNameRef) KernelNameRef_t;
typedef decltype(hipKernelNameRefByPtr) KernelNameRefByPtr_t;
typedef decltype(hipGetStreamDeviceId) GetStreamDeviceId_t;
typedef decltype(hipApiName) ApiName_t;
typedef int(hipGetStreamDeviceId_t)(hipStream_t stream);
typedef const char*(hipKernelNameRef_t)(const hipFunction_t function);
typedef const char*(hipKernelNameRefByPtr_t)(const void* host_function, hipStream_t stream);
typedef const char*(hipApiName_t)(uint32_t id);
typedef const char*(hipGetCmdName_t)(uint32_t op);
typedef void(hipRegisterTracerCallback_t)(int (*function)(activity_domain_t domain,
uint32_t operation_id, void* data));
RegisterApiCallback_t* RegisterApiCallback;
RemoveApiCallback_t* RemoveApiCallback;
RegisterActivityCallback_t* RegisterActivityCallback;
RemoveActivityCallback_t* RemoveActivityCallback;
KernelNameRef_t* KernelNameRef;
KernelNameRefByPtr_t* KernelNameRefByPtr_;
hipKernelNameRef_t* KernelNameRef;
const char* KernelNameRefByPtr(const void* function, hipStream_t stream = nullptr) const {
return KernelNameRefByPtr_(function, stream);
}
GetStreamDeviceId_t* GetStreamDeviceId;
ApiName_t* ApiName;
typedef hipError_t(hipRegisterAsyncActivityCallback_t)(uint32_t op, void* fun, void* arg);
typedef hipError_t(hipRemoveAsyncActivityCallback_t)(uint32_t op);
typedef const char*(hipGetOpName_t)(unsigned op);
hipRegisterAsyncActivityCallback_t* RegisterAsyncActivityCallback;
hipRemoveAsyncActivityCallback_t* RemoveAsyncActivityCallback;
hipGetOpName_t* GetOpName;
hipGetStreamDeviceId_t* GetStreamDeviceId;
hipGetCmdName_t* GetOpName;
hipApiName_t* ApiName;
hipRegisterTracerCallback_t* RegisterTracerCallback;
protected:
void init(Loader* loader) {
RegisterApiCallback = loader->GetFun<RegisterApiCallback_t>("hipRegisterApiCallback");
RemoveApiCallback = loader->GetFun<RemoveApiCallback_t>("hipRemoveApiCallback");
RegisterActivityCallback =
loader->GetFun<RegisterActivityCallback_t>("hipRegisterActivityCallback");
RemoveActivityCallback = loader->GetFun<RemoveActivityCallback_t>("hipRemoveActivityCallback");
KernelNameRef = loader->GetFun<KernelNameRef_t>("hipKernelNameRef");
KernelNameRefByPtr_ = loader->GetFun<KernelNameRefByPtr_t>("hipKernelNameRefByPtr");
GetStreamDeviceId = loader->GetFun<GetStreamDeviceId_t>("hipGetStreamDeviceId");
ApiName = loader->GetFun<ApiName_t>("hipApiName");
RegisterAsyncActivityCallback =
loader->GetFun<hipRegisterAsyncActivityCallback_t>("hipRegisterAsyncActivityCallback");
RemoveAsyncActivityCallback =
loader->GetFun<hipRemoveAsyncActivityCallback_t>("hipRemoveAsyncActivityCallback");
GetOpName = loader->GetFun<hipGetOpName_t>("hipGetCmdName");
GetStreamDeviceId = loader->GetFun<hipGetStreamDeviceId_t>("hipGetStreamDeviceId");
KernelNameRef = loader->GetFun<hipKernelNameRef_t>("hipKernelNameRef");
KernelNameRefByPtr_ = loader->GetFun<hipKernelNameRefByPtr_t>("hipKernelNameRefByPtr");
GetOpName = loader->GetFun<hipGetCmdName_t>("hipGetCmdName");
ApiName = loader->GetFun<hipApiName_t>("hipApiName");
RegisterTracerCallback =
loader->GetFun<hipRegisterTracerCallback_t>("hipRegisterTracerCallback");
}
private:
hipKernelNameRefByPtr_t* KernelNameRefByPtr_;
};
#endif
@@ -260,16 +230,14 @@ class RocTxApi {
public:
typedef BaseLoader<RocTxApi> Loader;
typedef bool(RegisterApiCallback_t)(uint32_t op, void* callback, void* arg);
typedef bool(RemoveApiCallback_t)(uint32_t op);
RegisterApiCallback_t* RegisterApiCallback;
RemoveApiCallback_t* RemoveApiCallback;
typedef void(roctxRegisterTracerCallback_t)(int (*function)(activity_domain_t domain,
uint32_t operation_id, void* data));
roctxRegisterTracerCallback_t* RegisterTracerCallback;
protected:
void init(Loader* loader) {
RegisterApiCallback = loader->GetFun<RegisterApiCallback_t>("RegisterApiCallback");
RemoveApiCallback = loader->GetFun<RemoveApiCallback_t>("RemoveApiCallback");
RegisterTracerCallback =
loader->GetFun<roctxRegisterTracerCallback_t>("roctxRegisterTracerCallback");
}
};
@@ -278,8 +246,7 @@ typedef BaseLoader<RocTxApi> RocTxLoader;
#if STATIC_BUILD
typedef HipLoaderStatic HipLoader;
#else
typedef BaseLoader<HipApi> HipLoaderShared;
typedef HipLoaderShared HipLoader;
using HipLoader = BaseLoader<HipApi>;
#endif
} // namespace roctracer
@@ -299,7 +266,7 @@ typedef HipLoaderShared HipLoader;
roctracer::HipLoaderStatic::instance_t roctracer::HipLoaderStatic::instance_{};
#else
#define LOADER_INSTANTIATE_HIP() \
template <> const char* roctracer::HipLoaderShared::lib_name_ = "libamdhip64.so";
template <> const char* roctracer::HipLoader::lib_name_ = "libamdhip64.so";
#endif
#define LOADER_INSTANTIATE() \
+84
Просмотреть файл
@@ -0,0 +1,84 @@
/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. */
#ifndef UTIL_CALLBACK_TABLE_H_
#define UTIL_CALLBACK_TABLE_H_
#include "ext/prof_protocol.h"
#include <array>
#include <atomic>
#include <cassert>
#include <optional>
#include <shared_mutex>
#include <utility>
namespace roctracer::util {
namespace detail {
struct False {
constexpr bool operator()() { return false; }
};
} // namespace detail
// Generic callbacks table
template <typename T, uint32_t N, typename IsStopped = detail::False> class RegistrationTable {
public:
template <typename... Args> void Register(uint32_t operation_id, Args... args) {
assert(operation_id < N && "operation_id is out of range");
auto& entry = table_[operation_id];
std::unique_lock lock(entry.mutex);
if (!entry.enabled.exchange(true, std::memory_order_relaxed))
registered_count_.fetch_add(1, std::memory_order_relaxed);
entry.data = T{std::forward<Args>(args)...};
}
void Unregister(uint32_t operation_id) {
assert(operation_id < N && "id is out of range");
auto& entry = table_[operation_id];
std::unique_lock lock(entry.mutex);
if (entry.enabled.exchange(false, std::memory_order_relaxed))
registered_count_.fetch_sub(1, std::memory_order_relaxed);
}
std::optional<T> Get(uint32_t operation_id) const {
assert(operation_id < N && "id is out of range");
auto& entry = table_[operation_id];
if (!entry.enabled.load(std::memory_order_relaxed) || IsStopped{}()) return std::nullopt;
std::shared_lock lock(entry.mutex);
return entry.enabled.load(std::memory_order_relaxed) ? std::make_optional(entry.data)
: std::nullopt;
}
bool IsEmpty() const { return registered_count_.load(std::memory_order_relaxed) == 0; }
private:
std::atomic<size_t> registered_count_{0};
struct {
std::atomic<bool> enabled{false};
mutable std::shared_mutex mutex;
T data;
} table_[N]{};
};
} // namespace roctracer::util
#endif // UTIL_CALLBACK_TABLE_H_
+392 -276
Просмотреть файл
@@ -34,17 +34,18 @@
#include <atomic>
#include <mutex>
#include <stack>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "correlation_id.h"
#include "debug.h"
#include "journal.h"
#include "loader.h"
#include "hsa_support.h"
#include "memory_pool.h"
#include "exception.h"
#include "hsa_support.h"
#include "loader.h"
#include "logger.h"
#include "memory_pool.h"
#include "registration_table.h"
#define API_METHOD_PREFIX \
roctracer_status_t err = ROCTRACER_STATUS_SUCCESS; \
@@ -74,126 +75,38 @@ static inline uint32_t GetTid() {
return tid;
}
using namespace roctracer;
namespace {
///////////////////////////////////////////////////////////////////////////////////////////////////
// Internal library methods
//
namespace roctracer {
namespace ext_support {
roctracer_start_cb_t roctracer_start_cb = nullptr;
roctracer_stop_cb_t roctracer_stop_cb = nullptr;
} // namespace ext_support
struct CallbackJournalData {
roctracer_rtapi_callback_t callback;
void* user_data;
};
static Journal<CallbackJournalData> cb_journal;
struct ActivityJournalData {
roctracer_pool_t* pool;
};
static Journal<ActivityJournalData> act_journal;
roctracer_status_t GetExcStatus(const std::exception& e) {
const ApiError* roctracer_exc_ptr = dynamic_cast<const ApiError*>(&e);
return (roctracer_exc_ptr) ? roctracer_exc_ptr->status() : ROCTRACER_STATUS_ERROR;
}
std::mutex hip_activity_mutex;
std::mutex registration_mutex;
enum { API_CB_MASK = 0x1, API_ACT_MASK = 0x2 };
// Memory pool routines and primitives
std::recursive_mutex memory_pool_mutex;
MemoryPool* default_memory_pool = nullptr;
class HIPActivityCallbackTracker {
public:
uint32_t enable_check(uint32_t op, uint32_t mask) { return data_[op] |= mask; }
uint32_t disable_check(uint32_t op, uint32_t mask) { return data_[op] &= ~mask; }
} // namespace
private:
std::unordered_map<uint32_t, uint32_t> data_;
};
static HIPActivityCallbackTracker hip_act_cb_tracker;
void HIP_ApiCallback(uint32_t op_id, roctracer_record_t* record, void* callback_data, void* arg) {
hip_api_data_t* data = static_cast<hip_api_data_t*>(callback_data);
MemoryPool* pool = static_cast<MemoryPool*>(arg);
if (data->phase == ACTIVITY_API_PHASE_ENTER) {
// Generate a new correlation ID.
uint64_t correlation_id = CorrelationIdPush();
data->correlation_id = correlation_id;
if (pool != nullptr) {
// Filing record info
record->domain = ACTIVITY_DOMAIN_HIP_API;
record->kind = 0;
record->op = op_id;
record->process_id = GetPid();
record->thread_id = GetTid();
record->begin_ns = hsa_support::timestamp_ns();
record->correlation_id = correlation_id;
}
} else {
if (pool != nullptr) {
record->end_ns = hsa_support::timestamp_ns();
if (auto external_id = ExternalCorrelationId()) {
roctracer_record_t ext_record{};
ext_record.domain = ACTIVITY_DOMAIN_EXT_API;
ext_record.op = ACTIVITY_EXT_OP_EXTERN_ID;
ext_record.correlation_id = record->correlation_id;
ext_record.external_id = *external_id;
// Write the external correlation id record directly followed by the activity record.
pool->Write(std::array<roctracer_record_t, 2>{ext_record, *record});
} else {
// Write record to the buffer.
pool->Write(*record);
}
}
CorrelationIdPop();
}
}
void HIP_AsyncActivityCallback(uint32_t op_id, void* record_ptr, void* arg) {
MemoryPool* pool = reinterpret_cast<MemoryPool*>(arg);
roctracer_record_t& record = *reinterpret_cast<roctracer_record_t*>(record_ptr);
record.domain = ACTIVITY_DOMAIN_HIP_OPS;
// If the record is for a kernel dispatch, write the kernel name in the pool's data,
// and make the record point to it. Older HIP runtimes do not provide a kernel
// name, so record.kernel_name might be null.
if (record.op == HIP_OP_ID_DISPATCH && record.kernel_name != nullptr)
pool->Write(record, record.kernel_name, strlen(record.kernel_name) + 1,
[](auto& record, const void* data) {
record.kernel_name = static_cast<const char*>(data);
});
else
pool->Write(record);
}
namespace roctracer {
// Logger routines and primitives
util::Logger::mutex_t util::Logger::mutex_;
std::atomic<util::Logger*> util::Logger::instance_{};
// Memory pool routines and primitives
MemoryPool* default_memory_pool = nullptr;
std::recursive_mutex memory_pool_mutex;
// Stop status routines and primitives
unsigned stop_status_value = 0;
std::mutex stop_status_mutex;
unsigned set_stopped(unsigned val) {
std::lock_guard lock(stop_status_mutex);
const unsigned ret = (stop_status_value ^ val);
stop_status_value = val;
return ret;
}
} // namespace roctracer
using namespace roctracer;
LOADER_INSTANTIATE();
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -227,7 +140,7 @@ ROCTRACER_API const char* roctracer_op_string(uint32_t domain, uint32_t op, uint
case ACTIVITY_DOMAIN_EXT_API:
return "EXT_API";
default:
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID");
}
API_METHOD_CATCH(nullptr)
}
@@ -261,95 +174,350 @@ ROCTRACER_API roctracer_status_t roctracer_op_code(uint32_t domain, const char*
API_METHOD_SUFFIX
}
static inline uint32_t get_op_begin(uint32_t domain) {
namespace {
template <activity_domain_t> struct DomainTraits;
template <> struct DomainTraits<ACTIVITY_DOMAIN_HIP_API> {
using ApiData = hip_api_data_t;
using OperationId = hip_api_id_t;
static constexpr size_t kOpIdBegin = HIP_API_ID_FIRST;
static constexpr size_t kOpIdEnd = HIP_API_ID_LAST + 1;
};
template <> struct DomainTraits<ACTIVITY_DOMAIN_HSA_API> {
using ApiData = hsa_api_data_t;
using OperationId = hsa_api_id_t;
static constexpr size_t kOpIdBegin = 0;
static constexpr size_t kOpIdEnd = HSA_API_ID_NUMBER;
};
template <> struct DomainTraits<ACTIVITY_DOMAIN_ROCTX> {
using ApiData = roctx_api_data_t;
using OperationId = roctx_api_id_t;
static constexpr size_t kOpIdBegin = 0;
static constexpr size_t kOpIdEnd = ROCTX_API_ID_NUMBER;
};
template <> struct DomainTraits<ACTIVITY_DOMAIN_HIP_OPS> {
using OperationId = hip_op_id_t;
static constexpr size_t kOpIdBegin = 0;
static constexpr size_t kOpIdEnd = HIP_OP_ID_NUMBER;
};
template <> struct DomainTraits<ACTIVITY_DOMAIN_HSA_OPS> {
using OperationId = hsa_op_id_t;
static constexpr size_t kOpIdBegin = 0;
static constexpr size_t kOpIdEnd = HSA_OP_ID_NUMBER;
};
template <> struct DomainTraits<ACTIVITY_DOMAIN_HSA_EVT> {
using ApiData = hsa_evt_data_t;
using OperationId = hsa_evt_id_t;
static constexpr size_t kOpIdBegin = 0;
static constexpr size_t kOpIdEnd = HSA_EVT_ID_NUMBER;
};
constexpr uint32_t get_op_begin(activity_domain_t domain) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
return 0;
return DomainTraits<ACTIVITY_DOMAIN_HSA_OPS>::kOpIdBegin;
case ACTIVITY_DOMAIN_HSA_API:
return 0;
return DomainTraits<ACTIVITY_DOMAIN_HSA_API>::kOpIdBegin;
case ACTIVITY_DOMAIN_HSA_EVT:
return 0;
return DomainTraits<ACTIVITY_DOMAIN_HSA_EVT>::kOpIdBegin;
case ACTIVITY_DOMAIN_HIP_OPS:
return 0;
return DomainTraits<ACTIVITY_DOMAIN_HIP_OPS>::kOpIdBegin;
case ACTIVITY_DOMAIN_HIP_API:
return HIP_API_ID_FIRST;
return DomainTraits<ACTIVITY_DOMAIN_HIP_API>::kOpIdBegin;
case ACTIVITY_DOMAIN_ROCTX:
return DomainTraits<ACTIVITY_DOMAIN_ROCTX>::kOpIdBegin;
case ACTIVITY_DOMAIN_EXT_API:
return 0;
case ACTIVITY_DOMAIN_ROCTX:
return 0;
default:
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID");
}
return 0;
}
static inline uint32_t get_op_end(uint32_t domain) {
constexpr uint32_t get_op_end(activity_domain_t domain) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
return HSA_OP_ID_NUMBER;
return DomainTraits<ACTIVITY_DOMAIN_HSA_OPS>::kOpIdEnd;
case ACTIVITY_DOMAIN_HSA_API:
return HSA_API_ID_NUMBER;
return DomainTraits<ACTIVITY_DOMAIN_HSA_API>::kOpIdEnd;
case ACTIVITY_DOMAIN_HSA_EVT:
return HSA_EVT_ID_NUMBER;
return DomainTraits<ACTIVITY_DOMAIN_HSA_EVT>::kOpIdEnd;
case ACTIVITY_DOMAIN_HIP_OPS:
return HIP_OP_ID_NUMBER;
return DomainTraits<ACTIVITY_DOMAIN_HIP_OPS>::kOpIdEnd;
case ACTIVITY_DOMAIN_HIP_API:
return HIP_API_ID_LAST + 1;
case ACTIVITY_DOMAIN_EXT_API:
return 0;
return DomainTraits<ACTIVITY_DOMAIN_HIP_API>::kOpIdEnd;
case ACTIVITY_DOMAIN_ROCTX:
return ROCTX_API_ID_NUMBER;
return DomainTraits<ACTIVITY_DOMAIN_ROCTX>::kOpIdEnd;
case ACTIVITY_DOMAIN_EXT_API:
return get_op_begin(ACTIVITY_DOMAIN_EXT_API);
default:
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
throw roctracer::ApiError(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID");
}
return 0;
}
// Enable runtime API callbacks
static void roctracer_enable_callback_fun(roctracer_domain_t domain, uint32_t op,
roctracer_rtapi_callback_t callback, void* user_data) {
std::atomic<bool> stopped_status{false};
struct IsStopped {
bool operator()() const { return stopped_status.load(std::memory_order_relaxed); }
};
struct NeverStopped {
constexpr bool operator()() { return false; }
};
using UserCallback = std::pair<activity_rtapi_callback_t, void*>;
template <activity_domain_t domain, typename IsStopped>
using CallbackRegistrationTable =
util::RegistrationTable<UserCallback, DomainTraits<domain>::kOpIdEnd, IsStopped>;
template <activity_domain_t domain, typename IsStopped>
using ActivityRegistrationTable =
util::RegistrationTable<MemoryPool*, DomainTraits<domain>::kOpIdEnd, IsStopped>;
template <activity_domain_t domain> struct ApiTracer {
using ApiData = typename DomainTraits<domain>::ApiData;
using OperationId = typename DomainTraits<domain>::OperationId;
struct TraceData {
ApiData api_data; // API specific data (for example, function arguments).
uint64_t phase_enter_timestamp; // timestamp when phase_enter was executed.
uint64_t phase_data; // data that can be shared between phase_enter and phase_exit.
void (*phase_enter)(OperationId operation_id, TraceData* data);
void (*phase_exit)(OperationId operation_id, TraceData* data);
};
static void Exit(OperationId operation_id, TraceData* trace_data) {
if (auto pool = activity_table.Get(operation_id)) {
assert(trace_data != nullptr);
activity_record_t record{};
record.domain = domain;
record.op = operation_id;
record.correlation_id = trace_data->api_data.correlation_id;
record.begin_ns = trace_data->phase_enter_timestamp;
record.end_ns = hsa_support::timestamp_ns();
record.process_id = GetPid();
record.thread_id = GetTid();
if (auto external_id = ExternalCorrelationId()) {
roctracer_record_t ext_record{};
ext_record.domain = ACTIVITY_DOMAIN_EXT_API;
ext_record.op = ACTIVITY_EXT_OP_EXTERN_ID;
ext_record.correlation_id = record.correlation_id;
ext_record.external_id = *external_id;
// Write the external correlation id record directly followed by the activity record.
(*pool)->Write(std::array<roctracer_record_t, 2>{ext_record, record});
} else {
// Write record to the buffer.
(*pool)->Write(record);
}
}
CorrelationIdPop();
}
static void Exit_UserCallback(OperationId operation_id, TraceData* trace_data) {
if (auto user_callback = callback_table.Get(operation_id)) {
assert(trace_data != nullptr);
trace_data->api_data.phase = ACTIVITY_API_PHASE_EXIT;
user_callback->first(domain, operation_id, &trace_data->api_data, user_callback->second);
}
Exit(operation_id, trace_data);
}
static void Enter_UserCallback(OperationId operation_id, TraceData* trace_data) {
if (auto user_callback = callback_table.Get(operation_id)) {
assert(trace_data != nullptr);
trace_data->api_data.phase = ACTIVITY_API_PHASE_ENTER;
user_callback->first(domain, operation_id, &trace_data->api_data, user_callback->second);
trace_data->phase_exit = Exit_UserCallback;
} else {
trace_data->phase_exit = Exit;
}
}
static int Enter(OperationId operation_id, TraceData* trace_data) {
bool callback_enabled = callback_table.Get(operation_id).has_value(),
activity_enabled = activity_table.Get(operation_id).has_value();
if (!callback_enabled && !activity_enabled) return -1;
if (trace_data != nullptr) {
// Generate a new correlation ID.
trace_data->api_data.correlation_id = CorrelationIdPush();
if (activity_enabled) {
trace_data->phase_enter_timestamp = hsa_support::timestamp_ns();
trace_data->phase_enter = nullptr;
trace_data->phase_exit = Exit;
}
if (callback_enabled) {
trace_data->phase_enter = Enter_UserCallback;
trace_data->phase_exit = [](OperationId, TraceData*) { fatal("should not reach here"); };
}
}
return 0;
}
static CallbackRegistrationTable<domain, IsStopped> callback_table;
static ActivityRegistrationTable<domain, IsStopped> activity_table;
};
template <activity_domain_t domain>
CallbackRegistrationTable<domain, IsStopped> ApiTracer<domain>::callback_table;
template <activity_domain_t domain>
ActivityRegistrationTable<domain, IsStopped> ApiTracer<domain>::activity_table;
using HIP_ApiTracer = ApiTracer<ACTIVITY_DOMAIN_HIP_API>;
using HSA_ApiTracer = ApiTracer<ACTIVITY_DOMAIN_HSA_API>;
CallbackRegistrationTable<ACTIVITY_DOMAIN_ROCTX, NeverStopped> roctx_api_callback_table;
ActivityRegistrationTable<ACTIVITY_DOMAIN_HIP_OPS, IsStopped> hip_ops_activity_table;
ActivityRegistrationTable<ACTIVITY_DOMAIN_HSA_OPS, IsStopped> hsa_ops_activity_table;
CallbackRegistrationTable<ACTIVITY_DOMAIN_HSA_EVT, IsStopped> hsa_evt_callback_table;
int TracerCallback(activity_domain_t domain, uint32_t operation_id, void* data) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
case ACTIVITY_DOMAIN_HSA_API:
case ACTIVITY_DOMAIN_HSA_EVT:
hsa_support::EnableCallback(domain, op, callback, user_data);
break;
return HSA_ApiTracer::Enter(static_cast<HSA_ApiTracer::OperationId>(operation_id),
static_cast<HSA_ApiTracer::TraceData*>(data));
case ACTIVITY_DOMAIN_HIP_API:
return HIP_ApiTracer::Enter(static_cast<HIP_ApiTracer::OperationId>(operation_id),
static_cast<HIP_ApiTracer::TraceData*>(data));
case ACTIVITY_DOMAIN_HIP_OPS:
break;
case ACTIVITY_DOMAIN_HIP_API: {
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
if (hipError_t err =
HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data);
err != hipSuccess)
fatal("HIP::RegisterApiCallback(%d) failed (err=%d)", op, err);
if ((hip_act_cb_tracker.enable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
err != hipSuccess)
fatal("HIP::RegisterActivityCallback(%d) failed (err=%d)", op, err);
if (auto pool = hip_ops_activity_table.Get(operation_id)) {
if (auto record = static_cast<activity_record_t*>(data)) {
// If the record is for a kernel dispatch, write the kernel name in the pool's data,
// and make the record point to it. Older HIP runtimes do not provide a kernel
// name, so record.kernel_name might be null.
if (operation_id == HIP_OP_ID_DISPATCH && record->kernel_name != nullptr)
(*pool)->Write(*record, record->kernel_name, strlen(record->kernel_name) + 1,
[](auto& record, const void* data) {
record.kernel_name = static_cast<const char*>(data);
});
else
(*pool)->Write(*record);
}
return 0;
}
break;
}
case ACTIVITY_DOMAIN_ROCTX: {
if (RocTxLoader::Instance().Enabled() &&
!RocTxLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data))
fatal("ROCTX::RegisterApiCallback(%d) failed", op);
case ACTIVITY_DOMAIN_ROCTX:
if (auto user_callback = roctx_api_callback_table.Get(operation_id)) {
if (auto api_data = static_cast<DomainTraits<ACTIVITY_DOMAIN_ROCTX>::ApiData*>(data))
user_callback->first(ACTIVITY_DOMAIN_ROCTX, operation_id, api_data,
user_callback->second);
return 0;
}
break;
case ACTIVITY_DOMAIN_HSA_OPS:
if (auto pool = hsa_ops_activity_table.Get(operation_id)) {
if (auto record = static_cast<activity_record_t*>(data)) (*pool)->Write(*record);
return 0;
}
break;
case ACTIVITY_DOMAIN_HSA_EVT:
if (auto user_callback = hsa_evt_callback_table.Get(operation_id)) {
if (auto api_data = static_cast<DomainTraits<ACTIVITY_DOMAIN_HSA_EVT>::ApiData*>(data))
user_callback->first(ACTIVITY_DOMAIN_HSA_EVT, operation_id, api_data,
user_callback->second);
return 0;
}
break;
default:
break;
}
return -1;
}
template <typename... Tables> struct RegistrationTableGroup {
private:
bool AllEmpty() const {
return std::apply([](auto&&... tables) { return (tables.IsEmpty() && ...); }, tables_);
}
public:
template <typename Functor1, typename Functor2>
RegistrationTableGroup(Functor1&& engage_tracer, Functor2&& disengage_tracer, Tables&... tables)
: engage_tracer_(std::forward<Functor1>(engage_tracer)),
disengage_tracer_(std::forward<Functor2>(disengage_tracer)),
tables_(tables...) {}
template <typename T, typename... Args>
void Register(T& table, uint32_t operation_id, Args... args) const {
if (AllEmpty()) engage_tracer_();
table.Register(operation_id, std::forward<Args>(args)...);
}
template <typename T> void Unregister(T& table, uint32_t operation_id) const {
table.Unregister(operation_id);
if (AllEmpty()) disengage_tracer_();
}
private:
const std::function<void()> engage_tracer_, disengage_tracer_;
const std::tuple<const Tables&...> tables_;
};
RegistrationTableGroup HSA_registration_group(
[]() { hsa_support::RegisterTracerCallback(TracerCallback); },
[]() { hsa_support::RegisterTracerCallback(nullptr); }, HSA_ApiTracer::callback_table,
HSA_ApiTracer::activity_table, hsa_ops_activity_table, hsa_evt_callback_table);
RegistrationTableGroup HIP_registration_group(
[]() { HipLoader::Instance().RegisterTracerCallback(TracerCallback); },
[]() { HipLoader::Instance().RegisterTracerCallback(nullptr); }, HIP_ApiTracer::callback_table,
HIP_ApiTracer::activity_table, hip_ops_activity_table);
RegistrationTableGroup ROCTX_registration_group(
[]() { RocTxLoader::Instance().RegisterTracerCallback(TracerCallback); },
[]() { RocTxLoader::Instance().RegisterTracerCallback(nullptr); }, roctx_api_callback_table);
} // namespace
// Enable runtime API callbacks
static void roctracer_enable_callback_impl(roctracer_domain_t domain, uint32_t operation_id,
roctracer_rtapi_callback_t callback, void* user_data) {
std::lock_guard lock(registration_mutex);
switch (domain) {
case ACTIVITY_DOMAIN_HSA_EVT:
HSA_registration_group.Register(hsa_evt_callback_table, operation_id, callback, user_data);
break;
case ACTIVITY_DOMAIN_HSA_API:
HSA_registration_group.Register(HSA_ApiTracer::callback_table, operation_id, callback,
user_data);
break;
case ACTIVITY_DOMAIN_HSA_OPS:
break;
case ACTIVITY_DOMAIN_HIP_API:
if (HipLoader::Instance().Enabled())
HIP_registration_group.Register(HIP_ApiTracer::callback_table, operation_id, callback,
user_data);
break;
case ACTIVITY_DOMAIN_HIP_OPS:
break;
case ACTIVITY_DOMAIN_ROCTX:
if (RocTxLoader::Instance().Enabled())
ROCTX_registration_group.Register(roctx_api_callback_table, operation_id, callback,
user_data);
break;
}
default:
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
}
}
static void roctracer_enable_callback_impl(roctracer_domain_t domain, uint32_t op,
roctracer_rtapi_callback_t callback, void* user_data) {
cb_journal.Insert(domain, op, {callback, user_data});
roctracer_enable_callback_fun(domain, op, callback, user_data);
}
ROCTRACER_API roctracer_status_t roctracer_enable_op_callback(roctracer_domain_t domain,
uint32_t op,
roctracer_rtapi_callback_t callback,
@@ -369,43 +537,33 @@ ROCTRACER_API roctracer_status_t roctracer_enable_domain_callback(
}
// Disable runtime API callbacks
static void roctracer_disable_callback_fun(roctracer_domain_t domain, uint32_t op) {
static void roctracer_disable_callback_impl(roctracer_domain_t domain, uint32_t operation_id) {
std::lock_guard lock(registration_mutex);
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
case ACTIVITY_DOMAIN_HSA_API:
case ACTIVITY_DOMAIN_HSA_EVT:
hsa_support::DisableCallback(domain, op);
HSA_registration_group.Unregister(hsa_evt_callback_table, operation_id);
break;
case ACTIVITY_DOMAIN_HSA_API:
HSA_registration_group.Unregister(HSA_ApiTracer::callback_table, operation_id);
break;
case ACTIVITY_DOMAIN_HSA_OPS:
break;
case ACTIVITY_DOMAIN_HIP_API:
if (HipLoader::Instance().Enabled())
HIP_registration_group.Unregister(HIP_ApiTracer::callback_table, operation_id);
break;
case ACTIVITY_DOMAIN_HIP_OPS:
break;
case ACTIVITY_DOMAIN_HIP_API: {
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
if (hipError_t err = HipLoader::Instance().RemoveApiCallback(op); err != hipSuccess)
fatal("HIP::RemoveApiCallback(%d) failed (err=%d)", op, err);
if ((hip_act_cb_tracker.disable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
fatal("HIP::RemoveActivityCallback(%d) failed (err=%d)", op, err);
}
case ACTIVITY_DOMAIN_ROCTX:
if (RocTxLoader::Instance().Enabled())
ROCTX_registration_group.Unregister(roctx_api_callback_table, operation_id);
break;
}
case ACTIVITY_DOMAIN_ROCTX: {
if (RocTxLoader::Instance().Enabled() && !RocTxLoader::Instance().RemoveApiCallback(op))
fatal("ROCTX::RemoveApiCallback(%d) failed", op);
break;
}
default:
EXC_RAISING(ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, "invalid domain ID(" << domain << ")");
}
}
static void roctracer_disable_callback_impl(roctracer_domain_t domain, uint32_t op) {
cb_journal.Remove(domain, op);
roctracer_disable_callback_fun(domain, op);
}
ROCTRACER_API roctracer_status_t roctracer_disable_op_callback(roctracer_domain_t domain,
uint32_t op) {
API_METHOD_PREFIX
@@ -470,33 +628,32 @@ ROCTRACER_API roctracer_status_t roctracer_next_record(const activity_record_t*
}
// Enable activity records logging
static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op,
roctracer_pool_t* pool) {
assert(pool != nullptr);
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
case ACTIVITY_DOMAIN_HSA_API:
case ACTIVITY_DOMAIN_HSA_EVT:
hsa_support::EnableActivity(domain, op, pool);
break;
case ACTIVITY_DOMAIN_HIP_OPS: {
if (HipLoader::Instance().Enabled() &&
HipLoader::Instance().RegisterAsyncActivityCallback(op, (void*)HIP_AsyncActivityCallback,
pool) != hipSuccess)
fatal("HIP::EnableActivityCallback error");
break;
}
case ACTIVITY_DOMAIN_HIP_API: {
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
static void roctracer_enable_activity_impl(roctracer_domain_t domain, uint32_t op,
roctracer_pool_t* pool) {
std::lock_guard lock(registration_mutex);
hip_act_cb_tracker.enable_check(op, API_ACT_MASK);
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool);
err != hipSuccess)
fatal("HIP::RegisterActivityCallback(%d) (err=%d)", op, err);
MemoryPool* memory_pool = reinterpret_cast<MemoryPool*>(pool);
if (memory_pool == nullptr) memory_pool = default_memory_pool;
if (memory_pool == nullptr)
EXC_RAISING(ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED, "no default pool");
switch (domain) {
case ACTIVITY_DOMAIN_HSA_EVT:
break;
case ACTIVITY_DOMAIN_HSA_API:
HSA_registration_group.Register(HSA_ApiTracer::activity_table, op, memory_pool);
break;
case ACTIVITY_DOMAIN_HSA_OPS:
HSA_registration_group.Register(hsa_ops_activity_table, op, memory_pool);
break;
case ACTIVITY_DOMAIN_HIP_API:
if (HipLoader::Instance().Enabled())
HIP_registration_group.Register(HIP_ApiTracer::activity_table, op, memory_pool);
break;
case ACTIVITY_DOMAIN_HIP_OPS:
if (HipLoader::Instance().Enabled())
HIP_registration_group.Register(hip_ops_activity_table, op, memory_pool);
break;
}
case ACTIVITY_DOMAIN_ROCTX:
break;
default:
@@ -504,15 +661,6 @@ static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op
}
}
static void roctracer_enable_activity_impl(roctracer_domain_t domain, uint32_t op,
roctracer_pool_t* pool) {
if (pool == nullptr) pool = default_memory_pool;
if (pool == nullptr)
EXC_RAISING(ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED, "no default pool");
act_journal.Insert(domain, op, {pool});
roctracer_enable_activity_fun(domain, op, pool);
}
ROCTRACER_API roctracer_status_t roctracer_enable_op_activity_expl(roctracer_domain_t domain,
uint32_t op,
roctracer_pool_t* pool) {
@@ -552,34 +700,26 @@ ROCTRACER_API roctracer_status_t roctracer_enable_domain_activity(activity_domai
}
// Disable activity records logging
static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t op) {
switch (domain) {
case ACTIVITY_DOMAIN_HSA_OPS:
case ACTIVITY_DOMAIN_HSA_API:
case ACTIVITY_DOMAIN_HSA_EVT:
hsa_support::DisableActivity(domain, op);
break;
case ACTIVITY_DOMAIN_HIP_OPS: {
if (HipLoader::Instance().Enabled() &&
HipLoader::Instance().RemoveAsyncActivityCallback(op) != hipSuccess)
fatal("HIP::EnableActivityCallback(%d)", op);
break;
}
case ACTIVITY_DOMAIN_HIP_API: {
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
static void roctracer_disable_activity_impl(roctracer_domain_t domain, uint32_t op) {
std::lock_guard lock(registration_mutex);
if ((hip_act_cb_tracker.disable_check(op, API_ACT_MASK) & API_CB_MASK) == 0) {
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
fatal("HIP::RemoveActivityCallback(%d) failed (err=%d)", op, err);
} else {
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
err != hipSuccess)
fatal("HIP::RegisterActivityCallback(%d) failed (err=%d)", op, err);
}
switch (domain) {
case ACTIVITY_DOMAIN_HSA_EVT:
break;
case ACTIVITY_DOMAIN_HSA_API:
HSA_registration_group.Unregister(HSA_ApiTracer::activity_table, op);
break;
case ACTIVITY_DOMAIN_HSA_OPS:
HSA_registration_group.Unregister(hsa_ops_activity_table, op);
break;
case ACTIVITY_DOMAIN_HIP_API:
if (HipLoader::Instance().Enabled())
HIP_registration_group.Unregister(HIP_ApiTracer::activity_table, op);
break;
case ACTIVITY_DOMAIN_HIP_OPS:
if (HipLoader::Instance().Enabled())
HIP_registration_group.Unregister(hip_ops_activity_table, op);
break;
}
case ACTIVITY_DOMAIN_ROCTX:
break;
default:
@@ -587,11 +727,6 @@ static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t o
}
}
static void roctracer_disable_activity_impl(roctracer_domain_t domain, uint32_t op) {
act_journal.Remove(domain, op);
roctracer_disable_activity_fun(domain, op);
}
ROCTRACER_API roctracer_status_t roctracer_disable_op_activity(roctracer_domain_t domain,
uint32_t op) {
API_METHOD_PREFIX
@@ -622,6 +757,7 @@ static void roctracer_close_pool_impl(roctracer_pool_t* pool) {
MemoryPool* p = reinterpret_cast<MemoryPool*>(pool);
if (p == default_memory_pool) default_memory_pool = nullptr;
#if 0
// Disable any activities that specify the pool being deleted.
std::vector<std::pair<roctracer_domain_t, uint32_t>> ops;
act_journal.ForEach(
@@ -630,6 +766,7 @@ static void roctracer_close_pool_impl(roctracer_pool_t* pool) {
return true;
});
for (auto&& [domain, op] : ops) roctracer_disable_activity_impl(domain, op);
#endif
delete (p);
}
@@ -693,35 +830,14 @@ roctracer_activity_pop_external_correlation_id(activity_correlation_id_t* last_i
// Start API
ROCTRACER_API void roctracer_start() {
if (set_stopped(0)) {
if (ext_support::roctracer_start_cb) ext_support::roctracer_start_cb();
cb_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const CallbackJournalData& data) {
roctracer_enable_callback_fun(domain, op, data.callback, data.user_data);
return true;
});
act_journal.ForEach(
[](roctracer_domain_t domain, uint32_t op, const ActivityJournalData& data) {
roctracer_enable_activity_fun(domain, op, data.pool);
return true;
});
}
if (stopped_status.exchange(false, std::memory_order_relaxed) && roctracer_start_cb)
roctracer_start_cb();
}
// Stop API
ROCTRACER_API void roctracer_stop() {
if (set_stopped(1)) {
// Must disable the activity first as the spawner checks for the activity being NULL
// to indicate that there is no callback.
act_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const ActivityJournalData&) {
roctracer_disable_activity_fun(domain, op);
return true;
});
cb_journal.ForEach([](roctracer_domain_t domain, uint32_t op, const CallbackJournalData&) {
roctracer_disable_callback_fun(domain, op);
return true;
});
if (ext_support::roctracer_stop_cb) ext_support::roctracer_stop_cb();
}
if (!stopped_status.exchange(true, std::memory_order_relaxed) && roctracer_stop_cb)
roctracer_stop_cb();
}
ROCTRACER_API roctracer_status_t roctracer_get_timestamp(roctracer_timestamp_t* timestamp) {
@@ -745,8 +861,8 @@ ROCTRACER_API roctracer_status_t roctracer_set_properties(roctracer_domain_t dom
case ACTIVITY_DOMAIN_EXT_API: {
roctracer_ext_properties_t* ops_properties =
reinterpret_cast<roctracer_ext_properties_t*>(properties);
ext_support::roctracer_start_cb = ops_properties->start_cb;
ext_support::roctracer_stop_cb = ops_properties->stop_cb;
roctracer_start_cb = ops_properties->start_cb;
roctracer_stop_cb = ops_properties->stop_cb;
break;
}
default:
+2 -3
Просмотреть файл
@@ -1,11 +1,10 @@
ROCTX_4.1 {
global: RegisterApiCallback;
RemoveApiCallback;
roctxMarkA;
global: roctxMarkA;
roctxRangePop;
roctxRangePushA;
roctxRangeStartA;
roctxRangeStop;
roctxRegisterTracerCallback;
roctx_version_major;
roctx_version_minor;
local: *;
+43 -37
Просмотреть файл
@@ -22,66 +22,72 @@
#include "roctracer_roctx.h"
#include "ext/prof_protocol.h"
#include "util/callback_table.h"
#include <atomic>
#include <cassert>
namespace {
roctracer::util::CallbackTable<ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_NUMBER> callbacks;
thread_local int nested_range_level(0);
std::atomic<int (*)(activity_domain_t domain, uint32_t operation_id, void* data)> report_activity;
thread_local int nested_range_level{0};
void ReportActivity(roctx_api_id_t operation_id, const char* message = nullptr,
roctx_range_id_t id = {}) {
auto function = report_activity.load(std::memory_order_relaxed);
if (!function) return;
roctx_api_data_t api_data{};
switch (operation_id) {
case ROCTX_API_ID_roctxMarkA:
api_data.args.roctxMarkA.message = message;
break;
case ROCTX_API_ID_roctxRangePushA:
api_data.args.roctxRangePushA.message = message;
break;
case ROCTX_API_ID_roctxRangePop:
break;
case ROCTX_API_ID_roctxRangeStartA:
api_data.args.roctxRangeStartA.message = message;
api_data.args.roctxRangeStartA.id = id;
break;
case ROCTX_API_ID_roctxRangeStop:
api_data.args.roctxRangeStop.id = id;
break;
default:
assert(!"should not reach here");
}
function(ACTIVITY_DOMAIN_ROCTX, operation_id, &api_data);
}
} // namespace
ROCTX_API uint32_t roctx_version_major() { return ROCTX_VERSION_MAJOR; }
ROCTX_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; }
ROCTX_API void roctxMarkA(const char* message) {
roctx_api_data_t api_data{};
api_data.args.roctxMarkA.message = message;
callbacks.Invoke(ROCTX_API_ID_roctxMarkA, &api_data);
}
ROCTX_API void roctxMarkA(const char* message) { ReportActivity(ROCTX_API_ID_roctxMarkA, message); }
ROCTX_API int roctxRangePushA(const char* message) {
roctx_api_data_t api_data{};
api_data.args.roctxRangePushA.message = message;
callbacks.Invoke(ROCTX_API_ID_roctxRangePushA, &api_data);
ReportActivity(ROCTX_API_ID_roctxRangePushA, message);
return nested_range_level++;
}
ROCTX_API int roctxRangePop() {
roctx_api_data_t api_data{};
callbacks.Invoke(ROCTX_API_ID_roctxRangePop, &api_data);
ReportActivity(ROCTX_API_ID_roctxRangePop);
if (nested_range_level == 0) return -1;
return --nested_range_level;
}
ROCTX_API roctx_range_id_t roctxRangeStartA(const char* message) {
static std::atomic<roctx_range_id_t> start_stop_range_id(1);
auto id = start_stop_range_id++;
roctx_api_data_t api_data{};
api_data.args.roctxRangeStartA.message = message;
api_data.args.roctxRangeStartA.id = id;
callbacks.Invoke(ROCTX_API_ID_roctxRangeStartA, &api_data);
return id;
auto range_id = start_stop_range_id++;
ReportActivity(ROCTX_API_ID_roctxRangeStartA, message, range_id);
return range_id;
}
ROCTX_API void roctxRangeStop(roctx_range_id_t rangeId) {
roctx_api_data_t api_data{};
api_data.args.roctxRangeStop.id = rangeId;
callbacks.Invoke(ROCTX_API_ID_roctxRangeStop, &api_data);
ROCTX_API void roctxRangeStop(roctx_range_id_t range_id) {
ReportActivity(ROCTX_API_ID_roctxRangeStop, nullptr, range_id);
}
extern "C" ROCTX_EXPORT bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
if (op >= ROCTX_API_ID_NUMBER) return false;
callbacks.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
return true;
extern "C" ROCTX_EXPORT void roctxRegisterTracerCallback(const void* function) {
report_activity.store(reinterpret_cast<decltype(report_activity.load())>(function),
std::memory_order_relaxed);
}
extern "C" ROCTX_EXPORT bool RemoveApiCallback(uint32_t op) {
if (op >= ROCTX_API_ID_NUMBER) return false;
callbacks.Set(op, nullptr, nullptr);
return true;
}
-70
Просмотреть файл
@@ -1,70 +0,0 @@
/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. */
#ifndef UTIL_CALLBACK_TABLE_H_
#define UTIL_CALLBACK_TABLE_H_
#include "ext/prof_protocol.h"
#include <array>
#include <atomic>
#include <cassert>
#include <mutex>
#include <utility>
namespace roctracer::util {
// Generic callbacks table
template <activity_domain_t Domain, uint32_t N> class CallbackTable {
public:
CallbackTable()
// Zero initialize the callbacks array as the function pointer is used to determine if the
// callback is enabled.
: callbacks_() {}
void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) {
assert(callback_id < N && "callback_id is out of range");
std::lock_guard lock(mutex_);
auto& callback = callbacks_[callback_id];
callback.first.store(callback_function, std::memory_order_relaxed);
callback.second = user_arg;
}
auto Get(uint32_t callback_id) const {
assert(callback_id < N && "id is out of range");
std::lock_guard lock(mutex_);
auto& callback = callbacks_[callback_id];
return std::make_pair(callback.first.load(std::memory_order_relaxed), callback.second);
}
template <typename... Args> void Invoke(uint32_t callback_id, Args... args) {
if (callbacks_[callback_id].first.load(std::memory_order_relaxed) == nullptr) return;
if (auto [callback_function, user_arg] = Get(callback_id); callback_function != nullptr)
callback_function(Domain, callback_id, std::forward<Args>(args)..., user_arg);
}
private:
std::array<std::pair<std::atomic<activity_rtapi_callback_t>, void*>, N> callbacks_;
mutable std::mutex mutex_;
};
} // namespace roctracer::util
#endif // UTIL_CALLBACK_TABLE_H_