Simplify memory_pool.h

Use the standard concurrent support library (std::thread, std::mutex,
st::condition_variable) instead of pthread.

Fix a mismatched memory allocation/deallocation when a custom allocator
is provided. The MemoryPool destructor was always using the default
allocator (using malloc/realloc/free) even if the pool memory was
allocated with the custom allocator.

Fix various thread safety issues and inefficiencies (spin loops).

Change-Id: I97592caa947f63463041bf43e00af9ebb5ff5886


[ROCm/roctracer commit: 9d728f74a1]
这个提交包含在:
Laurent Morichetti
2022-04-18 18:58:33 -07:00
父节点 45deedf43a
当前提交 3c18fb9f01
+143 -157
查看文件
@@ -21,186 +21,172 @@
#ifndef MEMORY_POOL_H_
#define MEMORY_POOL_H_
#include <pthread.h>
#include <stdlib.h>
#include <atomic>
#include <mutex>
#include "util/exception.h"
#define PTHREAD_CALL(call) \
do { \
int err = call; \
if (err != 0) { \
errno = err; \
perror(#call); \
abort(); \
} \
} while (0)
#include <condition_variable>
#include <cstdlib>
#include <cstring>
#include <future>
#include <mutex>
namespace roctracer {
class MemoryPool {
public:
typedef std::mutex mutex_t;
MemoryPool(const roctracer_properties_t& properties) : properties_(properties) {
// Pool definition: The memory pool is split in 2 buffers of equal size. When first initialized,
// the write pointer points to the first element of the first buffer. When a buffer is full, or
// when Flush() is called, the write pointer moves to the other buffer.
const size_t allocation_size = 2 * properties_.buffer_size;
pool_begin_ = nullptr;
AllocateMemory(&pool_begin_, allocation_size);
if (pool_begin_ == nullptr) EXC_ABORT(ROCTRACER_STATUS_ERROR, "pool allocator failed");
static void allocator_default(char** ptr, size_t size, void* arg) {
(void)arg;
if (*ptr == NULL) {
pool_end_ = pool_begin_ + allocation_size;
buffer_begin_ = pool_begin_;
buffer_end_ = buffer_begin_ + properties_.buffer_size;
write_ptr_ = buffer_begin_;
// Create a consumer thread and wait for it to be ready to accept work.
std::promise<void> ready;
std::future<void> future = ready.get_future();
consumer_thread_ = std::thread(ConsumerThreadLoop, this, std::move(ready));
future.wait();
}
~MemoryPool() {
Flush();
// Wait for the previous flush to complete, then send the exit signal.
NotifyConsumerThread(nullptr, nullptr);
consumer_thread_.join();
// Free the pool's buffer memory.
AllocateMemory(&pool_begin_, 0);
}
MemoryPool(const MemoryPool&) = delete;
MemoryPool& operator=(const MemoryPool&) = delete;
template <typename Record> void Write(Record&& record) {
std::lock_guard producer_lock(producer_mutex_);
char* next = write_ptr_ + sizeof(record);
if (next > buffer_end_) {
NotifyConsumerThread(buffer_begin_, write_ptr_);
// Switch buffers
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
buffer_end_ = buffer_begin_ + properties_.buffer_size;
write_ptr_ = buffer_begin_;
next = write_ptr_ + sizeof(record);
if (next > buffer_end_)
EXC_ABORT(ROCTRACER_STATUS_ERROR,
"buffer size(" << properties_.buffer_size << ") is less then the record("
<< sizeof(record) << ")");
}
// Store the record into the buffer, and increment the write pointer.
::memcpy(write_ptr_, &record, sizeof(record));
write_ptr_ = next;
}
// Flush the records and block until they are all made visible to the client.
void Flush() {
std::lock_guard producer_lock(producer_mutex_);
if (write_ptr_ == buffer_begin_) return;
NotifyConsumerThread(buffer_begin_, write_ptr_);
// Switch buffers
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
buffer_end_ = buffer_begin_ + properties_.buffer_size;
write_ptr_ = buffer_begin_;
// Wait for the current operation to complete.
std::unique_lock consumer_lock(consumer_mutex_);
consumer_cond_.wait(consumer_lock, [this]() { return !consumer_arg_.valid; });
}
private:
void ConsumerThreadLoop(std::promise<void> ready) {
std::unique_lock consumer_lock(consumer_mutex_);
// This consumer is now ready to accept work.
ready.set_value();
while (true) {
consumer_cond_.wait(consumer_lock, [this]() { return consumer_arg_.valid; });
// begin == end == nullptr means the thread needs to exit.
if (consumer_arg_.begin == nullptr && consumer_arg_.end == nullptr) break;
properties_.buffer_callback_fun(consumer_arg_.begin, consumer_arg_.end,
properties_.buffer_callback_arg);
// Mark this operation as complete (valid=false) and notify a producer, if any, that may be
// waiting to start a new operation. See comment below in NotifyConsumerThread().
consumer_arg_.valid = false;
consumer_cond_.notify_one();
}
}
void NotifyConsumerThread(const char* data_begin, const char* data_end) {
std::unique_lock consumer_lock(consumer_mutex_);
// If consumer_arg_ is still in use (valid=true), then wait for the consumer thread to finish
// processing the current operation. Multiple producers may wait here, one will be allowed to
// continue once the consumer thread is idle and valid=false. This prevents a race condition
// where operations would be lost if multiple producers could enter this critical section
// (sequentially) before the consumer thread could re-acquire the consumer_mutex_ lock.
consumer_cond_.wait(consumer_lock, [this]() { return !consumer_arg_.valid; });
consumer_arg_.begin = data_begin;
consumer_arg_.end = data_end;
consumer_arg_.valid = true;
consumer_cond_.notify_all();
}
void AllocateMemory(char** ptr, size_t size) const {
if (properties_.alloc_fun != nullptr) {
// Use the custom allocator provided in the properties.
properties_.alloc_fun(ptr, size, properties_.alloc_arg);
return;
}
// No custom allocator was provided so use the default malloc/realloc/free allocator.
if (*ptr == nullptr) {
*ptr = reinterpret_cast<char*>(malloc(size));
} else if (size != 0) {
*ptr = reinterpret_cast<char*>(realloc(*ptr, size));
} else {
free(*ptr);
*ptr = NULL;
*ptr = nullptr;
}
}
MemoryPool(const roctracer_properties_t& properties) {
// Assigning pool allocator
alloc_fun_ = allocator_default;
alloc_arg_ = NULL;
if (properties.alloc_fun != NULL) {
alloc_fun_ = properties.alloc_fun;
alloc_arg_ = properties.alloc_arg;
}
// Pool definition
buffer_size_ = properties.buffer_size;
const size_t pool_size = 2 * buffer_size_;
pool_begin_ = NULL;
alloc_fun_(&pool_begin_, pool_size, alloc_arg_);
if (pool_begin_ == NULL) EXC_ABORT(ROCTRACER_STATUS_ERROR, "pool allocator failed");
pool_end_ = pool_begin_ + pool_size;
buffer_begin_ = pool_begin_;
buffer_end_ = buffer_begin_ + buffer_size_;
write_ptr_ = buffer_begin_;
// Consuming read thread
read_callback_fun_ = properties.buffer_callback_fun;
read_callback_arg_ = properties.buffer_callback_arg;
consumer_arg_.set(this, NULL, NULL, true);
PTHREAD_CALL(pthread_mutex_init(&read_mutex_, NULL));
PTHREAD_CALL(pthread_cond_init(&read_cond_, NULL));
PTHREAD_CALL(pthread_create(&consumer_thread_, NULL, reader_fun, &consumer_arg_));
}
~MemoryPool() {
Flush();
PTHREAD_CALL(pthread_cancel(consumer_thread_));
void* res;
PTHREAD_CALL(pthread_join(consumer_thread_, &res));
if (res != PTHREAD_CANCELED)
EXC_ABORT(ROCTRACER_STATUS_ERROR, "consumer thread wasn't stopped correctly");
allocator_default(&pool_begin_, 0, alloc_arg_);
}
template <typename Record> void Write(const Record& record) {
std::lock_guard<mutex_t> lock(write_mutex_);
getRecord<Record>(record);
}
void Flush() {
std::lock_guard<mutex_t> lock(write_mutex_);
if (write_ptr_ > buffer_begin_) {
spawn_reader(buffer_begin_, write_ptr_);
sync_reader(&consumer_arg_);
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
buffer_end_ = buffer_begin_ + buffer_size_;
write_ptr_ = buffer_begin_;
}
}
private:
struct consumer_arg_t {
MemoryPool* obj;
const char* begin;
const char* end;
volatile std::atomic<bool> valid;
void set(MemoryPool* obj_p, const char* begin_p, const char* end_p, bool valid_p) {
obj = obj_p;
begin = begin_p;
end = end_p;
valid.store(valid_p);
}
};
template <typename Record> Record* getRecord(const Record& init) {
char* next = write_ptr_ + sizeof(Record);
if (next > buffer_end_) {
if (write_ptr_ == buffer_begin_)
EXC_ABORT(ROCTRACER_STATUS_ERROR,
"buffer size(" << buffer_size_ << ") is less then the record(" << sizeof(Record)
<< ")");
spawn_reader(buffer_begin_, write_ptr_);
buffer_begin_ = (buffer_end_ == pool_end_) ? pool_begin_ : buffer_end_;
buffer_end_ = buffer_begin_ + buffer_size_;
write_ptr_ = buffer_begin_;
next = write_ptr_ + sizeof(Record);
}
Record* ptr = reinterpret_cast<Record*>(write_ptr_);
write_ptr_ = next;
*ptr = init;
return ptr;
}
static void reset_reader(consumer_arg_t* arg) { arg->valid.store(false); }
static void sync_reader(const consumer_arg_t* arg) {
while (arg->valid.load() == true) PTHREAD_CALL(sched_yield());
}
static void* reader_fun(void* consumer_arg) {
consumer_arg_t* arg = reinterpret_cast<consumer_arg_t*>(consumer_arg);
roctracer::MemoryPool* obj = arg->obj;
reset_reader(arg);
while (1) {
PTHREAD_CALL(pthread_mutex_lock(&(obj->read_mutex_)));
while (arg->valid.load() == false) {
PTHREAD_CALL(pthread_cond_wait(&(obj->read_cond_), &(obj->read_mutex_)));
}
obj->read_callback_fun_(arg->begin, arg->end, obj->read_callback_arg_);
reset_reader(arg);
PTHREAD_CALL(pthread_mutex_unlock(&(obj->read_mutex_)));
}
return NULL;
}
void spawn_reader(const char* data_begin, const char* data_end) {
sync_reader(&consumer_arg_);
PTHREAD_CALL(pthread_mutex_lock(&read_mutex_));
consumer_arg_.set(this, data_begin, data_end, true);
PTHREAD_CALL(pthread_cond_signal(&read_cond_));
PTHREAD_CALL(pthread_mutex_unlock(&read_mutex_));
}
// pool allocator
roctracer_allocator_t alloc_fun_;
void* alloc_arg_;
// Properties used to create the memory pool.
const roctracer_properties_t properties_;
// Pool definition
size_t buffer_size_;
char* pool_begin_;
char* pool_begin_; // FIXME: shouldn't these be void*?
char* pool_end_;
char* buffer_begin_;
char* buffer_end_;
char* write_ptr_;
mutex_t write_mutex_;
std::mutex producer_mutex_;
// Consuming read thread
roctracer_buffer_callback_t read_callback_fun_;
void* read_callback_arg_;
consumer_arg_t consumer_arg_;
pthread_t consumer_thread_;
pthread_mutex_t read_mutex_;
pthread_cond_t read_cond_;
// Consumer thread
std::thread consumer_thread_;
struct {
const char* begin;
const char* end;
bool valid = false;
} consumer_arg_;
std::mutex consumer_mutex_;
std::condition_variable consumer_cond_;
};
} // namespace roctracer