diff --git a/src/core/memory_pool.h b/src/core/memory_pool.h index 784565cec5..6d84305eae 100644 --- a/src/core/memory_pool.h +++ b/src/core/memory_pool.h @@ -21,186 +21,172 @@ #ifndef MEMORY_POOL_H_ #define MEMORY_POOL_H_ -#include -#include - -#include -#include - #include "util/exception.h" -#define PTHREAD_CALL(call) \ - do { \ - int err = call; \ - if (err != 0) { \ - errno = err; \ - perror(#call); \ - abort(); \ - } \ - } while (0) +#include +#include +#include +#include +#include namespace roctracer { class MemoryPool { public: - typedef std::mutex mutex_t; + MemoryPool(const roctracer_properties_t& properties) : properties_(properties) { + // Pool definition: The memory pool is split in 2 buffers of equal size. When first initialized, + // the write pointer points to the first element of the first buffer. When a buffer is full, or + // when Flush() is called, the write pointer moves to the other buffer. + const size_t allocation_size = 2 * properties_.buffer_size; + pool_begin_ = nullptr; + AllocateMemory(&pool_begin_, allocation_size); + if (pool_begin_ == nullptr) EXC_ABORT(ROCTRACER_STATUS_ERROR, "pool allocator failed"); - static void allocator_default(char** ptr, size_t size, void* arg) { - (void)arg; - if (*ptr == NULL) { + pool_end_ = pool_begin_ + allocation_size; + buffer_begin_ = pool_begin_; + buffer_end_ = buffer_begin_ + properties_.buffer_size; + write_ptr_ = buffer_begin_; + + // Create a consumer thread and wait for it to be ready to accept work. + std::promise ready; + std::future future = ready.get_future(); + consumer_thread_ = std::thread(ConsumerThreadLoop, this, std::move(ready)); + future.wait(); + } + + ~MemoryPool() { + Flush(); + + // Wait for the previous flush to complete, then send the exit signal. + NotifyConsumerThread(nullptr, nullptr); + consumer_thread_.join(); + + // Free the pool's buffer memory. + AllocateMemory(&pool_begin_, 0); + } + + MemoryPool(const MemoryPool&) = delete; + MemoryPool& operator=(const MemoryPool&) = delete; + + template void Write(Record&& record) { + std::lock_guard producer_lock(producer_mutex_); + char* next = write_ptr_ + sizeof(record); + if (next > buffer_end_) { + NotifyConsumerThread(buffer_begin_, write_ptr_); + + // Switch buffers + buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_; + buffer_end_ = buffer_begin_ + properties_.buffer_size; + write_ptr_ = buffer_begin_; + + next = write_ptr_ + sizeof(record); + if (next > buffer_end_) + EXC_ABORT(ROCTRACER_STATUS_ERROR, + "buffer size(" << properties_.buffer_size << ") is less then the record(" + << sizeof(record) << ")"); + } + + // Store the record into the buffer, and increment the write pointer. + ::memcpy(write_ptr_, &record, sizeof(record)); + write_ptr_ = next; + } + + // Flush the records and block until they are all made visible to the client. + void Flush() { + std::lock_guard producer_lock(producer_mutex_); + if (write_ptr_ == buffer_begin_) return; + + NotifyConsumerThread(buffer_begin_, write_ptr_); + + // Switch buffers + buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_; + buffer_end_ = buffer_begin_ + properties_.buffer_size; + write_ptr_ = buffer_begin_; + + // Wait for the current operation to complete. + std::unique_lock consumer_lock(consumer_mutex_); + consumer_cond_.wait(consumer_lock, [this]() { return !consumer_arg_.valid; }); + } + + private: + void ConsumerThreadLoop(std::promise ready) { + std::unique_lock consumer_lock(consumer_mutex_); + + // This consumer is now ready to accept work. + ready.set_value(); + + while (true) { + consumer_cond_.wait(consumer_lock, [this]() { return consumer_arg_.valid; }); + + // begin == end == nullptr means the thread needs to exit. + if (consumer_arg_.begin == nullptr && consumer_arg_.end == nullptr) break; + + properties_.buffer_callback_fun(consumer_arg_.begin, consumer_arg_.end, + properties_.buffer_callback_arg); + + // Mark this operation as complete (valid=false) and notify a producer, if any, that may be + // waiting to start a new operation. See comment below in NotifyConsumerThread(). + consumer_arg_.valid = false; + consumer_cond_.notify_one(); + } + } + + void NotifyConsumerThread(const char* data_begin, const char* data_end) { + std::unique_lock consumer_lock(consumer_mutex_); + + // If consumer_arg_ is still in use (valid=true), then wait for the consumer thread to finish + // processing the current operation. Multiple producers may wait here, one will be allowed to + // continue once the consumer thread is idle and valid=false. This prevents a race condition + // where operations would be lost if multiple producers could enter this critical section + // (sequentially) before the consumer thread could re-acquire the consumer_mutex_ lock. + consumer_cond_.wait(consumer_lock, [this]() { return !consumer_arg_.valid; }); + + consumer_arg_.begin = data_begin; + consumer_arg_.end = data_end; + + consumer_arg_.valid = true; + consumer_cond_.notify_all(); + } + + void AllocateMemory(char** ptr, size_t size) const { + if (properties_.alloc_fun != nullptr) { + // Use the custom allocator provided in the properties. + properties_.alloc_fun(ptr, size, properties_.alloc_arg); + return; + } + + // No custom allocator was provided so use the default malloc/realloc/free allocator. + if (*ptr == nullptr) { *ptr = reinterpret_cast(malloc(size)); } else if (size != 0) { *ptr = reinterpret_cast(realloc(*ptr, size)); } else { free(*ptr); - *ptr = NULL; + *ptr = nullptr; } } - MemoryPool(const roctracer_properties_t& properties) { - // Assigning pool allocator - alloc_fun_ = allocator_default; - alloc_arg_ = NULL; - if (properties.alloc_fun != NULL) { - alloc_fun_ = properties.alloc_fun; - alloc_arg_ = properties.alloc_arg; - } - - // Pool definition - buffer_size_ = properties.buffer_size; - const size_t pool_size = 2 * buffer_size_; - pool_begin_ = NULL; - alloc_fun_(&pool_begin_, pool_size, alloc_arg_); - if (pool_begin_ == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "pool allocator failed"); - pool_end_ = pool_begin_ + pool_size; - buffer_begin_ = pool_begin_; - buffer_end_ = buffer_begin_ + buffer_size_; - write_ptr_ = buffer_begin_; - - // Consuming read thread - read_callback_fun_ = properties.buffer_callback_fun; - read_callback_arg_ = properties.buffer_callback_arg; - consumer_arg_.set(this, NULL, NULL, true); - PTHREAD_CALL(pthread_mutex_init(&read_mutex_, NULL)); - PTHREAD_CALL(pthread_cond_init(&read_cond_, NULL)); - PTHREAD_CALL(pthread_create(&consumer_thread_, NULL, reader_fun, &consumer_arg_)); - } - - ~MemoryPool() { - Flush(); - PTHREAD_CALL(pthread_cancel(consumer_thread_)); - void* res; - PTHREAD_CALL(pthread_join(consumer_thread_, &res)); - if (res != PTHREAD_CANCELED) - EXC_ABORT(ROCTRACER_STATUS_ERROR, "consumer thread wasn't stopped correctly"); - allocator_default(&pool_begin_, 0, alloc_arg_); - } - - template void Write(const Record& record) { - std::lock_guard lock(write_mutex_); - getRecord(record); - } - - void Flush() { - std::lock_guard lock(write_mutex_); - if (write_ptr_ > buffer_begin_) { - spawn_reader(buffer_begin_, write_ptr_); - sync_reader(&consumer_arg_); - buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_; - buffer_end_ = buffer_begin_ + buffer_size_; - write_ptr_ = buffer_begin_; - } - } - - private: - struct consumer_arg_t { - MemoryPool* obj; - const char* begin; - const char* end; - volatile std::atomic valid; - void set(MemoryPool* obj_p, const char* begin_p, const char* end_p, bool valid_p) { - obj = obj_p; - begin = begin_p; - end = end_p; - valid.store(valid_p); - } - }; - - template Record* getRecord(const Record& init) { - char* next = write_ptr_ + sizeof(Record); - if (next > buffer_end_) { - if (write_ptr_ == buffer_begin_) - EXC_ABORT(ROCTRACER_STATUS_ERROR, - "buffer size(" << buffer_size_ << ") is less then the record(" << sizeof(Record) - << ")"); - spawn_reader(buffer_begin_, write_ptr_); - buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_; - buffer_end_ = buffer_begin_ + buffer_size_; - write_ptr_ = buffer_begin_; - next = write_ptr_ + sizeof(Record); - } - - Record* ptr = reinterpret_cast(write_ptr_); - write_ptr_ = next; - - *ptr = init; - return ptr; - } - - static void reset_reader(consumer_arg_t* arg) { arg->valid.store(false); } - - static void sync_reader(const consumer_arg_t* arg) { - while (arg->valid.load() == true) PTHREAD_CALL(sched_yield()); - } - - static void* reader_fun(void* consumer_arg) { - consumer_arg_t* arg = reinterpret_cast(consumer_arg); - roctracer::MemoryPool* obj = arg->obj; - - reset_reader(arg); - - while (1) { - PTHREAD_CALL(pthread_mutex_lock(&(obj->read_mutex_))); - while (arg->valid.load() == false) { - PTHREAD_CALL(pthread_cond_wait(&(obj->read_cond_), &(obj->read_mutex_))); - } - - obj->read_callback_fun_(arg->begin, arg->end, obj->read_callback_arg_); - reset_reader(arg); - PTHREAD_CALL(pthread_mutex_unlock(&(obj->read_mutex_))); - } - - return NULL; - } - - void spawn_reader(const char* data_begin, const char* data_end) { - sync_reader(&consumer_arg_); - PTHREAD_CALL(pthread_mutex_lock(&read_mutex_)); - consumer_arg_.set(this, data_begin, data_end, true); - PTHREAD_CALL(pthread_cond_signal(&read_cond_)); - PTHREAD_CALL(pthread_mutex_unlock(&read_mutex_)); - } - - // pool allocator - roctracer_allocator_t alloc_fun_; - void* alloc_arg_; + // Properties used to create the memory pool. + const roctracer_properties_t properties_; // Pool definition - size_t buffer_size_; - char* pool_begin_; + char* pool_begin_; // FIXME: shouldn't these be void*? char* pool_end_; char* buffer_begin_; char* buffer_end_; char* write_ptr_; - mutex_t write_mutex_; + std::mutex producer_mutex_; - // Consuming read thread - roctracer_buffer_callback_t read_callback_fun_; - void* read_callback_arg_; - consumer_arg_t consumer_arg_; - pthread_t consumer_thread_; - pthread_mutex_t read_mutex_; - pthread_cond_t read_cond_; + // Consumer thread + std::thread consumer_thread_; + struct { + const char* begin; + const char* end; + bool valid = false; + } consumer_arg_; + + std::mutex consumer_mutex_; + std::condition_variable consumer_cond_; }; } // namespace roctracer