diff --git a/inc/roctracer.h b/inc/roctracer.h index 05d3b84295..17cda4446d 100644 --- a/inc/roctracer.h +++ b/inc/roctracer.h @@ -222,7 +222,7 @@ bool roctracer_load( uint64_t failed_tool_count, const char* const* failed_tool_names); -void roctracer_unload(); +void roctracer_unload(bool destruct); #ifdef __cplusplus } // extern "C" block diff --git a/src/core/roctracer.cpp b/src/core/roctracer.cpp index f4d078760d..6416ed124b 100644 --- a/src/core/roctracer.cpp +++ b/src/core/roctracer.cpp @@ -161,7 +161,13 @@ namespace roctracer { decltype(hsa_amd_memory_async_copy)* hsa_amd_memory_async_copy_fn; decltype(hsa_amd_memory_async_copy_rect)* hsa_amd_memory_async_copy_rect_fn; -TraceBuffer trace_buffer(0x200000); +void hsa_async_copy_handler(::proxy::Tracker::entry_t* entry); +void hsa_kernel_handler(::proxy::Tracker::entry_t* entry); +TraceBuffer::flush_prm_t trace_buffer_prm[] = { + {roctracer::COPY_ENTRY_TYPE, hsa_async_copy_handler}, + {roctracer::KERNEL_ENTRY_TYPE, hsa_kernel_handler} +}; +TraceBuffer trace_buffer(0x200000, trace_buffer_prm, 2); namespace hsa_support { // callbacks table @@ -975,13 +981,12 @@ PUBLIC_API bool roctracer_load(HsaApiTable* table, uint64_t runtime_version, uin return true; } -PUBLIC_API void roctracer_unload() { +PUBLIC_API void roctracer_unload(bool destruct) { static bool is_unloaded = false; if (is_unloaded) return; is_unloaded = true; - roctracer::trace_buffer.Flush(roctracer::COPY_ENTRY_TYPE, roctracer::hsa_async_copy_handler); - roctracer::trace_buffer.Flush(roctracer::KERNEL_ENTRY_TYPE, roctracer::hsa_kernel_handler); + if (destruct == false) roctracer::trace_buffer.Flush(); if ((roctracer::hsa_support::output_prefix != NULL) && (roctracer::kernel_file_handle != NULL)) fclose(roctracer::kernel_file_handle); } @@ -989,14 +994,14 @@ PUBLIC_API bool OnLoad(HsaApiTable* table, uint64_t runtime_version, uint64_t fa const char* const* failed_tool_names) { return roctracer_load(table, runtime_version, failed_tool_count, failed_tool_names); } -PUBLIC_API void OnUnload() { roctracer_unload(); } +PUBLIC_API void OnUnload() { roctracer_unload(false); } CONSTRUCTOR_API void constructor() { roctracer::util::Logger::Create(); } DESTRUCTOR_API void destructor() { - roctracer_unload(); + roctracer_unload(true); util::HsaRsrcFactory::Destroy(); roctracer::util::Logger::Destroy(); } diff --git a/src/core/trace_buffer.h b/src/core/trace_buffer.h index 0e0eaf80b2..d0bec76e2c 100644 --- a/src/core/trace_buffer.h +++ b/src/core/trace_buffer.h @@ -1,12 +1,27 @@ #ifndef SRC_CORE_TRACE_BUFFER_H_ #define SRC_CORE_TRACE_BUFFER_H_ +#include +#include +#include + +#define PTHREAD_CALL(call) \ + do { \ + int err = call; \ + if (err != 0) { \ + errno = err; \ + perror(#call); \ + abort(); \ + } \ + } while (0) + namespace roctracer { enum { TRACE_ENTRY_INV = 0, TRACE_ENTRY_INIT = 1, TRACE_ENTRY_COMPL = 2 }; + enum { API_ENTRY_TYPE, COPY_ENTRY_TYPE, @@ -37,42 +52,131 @@ struct trace_entry_t { template class TraceBuffer { public: - typedef void (*callabck_t)(Entry*); + typedef void (*callback_t)(Entry*); + typedef TraceBuffer Obj; + typedef uint64_t pointer_t; - TraceBuffer(uint32_t size) { + struct flush_prm_t { + uint32_t type; + callback_t fun; + }; + + TraceBuffer(uint32_t size, flush_prm_t* flush_prm_arr, uint32_t flush_prm_count) { size_ = size; - data_ = (Entry*) calloc(size, sizeof(Entry)); - memset(data_, 0, size * sizeof(Entry)); - read_pointer_ = data_; + data_ = allocate_fun(); + next_ = NULL; + read_pointer_ = 0; + end_pointer_ = size; + buf_list_.push_back(data_); + + flush_prm_arr_ = flush_prm_arr; + flush_prm_count_ = flush_prm_count; + is_flushed_ = false; + + PTHREAD_CALL(pthread_mutex_init(&work_mutex_, NULL)); + PTHREAD_CALL(pthread_cond_init(&work_cond_, NULL)); + PTHREAD_CALL(pthread_create(&work_thread_, NULL, allocate_worker, this)); } + ~TraceBuffer() { + PTHREAD_CALL(pthread_cancel(work_thread_)); + void *res; + PTHREAD_CALL(pthread_join(work_thread_, &res)); + if (res != PTHREAD_CANCELED) abort_run("~TraceBuffer: consumer thread wasn't stopped correctly"); + + if (is_flushed_ == false) flush_buf(); + } + + Entry* GetEntry() { - Entry* ptr = read_pointer_.fetch_add(1); - if (ptr >= (data_ + size_)) { - fprintf(stderr, "GetEntry: trace buffer is out of range\n"); - abort(); - } - return ptr; + const pointer_t pointer = read_pointer_.fetch_add(1); + if (pointer >= end_pointer_) wrap_buffer(pointer); + return data_ + pointer; } - void Flush(uint32_t type, callabck_t fun) { - Entry* ptr = data_; - for (; (ptr < read_pointer_) && (ptr < (data_ + size_)); ptr++) { - if (ptr->type == type) { - if (ptr->valid == TRACE_ENTRY_COMPL) { - fun(ptr); - } - } - } - if (ptr >= (data_ + size_)) { - fprintf(stderr, "Flush: trace buffer is out of range\n"); - } + void Flush() { + PTHREAD_CALL(pthread_mutex_lock(&work_mutex_)); + flush_buf(); + PTHREAD_CALL(pthread_mutex_unlock(&work_mutex_)); } private: - Entry* data_; + void flush_buf() { + is_flushed_ = true; + for (flush_prm_t* prm = flush_prm_arr_; prm < flush_prm_arr_ + flush_prm_count_; prm++) { + uint32_t type = prm->type; + callback_t fun = prm->fun; + pointer_t pointer = 0; + for (Entry* ptr : buf_list_) { + Entry* end = ptr + size_; + while ((ptr < end) && (pointer < read_pointer_)) { + if (ptr->type == type) { + if (ptr->valid == TRACE_ENTRY_COMPL) { + fun(ptr); + } + } + ptr++; + pointer++; + } + } + } + } + + inline Entry* allocate_fun() { + Entry* ptr = (Entry*) calloc(size_, sizeof(Entry)); + if (ptr == NULL) abort_run("TraceBuffer::allocate_fun: calloc failed"); + //memset(ptr, 0, size_ * sizeof(Entry)); + return ptr; + } + + static void* allocate_worker(void* arg) { + Obj* obj = (Obj*)arg; + + while (1) { + PTHREAD_CALL(pthread_mutex_lock(&(obj->work_mutex_))); + while (obj->next_ != NULL) { + PTHREAD_CALL(pthread_cond_wait(&(obj->work_cond_), &(obj->work_mutex_))); + } + obj->next_ = obj->allocate_fun(); + PTHREAD_CALL(pthread_mutex_unlock(&(obj->work_mutex_))); + } + + return NULL; + } + + void wrap_buffer(const pointer_t pointer) { + PTHREAD_CALL(pthread_mutex_lock(&work_mutex_)); + if (pointer >= end_pointer_) { + data_ = next_; + next_ = NULL; + PTHREAD_CALL(pthread_cond_signal(&work_cond_)); + end_pointer_ += size_; + if (end_pointer_ == 0) abort_run("TraceBuffer::wrap_buffer: pointer overflow"); + buf_list_.push_back(data_); + } + PTHREAD_CALL(pthread_mutex_unlock(&work_mutex_)); + } + + void abort_run(const char* str) { + fprintf(stderr, "%s\n", str); + fflush(stderr); + abort(); + } + uint32_t size_; - std::atomic read_pointer_; + Entry* data_; + Entry* next_; + std::atomic read_pointer_; + pointer_t end_pointer_; + std::list buf_list_; + + flush_prm_t* flush_prm_arr_; + uint32_t flush_prm_count_; + bool is_flushed_; + + pthread_t work_thread_; + pthread_mutex_t work_mutex_; + pthread_cond_t work_cond_; }; } // namespace roctracer diff --git a/test/tool/tracer_tool.cpp b/test/tool/tracer_tool.cpp index 105672d334..44b511149d 100644 --- a/test/tool/tracer_tool.cpp +++ b/test/tool/tracer_tool.cpp @@ -99,7 +99,9 @@ struct hsa_api_trace_entry_t { hsa_api_data_t data; }; -roctracer::TraceBuffer hsa_api_trace_buffer(0x200000); +void hsa_api_flush_cb(hsa_api_trace_entry_t* entry); +roctracer::TraceBuffer::flush_prm_t hsa_flush_prm[1] = {{0, hsa_api_flush_cb}}; +roctracer::TraceBuffer hsa_api_trace_buffer(0x200000, hsa_flush_prm, 1); // HSA API callback function void hsa_api_callback( @@ -153,7 +155,9 @@ struct hip_api_trace_entry_t { const char* name; }; -roctracer::TraceBuffer hip_api_trace_buffer(0x200000); +void hip_api_flush_cb(hip_api_trace_entry_t* entry); +roctracer::TraceBuffer::flush_prm_t hip_flush_prm[1] = {{0, hip_api_flush_cb}}; +roctracer::TraceBuffer hip_api_trace_buffer(0x200000, hip_flush_prm, 1); void hip_api_callback( uint32_t domain, @@ -425,20 +429,20 @@ extern "C" PUBLIC_API bool OnLoad(HsaApiTable* table, uint64_t runtime_version, return roctracer_load(table, runtime_version, failed_tool_count, failed_tool_names); } -// HSA-runtime tool on-unload method -extern "C" PUBLIC_API void OnUnload() { +// tool unload method +void tool_unload(bool destruct) { static bool is_unloaded = false; if (is_unloaded) { return; } is_unloaded = true; - roctracer_unload(); + roctracer_unload(destruct); if (trace_hsa) { ROCTRACER_CALL(roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HSA_API)); ROCTRACER_CALL(roctracer_disable_domain_activity(ACTIVITY_DOMAIN_HSA_OPS)); - hsa_api_trace_buffer.Flush(0, hsa_api_flush_cb); + if (destruct == false) hsa_api_trace_buffer.Flush(); fclose(hsa_api_file_handle); fclose(hsa_async_copy_file_handle); @@ -450,12 +454,15 @@ extern "C" PUBLIC_API void OnUnload() { ROCTRACER_CALL(roctracer_flush_activity()); ROCTRACER_CALL(roctracer_close_pool()); - hip_api_trace_buffer.Flush(0, hip_api_flush_cb); + if (destruct == false) hip_api_trace_buffer.Flush(); fclose(hip_api_file_handle); fclose(hcc_activity_file_handle); } } +// HSA-runtime on-unload method +extern "C" PUBLIC_API void OnUnload() { tool_unload(false); } + extern "C" CONSTRUCTOR_API void constructor() {} -extern "C" DESTRUCTOR_API void destructor() { OnUnload(); } +extern "C" DESTRUCTOR_API void destructor() { tool_unload(true); }