@@ -129,22 +129,16 @@ class MemoryPool {
|
||||
buffer_end_ = buffer_begin_ + buffer_size_;
|
||||
write_ptr_ = buffer_begin_;
|
||||
|
||||
// Pool references
|
||||
buffer_refs_ = new uint32_t[buffer_refs_count_];
|
||||
memset(buffer_refs_, 0, sizeof(uint32_t) * buffer_refs_count_);
|
||||
|
||||
// Consuming read thread
|
||||
read_callback_fun_ = properties.buffer_callback_fun;
|
||||
read_callback_arg_ = properties.buffer_callback_arg;
|
||||
consumer_arg_ = consumer_arg_t{this, true, NULL, NULL};
|
||||
consumer_arg_.set(this, NULL, NULL, true);
|
||||
PTHREAD_CALL(pthread_mutex_init(&read_mutex_, NULL));
|
||||
PTHREAD_CALL(pthread_cond_init(&read_cond_, NULL));
|
||||
PTHREAD_CALL(pthread_create(&consumer_thread_, NULL, reader_fun, &consumer_arg_));
|
||||
}
|
||||
|
||||
~MemoryPool() {
|
||||
std::lock_guard<mutex_t> lock(write_mutex_);
|
||||
|
||||
Flush();
|
||||
PTHREAD_CALL(pthread_cancel(consumer_thread_));
|
||||
void *res;
|
||||
@@ -154,9 +148,38 @@ class MemoryPool {
|
||||
}
|
||||
|
||||
template <typename Record>
|
||||
Record* getRecord() {
|
||||
void Write(const Record& record) {
|
||||
std::lock_guard<mutex_t> lock(write_mutex_);
|
||||
getRecord<Record>(record);
|
||||
}
|
||||
|
||||
void Flush() {
|
||||
std::lock_guard<mutex_t> lock(write_mutex_);
|
||||
if (write_ptr_ > buffer_begin_) {
|
||||
spawn_reader(buffer_begin_, write_ptr_);
|
||||
sync_reader(&consumer_arg_);
|
||||
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
|
||||
buffer_end_ = buffer_begin_ + buffer_size_;
|
||||
write_ptr_ = buffer_begin_;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct consumer_arg_t {
|
||||
MemoryPool* obj;
|
||||
const char* begin;
|
||||
const char* end;
|
||||
std::atomic<bool> valid;
|
||||
void set(MemoryPool* obj_p, const char* begin_p, const char* end_p, bool valid_p) {
|
||||
obj = obj_p;
|
||||
begin = begin_p;
|
||||
end = end_p;
|
||||
valid.store(valid_p);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Record>
|
||||
Record* getRecord(const Record& init) {
|
||||
char* next = write_ptr_ + sizeof(Record);
|
||||
if (next > buffer_end_) {
|
||||
if (write_ptr_ == buffer_begin_) EXC_ABORT(ROCTRACER_STATUS_ERROR, "buffer size(" << buffer_size_ << ") is less then the record(" << sizeof(Record) << ")");
|
||||
@@ -170,40 +193,16 @@ class MemoryPool {
|
||||
Record* ptr = reinterpret_cast<Record*>(write_ptr_);
|
||||
write_ptr_ = next;
|
||||
|
||||
*ptr = {};
|
||||
*ptr = init;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
template <typename Record>
|
||||
void Write(const Record& record) {
|
||||
*getRecord<Record>() = record;
|
||||
}
|
||||
|
||||
void Flush() {
|
||||
if (write_ptr_ > buffer_begin_) {
|
||||
spawn_reader(buffer_begin_, write_ptr_);
|
||||
sync_reader(&consumer_arg_);
|
||||
buffer_begin_ = write_ptr_;
|
||||
}
|
||||
}
|
||||
|
||||
void incrementRef(void* ptr) { buffer_refs_[calc_buffer_index(ptr)] += 1; }
|
||||
void decrementRef(void* ptr) { buffer_refs_[calc_buffer_index(ptr)] -= 1; }
|
||||
|
||||
private:
|
||||
struct consumer_arg_t {
|
||||
MemoryPool* obj;
|
||||
bool valid;
|
||||
const char* begin;
|
||||
const char* end;
|
||||
};
|
||||
|
||||
static void reset_reader(consumer_arg_t* arg) {
|
||||
reinterpret_cast<std::atomic<bool>*>(&(arg->valid))->store(false, std::memory_order_release);
|
||||
arg->valid.store(false);
|
||||
}
|
||||
|
||||
static void sync_reader(const consumer_arg_t* arg) {
|
||||
while(arg->valid) PTHREAD_CALL(pthread_yield());
|
||||
while(arg->valid.load() == true) PTHREAD_CALL(pthread_yield());
|
||||
}
|
||||
|
||||
static void* reader_fun(void* consumer_arg) {
|
||||
@@ -214,13 +213,10 @@ class MemoryPool {
|
||||
|
||||
while (1) {
|
||||
PTHREAD_CALL(pthread_mutex_lock(&(obj->read_mutex_)));
|
||||
while (arg->valid == false) {
|
||||
while (arg->valid.load() == false) {
|
||||
PTHREAD_CALL(pthread_cond_wait(&(obj->read_cond_), &(obj->read_mutex_)));
|
||||
}
|
||||
|
||||
const uint32_t buffer_index = obj->calc_buffer_index(arg->begin);
|
||||
while(obj->buffer_refs_[buffer_index] != 0) PTHREAD_CALL(pthread_yield());
|
||||
|
||||
obj->read_callback_fun_(arg->begin, arg->end, obj->read_callback_arg_);
|
||||
reset_reader(arg);
|
||||
PTHREAD_CALL(pthread_mutex_unlock(&(obj->read_mutex_)));
|
||||
@@ -232,7 +228,7 @@ class MemoryPool {
|
||||
void spawn_reader(const char* data_begin, const char* data_end) {
|
||||
sync_reader(&consumer_arg_);
|
||||
PTHREAD_CALL(pthread_mutex_lock(&read_mutex_));
|
||||
consumer_arg_ = consumer_arg_t{this, true, data_begin, data_end};
|
||||
consumer_arg_.set(this, data_begin, data_end, true);
|
||||
PTHREAD_CALL(pthread_cond_signal(&read_cond_));
|
||||
PTHREAD_CALL(pthread_mutex_unlock(&read_mutex_));
|
||||
}
|
||||
@@ -253,10 +249,6 @@ class MemoryPool {
|
||||
char* write_ptr_;
|
||||
mutex_t write_mutex_;
|
||||
|
||||
// Pool references
|
||||
uint32_t* buffer_refs_;
|
||||
static const uint32_t buffer_refs_count_ = 2;
|
||||
|
||||
// Consuming read thread
|
||||
roctracer_buffer_callback_t read_callback_fun_;
|
||||
void* read_callback_arg_;
|
||||
@@ -298,9 +290,9 @@ DESTRUCTOR_API void destructor() {
|
||||
util::Logger::Destroy();
|
||||
}
|
||||
|
||||
void ActivityCallback(
|
||||
roctracer_record_t* ActivityCallback(
|
||||
uint32_t activity_kind,
|
||||
roctracer_record_t** record,
|
||||
roctracer_record_t* record,
|
||||
const void* callback_data,
|
||||
void* arg)
|
||||
{
|
||||
@@ -310,23 +302,25 @@ void ActivityCallback(
|
||||
MemoryPool* pool = reinterpret_cast<MemoryPool*>(arg);
|
||||
if (pool == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "ActivityCallback pool is NULL");
|
||||
if (data->phase == ROCTRACER_API_PHASE_ENTER) {
|
||||
*record = pool->getRecord<roctracer_record_t>();
|
||||
(*record)->domain = ROCTRACER_DOMAIN_HIP_API;
|
||||
(*record)->activity_kind = activity_kind;
|
||||
(*record)->begin_ns = timer.timestamp_ns();
|
||||
record->domain = ROCTRACER_DOMAIN_HIP_API;
|
||||
record->activity_kind = activity_kind;
|
||||
record->begin_ns = timer.timestamp_ns();
|
||||
// Correlation ID generating
|
||||
uint64_t correlation_id = data->correlation_id;
|
||||
if (correlation_id == 0) {
|
||||
correlation_id = GlobalCounter::Increment();
|
||||
const_cast<hip_cb_data_t*>(data)->correlation_id = correlation_id;
|
||||
}
|
||||
(*record)->correlation_id = correlation_id;
|
||||
record->correlation_id = correlation_id;
|
||||
// Passing record to HCC
|
||||
HSAOp_set_activity_record(correlation_id);
|
||||
return record;
|
||||
} else {
|
||||
(*record)->end_ns = timer.timestamp_ns();
|
||||
record->end_ns = timer.timestamp_ns();
|
||||
pool->Write(*record);
|
||||
// Clearing record in HCC
|
||||
HSAOp_set_activity_record(0);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,14 +330,15 @@ void ActivityAsyncCallback(
|
||||
void* arg)
|
||||
{
|
||||
MemoryPool* pool = reinterpret_cast<MemoryPool*>(arg);
|
||||
roctracer_async_record_t* record_ptr = pool->getRecord<roctracer_async_record_t>();
|
||||
*record_ptr = *reinterpret_cast<roctracer_async_record_t*>(record);
|
||||
roctracer_async_record_t* record_ptr = reinterpret_cast<roctracer_async_record_t*>(record);
|
||||
record_ptr->domain = ROCTRACER_DOMAIN_HCC_OPS;
|
||||
pool->Write(*record_ptr);
|
||||
}
|
||||
|
||||
util::Logger::mutex_t util::Logger::mutex_;
|
||||
util::Logger* util::Logger::instance_ = NULL;
|
||||
MemoryPool* memory_pool = NULL;
|
||||
std::mutex memory_pool_mutex;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -430,9 +425,10 @@ PUBLIC_API roctracer_status_t roctracer_disable_api_callback(
|
||||
|
||||
// Return default pool and set new one if parameter pool is not NULL.
|
||||
PUBLIC_API roctracer_pool_t* roctracer_default_pool(roctracer_pool_t* pool) {
|
||||
std::lock_guard<std::mutex> lock(roctracer::memory_pool_mutex);
|
||||
roctracer_pool_t* p = reinterpret_cast<roctracer_pool_t*>(roctracer::memory_pool);
|
||||
if (pool != NULL) roctracer::memory_pool = reinterpret_cast<roctracer::MemoryPool*>(pool);
|
||||
if (p == NULL) EXC_RAISING(ROCTRACER_STATUS_UNINIT, "default pool is not initialized");
|
||||
//if (p == NULL) EXC_RAISING(ROCTRACER_STATUS_UNINIT, "default pool is not initialized");
|
||||
return p;
|
||||
}
|
||||
|
||||
@@ -442,6 +438,7 @@ PUBLIC_API roctracer_status_t roctracer_open_pool(
|
||||
roctracer_pool_t** pool)
|
||||
{
|
||||
API_METHOD_PREFIX
|
||||
std::lock_guard<std::mutex> lock(roctracer::memory_pool_mutex);
|
||||
if ((pool == NULL) && (roctracer::memory_pool != NULL)) {
|
||||
EXC_RAISING(ROCTRACER_STATUS_ERROR, "default pool already set");
|
||||
}
|
||||
@@ -455,6 +452,7 @@ PUBLIC_API roctracer_status_t roctracer_open_pool(
|
||||
// Close memory pool
|
||||
PUBLIC_API roctracer_status_t roctracer_close_pool(roctracer_pool_t* pool) {
|
||||
API_METHOD_PREFIX
|
||||
std::lock_guard<std::mutex> lock(roctracer::memory_pool_mutex);
|
||||
roctracer_pool_t* ptr = (pool == NULL) ? roctracer_default_pool() : pool;
|
||||
roctracer::MemoryPool* memory_pool = reinterpret_cast<roctracer::MemoryPool*>(ptr);
|
||||
delete(memory_pool);
|
||||
|
||||
Reference in New Issue
Block a user