Cleanup ROCTX's implementation

Remove thread_data_init. The C++ standard guarantees that the thread
local variable is initialized before its first odr-use and destructed
when the thread exits. Use a global initializer to set the reference
from the message stack instance in the map.

Remove roctracer_error_string. This does not belong to this library.
ROCTX does not expose errors to the application. The only functions
returning errors are returning -1 (Push/Pop).

Remove memory leaks due to strdup on the ranges messages. The memory
for the messages is guaranteed to be valid for the duration of the
callback, and it is the application's responsibility to strdup the
strings if it needs to extend the message's lifetime.

Add a lock to the RegisterApiCallback implementation. Iterating the
message stack map must be synchronized as a new thread could be adding
a new value to the map.

Change-Id: Iaf5b07ebc9efe4061cb01327d4c7034888727816


[ROCm/roctracer commit: 713db1fce5]
This commit is contained in:
Laurent Morichetti
2022-05-06 12:29:50 -07:00
rodzic bac7f1c162
commit 4a04400f85
2 zmienionych plików z 60 dodań i 100 usunięć
-4
Wyświetl plik
@@ -45,10 +45,6 @@ extern "C" {
uint32_t roctx_version_major();
uint32_t roctx_version_minor();
////////////////////////////////////////////////////////////////////////////////
// Returning the last error
const char* roctracer_error_string();
////////////////////////////////////////////////////////////////////////////////
// Markers annotating API
+60 -96
Wyświetl plik
@@ -21,10 +21,12 @@
#include "inc/roctx.h"
#include "inc/roctracer_roctx.h"
#include <string.h>
#include <map>
#include <cassert>
#include <cstring>
#include <unordered_map>
#include <mutex>
#include <stack>
#include <string>
#include "inc/ext/prof_protocol.h"
#include "core/callback_table.h"
@@ -32,38 +34,21 @@
#include "util/logger.h"
#define PUBLIC_API __attribute__((visibility("default")))
#define CONSTRUCTOR_API __attribute__((constructor))
#define DESTRUCTOR_API __attribute__((destructor))
#define API_METHOD_PREFIX \
roctx_status_t err = ROCTX_STATUS_SUCCESS; \
try {
#define API_METHOD_SUFFIX \
} \
catch (std::exception & e) { \
ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \
err = roctx::GetExcStatus(e); \
} \
return (err == ROCTX_STATUS_SUCCESS) ? 0 : -1;
#define API_METHOD_PREFIX try {
#define API_METHOD_SUFFIX_NRET \
} \
catch (std::exception & e) { \
catch (const std::exception& e) { \
ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \
err = roctx::GetExcStatus(e); \
} \
(void)err;
}
#define API_METHOD_CATCH(X) \
} \
catch (std::exception & e) { \
catch (const std::exception& e) { \
ERR_LOGGING(__FUNCTION__ << "(), " << e.what()); \
return X; \
} \
(void)err; \
return X;
inline uint32_t GetPid() { return syscall(__NR_getpid); }
inline uint32_t GetTid() { return syscall(__NR_gettid); }
assert(false && "should not reach here");
////////////////////////////////////////////////////////////////////////////////
// Library errors enumeration
@@ -75,129 +60,110 @@ typedef enum {
///////////////////////////////////////////////////////////////////////////////////////////////////
// Library implementation
//
namespace roctx {
namespace {
// ROCTX callbacks table
roctracer::CallbackTable<ROCTX_API_ID_NUMBER> callbacks;
std::unordered_map<uint32_t, std::stack<std::string>> message_stack_map;
std::mutex message_stack_mutex;
thread_local auto& message_stack = []() -> decltype(message_stack_map)::mapped_type& {
const auto tid = syscall(__NR_gettid);
std::lock_guard lock(message_stack_mutex);
return message_stack_map[tid];
}();
typedef std::stack<std::string> message_stack_t;
typedef std::map<uint32_t, message_stack_t*> thread_map_t;
typedef std::mutex map_mutex_t;
map_mutex_t map_mutex;
thread_map_t thread_map;
static thread_local message_stack_t* message_stack = NULL;
roctx_status_t GetExcStatus(const std::exception& e) {
const roctracer::util::exception<roctx_status_t>* roctx_exc_ptr =
dynamic_cast<const roctracer::util::exception<roctx_status_t>*>(&e);
return (roctx_exc_ptr) ? roctx_exc_ptr->status() : ROCTX_STATUS_ERROR;
}
void thread_data_init() {
message_stack = new message_stack_t;
const auto tid = GetTid();
std::lock_guard<map_mutex_t> lck(map_mutex);
thread_map[tid] = message_stack;
}
} // namespace roctx
} // namespace
// Logger instantiation
roctracer::util::Logger::mutex_t roctracer::util::Logger::mutex_;
std::atomic<roctracer::util::Logger*> roctracer::util::Logger::instance_{};
std::atomic<int> roctx_range_counter(0);
///////////////////////////////////////////////////////////////////////////////////////////////////
// Public library methods
//
extern "C" {
PUBLIC_API uint32_t roctx_version_major() { return ROCTX_VERSION_MAJOR; }
PUBLIC_API uint32_t roctx_version_minor() { return ROCTX_VERSION_MINOR; }
PUBLIC_API const char* roctracer_error_string() {
return strdup(roctracer::util::Logger::LastMessage().c_str());
}
PUBLIC_API void roctxMarkA(const char* message) {
API_METHOD_PREFIX
roctx_api_data_t api_data{};
api_data.args.roctxMarkA.message = strdup(message);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxMarkA);
if (api_callback_fun)
if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxMarkA);
api_callback_fun != nullptr) {
roctx_api_data_t api_data{};
api_data.args.roctxMarkA.message = message;
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxMarkA, &api_data, api_callback_arg);
}
API_METHOD_SUFFIX_NRET
}
PUBLIC_API int roctxRangePushA(const char* message) {
API_METHOD_PREFIX
if (roctx::message_stack == NULL) roctx::thread_data_init();
roctx_api_data_t api_data{};
api_data.args.roctxRangePushA.message = strdup(message);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePushA);
if (api_callback_fun)
if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePushA);
api_callback_fun != nullptr) {
roctx_api_data_t api_data{};
api_data.args.roctxRangePushA.message = message;
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePushA, &api_data,
api_callback_arg);
roctx::message_stack->push(strdup(message));
}
return roctx::message_stack->size() - 1;
message_stack.emplace(message);
return message_stack.size() - 1;
API_METHOD_CATCH(-1);
}
PUBLIC_API int roctxRangePop() {
API_METHOD_PREFIX
if (roctx::message_stack == NULL) roctx::thread_data_init();
roctx_api_data_t api_data{};
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangePop);
if (api_callback_fun)
if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangePop);
api_callback_fun != nullptr) {
roctx_api_data_t api_data{};
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangePop, &api_data,
api_callback_arg);
if (roctx::message_stack->empty()) {
}
if (message_stack.empty()) {
EXC_RAISING(ROCTX_STATUS_ERROR, "Pop from empty stack!");
}
roctx::message_stack->pop();
return roctx::message_stack->size();
message_stack.pop();
return message_stack.size();
API_METHOD_CATCH(-1)
}
PUBLIC_API roctx_range_id_t roctxRangeStartA(const char* message) {
API_METHOD_PREFIX
roctx_range_counter++;
static std::atomic<roctx_range_id_t> roctx_range_counter(1);
roctx_api_data_t api_data{};
api_data.args.roctxRangeStartA.message = strdup(message);
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStartA);
if (api_callback_fun)
if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStartA);
api_callback_fun != nullptr) {
roctx_api_data_t api_data{};
api_data.args.roctxRangeStartA.message = message;
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStartA, &api_data,
api_callback_arg);
}
return roctx_range_counter;
API_METHOD_CATCH(-1);
return roctx_range_counter++;
API_METHOD_CATCH(-1)
}
PUBLIC_API void roctxRangeStop(roctx_range_id_t rangeId) {
API_METHOD_PREFIX
roctx_api_data_t api_data{};
api_data.args.roctxRangeStop.id = rangeId;
auto [api_callback_fun, api_callback_arg] = roctx::callbacks.Get(ROCTX_API_ID_roctxRangeStop);
if (api_callback_fun)
if (auto [api_callback_fun, api_callback_arg] = callbacks.Get(ROCTX_API_ID_roctxRangeStop);
api_callback_fun != nullptr) {
roctx_api_data_t api_data{};
api_data.args.roctxRangeStop.id = rangeId;
api_callback_fun(ACTIVITY_DOMAIN_ROCTX, ROCTX_API_ID_roctxRangeStop, &api_data,
api_callback_arg);
}
API_METHOD_SUFFIX_NRET
}
PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg) {
for (const auto& entry : roctx::thread_map) {
const auto tid = entry.first;
for (roctx::message_stack_t stack = *(entry.second); !stack.empty(); stack.pop()) {
std::string message = stack.top();
std::lock_guard lock(message_stack_mutex);
for (auto&& [tid, message_stack] : message_stack_map) {
// Since we can't iterate a std::stack, we must first make a copy and then unwind it.
for (auto stack_copy = message_stack; !stack_copy.empty(); stack_copy.pop()) {
roctx_range_data_t data{};
data.message = message.c_str();
data.message = stack_copy.top().c_str();
data.tid = tid;
callback(&data, arg);
}
@@ -206,14 +172,12 @@ PUBLIC_API void RangeStackIterate(roctx_range_iterate_cb_t callback, void* arg)
PUBLIC_API bool RegisterApiCallback(uint32_t op, void* callback, void* arg) {
if (op >= ROCTX_API_ID_NUMBER) return false;
roctx::callbacks.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
callbacks.Set(op, reinterpret_cast<activity_rtapi_callback_t>(callback), arg);
return true;
}
PUBLIC_API bool RemoveApiCallback(uint32_t op) {
if (op >= ROCTX_API_ID_NUMBER) return false;
roctx::callbacks.Set(op, nullptr, nullptr);
callbacks.Set(op, nullptr, nullptr);
return true;
}
} // extern "C"
}