Cleanup CallbackTable::Get
Make CallbackTable::Get return the callback_function/user_arg pair as an actual return value instead of returning it through arguments pointers. Change-Id: Ia2dfcdad8c237a09620518ad67af94add47220da
Этот коммит содержится в:
@@ -411,9 +411,7 @@ class API_DescrParser:
|
||||
self.content += ' api_data.args.' + call + '.' + var + ' = ' + var + ';\n'
|
||||
if call == 'hsa_amd_memory_async_copy_rect' and var == 'range':
|
||||
self.content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
|
||||
self.content += ' activity_rtapi_callback_t api_callback_fun = NULL;\n'
|
||||
self.content += ' void* api_callback_arg = NULL;\n'
|
||||
self.content += ' cb_table.Get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n'
|
||||
self.content += ' auto [ api_callback_fun, api_callback_arg ] = cb_table.Get(' + call_id + ');\n'
|
||||
self.content += ' api_data.phase = 0;\n'
|
||||
self.content += ' if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
|
||||
if ret_type != 'void':
|
||||
|
||||
@@ -37,21 +37,20 @@ template <uint32_t N> class CallbackTable {
|
||||
// callback is enabled.
|
||||
: callbacks_() {}
|
||||
|
||||
void Set(uint32_t id, activity_rtapi_callback_t callback, void* arg) {
|
||||
assert(id < N && "id is out of range");
|
||||
void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) {
|
||||
assert(callback_id < N && "callback_id is out of range");
|
||||
std::lock_guard lock(mutex_);
|
||||
callbacks_[id] = {callback, arg};
|
||||
callbacks_[callback_id] = {callback_function, user_arg};
|
||||
}
|
||||
|
||||
void Get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) const {
|
||||
assert(id < N && "id is out of range");
|
||||
assert(callback != nullptr && arg != nullptr && "invalid arguments");
|
||||
std::pair<activity_rtapi_callback_t, void*> Get(uint32_t callback_id) const {
|
||||
assert(callback_id < N && "id is out of range");
|
||||
std::lock_guard lock(mutex_);
|
||||
std::tie(*callback, *arg) = callbacks_[id];
|
||||
return callbacks_[callback_id];
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<std::pair<activity_rtapi_callback_t /* callback */, void* /* arg */>, N> callbacks_;
|
||||
std::array<std::pair<activity_rtapi_callback_t, void*>, N> callbacks_;
|
||||
mutable std::mutex mutex_;
|
||||
};
|
||||
|
||||
|
||||
@@ -77,11 +77,8 @@ typedef enum {
|
||||
//
|
||||
namespace roctx {
|
||||
|
||||
// ROCTX callbacks table type
|
||||
typedef roctracer::CallbackTable<ROCTX_API_ID_NUMBER> cb_table_t;
|
||||
|
||||
// callbacks table
|
||||
cb_table_t cb_table;
|
||||
// ROCTX callbacks table
|
||||
roctracer::CallbackTable<ROCTX_API_ID_NUMBER> callbacks;
|
||||
|
||||
|
||||
typedef std::stack<std::string> message_stack_t;
|
||||
@@ -105,8 +102,6 @@ void thread_data_init() {
|
||||
thread_map[tid] = message_stack;
|
||||
}
|
||||
|
||||
// callbacks table
|
||||
extern cb_table_t cb_table;
|
||||
} // namespace roctx
|
||||
|
||||
// Logger instantiation
|
||||
@@ -130,9 +125,7 @@ PUBLIC_API void roctxMarkA(const char* message) {
|
||||
API_METHOD_PREFIX
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxMarkA.message = strdup(message);
|
||||
activity_rtapi_callback_t api_callback_fun = NULL;
|
||||
void* api_callback_arg = NULL;
|
||||
roctx::cb_table.Get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg);
|
||||
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxMarkA);
|
||||
if (api_callback_fun)
|
||||
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg);
|
||||
API_METHOD_SUFFIX_NRET
|
||||
@@ -144,9 +137,7 @@ PUBLIC_API int roctxRangePushA(const char* message) {
|
||||
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxRangePushA.message = strdup(message);
|
||||
activity_rtapi_callback_t api_callback_fun = NULL;
|
||||
void* api_callback_arg = NULL;
|
||||
roctx::cb_table.Get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg);
|
||||
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePushA);
|
||||
if (api_callback_fun)
|
||||
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data,
|
||||
api_callback_arg);
|
||||
@@ -161,9 +152,7 @@ PUBLIC_API int roctxRangePop() {
|
||||
if (roctx::message_stack == NULL) roctx::thread_data_init();
|
||||
|
||||
roctx_api_data_t api_data{};
|
||||
activity_rtapi_callback_t api_callback_fun = NULL;
|
||||
void* api_callback_arg = NULL;
|
||||
roctx::cb_table.Get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg);
|
||||
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePop);
|
||||
if (api_callback_fun)
|
||||
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data,
|
||||
api_callback_arg);
|
||||
@@ -182,10 +171,7 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) {
|
||||
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxRangeStartA.message = strdup(message);
|
||||
api_data.args.roctxRangeStartA.id = roctx_range_counter;
|
||||
activity_rtapi_callback_t api_callback_fun = NULL;
|
||||
void* api_callback_arg = NULL;
|
||||
roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg);
|
||||
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStartA);
|
||||
if (api_callback_fun)
|
||||
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data,
|
||||
api_callback_arg);
|
||||
@@ -198,9 +184,7 @@ PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) {
|
||||
API_METHOD_PREFIX
|
||||
roctx_api_data_t api_data{};
|
||||
api_data.args.roctxRangeStop.id = rangeId;
|
||||
activity_rtapi_callback_t api_callback_fun = NULL;
|
||||
void* api_callback_arg = NULL;
|
||||
roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg);
|
||||
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStop);
|
||||
if (api_callback_fun)
|
||||
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data,
|
||||
api_callback_arg);
|
||||
@@ -222,13 +206,13 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg)
|
||||
|
||||
PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
|
||||
if (op >= ROCTX_API_ID_NUMBER) return false;
|
||||
roctx::cb_table.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
|
||||
roctx::callbacks.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
|
||||
return true;
|
||||
}
|
||||
|
||||
PUBLIC_API bool RemoveApiCallback(uint32_t op) {
|
||||
if (op >= ROCTX_API_ID_NUMBER) return false;
|
||||
roctx::cb_table.Set(op, nullptr, nullptr);
|
||||
roctx::callbacks.Set(op, nullptr, nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user