Simplify memory_pool.h
Use the standard concurrent support library (std::thread, std::mutex,
st::condition_variable) instead of pthread.
Fix a mismatched memory allocation/deallocation when a custom allocator
is provided. The MemoryPool destructor was always using the default
allocator (using malloc/realloc/free) even if the pool memory was
allocated with the custom allocator.
Fix various thread safety issues and inefficiencies (spin loops).
Change-Id: I97592caa947f63463041bf43e00af9ebb5ff5886
[ROCm/roctracer commit: 9d728f74a1]
This commit is contained in:
@@ -21,186 +21,172 @@
|
||||
#ifndef MEMORY_POOL_H_
|
||||
#define MEMORY_POOL_H_
|
||||
|
||||
#include <pthread.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
|
||||
#include "util/exception.h"
|
||||
|
||||
#define PTHREAD_CALL(call) \
|
||||
do { \
|
||||
int err = call; \
|
||||
if (err != 0) { \
|
||||
errno = err; \
|
||||
perror(#call); \
|
||||
abort(); \
|
||||
} \
|
||||
} while (0)
|
||||
#include <condition_variable>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <future>
|
||||
#include <mutex>
|
||||
|
||||
namespace roctracer {
|
||||
|
||||
class MemoryPool {
|
||||
public:
|
||||
typedef std::mutex mutex_t;
|
||||
MemoryPool(const roctracer_properties_t& properties) : properties_(properties) {
|
||||
// Pool definition: The memory pool is split in 2 buffers of equal size. When first initialized,
|
||||
// the write pointer points to the first element of the first buffer. When a buffer is full, or
|
||||
// when Flush() is called, the write pointer moves to the other buffer.
|
||||
const size_t allocation_size = 2 * properties_.buffer_size;
|
||||
pool_begin_ = nullptr;
|
||||
AllocateMemory(&pool_begin_, allocation_size);
|
||||
if (pool_begin_ == nullptr) EXC_ABORT(ROCTRACER_STATUS_ERROR, "pool allocator failed");
|
||||
|
||||
static void allocator_default(char** ptr, size_t size, void* arg) {
|
||||
(void)arg;
|
||||
if (*ptr == NULL) {
|
||||
pool_end_ = pool_begin_ + allocation_size;
|
||||
buffer_begin_ = pool_begin_;
|
||||
buffer_end_ = buffer_begin_ + properties_.buffer_size;
|
||||
write_ptr_ = buffer_begin_;
|
||||
|
||||
// Create a consumer thread and wait for it to be ready to accept work.
|
||||
std::promise<void> ready;
|
||||
std::future<void> future = ready.get_future();
|
||||
consumer_thread_ = std::thread(ConsumerThreadLoop, this, std::move(ready));
|
||||
future.wait();
|
||||
}
|
||||
|
||||
~MemoryPool() {
|
||||
Flush();
|
||||
|
||||
// Wait for the previous flush to complete, then send the exit signal.
|
||||
NotifyConsumerThread(nullptr, nullptr);
|
||||
consumer_thread_.join();
|
||||
|
||||
// Free the pool's buffer memory.
|
||||
AllocateMemory(&pool_begin_, 0);
|
||||
}
|
||||
|
||||
MemoryPool(const MemoryPool&) = delete;
|
||||
MemoryPool& operator=(const MemoryPool&) = delete;
|
||||
|
||||
template <typename Record> void Write(Record&& record) {
|
||||
std::lock_guard producer_lock(producer_mutex_);
|
||||
char* next = write_ptr_ + sizeof(record);
|
||||
if (next > buffer_end_) {
|
||||
NotifyConsumerThread(buffer_begin_, write_ptr_);
|
||||
|
||||
// Switch buffers
|
||||
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
|
||||
buffer_end_ = buffer_begin_ + properties_.buffer_size;
|
||||
write_ptr_ = buffer_begin_;
|
||||
|
||||
next = write_ptr_ + sizeof(record);
|
||||
if (next > buffer_end_)
|
||||
EXC_ABORT(ROCTRACER_STATUS_ERROR,
|
||||
"buffer size(" << properties_.buffer_size << ") is less then the record("
|
||||
<< sizeof(record) << ")");
|
||||
}
|
||||
|
||||
// Store the record into the buffer, and increment the write pointer.
|
||||
::memcpy(write_ptr_, &record, sizeof(record));
|
||||
write_ptr_ = next;
|
||||
}
|
||||
|
||||
// Flush the records and block until they are all made visible to the client.
|
||||
void Flush() {
|
||||
std::lock_guard producer_lock(producer_mutex_);
|
||||
if (write_ptr_ == buffer_begin_) return;
|
||||
|
||||
NotifyConsumerThread(buffer_begin_, write_ptr_);
|
||||
|
||||
// Switch buffers
|
||||
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
|
||||
buffer_end_ = buffer_begin_ + properties_.buffer_size;
|
||||
write_ptr_ = buffer_begin_;
|
||||
|
||||
// Wait for the current operation to complete.
|
||||
std::unique_lock consumer_lock(consumer_mutex_);
|
||||
consumer_cond_.wait(consumer_lock, [this]() { return !consumer_arg_.valid; });
|
||||
}
|
||||
|
||||
private:
|
||||
void ConsumerThreadLoop(std::promise<void> ready) {
|
||||
std::unique_lock consumer_lock(consumer_mutex_);
|
||||
|
||||
// This consumer is now ready to accept work.
|
||||
ready.set_value();
|
||||
|
||||
while (true) {
|
||||
consumer_cond_.wait(consumer_lock, [this]() { return consumer_arg_.valid; });
|
||||
|
||||
// begin == end == nullptr means the thread needs to exit.
|
||||
if (consumer_arg_.begin == nullptr && consumer_arg_.end == nullptr) break;
|
||||
|
||||
properties_.buffer_callback_fun(consumer_arg_.begin, consumer_arg_.end,
|
||||
properties_.buffer_callback_arg);
|
||||
|
||||
// Mark this operation as complete (valid=false) and notify a producer, if any, that may be
|
||||
// waiting to start a new operation. See comment below in NotifyConsumerThread().
|
||||
consumer_arg_.valid = false;
|
||||
consumer_cond_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
void NotifyConsumerThread(const char* data_begin, const char* data_end) {
|
||||
std::unique_lock consumer_lock(consumer_mutex_);
|
||||
|
||||
// If consumer_arg_ is still in use (valid=true), then wait for the consumer thread to finish
|
||||
// processing the current operation. Multiple producers may wait here, one will be allowed to
|
||||
// continue once the consumer thread is idle and valid=false. This prevents a race condition
|
||||
// where operations would be lost if multiple producers could enter this critical section
|
||||
// (sequentially) before the consumer thread could re-acquire the consumer_mutex_ lock.
|
||||
consumer_cond_.wait(consumer_lock, [this]() { return !consumer_arg_.valid; });
|
||||
|
||||
consumer_arg_.begin = data_begin;
|
||||
consumer_arg_.end = data_end;
|
||||
|
||||
consumer_arg_.valid = true;
|
||||
consumer_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AllocateMemory(char** ptr, size_t size) const {
|
||||
if (properties_.alloc_fun != nullptr) {
|
||||
// Use the custom allocator provided in the properties.
|
||||
properties_.alloc_fun(ptr, size, properties_.alloc_arg);
|
||||
return;
|
||||
}
|
||||
|
||||
// No custom allocator was provided so use the default malloc/realloc/free allocator.
|
||||
if (*ptr == nullptr) {
|
||||
*ptr = reinterpret_cast<char*>(malloc(size));
|
||||
} else if (size != 0) {
|
||||
*ptr = reinterpret_cast<char*>(realloc(*ptr, size));
|
||||
} else {
|
||||
free(*ptr);
|
||||
*ptr = NULL;
|
||||
*ptr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
MemoryPool(const roctracer_properties_t& properties) {
|
||||
// Assigning pool allocator
|
||||
alloc_fun_ = allocator_default;
|
||||
alloc_arg_ = NULL;
|
||||
if (properties.alloc_fun != NULL) {
|
||||
alloc_fun_ = properties.alloc_fun;
|
||||
alloc_arg_ = properties.alloc_arg;
|
||||
}
|
||||
|
||||
// Pool definition
|
||||
buffer_size_ = properties.buffer_size;
|
||||
const size_t pool_size = 2 * buffer_size_;
|
||||
pool_begin_ = NULL;
|
||||
alloc_fun_(&pool_begin_, pool_size, alloc_arg_);
|
||||
if (pool_begin_ == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "pool allocator failed");
|
||||
pool_end_ = pool_begin_ + pool_size;
|
||||
buffer_begin_ = pool_begin_;
|
||||
buffer_end_ = buffer_begin_ + buffer_size_;
|
||||
write_ptr_ = buffer_begin_;
|
||||
|
||||
// Consuming read thread
|
||||
read_callback_fun_ = properties.buffer_callback_fun;
|
||||
read_callback_arg_ = properties.buffer_callback_arg;
|
||||
consumer_arg_.set(this, NULL, NULL, true);
|
||||
PTHREAD_CALL(pthread_mutex_init(&read_mutex_, NULL));
|
||||
PTHREAD_CALL(pthread_cond_init(&read_cond_, NULL));
|
||||
PTHREAD_CALL(pthread_create(&consumer_thread_, NULL, reader_fun, &consumer_arg_));
|
||||
}
|
||||
|
||||
~MemoryPool() {
|
||||
Flush();
|
||||
PTHREAD_CALL(pthread_cancel(consumer_thread_));
|
||||
void* res;
|
||||
PTHREAD_CALL(pthread_join(consumer_thread_, &res));
|
||||
if (res != PTHREAD_CANCELED)
|
||||
EXC_ABORT(ROCTRACER_STATUS_ERROR, "consumer thread wasn't stopped correctly");
|
||||
allocator_default(&pool_begin_, 0, alloc_arg_);
|
||||
}
|
||||
|
||||
template <typename Record> void Write(const Record& record) {
|
||||
std::lock_guard<mutex_t> lock(write_mutex_);
|
||||
getRecord<Record>(record);
|
||||
}
|
||||
|
||||
void Flush() {
|
||||
std::lock_guard<mutex_t> lock(write_mutex_);
|
||||
if (write_ptr_ > buffer_begin_) {
|
||||
spawn_reader(buffer_begin_, write_ptr_);
|
||||
sync_reader(&consumer_arg_);
|
||||
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
|
||||
buffer_end_ = buffer_begin_ + buffer_size_;
|
||||
write_ptr_ = buffer_begin_;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct consumer_arg_t {
|
||||
MemoryPool* obj;
|
||||
const char* begin;
|
||||
const char* end;
|
||||
volatile std::atomic<bool> valid;
|
||||
void set(MemoryPool* obj_p, const char* begin_p, const char* end_p, bool valid_p) {
|
||||
obj = obj_p;
|
||||
begin = begin_p;
|
||||
end = end_p;
|
||||
valid.store(valid_p);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Record> Record* getRecord(const Record& init) {
|
||||
char* next = write_ptr_ + sizeof(Record);
|
||||
if (next > buffer_end_) {
|
||||
if (write_ptr_ == buffer_begin_)
|
||||
EXC_ABORT(ROCTRACER_STATUS_ERROR,
|
||||
"buffer size(" << buffer_size_ << ") is less then the record(" << sizeof(Record)
|
||||
<< ")");
|
||||
spawn_reader(buffer_begin_, write_ptr_);
|
||||
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
|
||||
buffer_end_ = buffer_begin_ + buffer_size_;
|
||||
write_ptr_ = buffer_begin_;
|
||||
next = write_ptr_ + sizeof(Record);
|
||||
}
|
||||
|
||||
Record* ptr = reinterpret_cast<Record*>(write_ptr_);
|
||||
write_ptr_ = next;
|
||||
|
||||
*ptr = init;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void reset_reader(consumer_arg_t* arg) { arg->valid.store(false); }
|
||||
|
||||
static void sync_reader(const consumer_arg_t* arg) {
|
||||
while (arg->valid.load() == true) PTHREAD_CALL(sched_yield());
|
||||
}
|
||||
|
||||
static void* reader_fun(void* consumer_arg) {
|
||||
consumer_arg_t* arg = reinterpret_cast<consumer_arg_t*>(consumer_arg);
|
||||
roctracer::MemoryPool* obj = arg->obj;
|
||||
|
||||
reset_reader(arg);
|
||||
|
||||
while (1) {
|
||||
PTHREAD_CALL(pthread_mutex_lock(&(obj->read_mutex_)));
|
||||
while (arg->valid.load() == false) {
|
||||
PTHREAD_CALL(pthread_cond_wait(&(obj->read_cond_), &(obj->read_mutex_)));
|
||||
}
|
||||
|
||||
obj->read_callback_fun_(arg->begin, arg->end, obj->read_callback_arg_);
|
||||
reset_reader(arg);
|
||||
PTHREAD_CALL(pthread_mutex_unlock(&(obj->read_mutex_)));
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void spawn_reader(const char* data_begin, const char* data_end) {
|
||||
sync_reader(&consumer_arg_);
|
||||
PTHREAD_CALL(pthread_mutex_lock(&read_mutex_));
|
||||
consumer_arg_.set(this, data_begin, data_end, true);
|
||||
PTHREAD_CALL(pthread_cond_signal(&read_cond_));
|
||||
PTHREAD_CALL(pthread_mutex_unlock(&read_mutex_));
|
||||
}
|
||||
|
||||
// pool allocator
|
||||
roctracer_allocator_t alloc_fun_;
|
||||
void* alloc_arg_;
|
||||
// Properties used to create the memory pool.
|
||||
const roctracer_properties_t properties_;
|
||||
|
||||
// Pool definition
|
||||
size_t buffer_size_;
|
||||
char* pool_begin_;
|
||||
char* pool_begin_; // FIXME: shouldn't these be void*?
|
||||
char* pool_end_;
|
||||
char* buffer_begin_;
|
||||
char* buffer_end_;
|
||||
char* write_ptr_;
|
||||
mutex_t write_mutex_;
|
||||
std::mutex producer_mutex_;
|
||||
|
||||
// Consuming read thread
|
||||
roctracer_buffer_callback_t read_callback_fun_;
|
||||
void* read_callback_arg_;
|
||||
consumer_arg_t consumer_arg_;
|
||||
pthread_t consumer_thread_;
|
||||
pthread_mutex_t read_mutex_;
|
||||
pthread_cond_t read_cond_;
|
||||
// Consumer thread
|
||||
std::thread consumer_thread_;
|
||||
struct {
|
||||
const char* begin;
|
||||
const char* end;
|
||||
bool valid = false;
|
||||
} consumer_arg_;
|
||||
|
||||
std::mutex consumer_mutex_;
|
||||
std::condition_variable consumer_cond_;
|
||||
};
|
||||
|
||||
} // namespace roctracer
|
||||
|
||||
Reference in New Issue
Block a user