Simplify implementation of journal.h.

Change-Id: I9e2e93fd3cd3391fdf182249f5c4c5ef3debae03


[ROCm/roctracer commit: 89f6880371]
Этот коммит содержится в:
Laurent Morichetti
2022-04-18 15:02:22 -07:00
родитель 80c01a27c0
Коммит b9bbce0017
2 изменённых файлов: 67 добавлений и 101 удалений
+24 -62
Просмотреть файл
@@ -21,78 +21,40 @@
#ifndef SRC_CORE_JOURNAL_H_
#define SRC_CORE_JOURNAL_H_
#include <map>
#include "ext/prof_protocol.h"
#include <mutex>
#include <type_traits>
#include <unordered_map>
namespace roctracer {
template <class Data>
class Journal {
public:
typedef std::mutex mutex_t;
typedef std::map<uint32_t, Data> domain_map_t;
typedef std::map<uint32_t, domain_map_t*> journal_map_t;
struct record_t {
uint32_t domain;
uint32_t op;
Data data;
};
Journal() {
domain_mask_ = 0;
map_ = new journal_map_t;
template <typename Data> class Journal {
public:
/* Insert { domain, op } into the journal. Return false if the insertion failed. */
template <typename T = Data, std::enable_if_t<std::is_constructible_v<Data, T>, int> = 0>
bool Insert(roctracer_domain_t domain, uint32_t op, T&& data) {
std::lock_guard lock(mutex_);
return map_[domain].emplace(op, std::forward<T>(data)).second;
}
~Journal() {
for (auto& val : *map_) delete val.second;
delete map_;
/* Remove { domain, op } from the journal. Return false if the entry did not exist. */
bool Remove(roctracer_domain_t domain, uint32_t op) {
std::lock_guard lock(mutex_);
return map_[domain].erase(op) == 1;
}
void registr(const record_t& record) {
std::lock_guard<mutex_t> lck(mutex_);
auto* map = get_domain_map(record.domain);
map->insert({record.op, record.data});
template <typename Functor> void ForEach(Functor&& func) {
std::lock_guard lock(mutex_);
for (auto&& domain : map_)
for (auto&& operation : domain.second)
if (!func(domain.first /* domain */, operation.first /* op */, operation.second /* data */))
break; /* FIXME: what are we breaking out of? */
}
void remove(const record_t& record) {
std::lock_guard<mutex_t> lck(mutex_);
auto* map = get_domain_map(record.domain);
map->erase(record.op);
}
template <class F>
F foreach(const F& f_i) {
std::lock_guard<mutex_t> lck(mutex_);
F f = f_i;
for (uint32_t domain = 0, mask = domain_mask_; mask != 0; ++domain, mask >>= 1) {
if (mask & 1) {
auto map = get_domain_map(domain);
auto begin = map->begin();
auto end = map->end();
for (auto it = begin; it != end; ++it) {
if (f.fun({domain, it->first, it->second}) == false) break;
}
}
}
return f;
}
private:
domain_map_t* get_domain_map(const uint32_t& domain) {
domain_mask_ |= 1u << domain;
auto domain_it = map_->find(domain);
if (domain_it == map_->end()) {
auto* domain_map = new domain_map_t;
auto ret = map_->insert({domain, domain_map});
domain_it = ret.first;
}
return domain_it->second;
}
mutex_t mutex_;
journal_map_t* map_;
uint32_t domain_mask_;
private:
std::mutex mutex_;
std::unordered_map<roctracer_domain_t, std::unordered_map<uint32_t, Data>> map_;
};
} // namespace roctracer
+43 -39
Просмотреть файл
@@ -183,35 +183,39 @@ struct cb_journal_data_t {
roctracer_rtapi_callback_t callback;
void* user_data;
};
typedef Journal<cb_journal_data_t> CbJournal;
using CbJournal = Journal<cb_journal_data_t>;
CbJournal* cb_journal;
struct act_journal_data_t {
roctracer_pool_t* pool;
};
typedef Journal<act_journal_data_t> ActJournal;
using ActJournal = Journal<act_journal_data_t>;
ActJournal* act_journal;
template <class T, class F>
struct journal_functor_t {
typedef typename T::record_t record_t;
F f_;
journal_functor_t(F f) : f_(f) {}
bool fun(const record_t& record) {
f_((activity_domain_t)record.domain, record.op);
template <typename Functor> struct journal_functor_t {
Functor func_;
journal_functor_t(Functor&& f) : func_(std::forward<Functor>(f)) {}
template <typename Data> bool operator ()(activity_domain_t domain, uint32_t op, Data&& /* data */) const {
func_(domain, op);
return true;
}
};
typedef journal_functor_t<CbJournal, roctracer_enable_op_callback_t> cb_en_functor_t;
typedef journal_functor_t<CbJournal, roctracer_disable_op_callback_t> cb_dis_functor_t;
typedef journal_functor_t<ActJournal, roctracer_enable_op_activity_t> act_en_functor_t;
typedef journal_functor_t<ActJournal, roctracer_disable_op_activity_t> act_dis_functor_t;
template<> bool cb_en_functor_t::fun(const cb_en_functor_t::record_t& record) {
f_((activity_domain_t)record.domain, record.op, record.data.callback, record.data.user_data);
using cb_en_functor_t = journal_functor_t<roctracer_enable_op_callback_t>;
using cb_dis_functor_t = journal_functor_t<roctracer_disable_op_callback_t>;
using act_en_functor_t = journal_functor_t<roctracer_enable_op_activity_t>;
using act_dis_functor_t = journal_functor_t<roctracer_disable_op_activity_t>;
template <>
template <typename Data>
bool cb_en_functor_t::operator ()(activity_domain_t domain, uint32_t op, Data&& data) const {
func_(domain, op, data.callback, data.user_data);
return true;
}
template<> bool act_en_functor_t::fun(const act_en_functor_t::record_t& record) {
f_((activity_domain_t)record.domain, record.op, record.data.pool);
template <>
template <typename Data>
bool act_en_functor_t::operator ()(activity_domain_t domain, uint32_t op, Data&& data) const {
func_(domain, op, data.pool);
return true;
}
@@ -420,7 +424,7 @@ void* HIP_SyncActivityCallback(
if (data == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "ActivityCallback: data is NULL");
phase = data->phase;
} else if (pool != NULL) {
phase = ACTIVITY_API_PHASE_EXIT;
phase = ACTIVITY_API_PHASE_EXIT;
}
if (phase == ACTIVITY_API_PHASE_ENTER) {
@@ -820,13 +824,13 @@ static roctracer_status_t roctracer_enable_callback_fun(
}
static void roctracer_enable_callback_impl(
uint32_t domain,
roctracer_domain_t domain,
uint32_t op,
roctracer_rtapi_callback_t callback,
void* user_data)
{
roctracer::cb_journal->registr({domain, op, {callback, user_data}});
roctracer_enable_callback_fun((roctracer_domain_t)domain, op, callback, user_data);
roctracer::cb_journal->Insert(domain, op, {callback, user_data});
roctracer_enable_callback_fun(domain, op, callback, user_data);
}
PUBLIC_API roctracer_status_t roctracer_enable_op_callback(
@@ -860,7 +864,7 @@ PUBLIC_API roctracer_status_t roctracer_enable_callback(
for (uint32_t domain = 0; domain < ACTIVITY_DOMAIN_NUMBER; ++domain) {
const uint32_t op_end = get_op_end(domain);
for (uint32_t op = get_op_begin(domain); op < op_end; ++op)
roctracer_enable_callback_impl(domain, op, callback, user_data);
roctracer_enable_callback_impl((roctracer_domain_t)domain, op, callback, user_data);
}
API_METHOD_SUFFIX
}
@@ -916,11 +920,11 @@ static roctracer_status_t roctracer_disable_callback_fun(
}
static void roctracer_disable_callback_impl(
uint32_t domain,
roctracer_domain_t domain,
uint32_t op)
{
roctracer::cb_journal->remove({domain, op, {}});
roctracer_disable_callback_fun((roctracer_domain_t)domain, op);
roctracer::cb_journal->Remove(domain, op);
roctracer_disable_callback_fun(domain, op);
}
PUBLIC_API roctracer_status_t roctracer_disable_op_callback(
@@ -948,7 +952,7 @@ PUBLIC_API roctracer_status_t roctracer_disable_callback()
for (uint32_t domain = 0; domain < ACTIVITY_DOMAIN_NUMBER; ++domain) {
const uint32_t op_end = get_op_end(domain);
for (uint32_t op = get_op_begin(domain); op < op_end; ++op)
roctracer_disable_callback_impl(domain, op);
roctracer_disable_callback_impl((roctracer_domain_t)domain, op);
}
API_METHOD_SUFFIX
}
@@ -1045,12 +1049,12 @@ static roctracer_status_t roctracer_enable_activity_fun(
}
static void roctracer_enable_activity_impl(
uint32_t domain,
roctracer_domain_t domain,
uint32_t op,
roctracer_pool_t* pool)
{
roctracer::act_journal->registr({domain, op, {pool}});
roctracer_enable_activity_fun((roctracer_domain_t)domain, op, pool);
roctracer::act_journal->Insert(domain, op, {pool});
roctracer_enable_activity_fun(domain, op, pool);
}
PUBLIC_API roctracer_status_t roctracer_enable_op_activity_expl(
@@ -1081,7 +1085,7 @@ PUBLIC_API roctracer_status_t roctracer_enable_activity_expl(
for (uint32_t domain = 0; domain < ACTIVITY_DOMAIN_NUMBER; ++domain) {
const uint32_t op_end = get_op_end(domain);
for (uint32_t op = get_op_begin(domain); op < op_end; ++op)
roctracer_enable_activity_impl(domain, op, pool);
roctracer_enable_activity_impl((roctracer_domain_t)domain, op, pool);
}
API_METHOD_SUFFIX
}
@@ -1132,11 +1136,11 @@ static roctracer_status_t roctracer_disable_activity_fun(
}
static void roctracer_disable_activity_impl(
uint32_t domain,
roctracer_domain_t domain,
uint32_t op)
{
roctracer::act_journal->remove({domain, op, {}});
roctracer_disable_activity_fun((roctracer_domain_t)domain, op);
roctracer::act_journal->Remove(domain, op);
roctracer_disable_activity_fun(domain, op);
}
PUBLIC_API roctracer_status_t roctracer_disable_op_activity(
@@ -1164,7 +1168,7 @@ PUBLIC_API roctracer_status_t roctracer_disable_activity()
for (uint32_t domain = 0; domain < ACTIVITY_DOMAIN_NUMBER; ++domain) {
const uint32_t op_end = get_op_end(domain);
for (uint32_t op = get_op_begin(domain); op < op_end; ++op)
roctracer_disable_activity_impl(domain, op);
roctracer_disable_activity_impl((roctracer_domain_t)domain, op);
}
API_METHOD_SUFFIX
}
@@ -1218,18 +1222,18 @@ PUBLIC_API void roctracer_mark(const char* str) {
PUBLIC_API void roctracer_start() {
if (roctracer::set_stopped(0)) {
if (roctracer::ext_support::roctracer_start_cb) roctracer::ext_support::roctracer_start_cb();
roctracer::cb_journal->foreach(roctracer::cb_en_functor_t(roctracer_enable_callback_fun));
roctracer::act_journal->foreach(roctracer::act_en_functor_t(roctracer_enable_activity_fun));
roctracer::cb_journal->ForEach(roctracer::cb_en_functor_t(roctracer_enable_callback_fun));
roctracer::act_journal->ForEach(roctracer::act_en_functor_t(roctracer_enable_activity_fun));
}
}
// Stop API
PUBLIC_API void roctracer_stop() {
if (roctracer::set_stopped(1)) {
// Must disable the activity first as the spawner checks for the activity being NULL
// Must disable the activity first as the spawner checks for the activity being NULL
// to indicate that there is no callback.
roctracer::act_journal->foreach(roctracer::act_dis_functor_t(roctracer_disable_activity_fun));
roctracer::cb_journal->foreach(roctracer::cb_dis_functor_t(roctracer_disable_callback_fun));
roctracer::act_journal->ForEach(roctracer::act_dis_functor_t(roctracer_disable_activity_fun));
roctracer::cb_journal->ForEach(roctracer::cb_dis_functor_t(roctracer_disable_callback_fun));
if (roctracer::ext_support::roctracer_stop_cb) roctracer::ext_support::roctracer_stop_cb();
}
}