Merge "Cleanup CallbackTable::Get" into amd-staging

[ROCm/roctracer commit: 6e4055503c]
This commit is contained in:
Laurent Morichetti
2022-05-10 14:55:20 -04:00
zatwierdzone przez Gerrit Code Review
3 zmienionych plików z 17 dodań i 36 usunięć
+1 -3
Wyświetl plik
@@ -411,9 +411,7 @@ class API_DescrParser:
self.content += ' api_data.args.' + call + '.' + var + ' = ' + var + ';\n'
if call == 'hsa_amd_memory_async_copy_rect' and var == 'range':
self.content += ' api_data.args.' + call + '.' + var + '__val = ' + '*(' + var + ');\n'
self.content += ' activity_rtapi_callback_t api_callback_fun = NULL;\n'
self.content += ' void* api_callback_arg = NULL;\n'
self.content += ' cb_table.Get(' + call_id + ', &api_callback_fun, &api_callback_arg);\n'
self.content += ' auto [ api_callback_fun, api_callback_arg ] = cb_table.Get(' + call_id + ');\n'
self.content += ' api_data.phase = 0;\n'
self.content += ' if (api_callback_fun) api_callback_fun(ACTIVITY_DOMAIN_HSA_API, ' + call_id + ', &api_data, api_callback_arg);\n'
if ret_type != 'void':
@@ -37,21 +37,20 @@ template <uint32_t N> class CallbackTable {
// callback is enabled.
: callbacks_() {}
void Set(uint32_t id, activity_rtapi_callback_t callback, void* arg) {
assert(id < N && "id is out of range");
void Set(uint32_t callback_id, activity_rtapi_callback_t callback_function, void* user_arg) {
assert(callback_id < N && "callback_id is out of range");
std::lock_guard lock(mutex_);
callbacks_[id] = {callback, arg};
callbacks_[callback_id] = {callback_function, user_arg};
}
void Get(uint32_t id, activity_rtapi_callback_t* callback, void** arg) const {
assert(id < N && "id is out of range");
assert(callback != nullptr && arg != nullptr && "invalid arguments");
std::pair<activity_rtapi_callback_t, void*> Get(uint32_t callback_id) const {
assert(callback_id < N && "id is out of range");
std::lock_guard lock(mutex_);
std::tie(*callback, *arg) = callbacks_[id];
return callbacks_[callback_id];
}
private:
std::array<std::pair<activity_rtapi_callback_t /* callback */, void* /* arg */>, N> callbacks_;
std::array<std::pair<activity_rtapi_callback_t, void*>, N> callbacks_;
mutable std::mutex mutex_;
};
+9 -25
Wyświetl plik
@@ -77,11 +77,8 @@ typedef enum {
//
namespace roctx {
// ROCTX callbacks table type
typedef roctracer::CallbackTable<ROCTX_API_ID_NUMBER> cb_table_t;
// callbacks table
cb_table_t cb_table;
// ROCTX callbacks table
roctracer::CallbackTable<ROCTX_API_ID_NUMBER> callbacks;
typedef std::stack<std::string> message_stack_t;
@@ -105,8 +102,6 @@ void thread_data_init() {
thread_map[tid] = message_stack;
}
// callbacks table
extern cb_table_t cb_table;
} // namespace roctx
// Logger instantiation
@@ -130,9 +125,7 @@ PUBLIC_API void roctxMarkA(const char* message) {
API_METHOD_PREFIX
roctx_api_data_t api_data{};
api_data.args.roctxMarkA.message = strdup(message);
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.Get(ROCTX_API_ID_roctxMarkA, &api_callback_fun, &api_callback_arg);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxMarkA);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg);
API_METHOD_SUFFIX_NRET
@@ -144,9 +137,7 @@ PUBLIC_API int roctxRangePushA(const char* message) {
roctx_api_data_t api_data{};
api_data.args.roctxRangePushA.message = strdup(message);
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.Get(ROCTX_API_ID_roctxRangePushA, &api_callback_fun, &api_callback_arg);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePushA);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data,
api_callback_arg);
@@ -161,9 +152,7 @@ PUBLIC_API int roctxRangePop() {
if (roctx::message_stack == NULL) roctx::thread_data_init();
roctx_api_data_t api_data{};
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.Get(ROCTX_API_ID_roctxRangePop, &api_callback_fun, &api_callback_arg);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePop);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data,
api_callback_arg);
@@ -182,10 +171,7 @@ PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) {
roctx_api_data_t api_data{};
api_data.args.roctxRangeStartA.message = strdup(message);
api_data.args.roctxRangeStartA.id = roctx_range_counter;
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStartA, &api_callback_fun, &api_callback_arg);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStartA);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data,
api_callback_arg);
@@ -198,9 +184,7 @@ PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) {
API_METHOD_PREFIX
roctx_api_data_t api_data{};
api_data.args.roctxRangeStop.id = rangeId;
activity_rtapi_callback_t api_callback_fun = NULL;
void* api_callback_arg = NULL;
roctx::cb_table.Get(ROCTX_API_ID_roctxRangeStop, &api_callback_fun, &api_callback_arg);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStop);
if (api_callback_fun)
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data,
api_callback_arg);
@@ -222,13 +206,13 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg)
PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
if (op >= ROCTX_API_ID_NUMBER) return false;
roctx::cb_table.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
roctx::callbacks.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
return true;
}
PUBLIC_API bool RemoveApiCallback(uint32_t op) {
if (op >= ROCTX_API_ID_NUMBER) return false;
roctx::cb_table.Set(op, nullptr, nullptr);
roctx::callbacks.Set(op, nullptr, nullptr);
return true;
}