This commit is contained in:
Evgeny
2018-06-22 19:02:42 +00:00
rodzic 5c5cc0c63f
commit ab9f15454f
+53 -55
Wyświetl plik
@@ -129,22 +129,16 @@ class MemoryPool {
buffer_end_ = buffer_begin_ + buffer_size_;
write_ptr_ = buffer_begin_;
// Pool references
buffer_refs_ = new uint32_t[buffer_refs_count_];
memset(buffer_refs_, 0, sizeof(uint32_t) * buffer_refs_count_);
// Consuming read thread
read_callback_fun_ = properties.buffer_callback_fun;
read_callback_arg_ = properties.buffer_callback_arg;
consumer_arg_ = consumer_arg_t{this, true, NULL, NULL};
consumer_arg_.set(this, NULL, NULL, true);
PTHREAD_CALL(pthread_mutex_init(&read_mutex_, NULL));
PTHREAD_CALL(pthread_cond_init(&read_cond_, NULL));
PTHREAD_CALL(pthread_create(&consumer_thread_, NULL, reader_fun, &consumer_arg_));
}
~MemoryPool() {
std::lock_guard<mutex_t> lock(write_mutex_);
Flush();
PTHREAD_CALL(pthread_cancel(consumer_thread_));
void *res;
@@ -154,9 +148,38 @@ class MemoryPool {
}
template <typename Record>
Record* getRecord() {
void Write(const Record& record) {
std::lock_guard<mutex_t> lock(write_mutex_);
getRecord<Record>(record);
}
void Flush() {
std::lock_guard<mutex_t> lock(write_mutex_);
if (write_ptr_ > buffer_begin_) {
spawn_reader(buffer_begin_, write_ptr_);
sync_reader(&consumer_arg_);
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
buffer_end_ = buffer_begin_ + buffer_size_;
write_ptr_ = buffer_begin_;
}
}
private:
struct consumer_arg_t {
MemoryPool* obj;
const char* begin;
const char* end;
std::atomic<bool> valid;
void set(MemoryPool* obj_p, const char* begin_p, const char* end_p, bool valid_p) {
obj = obj_p;
begin = begin_p;
end = end_p;
valid.store(valid_p);
}
};
template <typename Record>
Record* getRecord(const Record& init) {
char* next = write_ptr_ + sizeof(Record);
if (next > buffer_end_) {
if (write_ptr_ == buffer_begin_) EXC_ABORT(ROCTRACER_STATUS_ERROR, "buffer size(" << buffer_size_ << ") is less then the record(" << sizeof(Record) << ")");
@@ -170,40 +193,16 @@ class MemoryPool {
Record* ptr = reinterpret_cast<Record*>(write_ptr_);
write_ptr_ = next;
*ptr = {};
*ptr = init;
return ptr;
}
template <typename Record>
void Write(const Record& record) {
*getRecord<Record>() = record;
}
void Flush() {
if (write_ptr_ > buffer_begin_) {
spawn_reader(buffer_begin_, write_ptr_);
sync_reader(&consumer_arg_);
buffer_begin_ = write_ptr_;
}
}
void incrementRef(void* ptr) { buffer_refs_[calc_buffer_index(ptr)] += 1; }
void decrementRef(void* ptr) { buffer_refs_[calc_buffer_index(ptr)] -= 1; }
private:
struct consumer_arg_t {
MemoryPool* obj;
bool valid;
const char* begin;
const char* end;
};
static void reset_reader(consumer_arg_t* arg) {
reinterpret_cast<std::atomic<bool>*>(&(arg->valid))->store(false, std::memory_order_release);
arg->valid.store(false);
}
static void sync_reader(const consumer_arg_t* arg) {
while(arg->valid) PTHREAD_CALL(pthread_yield());
while(arg->valid.load() == true) PTHREAD_CALL(pthread_yield());
}
static void* reader_fun(void* consumer_arg) {
@@ -214,13 +213,10 @@ class MemoryPool {
while (1) {
PTHREAD_CALL(pthread_mutex_lock(&(obj->read_mutex_)));
while (arg->valid == false) {
while (arg->valid.load() == false) {
PTHREAD_CALL(pthread_cond_wait(&(obj->read_cond_), &(obj->read_mutex_)));
}
const uint32_t buffer_index = obj->calc_buffer_index(arg->begin);
while(obj->buffer_refs_[buffer_index] != 0) PTHREAD_CALL(pthread_yield());
obj->read_callback_fun_(arg->begin, arg->end, obj->read_callback_arg_);
reset_reader(arg);
PTHREAD_CALL(pthread_mutex_unlock(&(obj->read_mutex_)));
@@ -232,7 +228,7 @@ class MemoryPool {
void spawn_reader(const char* data_begin, const char* data_end) {
sync_reader(&consumer_arg_);
PTHREAD_CALL(pthread_mutex_lock(&read_mutex_));
consumer_arg_ = consumer_arg_t{this, true, data_begin, data_end};
consumer_arg_.set(this, data_begin, data_end, true);
PTHREAD_CALL(pthread_cond_signal(&read_cond_));
PTHREAD_CALL(pthread_mutex_unlock(&read_mutex_));
}
@@ -253,10 +249,6 @@ class MemoryPool {
char* write_ptr_;
mutex_t write_mutex_;
// Pool references
uint32_t* buffer_refs_;
static const uint32_t buffer_refs_count_ = 2;
// Consuming read thread
roctracer_buffer_callback_t read_callback_fun_;
void* read_callback_arg_;
@@ -298,9 +290,9 @@ DESTRUCTOR_API void destructor() {
util::Logger::Destroy();
}
void ActivityCallback(
roctracer_record_t* ActivityCallback(
uint32_t activity_kind,
roctracer_record_t** record,
roctracer_record_t* record,
const void* callback_data,
void* arg)
{
@@ -310,23 +302,25 @@ void ActivityCallback(
MemoryPool* pool = reinterpret_cast<MemoryPool*>(arg);
if (pool == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "ActivityCallback pool is NULL");
if (data->phase == ROCTRACER_API_PHASE_ENTER) {
*record = pool->getRecord<roctracer_record_t>();
(*record)->domain = ROCTRACER_DOMAIN_HIP_API;
(*record)->activity_kind = activity_kind;
(*record)->begin_ns = timer.timestamp_ns();
record->domain = ROCTRACER_DOMAIN_HIP_API;
record->activity_kind = activity_kind;
record->begin_ns = timer.timestamp_ns();
// Correlation ID generating
uint64_t correlation_id = data->correlation_id;
if (correlation_id == 0) {
correlation_id = GlobalCounter::Increment();
const_cast<hip_cb_data_t*>(data)->correlation_id = correlation_id;
}
(*record)->correlation_id = correlation_id;
record->correlation_id = correlation_id;
// Passing record to HCC
HSAOp_set_activity_record(correlation_id);
return record;
} else {
(*record)->end_ns = timer.timestamp_ns();
record->end_ns = timer.timestamp_ns();
pool->Write(*record);
// Clearing record in HCC
HSAOp_set_activity_record(0);
return NULL;
}
}
@@ -336,14 +330,15 @@ void ActivityAsyncCallback(
void* arg)
{
MemoryPool* pool = reinterpret_cast<MemoryPool*>(arg);
roctracer_async_record_t* record_ptr = pool->getRecord<roctracer_async_record_t>();
*record_ptr = *reinterpret_cast<roctracer_async_record_t*>(record);
roctracer_async_record_t* record_ptr = reinterpret_cast<roctracer_async_record_t*>(record);
record_ptr->domain = ROCTRACER_DOMAIN_HCC_OPS;
pool->Write(*record_ptr);
}
util::Logger::mutex_t util::Logger::mutex_;
util::Logger* util::Logger::instance_ = NULL;
MemoryPool* memory_pool = NULL;
std::mutex memory_pool_mutex;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -430,9 +425,10 @@ PUBLIC_API roctracer_status_t roctracer_disable_api_callback(
// Return default pool and set new one if parameter pool is not NULL.
PUBLIC_API roctracer_pool_t* roctracer_default_pool(roctracer_pool_t* pool) {
std::lock_guard<std::mutex> lock(roctracer::memory_pool_mutex);
roctracer_pool_t* p = reinterpret_cast<roctracer_pool_t*>(roctracer::memory_pool);
if (pool != NULL) roctracer::memory_pool = reinterpret_cast<roctracer::MemoryPool*>(pool);
if (p == NULL) EXC_RAISING(ROCTRACER_STATUS_UNINIT, "default pool is not initialized");
//if (p == NULL) EXC_RAISING(ROCTRACER_STATUS_UNINIT, "default pool is not initialized");
return p;
}
@@ -442,6 +438,7 @@ PUBLIC_API roctracer_status_t roctracer_open_pool(
roctracer_pool_t** pool)
{
API_METHOD_PREFIX
std::lock_guard<std::mutex> lock(roctracer::memory_pool_mutex);
if ((pool == NULL) && (roctracer::memory_pool != NULL)) {
EXC_RAISING(ROCTRACER_STATUS_ERROR, "default pool already set");
}
@@ -455,6 +452,7 @@ PUBLIC_API roctracer_status_t roctracer_open_pool(
// Close memory pool
PUBLIC_API roctracer_status_t roctracer_close_pool(roctracer_pool_t* pool) {
API_METHOD_PREFIX
std::lock_guard<std::mutex> lock(roctracer::memory_pool_mutex);
roctracer_pool_t* ptr = (pool == NULL) ? roctracer_default_pool() : pool;
roctracer::MemoryPool* memory_pool = reinterpret_cast<roctracer::MemoryPool*>(ptr);
delete(memory_pool);