diff --git a/projects/roctracer/inc/roctx.h b/projects/roctracer/inc/roctx.h index 1f7f6f74b4..d0631e4998 100644 --- a/projects/roctracer/inc/roctx.h +++ b/projects/roctracer/inc/roctx.h @@ -45,10 +45,6 @@ extern "C" { uint32_t roctx_version_major(); uint32_t roctx_version_minor(); -//////////////////////////////////////////////////////////////////////////////// -// Returning the last error -const char* roctracer_error_string(); - //////////////////////////////////////////////////////////////////////////////// // Markers annotating API diff --git a/projects/roctracer/src/roctx/roctx.cpp b/projects/roctracer/src/roctx/roctx.cpp index 1695f3e10e..7c049c3f7e 100644 --- a/projects/roctracer/src/roctx/roctx.cpp +++ b/projects/roctracer/src/roctx/roctx.cpp @@ -21,10 +21,12 @@ #include "inc/roctx.h" #include "inc/roctracer_roctx.h" -#include -#include +#include +#include +#include #include #include +#include #include "inc/ext/prof_protocol.h" #include "core/callback_table.h" @@ -32,38 +34,21 @@ #include "util/logger.h" #define PUBLIC_API __attribute__((visibility("default"))) -#define CONSTRUCTOR_API __attribute__((constructor)) -#define DESTRUCTOR_API __attribute__((destructor)) - -#define API_METHOD_PREFIX \ - roctx_status_t err = ROCTX_STATUS_SUCCESS; \ - try { -#define API_METHOD_SUFFIX \ - } \ - catch (std::exception & e) { \ - ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \ - err = roctx::GetExcStatus(e); \ - } \ - return (err == ROCTX_STATUS_SUCCESS) ? 0 : -1; +#define API_METHOD_PREFIX try { #define API_METHOD_SUFFIX_NRET \ } \ - catch (std::exception & e) { \ + catch (const std::exception& e) { \ ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \ - err = roctx::GetExcStatus(e); \ - } \ - (void)err; + } #define API_METHOD_CATCH(X) \ } \ - catch (std::exception & e) { \ + catch (const std::exception& e) { \ ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \ + return X; \ } \ - (void)err; \ - return X; - -inline uint32_t GetPid() { return syscall(__NR_getpid); } -inline uint32_t GetTid() { return syscall(__NR_gettid); } + assert(false && "should not reach here"); //////////////////////////////////////////////////////////////////////////////// // Library errors enumeration @@ -75,129 +60,110 @@ typedef enum { /////////////////////////////////////////////////////////////////////////////////////////////////// // Library implementation // -namespace roctx { +namespace { -// ROCTX callbacks table roctracer::CallbackTable callbacks; +std::unordered_map> message_stack_map; +std::mutex message_stack_mutex; +thread_local auto& message_stack = []() -> decltype(message_stack_map)::mapped_type& { + const auto tid = syscall(__NR_gettid); + std::lock_guard lock(message_stack_mutex); + return message_stack_map[tid]; +}(); -typedef std::stack message_stack_t; -typedef std::map thread_map_t; -typedef std::mutex map_mutex_t; -map_mutex_t map_mutex; -thread_map_t thread_map; -static thread_local message_stack_t* message_stack = NULL; - -roctx_status_t GetExcStatus(const std::exception& e) { - const roctracer::util::exception* roctx_exc_ptr = - dynamic_cast*>(&e); - return (roctx_exc_ptr) ? roctx_exc_ptr->status() : ROCTX_STATUS_ERROR; -} - -void thread_data_init() { - message_stack = new message_stack_t; - const auto tid = GetTid(); - - std::lock_guard lck(map_mutex); - thread_map[tid] = message_stack; -} - -} // namespace roctx +} // namespace // Logger instantiation roctracer::util::Logger::mutex_t roctracer::util::Logger::mutex_; std::atomic roctracer::util::Logger::instance_{}; -std::atomic roctx_range_counter(0); /////////////////////////////////////////////////////////////////////////////////////////////////// // Public library methods // -extern "C" { PUBLIC_API uint32_t roctx_version_major() { return ROCTX_VERSION_MAJOR; } PUBLIC_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; } -PUBLIC_API const char* roctracer_error_string() { - return strdup(roctracer::util::Logger::LastMessage().c_str()); -} - PUBLIC_API void roctxMarkA(const char* message) { API_METHOD_PREFIX - roctx_api_data_t api_data{}; - api_data.args.roctxMarkA.message = strdup(message); - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxMarkA); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxMarkA); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxMarkA.message = message; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg); + } API_METHOD_SUFFIX_NRET } PUBLIC_API int roctxRangePushA(const char* message) { API_METHOD_PREFIX - if (roctx::message_stack == NULL) roctx::thread_data_init(); - - roctx_api_data_t api_data{}; - api_data.args.roctxRangePushA.message = strdup(message); - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePushA); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePushA); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxRangePushA.message = message; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data, api_callback_arg); - roctx::message_stack->push(strdup(message)); + } - return roctx::message_stack->size() - 1; + message_stack.emplace(message); + return message_stack.size() - 1; API_METHOD_CATCH(-1); } PUBLIC_API int roctxRangePop() { API_METHOD_PREFIX - if (roctx::message_stack == NULL) roctx::thread_data_init(); - - roctx_api_data_t api_data{}; - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePop); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePop); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data, api_callback_arg); - if (roctx::message_stack->empty()) { + } + + if (message_stack.empty()) { EXC_RAISING(ROCTX_STATUS_ERROR, "Pop from empty stack!"); } - roctx::message_stack->pop(); - return roctx::message_stack->size(); + message_stack.pop(); + return message_stack.size(); API_METHOD_CATCH(-1) } PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { API_METHOD_PREFIX - roctx_range_counter++; + static std::atomic roctx_range_counter(1); - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStartA.message = strdup(message); - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStartA); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStartA); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStartA.message = message; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data, api_callback_arg); + } - return roctx_range_counter; - API_METHOD_CATCH(-1); + return roctx_range_counter++; + API_METHOD_CATCH(-1) } PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) { API_METHOD_PREFIX - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStop.id = rangeId; - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStop); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStop); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStop.id = rangeId; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data, api_callback_arg); + } API_METHOD_SUFFIX_NRET } PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) { - for (const auto& entry : roctx::thread_map) { - const auto tid = entry.first; - for (roctx::message_stack_t stack = *(entry.second); !stack.empty(); stack.pop()) { - std::string message = stack.top(); + std::lock_guard lock(message_stack_mutex); + for (auto&& [tid, message_stack] : message_stack_map) { + // Since we can't iterate a std::stack, we must first make a copy and then unwind it. + for (auto stack_copy = message_stack; !stack_copy.empty(); stack_copy.pop()) { roctx_range_data_t data{}; - data.message = message.c_str(); + data.message = stack_copy.top().c_str(); data.tid = tid; callback(&data, arg); } @@ -206,14 +172,12 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) { if (op >= ROCTX_API_ID_NUMBER) return false; - roctx::callbacks.Set(op, reinterpret_cast(callback), arg); + callbacks.Set(op, reinterpret_cast(callback), arg); return true; } PUBLIC_API bool RemoveApiCallback(uint32_t op) { if (op >= ROCTX_API_ID_NUMBER) return false; - roctx::callbacks.Set(op, nullptr, nullptr); + callbacks.Set(op, nullptr, nullptr); return true; -} - -} // extern "C" +} \ No newline at end of file