diff --git a/projects/roctracer/script/hsaap.py b/projects/roctracer/script/hsaap.py index 54d75712eb..65a7bb9e47 100755 --- a/projects/roctracer/script/hsaap.py +++ b/projects/roctracer/script/hsaap.py @@ -395,7 +395,7 @@ class API_DescrParser: # generate API callbacks def gen_callbacks(self, n, name, call, struct): if n == -1: - self.content += 'typedef CallbackTable cb_table_t;\n' + self.content += 'typedef CallbackTable cb_table_t;\n' self.content += 'extern cb_table_t cb_table;\n' self.content += '\n' if call != '-': diff --git a/projects/roctracer/src/core/callback_table.h b/projects/roctracer/src/core/callback_table.h index 1273140102..a00db08038 100644 --- a/projects/roctracer/src/core/callback_table.h +++ b/projects/roctracer/src/core/callback_table.h @@ -23,6 +23,8 @@ #include +#include +#include #include #include #include @@ -30,7 +32,7 @@ namespace roctracer { // Generic callbacks table -template class CallbackTable { +template class CallbackTable { public: CallbackTable() // Zero initialize the callbacks array as the function pointer is used to determine if the @@ -40,17 +42,26 @@ template class CallbackTable { void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) { assert(callback_id < N && "callback_id is out of range"); std::lock_guard lock(mutex_); - callbacks_[callback_id] = {callback_function, user_arg}; + auto& callback = callbacks_[callback_id]; + callback.first.store(callback_function, std::memory_order_relaxed); + callback.second = user_arg; } - std::pair Get(uint32_t callback_id) const { + auto Get(uint32_t callback_id) const { assert(callback_id < N && "id is out of range"); std::lock_guard lock(mutex_); - return callbacks_[callback_id]; + auto& callback = callbacks_[callback_id]; + return std::make_pair(callback.first.load(std::memory_order_relaxed), callback.second); + } + + template void Invoke(uint32_t callback_id, Args... args) { + if (callbacks_[callback_id].first.load(std::memory_order_relaxed) == nullptr) return; + if (auto [callback_function, user_arg] = Get(callback_id); callback_function != nullptr) + callback_function(Domain, callback_id, std::forward(args)..., user_arg); } private: - std::array, N> callbacks_; + std::array, void*>, N> callbacks_; mutable std::mutex mutex_; }; diff --git a/projects/roctracer/src/roctx/roctx.cpp b/projects/roctracer/src/roctx/roctx.cpp index 8dda086764..92eb84d5d4 100644 --- a/projects/roctracer/src/roctx/roctx.cpp +++ b/projects/roctracer/src/roctx/roctx.cpp @@ -57,7 +57,7 @@ typedef enum { // namespace { -roctracer::CallbackTable callbacks; +roctracer::CallbackTable callbacks; thread_local int range_level(0); } // namespace @@ -75,24 +75,17 @@ PUBLIC_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; } PUBLIC_API void roctxMarkA(const char* message) { API_METHOD_PREFIX - if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxMarkA); - api_callback_fun != nullptr) { - roctx_api_data_t api_data{}; - api_data.args.roctxMarkA.message = message; - api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg); - } + roctx_api_data_t api_data{}; + api_data.args.roctxMarkA.message = message; + callbacks.Invoke(ROCTX_API_ID_roctxMarkA, &api_data); API_METHOD_SUFFIX_NRET } PUBLIC_API int roctxRangePushA(const char* message) { API_METHOD_PREFIX - if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePushA); - api_callback_fun != nullptr) { - roctx_api_data_t api_data{}; - api_data.args.roctxRangePushA.message = message; - api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data, - api_callback_arg); - } + roctx_api_data_t api_data{}; + api_data.args.roctxRangePushA.message = message; + callbacks.Invoke(ROCTX_API_ID_roctxRangePushA, &api_data); return range_level++; API_METHOD_CATCH(-1); @@ -100,12 +93,9 @@ PUBLIC_API int roctxRangePushA(const char* message) { PUBLIC_API int roctxRangePop() { API_METHOD_PREFIX - if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePop); - api_callback_fun != nullptr) { - roctx_api_data_t api_data{}; - api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data, - api_callback_arg); - } + + roctx_api_data_t api_data{}; + callbacks.Invoke(ROCTX_API_ID_roctxRangePop, &api_data); if (range_level == 0) EXC_RAISING(ROCTX_STATUS_ERROR, "Pop from empty stack!"); return --range_level; @@ -116,13 +106,9 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { API_METHOD_PREFIX static std::atomic roctx_range_counter(1); - if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStartA); - api_callback_fun != nullptr) { - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStartA.message = message; - api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data, - api_callback_arg); - } + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStartA.message = message; + callbacks.Invoke(ROCTX_API_ID_roctxRangeStartA, &api_data); return roctx_range_counter++; API_METHOD_CATCH(-1) @@ -130,13 +116,9 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) { API_METHOD_PREFIX - if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStop); - api_callback_fun != nullptr) { - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStop.id = rangeId; - api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data, - api_callback_arg); - } + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStop.id = rangeId; + callbacks.Invoke(ROCTX_API_ID_roctxRangeStop, &api_data); API_METHOD_SUFFIX_NRET }