fix hipStreamAddCallback, block future work on stream (#1934)

This commit is contained in:
Jeff Daily
2020-03-19 03:46:04 -07:00
committed by GitHub
parent d90a0c05c0
commit 1444f850ac
5 changed files with 587 additions and 33 deletions
-14
View File
@@ -1520,20 +1520,6 @@ hipError_t ihipStreamSynchronize(TlsData *tls, hipStream_t stream) {
return e;
}
void ihipStreamCallbackHandler(ihipStreamCallback_t* cb) {
hipError_t e = hipSuccess;
// Synchronize stream
tprintf(DB_SYNC, "ihipStreamCallbackHandler wait on stream %s\n",
ToString(cb->_stream).c_str());
GET_TLS();
e = ihipStreamSynchronize(tls, cb->_stream);
// Call registered callback function
cb->_callback(cb->_stream, e, cb->_userData);
delete cb;
}
//---
// Get the stream to use for a command submission.
//
-14
View File
@@ -654,19 +654,6 @@ class ihipStream_t {
};
//----
// Internal structure for stream callback handler
class ihipStreamCallback_t {
public:
ihipStreamCallback_t(hipStream_t stream, hipStreamCallback_t callback, void* userData)
: _stream(stream), _callback(callback), _userData(userData) {
};
hipStream_t _stream;
hipStreamCallback_t _callback;
void* _userData;
};
//----
// Internal event structure:
enum hipEventStatus_t {
@@ -980,7 +967,6 @@ hipError_t hipModuleGetFunctionEx(hipFunction_t* hfunc, hipModule_t hmod,
hipStream_t ihipSyncAndResolveStream(hipStream_t, bool lockAcquired = 0);
hipError_t ihipStreamSynchronize(TlsData *tls, hipStream_t stream);
void ihipStreamCallbackHandler(ihipStreamCallback_t* cb);
// Stream printf functions:
inline std::ostream& operator<<(std::ostream& os, const ihipStream_t& s) {
+33 -5
View File
@@ -257,11 +257,39 @@ hipError_t hipStreamGetPriority(hipStream_t stream, int* priority) {
hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallback_t callback, void* userData,
unsigned int flags) {
HIP_INIT_API(hipStreamAddCallback, stream, callback, userData, flags);
hipError_t e = hipSuccess;
// Create a thread in detached mode to handle callback
ihipStreamCallback_t* cb = new ihipStreamCallback_t(stream, callback, userData);
std::thread(ihipStreamCallbackHandler, cb).detach();
auto stream_original{stream};
stream = ihipSyncAndResolveStream(stream);
return ihipLogStatus(e);
if (!stream) return hipErrorInvalidValue;
LockedAccessor_StreamCrit_t cs{stream->criticalData()};
// create first marker
auto cf = cs->_av.create_marker(hc::no_scope);
// get its signal
auto signal = *reinterpret_cast<hsa_signal_t*>(cf.get_native_handle());
// increment its signal value
hsa_signal_add_relaxed(signal, 1);
// create callback that can be passed to hsa_amd_signal_async_handler
// this function will call the user's callback, then sets first packet's signal to 0 to indicate completion
auto t{new std::function<void()>{[=]() {
callback(stream_original, hipSuccess, userData);
hsa_signal_store_relaxed(signal, 0);
}}};
// register above callback with HSA runtime to be called when first packet's signal
// is decremented from 2 to 1 by CP (or it is already at 1)
hsa_amd_signal_async_handler(signal, HSA_SIGNAL_CONDITION_EQ, 1,
[](hsa_signal_value_t x, void* p) {
(*static_cast<decltype(t)>(p))();
delete static_cast<decltype(t)>(p);
return false;
}, t);
// create additional marker that blocks on the first one
cs->_av.create_blocking_marker(cf, hc::no_scope);
return ihipLogStatus(hipSuccess);
}
@@ -0,0 +1,145 @@
#include <stdio.h>
#include <hip/hip_runtime.h>
#include <unistd.h>
#include "test_common.h"
#include <atomic>
/* HIT_START
* BUILD: %t %s ../../test_common.cpp NVCC_OPTIONS -std=c++11
* TEST: %t
* HIT_END
*/
enum class ExecState
{
EXEC_NOT_STARTED,
EXEC_STARTED,
EXEC_CB_STARTED,
EXEC_CB_FINISHED,
EXEC_FINISHED
};
struct UserData
{
size_t size;
int* ptr;
};
// Global variable to check exection order
std::atomic<ExecState> gData(ExecState::EXEC_NOT_STARTED);
void myCallback(hipStream_t stream, hipError_t status, void* user_data)
{
if(gData.load() != ExecState::EXEC_STARTED)
return; // Error hence return early
gData.store(ExecState::EXEC_CB_STARTED);
UserData* data = reinterpret_cast<UserData*>(user_data);
printf("Callback started\n");
sleep(1);
printf("Callback ending.\n");
gData.store(ExecState::EXEC_CB_FINISHED);
}
bool test(int count)
{
printf("\n============ Test iteration %d =============\n",count);
// Stream
hipStream_t stream;
bool result = true;
gData.store(ExecState::EXEC_STARTED);
HIPCHECK(hipStreamCreate(&stream));
// Array size
size_t size = 10000;
// Device array
int *data = NULL;
HIPCHECK(hipMalloc((void**)&data, sizeof(int) * size));
// Initialize device array to -1
HIPCHECK(hipMemset(data, -1, sizeof(int) * size));
// Host array
int *host = NULL;
HIPCHECK(hipHostMalloc((void**)&host, sizeof(int) * size));
// Print host ptr address
printf("In main thread\n");
// Initialize user_data for callback
UserData arg;
arg.size = size;
arg.ptr = host;
// Synchronize device
HIPCHECK(hipDeviceSynchronize());
// Asynchronous copy from device to host
HIPCHECK(hipMemcpyAsync(host, data, sizeof(int) * size, hipMemcpyDeviceToHost, stream));
// Asynchronous memset on device
HIPCHECK(hipMemsetAsync(data, 0, sizeof(int) * size, stream));
// Add callback - should happen after hipMemsetAsync()
HIPCHECK(hipStreamAddCallback(stream, myCallback, &arg, 0));
printf("Will wait in main thread until callback completes\n");
//This should synchronize the stream (including the callback)
HIPCHECK(hipStreamSynchronize(stream));
if(gData.load() != ExecState::EXEC_CB_FINISHED)
{
std::cout<<"Callback is not finished\n";
return false;
}
printf("Callback completed will resume main thread execution\n");
if(host[size/2] != -1)
{
// Print some host data that just got copied
printf("Pseudo host data printing (should be -1): %d\n", host[size/2]);
result = false;
}
HIPCHECK(hipMemcpy(host, data, sizeof(int)*size, hipMemcpyDeviceToHost));
if(host[size-1] != 0)
{
printf("Pseudo host data printing (should be 0): %d\n", host[size-1]);
result = false;
}
HIPCHECK(hipFree(data));
HIPCHECK(hipHostFree(host));
HIPCHECK(hipStreamDestroy(stream));
gData.store(ExecState::EXEC_FINISHED);
return result;
}
int main()
{
// Test involves multithreading hence running multiple times
// to make sure consitency in the behavior
bool status = true;
for(int i=0; i < 10; i++){
status = test(i+1);
if(status == false)
{
failed("Test Failed!\n");
break;
}
}
if(status == true) passed();
return 0;
}
@@ -0,0 +1,409 @@
#include <hip/hip_runtime.h>
#include <stdexcept>
#include <memory>
#include <functional>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <future>
#include "test_common.h"
/* HIT_START
* BUILD: %t %s ../../test_common.cpp NVCC_OPTIONS -std=c++11
* TEST: %t
* HIT_END
*/
#define WORKAROUND 0 // Enable (1) this to make stream thread-safe by a workaround
template<bool IsBlocking> // <true> = queue blocks, until task is finished in enqueue(queue,task)
class QueueHipRt;
// Queue types used in the tests
using TestQueues = std::tuple<QueueHipRt<true>, QueueHipRt<false>>;
// --- Implementation
#define HIP_ASSERT(x) (assert((x)==hipSuccess))
#define HIP_ASSERT_IGNORE(x,ign) auto err=x; HIP_ASSERT(err==ign ? hipSuccess : err)
#ifdef __HIP_PLATFORM_HCC__
#define HIPRT_CB
#endif
template<bool isBlocking>
static auto currentThreadWaitFor(QueueHipRt<isBlocking> const & queue) -> void;
template<bool IsBlocking>
class QueueHipRt
{
public:
static constexpr bool isBlocking = IsBlocking;
//-----------------------------------------------------------------------------
QueueHipRt(
int dev) :
m_dev(dev),
m_HipQueue()
{
HIP_ASSERT(
hipSetDevice(
m_dev));
HIP_ASSERT(
hipStreamCreateWithFlags(
&m_HipQueue,
hipStreamNonBlocking));
}
//-----------------------------------------------------------------------------
QueueHipRt(QueueHipRt const &) = delete;
//-----------------------------------------------------------------------------
QueueHipRt(QueueHipRt &&) = delete;
//-----------------------------------------------------------------------------
auto operator=(QueueHipRt const &) -> QueueHipRt & = delete;
//-----------------------------------------------------------------------------
auto operator=(QueueHipRt &&) -> QueueHipRt & = delete;
//-----------------------------------------------------------------------------
~QueueHipRt()
{
if(isBlocking) {
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
// we are a non-blocking queue, so we have to wait here with its destruction until all spawned tasks have been processed
currentThreadWaitFor(*this);
#endif
}
HIP_ASSERT(
hipSetDevice(
m_dev));
HIP_ASSERT(
hipStreamDestroy(
m_HipQueue));
}
public:
int m_dev; //!< The device this queue is bound to.
hipStream_t m_HipQueue;
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
int m_callees = 0;
std::mutex m_mutex;
#endif
};
template<typename TTask>
struct Enqueue
{
//#############################################################################
enum class CallbackState
{
enqueued,
notified,
finished,
};
//#############################################################################
struct CallbackSynchronizationData : public std::enable_shared_from_this<CallbackSynchronizationData>
{
std::mutex m_mutex;
std::condition_variable m_event;
CallbackState state = CallbackState::enqueued;
};
//-----------------------------------------------------------------------------
static void HIPRT_CB hipRtCallback(hipStream_t /*queue*/, hipError_t /*status*/, void *arg)
{
// explicitly copy the shared_ptr so that this method holds the state even when the executing thread has already finished.
const auto pCallbackSynchronizationData = reinterpret_cast<CallbackSynchronizationData*>(arg)->shared_from_this();
// Notify the executing thread.
{
std::unique_lock<std::mutex> lock(pCallbackSynchronizationData->m_mutex);
pCallbackSynchronizationData->state = CallbackState::notified;
}
pCallbackSynchronizationData->m_event.notify_one();
// Wait for the executing thread to finish the task if it has not already finished.
std::unique_lock<std::mutex> lock(pCallbackSynchronizationData->m_mutex);
if(pCallbackSynchronizationData->state != CallbackState::finished)
{
pCallbackSynchronizationData->m_event.wait(
lock,
[pCallbackSynchronizationData](){
return pCallbackSynchronizationData->state == CallbackState::finished;
}
);
}
}
//-----------------------------------------------------------------------------
template<bool isBlocking>
static auto enqueue(
QueueHipRt<isBlocking> & queue,
TTask const & task)
-> void
{
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
{
// thread-safe callee incrementing
std::lock_guard<std::mutex> guard(queue.m_mutex);
queue.m_callees += 1;
}
#endif
auto pCallbackSynchronizationData = std::make_shared<CallbackSynchronizationData>();
// test example: https://github.com/ROCm-Developer-Tools/HIP/blob/roc-1.9.x/tests/src/runtimeApi/stream/hipStreamAddCallback.cpp
HIP_ASSERT(hipStreamAddCallback(
queue.m_HipQueue,
hipRtCallback,
pCallbackSynchronizationData.get(),
0u));
// We start a new std::thread which stores the task to be executed.
// This circumvents the limitation that it is not possible to call HIP methods within the HIP callback thread.
// The HIP thread signals the std::thread when it is ready to execute the task.
// The HIP thread is waiting for the std::thread to signal that it is finished executing the task
// before it executes the next task in the queue (HIP stream).
std::thread t(
[pCallbackSynchronizationData,
task
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
,&queue // requires queue's destructor to wait for all tasks
#endif
](){
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
// thread-safe task execution and callee decrementing
std::lock_guard<std::mutex> guard(queue.m_mutex);
#endif
// If the callback has not yet been called, we wait for it.
{
std::unique_lock<std::mutex> lock(pCallbackSynchronizationData->m_mutex);
if(pCallbackSynchronizationData->state != CallbackState::notified)
{
pCallbackSynchronizationData->m_event.wait(
lock,
[pCallbackSynchronizationData](){
return pCallbackSynchronizationData->state == CallbackState::notified;
}
);
}
task();
// Notify the waiting HIP thread.
pCallbackSynchronizationData->state = CallbackState::finished;
}
pCallbackSynchronizationData->m_event.notify_one();
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
queue.m_callees -= 1;
#endif
}
);
if(isBlocking)
t.join(); // => waiting for task completion
else
t.detach(); // => do not wait for task completion
}
};
//#############################################################################
//! The HIP RT non-blocking queue test trait specialization.
struct Empty
{
//-----------------------------------------------------------------------------
template<bool isBlocking>
static auto empty(
QueueHipRt<isBlocking> const & queue)
-> bool
{
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
return (queue.m_callees==0);
#else
// Query is allowed even for queues on non current device.
hipError_t ret = hipSuccess;
HIP_ASSERT_IGNORE(
ret = hipStreamQuery(
queue.m_HipQueue),
hipErrorNotReady);
return (ret == hipSuccess);
#endif
}
};
template<bool isBlocking>
auto currentThreadWaitFor(QueueHipRt<isBlocking> const & queue) -> void
{
#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking)
while(queue.m_callees>0) {
std::this_thread::sleep_for(std::chrono::milliseconds(10u));
}
#else
// Sync is allowed even for queues on non current device.
HIP_ASSERT( hipStreamSynchronize(
queue.m_HipQueue));
#endif
}
// --- Tests
#define TEMPLATE_LIST_TEST_CASE(TestName) \
template<typename TestType> static void TestName (std::atomic<int> &check); \
static int TestName##Runner () { \
std::atomic<int> check{0}; \
TestName< QueueHipRt<true> >(check); \
fprintf(stderr, "After " #TestName " < QueueHipRt<true> > errors=%d\n", check.load()); \
TestName< QueueHipRt<false> >(check); \
fprintf(stderr, "After " #TestName " < QueueHipRt<false> > errors=%d\n", check.load()); \
return check.load(); \
} \
template<typename TestType> static void TestName (std::atomic<int> &check)
// add 1 if a check fails
#define CHECK(result) do{int arg=(!(result)); fprintf(stderr, "Checking " #result " %d\n", arg); check.fetch_add(arg);}while(false)
//-----------------------------------------------------------------------------
TEMPLATE_LIST_TEST_CASE( queueIsInitiallyEmpty )
{
TestType queue{0};
CHECK(Empty::empty(queue));
}
//-----------------------------------------------------------------------------
TEMPLATE_LIST_TEST_CASE( queueCallbackIsWorking )
{
std::promise<bool> promise;
auto task = [&](){ promise.set_value(true); };
TestType queue{0};
Enqueue<decltype(task)> enqueue;
enqueue.enqueue(
queue,
task
);
CHECK(promise.get_future().get());
}
//-----------------------------------------------------------------------------
TEMPLATE_LIST_TEST_CASE( queueWaitShouldWork )
{
bool CallbackFinished = false;
auto task =
[&CallbackFinished]() noexcept
{
std::this_thread::sleep_for(std::chrono::milliseconds(100u));
CallbackFinished = true;
};
TestType queue{0};
Enqueue<decltype(task)> enqueue;
enqueue.enqueue(
queue,
task
);
currentThreadWaitFor(queue);
CHECK(CallbackFinished);
}
//-----------------------------------------------------------------------------
TEMPLATE_LIST_TEST_CASE( queueShouldNotBeEmptyWhenLastTaskIsStillExecutingAndIsEmptyAfterProcessingFinished )
{
bool CallbackFinished = false;
TestType queue{0};
auto task = [&queue, &CallbackFinished, &check]() noexcept
{
CHECK(!Empty::empty(queue));
std::this_thread::sleep_for(std::chrono::milliseconds(100u));
CallbackFinished = true;
};
Enqueue<decltype(task)> enqueue;
enqueue.enqueue(
queue,
task
);
// A non-blocking queue will always stay empty because the task has been executed immediately.
if(!TestType::isBlocking)
{
currentThreadWaitFor(queue);
}
CHECK(Empty::empty(queue));
CHECK(CallbackFinished);
}
//-----------------------------------------------------------------------------
TEMPLATE_LIST_TEST_CASE( queueShouldNotExecuteTasksInParallel )
{
std::atomic<bool> taskIsExecuting(false);
std::promise<void> firstTaskFinished;
std::future<void> firstTaskFinishedFuture = firstTaskFinished.get_future();
std::promise<void> secondTaskFinished;
std::future<void> secondTaskFinishedFuture = secondTaskFinished.get_future();
TestType queue{0};
std::thread thread1(
[&queue, &taskIsExecuting, &firstTaskFinished, &check]()
{
auto task1 = [&taskIsExecuting, &firstTaskFinished, &check]() noexcept
{
CHECK(!taskIsExecuting.exchange(true));
std::this_thread::sleep_for(std::chrono::milliseconds(100u));
CHECK(taskIsExecuting.exchange(false));
firstTaskFinished.set_value();
};
Enqueue<decltype(task1)> enqueue;
enqueue.enqueue(
queue,
task1
);
});
std::thread thread2(
[&queue, &taskIsExecuting, &secondTaskFinished, &check]()
{
auto task2 = [&taskIsExecuting, &secondTaskFinished, &check]() noexcept
{
CHECK(!taskIsExecuting.exchange(true));
std::this_thread::sleep_for(std::chrono::milliseconds(100u));
CHECK(taskIsExecuting.exchange(false));
secondTaskFinished.set_value();
};
Enqueue<decltype(task2)> enqueue;
enqueue.enqueue(
queue,
task2
);
});
// Both tasks have to be enqueued
thread1.join();
thread2.join();
currentThreadWaitFor(queue);
firstTaskFinishedFuture.get();
secondTaskFinishedFuture.get();
}
#define TESTER(name) do { \
int result = name (); \
fprintf(stderr, #name " %s\n", result?"Errors":"No Errors"); \
if (result) { failed(#name " failed\n"); } \
} while (false)
int main()
{
TESTER(queueIsInitiallyEmptyRunner);
TESTER(queueCallbackIsWorkingRunner);
TESTER(queueWaitShouldWorkRunner);
TESTER(queueShouldNotBeEmptyWhenLastTaskIsStillExecutingAndIsEmptyAfterProcessingFinishedRunner);
TESTER(queueShouldNotExecuteTasksInParallelRunner);
passed();
}