From 4a04400f8524fd3d6710947fa4a9732d79f445db Mon Sep 17 00:00:00 2001 From: Laurent Morichetti Date: Fri, 6 May 2022 12:29:50 -0700 Subject: [PATCH] Cleanup ROCTX's implementation Remove thread_data_init. The C++ standard guarantees that the thread local variable is initialized before its first odr-use and destructed when the thread exits. Use a global initializer to set the reference from the message stack instance in the map. Remove roctracer_error_string. This does not belong to this library. ROCTX does not expose errors to the application. The only functions returning errors are returning -1 (Push/Pop). Remove memory leaks due to strdup on the ranges messages. The memory for the messages is guaranteed to be valid for the duration of the callback, and it is the application's responsibility to strdup the strings if it needs to extend the message's lifetime. Add a lock to the RegisterApiCallback implementation. Iterating the message stack map must be synchronized as a new thread could be adding a new value to the map. Change-Id: Iaf5b07ebc9efe4061cb01327d4c7034888727816 [ROCm/roctracer commit: 713db1fce5869523db6a1f02318c0767cc221e18] --- projects/roctracer/inc/roctx.h | 4 - projects/roctracer/src/roctx/roctx.cpp | 156 ++++++++++--------------- 2 files changed, 60 insertions(+), 100 deletions(-) diff --git a/projects/roctracer/inc/roctx.h b/projects/roctracer/inc/roctx.h index 1f7f6f74b4..d0631e4998 100644 --- a/projects/roctracer/inc/roctx.h +++ b/projects/roctracer/inc/roctx.h @@ -45,10 +45,6 @@ extern "C" { uint32_t roctx_version_major(); uint32_t roctx_version_minor(); -//////////////////////////////////////////////////////////////////////////////// -// Returning the last error -const char* roctracer_error_string(); - //////////////////////////////////////////////////////////////////////////////// // Markers annotating API diff --git a/projects/roctracer/src/roctx/roctx.cpp b/projects/roctracer/src/roctx/roctx.cpp index 1695f3e10e..7c049c3f7e 100644 --- a/projects/roctracer/src/roctx/roctx.cpp +++ b/projects/roctracer/src/roctx/roctx.cpp @@ -21,10 +21,12 @@ #include "inc/roctx.h" #include "inc/roctracer_roctx.h" -#include -#include +#include +#include +#include #include #include +#include #include "inc/ext/prof_protocol.h" #include "core/callback_table.h" @@ -32,38 +34,21 @@ #include "util/logger.h" #define PUBLIC_API __attribute__((visibility("default"))) -#define CONSTRUCTOR_API __attribute__((constructor)) -#define DESTRUCTOR_API __attribute__((destructor)) - -#define API_METHOD_PREFIX \ - roctx_status_t err = ROCTX_STATUS_SUCCESS; \ - try { -#define API_METHOD_SUFFIX \ - } \ - catch (std::exception & e) { \ - ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \ - err = roctx::GetExcStatus(e); \ - } \ - return (err == ROCTX_STATUS_SUCCESS) ? 0 : -1; +#define API_METHOD_PREFIX try { #define API_METHOD_SUFFIX_NRET \ } \ - catch (std::exception & e) { \ + catch (const std::exception& e) { \ ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \ - err = roctx::GetExcStatus(e); \ - } \ - (void)err; + } #define API_METHOD_CATCH(X) \ } \ - catch (std::exception & e) { \ + catch (const std::exception& e) { \ ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \ + return X; \ } \ - (void)err; \ - return X; - -inline uint32_t GetPid() { return syscall(__NR_getpid); } -inline uint32_t GetTid() { return syscall(__NR_gettid); } + assert(false && "should not reach here"); //////////////////////////////////////////////////////////////////////////////// // Library errors enumeration @@ -75,129 +60,110 @@ typedef enum { /////////////////////////////////////////////////////////////////////////////////////////////////// // Library implementation // -namespace roctx { +namespace { -// ROCTX callbacks table roctracer::CallbackTable callbacks; +std::unordered_map> message_stack_map; +std::mutex message_stack_mutex; +thread_local auto& message_stack = []() -> decltype(message_stack_map)::mapped_type& { + const auto tid = syscall(__NR_gettid); + std::lock_guard lock(message_stack_mutex); + return message_stack_map[tid]; +}(); -typedef std::stack message_stack_t; -typedef std::map thread_map_t; -typedef std::mutex map_mutex_t; -map_mutex_t map_mutex; -thread_map_t thread_map; -static thread_local message_stack_t* message_stack = NULL; - -roctx_status_t GetExcStatus(const std::exception& e) { - const roctracer::util::exception* roctx_exc_ptr = - dynamic_cast*>(&e); - return (roctx_exc_ptr) ? roctx_exc_ptr->status() : ROCTX_STATUS_ERROR; -} - -void thread_data_init() { - message_stack = new message_stack_t; - const auto tid = GetTid(); - - std::lock_guard lck(map_mutex); - thread_map[tid] = message_stack; -} - -} // namespace roctx +} // namespace // Logger instantiation roctracer::util::Logger::mutex_t roctracer::util::Logger::mutex_; std::atomic roctracer::util::Logger::instance_{}; -std::atomic roctx_range_counter(0); /////////////////////////////////////////////////////////////////////////////////////////////////// // Public library methods // -extern "C" { PUBLIC_API uint32_t roctx_version_major() { return ROCTX_VERSION_MAJOR; } PUBLIC_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; } -PUBLIC_API const char* roctracer_error_string() { - return strdup(roctracer::util::Logger::LastMessage().c_str()); -} - PUBLIC_API void roctxMarkA(const char* message) { API_METHOD_PREFIX - roctx_api_data_t api_data{}; - api_data.args.roctxMarkA.message = strdup(message); - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxMarkA); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxMarkA); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxMarkA.message = message; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg); + } API_METHOD_SUFFIX_NRET } PUBLIC_API int roctxRangePushA(const char* message) { API_METHOD_PREFIX - if (roctx::message_stack == NULL) roctx::thread_data_init(); - - roctx_api_data_t api_data{}; - api_data.args.roctxRangePushA.message = strdup(message); - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePushA); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePushA); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxRangePushA.message = message; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data, api_callback_arg); - roctx::message_stack->push(strdup(message)); + } - return roctx::message_stack->size() - 1; + message_stack.emplace(message); + return message_stack.size() - 1; API_METHOD_CATCH(-1); } PUBLIC_API int roctxRangePop() { API_METHOD_PREFIX - if (roctx::message_stack == NULL) roctx::thread_data_init(); - - roctx_api_data_t api_data{}; - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePop); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePop); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data, api_callback_arg); - if (roctx::message_stack->empty()) { + } + + if (message_stack.empty()) { EXC_RAISING(ROCTX_STATUS_ERROR, "Pop from empty stack!"); } - roctx::message_stack->pop(); - return roctx::message_stack->size(); + message_stack.pop(); + return message_stack.size(); API_METHOD_CATCH(-1) } PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) { API_METHOD_PREFIX - roctx_range_counter++; + static std::atomic roctx_range_counter(1); - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStartA.message = strdup(message); - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStartA); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStartA); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStartA.message = message; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data, api_callback_arg); + } - return roctx_range_counter; - API_METHOD_CATCH(-1); + return roctx_range_counter++; + API_METHOD_CATCH(-1) } PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) { API_METHOD_PREFIX - roctx_api_data_t api_data{}; - api_data.args.roctxRangeStop.id = rangeId; - auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStop); - if (api_callback_fun) + if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStop); + api_callback_fun != nullptr) { + roctx_api_data_t api_data{}; + api_data.args.roctxRangeStop.id = rangeId; api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data, api_callback_arg); + } API_METHOD_SUFFIX_NRET } PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) { - for (const auto& entry : roctx::thread_map) { - const auto tid = entry.first; - for (roctx::message_stack_t stack = *(entry.second); !stack.empty(); stack.pop()) { - std::string message = stack.top(); + std::lock_guard lock(message_stack_mutex); + for (auto&& [tid, message_stack] : message_stack_map) { + // Since we can't iterate a std::stack, we must first make a copy and then unwind it. + for (auto stack_copy = message_stack; !stack_copy.empty(); stack_copy.pop()) { roctx_range_data_t data{}; - data.message = message.c_str(); + data.message = stack_copy.top().c_str(); data.tid = tid; callback(&data, arg); } @@ -206,14 +172,12 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) { if (op >= ROCTX_API_ID_NUMBER) return false; - roctx::callbacks.Set(op, reinterpret_cast(callback), arg); + callbacks.Set(op, reinterpret_cast(callback), arg); return true; } PUBLIC_API bool RemoveApiCallback(uint32_t op) { if (op >= ROCTX_API_ID_NUMBER) return false; - roctx::callbacks.Set(op, nullptr, nullptr); + callbacks.Set(op, nullptr, nullptr); return true; -} - -} // extern "C" +} \ No newline at end of file