diff --git a/src/core/roctracer.cpp b/src/core/roctracer.cpp index 816315ebf4..e42a943418 100644 --- a/src/core/roctracer.cpp +++ b/src/core/roctracer.cpp @@ -129,22 +129,16 @@ class MemoryPool { buffer_end_ = buffer_begin_ + buffer_size_; write_ptr_ = buffer_begin_; - // Pool references - buffer_refs_ = new uint32_t[buffer_refs_count_]; - memset(buffer_refs_, 0, sizeof(uint32_t) * buffer_refs_count_); - // Consuming read thread read_callback_fun_ = properties.buffer_callback_fun; read_callback_arg_ = properties.buffer_callback_arg; - consumer_arg_ = consumer_arg_t{this, true, NULL, NULL}; + consumer_arg_.set(this, NULL, NULL, true); PTHREAD_CALL(pthread_mutex_init(&read_mutex_, NULL)); PTHREAD_CALL(pthread_cond_init(&read_cond_, NULL)); PTHREAD_CALL(pthread_create(&consumer_thread_, NULL, reader_fun, &consumer_arg_)); } ~MemoryPool() { - std::lock_guard lock(write_mutex_); - Flush(); PTHREAD_CALL(pthread_cancel(consumer_thread_)); void *res; @@ -154,9 +148,38 @@ class MemoryPool { } template - Record* getRecord() { + void Write(const Record& record) { std::lock_guard lock(write_mutex_); + getRecord(record); + } + void Flush() { + std::lock_guard lock(write_mutex_); + if (write_ptr_ > buffer_begin_) { + spawn_reader(buffer_begin_, write_ptr_); + sync_reader(&consumer_arg_); + buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_; + buffer_end_ = buffer_begin_ + buffer_size_; + write_ptr_ = buffer_begin_; + } + } + + private: + struct consumer_arg_t { + MemoryPool* obj; + const char* begin; + const char* end; + std::atomic valid; + void set(MemoryPool* obj_p, const char* begin_p, const char* end_p, bool valid_p) { + obj = obj_p; + begin = begin_p; + end = end_p; + valid.store(valid_p); + } + }; + + template + Record* getRecord(const Record& init) { char* next = write_ptr_ + sizeof(Record); if (next > buffer_end_) { if (write_ptr_ == buffer_begin_) EXC_ABORT(ROCTRACER_STATUS_ERROR, "buffer size(" << buffer_size_ << ") is less then the record(" << sizeof(Record) << ")"); @@ -170,40 +193,16 @@ class MemoryPool { Record* ptr = reinterpret_cast(write_ptr_); write_ptr_ = next; - *ptr = {}; + *ptr = init; return ptr; } - template - void Write(const Record& record) { - *getRecord() = record; - } - - void Flush() { - if (write_ptr_ > buffer_begin_) { - spawn_reader(buffer_begin_, write_ptr_); - sync_reader(&consumer_arg_); - buffer_begin_ = write_ptr_; - } - } - - void incrementRef(void* ptr) { buffer_refs_[calc_buffer_index(ptr)] += 1; } - void decrementRef(void* ptr) { buffer_refs_[calc_buffer_index(ptr)] -= 1; } - - private: - struct consumer_arg_t { - MemoryPool* obj; - bool valid; - const char* begin; - const char* end; - }; - static void reset_reader(consumer_arg_t* arg) { - reinterpret_cast*>(&(arg->valid))->store(false, std::memory_order_release); + arg->valid.store(false); } static void sync_reader(const consumer_arg_t* arg) { - while(arg->valid) PTHREAD_CALL(pthread_yield()); + while(arg->valid.load() == true) PTHREAD_CALL(pthread_yield()); } static void* reader_fun(void* consumer_arg) { @@ -214,13 +213,10 @@ class MemoryPool { while (1) { PTHREAD_CALL(pthread_mutex_lock(&(obj->read_mutex_))); - while (arg->valid == false) { + while (arg->valid.load() == false) { PTHREAD_CALL(pthread_cond_wait(&(obj->read_cond_), &(obj->read_mutex_))); } - const uint32_t buffer_index = obj->calc_buffer_index(arg->begin); - while(obj->buffer_refs_[buffer_index] != 0) PTHREAD_CALL(pthread_yield()); - obj->read_callback_fun_(arg->begin, arg->end, obj->read_callback_arg_); reset_reader(arg); PTHREAD_CALL(pthread_mutex_unlock(&(obj->read_mutex_))); @@ -232,7 +228,7 @@ class MemoryPool { void spawn_reader(const char* data_begin, const char* data_end) { sync_reader(&consumer_arg_); PTHREAD_CALL(pthread_mutex_lock(&read_mutex_)); - consumer_arg_ = consumer_arg_t{this, true, data_begin, data_end}; + consumer_arg_.set(this, data_begin, data_end, true); PTHREAD_CALL(pthread_cond_signal(&read_cond_)); PTHREAD_CALL(pthread_mutex_unlock(&read_mutex_)); } @@ -253,10 +249,6 @@ class MemoryPool { char* write_ptr_; mutex_t write_mutex_; - // Pool references - uint32_t* buffer_refs_; - static const uint32_t buffer_refs_count_ = 2; - // Consuming read thread roctracer_buffer_callback_t read_callback_fun_; void* read_callback_arg_; @@ -298,9 +290,9 @@ DESTRUCTOR_API void destructor() { util::Logger::Destroy(); } -void ActivityCallback( +roctracer_record_t* ActivityCallback( uint32_t activity_kind, - roctracer_record_t** record, + roctracer_record_t* record, const void* callback_data, void* arg) { @@ -310,23 +302,25 @@ void ActivityCallback( MemoryPool* pool = reinterpret_cast(arg); if (pool == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "ActivityCallback pool is NULL"); if (data->phase == ROCTRACER_API_PHASE_ENTER) { - *record = pool->getRecord(); - (*record)->domain = ROCTRACER_DOMAIN_HIP_API; - (*record)->activity_kind = activity_kind; - (*record)->begin_ns = timer.timestamp_ns(); + record->domain = ROCTRACER_DOMAIN_HIP_API; + record->activity_kind = activity_kind; + record->begin_ns = timer.timestamp_ns(); // Correlation ID generating uint64_t correlation_id = data->correlation_id; if (correlation_id == 0) { correlation_id = GlobalCounter::Increment(); const_cast(data)->correlation_id = correlation_id; } - (*record)->correlation_id = correlation_id; + record->correlation_id = correlation_id; // Passing record to HCC HSAOp_set_activity_record(correlation_id); + return record; } else { - (*record)->end_ns = timer.timestamp_ns(); + record->end_ns = timer.timestamp_ns(); + pool->Write(*record); // Clearing record in HCC HSAOp_set_activity_record(0); + return NULL; } } @@ -336,14 +330,15 @@ void ActivityAsyncCallback( void* arg) { MemoryPool* pool = reinterpret_cast(arg); - roctracer_async_record_t* record_ptr = pool->getRecord(); - *record_ptr = *reinterpret_cast(record); + roctracer_async_record_t* record_ptr = reinterpret_cast(record); record_ptr->domain = ROCTRACER_DOMAIN_HCC_OPS; + pool->Write(*record_ptr); } util::Logger::mutex_t util::Logger::mutex_; util::Logger* util::Logger::instance_ = NULL; MemoryPool* memory_pool = NULL; +std::mutex memory_pool_mutex; } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -430,9 +425,10 @@ PUBLIC_API roctracer_status_t roctracer_disable_api_callback( // Return default pool and set new one if parameter pool is not NULL. PUBLIC_API roctracer_pool_t* roctracer_default_pool(roctracer_pool_t* pool) { + std::lock_guard lock(roctracer::memory_pool_mutex); roctracer_pool_t* p = reinterpret_cast(roctracer::memory_pool); if (pool != NULL) roctracer::memory_pool = reinterpret_cast(pool); - if (p == NULL) EXC_RAISING(ROCTRACER_STATUS_UNINIT, "default pool is not initialized"); + //if (p == NULL) EXC_RAISING(ROCTRACER_STATUS_UNINIT, "default pool is not initialized"); return p; } @@ -442,6 +438,7 @@ PUBLIC_API roctracer_status_t roctracer_open_pool( roctracer_pool_t** pool) { API_METHOD_PREFIX + std::lock_guard lock(roctracer::memory_pool_mutex); if ((pool == NULL) && (roctracer::memory_pool != NULL)) { EXC_RAISING(ROCTRACER_STATUS_ERROR, "default pool already set"); } @@ -455,6 +452,7 @@ PUBLIC_API roctracer_status_t roctracer_open_pool( // Close memory pool PUBLIC_API roctracer_status_t roctracer_close_pool(roctracer_pool_t* pool) { API_METHOD_PREFIX + std::lock_guard lock(roctracer::memory_pool_mutex); roctracer_pool_t* ptr = (pool == NULL) ? roctracer_default_pool() : pool; roctracer::MemoryPool* memory_pool = reinterpret_cast(ptr); delete(memory_pool);