diff --git a/projects/roctracer/script/hsaap.py b/projects/roctracer/script/hsaap.py index dd2ce504bf..54d75712eb 100755 --- a/projects/roctracer/script/hsaap.py +++ b/projects/roctracer/script/hsaap.py @@ -411,9 +411,7 @@ class API_DescrParser: self.content += ' api_data.args.' + call + '.' + var + ' = ' + var + ';\n' if call == 'hsa_amd_memory_async_copy_rect' and var == 'range': self.content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n' - self.content += ' activity_rtapi_callback_t api_callback_fun = NULL;\n' - self.content += ' void* api_callback_arg = NULL;\n' - self.content += ' cb_table.Get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n' + self.content += ' auto [ api_callback_fun, api_callback_arg ] = cb_table.Get(' + call_id + ');\n' self.content += ' api_data.phase = 0;\n' self.content += ' if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n' if ret_type != 'void': diff --git a/projects/roctracer/src/core/callback_table.h b/projects/roctracer/src/core/callback_table.h index 608154ce0d..1273140102 100644 --- a/projects/roctracer/src/core/callback_table.h +++ b/projects/roctracer/src/core/callback_table.h @@ -37,21 +37,20 @@ template class CallbackTable { // callback is enabled. : callbacks_() {} - void Set(uint32_t id, activity_rtapi_callback_t callback, void* arg) { - assert(id < N && "id is out of range"); + void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) { + assert(callback_id < N && "callback_id is out of range"); std::lock_guard lock(mutex_); - callbacks_[id] = {callback, arg}; + callbacks_[callback_id] = {callback_function, user_arg}; } - void Get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) const { - assert(id < N && "id is out of range"); - assert(callback != nullptr && arg != nullptr && "invalid arguments"); + std::pair Get(uint32_t callback_id) const { + assert(callback_id < N && "id is out of range"); std::lock_guard lock(mutex_); - std::tie(*callback, *arg) = callbacks_[id]; + return callbacks_[callback_id]; } private: - std::array, N> callbacks_; + std::array, N> callbacks_; mutable std::mutex mutex_; }; diff --git a/projects/roctracer/src/roctx/roctx.cpp b/projects/roctracer/src/roctx/roctx.cpp index 780975135d..1695f3e10e 100644 --- a/projects/roctracer/src/roctx/roctx.cpp +++ b/projects/roctracer/src/roctx/roctx.cpp @@ -77,11 +77,8 @@ typedef enum { // namespace roctx { -// ROCTX callbacks table type -typedef roctracer::CallbackTable cb_table_t; - -// callbacks table -cb_table_t cb_table; +// ROCTX callbacks table +roctracer::CallbackTable callbacks; typedef std::stack message_stack_t; @@ -105,8 +102,6 @@ void thread_data_init() { thread_map[tid] = message_stack; } -// callbacks table -extern cb_table_t cb_table; } // namespace roctx // Logger instantiation @@ -130,9 +125,7 @@ PUBLIC_API void roctxMarkA(const char* message) { API_METHOD_PREFIX roctx_api_data_t api_data{}; api_data.args.roctxMarkA.message = strdup(message); - activity_rtapi_callback_t api_callback_fun = NULL; - void* api_callback_arg = NULL; - roctx::cb_table.Get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg); + auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxMarkA); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg); API_METHOD_SUFFIX_NRET @@ -144,9 +137,7 @@ PUBLIC_API int roctxRangePushA(const char* message) { roctx_api_data_t api_data{}; api_data.args.roctxRangePushA.message = strdup(message); - activity_rtapi_callback_t api_callback_fun = NULL; - void* api_callback_arg = NULL; - roctx::cb_table.Get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg); + auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePushA); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data, api_callback_arg); @@ -161,9 +152,7 @@ PUBLIC_API int roctxRangePop() { if (roctx::message_stack == NULL) roctx::thread_data_init(); roctx_api_data_t api_data{}; - activity_rtapi_callback_t api_callback_fun = NULL; - void* api_callback_arg = NULL; - roctx::cb_table.Get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg); + auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePop); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data, api_callback_arg); @@ -182,10 +171,7 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { roctx_api_data_t api_data{}; api_data.args.roctxRangeStartA.message = strdup(message); - api_data.args.roctxRangeStartA.id = roctx_range_counter; - activity_rtapi_callback_t api_callback_fun = NULL; - void* api_callback_arg = NULL; - roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg); + auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStartA); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data, api_callback_arg); @@ -198,9 +184,7 @@ PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) { API_METHOD_PREFIX roctx_api_data_t api_data{}; api_data.args.roctxRangeStop.id = rangeId; - activity_rtapi_callback_t api_callback_fun = NULL; - void* api_callback_arg = NULL; - roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg); + auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStop); if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data, api_callback_arg); @@ -222,13 +206,13 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) { if (op >= ROCTX_API_ID_NUMBER) return false; - roctx::cb_table.Set(op, reinterpret_cast(callback), arg); + roctx::callbacks.Set(op, reinterpret_cast(callback), arg); return true; } PUBLIC_API bool RemoveApiCallback(uint32_t op) { if (op >= ROCTX_API_ID_NUMBER) return false; - roctx::cb_table.Set(op, nullptr, nullptr); + roctx::callbacks.Set(op, nullptr, nullptr); return true; }