From 1444f850ac6aaeed85d542e7a834ec968deff246 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 19 Mar 2020 03:46:04 -0700 Subject: [PATCH] fix hipStreamAddCallback, block future work on stream (#1934) --- hipamd/src/hip_hcc.cpp | 14 - hipamd/src/hip_hcc_internal.h | 14 - hipamd/src/hip_stream.cpp | 38 +- .../runtimeApi/stream/StreamAddCallback.cpp | 145 +++++++ .../stream/hipStreamAddCallbackCatch.cpp | 409 ++++++++++++++++++ 5 files changed, 587 insertions(+), 33 deletions(-) create mode 100644 hipamd/tests/src/runtimeApi/stream/StreamAddCallback.cpp create mode 100644 hipamd/tests/src/runtimeApi/stream/hipStreamAddCallbackCatch.cpp diff --git a/hipamd/src/hip_hcc.cpp b/hipamd/src/hip_hcc.cpp index 3f7128e964..807dcc7391 100644 --- a/hipamd/src/hip_hcc.cpp +++ b/hipamd/src/hip_hcc.cpp @@ -1520,20 +1520,6 @@ hipError_t ihipStreamSynchronize(TlsData *tls, hipStream_t stream) { return e; } -void ihipStreamCallbackHandler(ihipStreamCallback_t* cb) { - hipError_t e = hipSuccess; - - // Synchronize stream - tprintf(DB_SYNC, "ihipStreamCallbackHandler wait on stream %s\n", - ToString(cb->_stream).c_str()); - GET_TLS(); - e = ihipStreamSynchronize(tls, cb->_stream); - - // Call registered callback function - cb->_callback(cb->_stream, e, cb->_userData); - delete cb; -} - //--- // Get the stream to use for a command submission. // diff --git a/hipamd/src/hip_hcc_internal.h b/hipamd/src/hip_hcc_internal.h index ac63f49dba..c7ff27c7b5 100644 --- a/hipamd/src/hip_hcc_internal.h +++ b/hipamd/src/hip_hcc_internal.h @@ -654,19 +654,6 @@ class ihipStream_t { }; -//---- -// Internal structure for stream callback handler -class ihipStreamCallback_t { - public: - ihipStreamCallback_t(hipStream_t stream, hipStreamCallback_t callback, void* userData) - : _stream(stream), _callback(callback), _userData(userData) { - }; - hipStream_t _stream; - hipStreamCallback_t _callback; - void* _userData; -}; - - //---- // Internal event structure: enum hipEventStatus_t { @@ -980,7 +967,6 @@ hipError_t hipModuleGetFunctionEx(hipFunction_t* hfunc, hipModule_t hmod, hipStream_t ihipSyncAndResolveStream(hipStream_t, bool lockAcquired = 0); hipError_t ihipStreamSynchronize(TlsData *tls, hipStream_t stream); -void ihipStreamCallbackHandler(ihipStreamCallback_t* cb); // Stream printf functions: inline std::ostream& operator<<(std::ostream& os, const ihipStream_t& s) { diff --git a/hipamd/src/hip_stream.cpp b/hipamd/src/hip_stream.cpp index 2add6a77c4..63551d1204 100644 --- a/hipamd/src/hip_stream.cpp +++ b/hipamd/src/hip_stream.cpp @@ -257,11 +257,39 @@ hipError_t hipStreamGetPriority(hipStream_t stream, int* priority) { hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallback_t callback, void* userData, unsigned int flags) { HIP_INIT_API(hipStreamAddCallback, stream, callback, userData, flags); - hipError_t e = hipSuccess; - // Create a thread in detached mode to handle callback - ihipStreamCallback_t* cb = new ihipStreamCallback_t(stream, callback, userData); - std::thread(ihipStreamCallbackHandler, cb).detach(); + auto stream_original{stream}; + stream = ihipSyncAndResolveStream(stream); - return ihipLogStatus(e); + if (!stream) return hipErrorInvalidValue; + + LockedAccessor_StreamCrit_t cs{stream->criticalData()}; + + // create first marker + auto cf = cs->_av.create_marker(hc::no_scope); + // get its signal + auto signal = *reinterpret_cast(cf.get_native_handle()); + // increment its signal value + hsa_signal_add_relaxed(signal, 1); + + // create callback that can be passed to hsa_amd_signal_async_handler + // this function will call the user's callback, then sets first packet's signal to 0 to indicate completion + auto t{new std::function{[=]() { + callback(stream_original, hipSuccess, userData); + hsa_signal_store_relaxed(signal, 0); + }}}; + + // register above callback with HSA runtime to be called when first packet's signal + // is decremented from 2 to 1 by CP (or it is already at 1) + hsa_amd_signal_async_handler(signal, HSA_SIGNAL_CONDITION_EQ, 1, + [](hsa_signal_value_t x, void* p) { + (*static_cast(p))(); + delete static_cast(p); + return false; + }, t); + + // create additional marker that blocks on the first one + cs->_av.create_blocking_marker(cf, hc::no_scope); + + return ihipLogStatus(hipSuccess); } diff --git a/hipamd/tests/src/runtimeApi/stream/StreamAddCallback.cpp b/hipamd/tests/src/runtimeApi/stream/StreamAddCallback.cpp new file mode 100644 index 0000000000..e6492c7ce2 --- /dev/null +++ b/hipamd/tests/src/runtimeApi/stream/StreamAddCallback.cpp @@ -0,0 +1,145 @@ +#include +#include +#include +#include "test_common.h" +#include + +/* HIT_START + * BUILD: %t %s ../../test_common.cpp NVCC_OPTIONS -std=c++11 + * TEST: %t + * HIT_END + */ + +enum class ExecState +{ + EXEC_NOT_STARTED, + EXEC_STARTED, + EXEC_CB_STARTED, + EXEC_CB_FINISHED, + EXEC_FINISHED +}; + +struct UserData +{ + size_t size; + int* ptr; +}; + +// Global variable to check exection order +std::atomic gData(ExecState::EXEC_NOT_STARTED); + + +void myCallback(hipStream_t stream, hipError_t status, void* user_data) +{ + if(gData.load() != ExecState::EXEC_STARTED) + return; // Error hence return early + + gData.store(ExecState::EXEC_CB_STARTED); + + UserData* data = reinterpret_cast(user_data); + printf("Callback started\n"); + + sleep(1); + + printf("Callback ending.\n"); + gData.store(ExecState::EXEC_CB_FINISHED); +} + +bool test(int count) +{ + printf("\n============ Test iteration %d =============\n",count); + // Stream + hipStream_t stream; + bool result = true; + + gData.store(ExecState::EXEC_STARTED); + + HIPCHECK(hipStreamCreate(&stream)); + + // Array size + size_t size = 10000; + + // Device array + int *data = NULL; + HIPCHECK(hipMalloc((void**)&data, sizeof(int) * size)); + + // Initialize device array to -1 + HIPCHECK(hipMemset(data, -1, sizeof(int) * size)); + + // Host array + int *host = NULL; + HIPCHECK(hipHostMalloc((void**)&host, sizeof(int) * size)); + + // Print host ptr address + printf("In main thread\n"); + + // Initialize user_data for callback + UserData arg; + arg.size = size; + arg.ptr = host; + + // Synchronize device + HIPCHECK(hipDeviceSynchronize()); + + // Asynchronous copy from device to host + HIPCHECK(hipMemcpyAsync(host, data, sizeof(int) * size, hipMemcpyDeviceToHost, stream)); + + // Asynchronous memset on device + HIPCHECK(hipMemsetAsync(data, 0, sizeof(int) * size, stream)); + + // Add callback - should happen after hipMemsetAsync() + HIPCHECK(hipStreamAddCallback(stream, myCallback, &arg, 0)); + + printf("Will wait in main thread until callback completes\n"); + + //This should synchronize the stream (including the callback) + HIPCHECK(hipStreamSynchronize(stream)); + + if(gData.load() != ExecState::EXEC_CB_FINISHED) + { + std::cout<<"Callback is not finished\n"; + return false; + } + printf("Callback completed will resume main thread execution\n"); + + if(host[size/2] != -1) + { + // Print some host data that just got copied + printf("Pseudo host data printing (should be -1): %d\n", host[size/2]); + result = false; + } + + HIPCHECK(hipMemcpy(host, data, sizeof(int)*size, hipMemcpyDeviceToHost)); + + if(host[size-1] != 0) + { + printf("Pseudo host data printing (should be 0): %d\n", host[size-1]); + result = false; + } + + HIPCHECK(hipFree(data)); + HIPCHECK(hipHostFree(host)); + HIPCHECK(hipStreamDestroy(stream)); + + gData.store(ExecState::EXEC_FINISHED); + return result; +} + +int main() +{ + // Test involves multithreading hence running multiple times + // to make sure consitency in the behavior + bool status = true; + + for(int i=0; i < 10; i++){ + status = test(i+1); + if(status == false) + { + failed("Test Failed!\n"); + break; + } + } + + if(status == true) passed(); + return 0; +} diff --git a/hipamd/tests/src/runtimeApi/stream/hipStreamAddCallbackCatch.cpp b/hipamd/tests/src/runtimeApi/stream/hipStreamAddCallbackCatch.cpp new file mode 100644 index 0000000000..5f267bba28 --- /dev/null +++ b/hipamd/tests/src/runtimeApi/stream/hipStreamAddCallbackCatch.cpp @@ -0,0 +1,409 @@ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include "test_common.h" + +/* HIT_START + * BUILD: %t %s ../../test_common.cpp NVCC_OPTIONS -std=c++11 + * TEST: %t + * HIT_END + */ + +#define WORKAROUND 0 // Enable (1) this to make stream thread-safe by a workaround + +template // = queue blocks, until task is finished in enqueue(queue,task) +class QueueHipRt; + +// Queue types used in the tests +using TestQueues = std::tuple, QueueHipRt>; + + +// --- Implementation + +#define HIP_ASSERT(x) (assert((x)==hipSuccess)) +#define HIP_ASSERT_IGNORE(x,ign) auto err=x; HIP_ASSERT(err==ign ? hipSuccess : err) + +#ifdef __HIP_PLATFORM_HCC__ + #define HIPRT_CB +#endif + +template +static auto currentThreadWaitFor(QueueHipRt const & queue) -> void; + +template +class QueueHipRt +{ +public: + static constexpr bool isBlocking = IsBlocking; + //----------------------------------------------------------------------------- + QueueHipRt( + int dev) : + m_dev(dev), + m_HipQueue() + { + HIP_ASSERT( + hipSetDevice( + m_dev)); + HIP_ASSERT( + hipStreamCreateWithFlags( + &m_HipQueue, + hipStreamNonBlocking)); + } + //----------------------------------------------------------------------------- + QueueHipRt(QueueHipRt const &) = delete; + //----------------------------------------------------------------------------- + QueueHipRt(QueueHipRt &&) = delete; + //----------------------------------------------------------------------------- + auto operator=(QueueHipRt const &) -> QueueHipRt & = delete; + //----------------------------------------------------------------------------- + auto operator=(QueueHipRt &&) -> QueueHipRt & = delete; + //----------------------------------------------------------------------------- + ~QueueHipRt() + { + if(isBlocking) { +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + // we are a non-blocking queue, so we have to wait here with its destruction until all spawned tasks have been processed + currentThreadWaitFor(*this); +#endif + } + HIP_ASSERT( + hipSetDevice( + m_dev)); + HIP_ASSERT( + hipStreamDestroy( + m_HipQueue)); + } + +public: + int m_dev; //!< The device this queue is bound to. + hipStream_t m_HipQueue; + +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + int m_callees = 0; + std::mutex m_mutex; +#endif +}; + +template +struct Enqueue +{ + //############################################################################# + enum class CallbackState + { + enqueued, + notified, + finished, + }; + + //############################################################################# + struct CallbackSynchronizationData : public std::enable_shared_from_this + { + std::mutex m_mutex; + std::condition_variable m_event; + CallbackState state = CallbackState::enqueued; + }; + + //----------------------------------------------------------------------------- + static void HIPRT_CB hipRtCallback(hipStream_t /*queue*/, hipError_t /*status*/, void *arg) + { + // explicitly copy the shared_ptr so that this method holds the state even when the executing thread has already finished. + const auto pCallbackSynchronizationData = reinterpret_cast(arg)->shared_from_this(); + + // Notify the executing thread. + { + std::unique_lock lock(pCallbackSynchronizationData->m_mutex); + pCallbackSynchronizationData->state = CallbackState::notified; + } + pCallbackSynchronizationData->m_event.notify_one(); + + // Wait for the executing thread to finish the task if it has not already finished. + std::unique_lock lock(pCallbackSynchronizationData->m_mutex); + if(pCallbackSynchronizationData->state != CallbackState::finished) + { + pCallbackSynchronizationData->m_event.wait( + lock, + [pCallbackSynchronizationData](){ + return pCallbackSynchronizationData->state == CallbackState::finished; + } + ); + } + } + + //----------------------------------------------------------------------------- + template + static auto enqueue( + QueueHipRt & queue, + TTask const & task) + -> void + { + +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + { + // thread-safe callee incrementing + std::lock_guard guard(queue.m_mutex); + queue.m_callees += 1; + } +#endif + auto pCallbackSynchronizationData = std::make_shared(); + // test example: https://github.com/ROCm-Developer-Tools/HIP/blob/roc-1.9.x/tests/src/runtimeApi/stream/hipStreamAddCallback.cpp + HIP_ASSERT(hipStreamAddCallback( + queue.m_HipQueue, + hipRtCallback, + pCallbackSynchronizationData.get(), + 0u)); + + // We start a new std::thread which stores the task to be executed. + // This circumvents the limitation that it is not possible to call HIP methods within the HIP callback thread. + // The HIP thread signals the std::thread when it is ready to execute the task. + // The HIP thread is waiting for the std::thread to signal that it is finished executing the task + // before it executes the next task in the queue (HIP stream). + std::thread t( + [pCallbackSynchronizationData, + task +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + ,&queue // requires queue's destructor to wait for all tasks +#endif + ](){ + +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + // thread-safe task execution and callee decrementing + std::lock_guard guard(queue.m_mutex); +#endif + + // If the callback has not yet been called, we wait for it. + { + std::unique_lock lock(pCallbackSynchronizationData->m_mutex); + if(pCallbackSynchronizationData->state != CallbackState::notified) + { + pCallbackSynchronizationData->m_event.wait( + lock, + [pCallbackSynchronizationData](){ + return pCallbackSynchronizationData->state == CallbackState::notified; + } + ); + } + + task(); + + // Notify the waiting HIP thread. + pCallbackSynchronizationData->state = CallbackState::finished; + } + pCallbackSynchronizationData->m_event.notify_one(); +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + queue.m_callees -= 1; +#endif + } + ); + if(isBlocking) + t.join(); // => waiting for task completion + else + t.detach(); // => do not wait for task completion + } +}; +//############################################################################# +//! The HIP RT non-blocking queue test trait specialization. +struct Empty +{ + //----------------------------------------------------------------------------- + template + static auto empty( + QueueHipRt const & queue) + -> bool + { + +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + return (queue.m_callees==0); +#else + + // Query is allowed even for queues on non current device. + hipError_t ret = hipSuccess; + HIP_ASSERT_IGNORE( + ret = hipStreamQuery( + queue.m_HipQueue), + hipErrorNotReady); + return (ret == hipSuccess); +#endif + } +}; + +template +auto currentThreadWaitFor(QueueHipRt const & queue) -> void +{ +#if WORKAROUND // NOTE: workaround for unwanted nonblocking hip streams for HCC (NVCC streams are blocking) + while(queue.m_callees>0) { + std::this_thread::sleep_for(std::chrono::milliseconds(10u)); + } +#else + // Sync is allowed even for queues on non current device. + HIP_ASSERT( hipStreamSynchronize( + queue.m_HipQueue)); +#endif +} + + + + +// --- Tests + +#define TEMPLATE_LIST_TEST_CASE(TestName) \ +template static void TestName (std::atomic &check); \ +static int TestName##Runner () { \ + std::atomic check{0}; \ + TestName< QueueHipRt >(check); \ + fprintf(stderr, "After " #TestName " < QueueHipRt > errors=%d\n", check.load()); \ + TestName< QueueHipRt >(check); \ + fprintf(stderr, "After " #TestName " < QueueHipRt > errors=%d\n", check.load()); \ + return check.load(); \ +} \ +template static void TestName (std::atomic &check) + +// add 1 if a check fails +#define CHECK(result) do{int arg=(!(result)); fprintf(stderr, "Checking " #result " %d\n", arg); check.fetch_add(arg);}while(false) + +//----------------------------------------------------------------------------- +TEMPLATE_LIST_TEST_CASE( queueIsInitiallyEmpty ) +{ + TestType queue{0}; + CHECK(Empty::empty(queue)); +} + +//----------------------------------------------------------------------------- +TEMPLATE_LIST_TEST_CASE( queueCallbackIsWorking ) +{ + std::promise promise; + auto task = [&](){ promise.set_value(true); }; + TestType queue{0}; + Enqueue enqueue; + enqueue.enqueue( + queue, + task + ); + + CHECK(promise.get_future().get()); +} + +//----------------------------------------------------------------------------- +TEMPLATE_LIST_TEST_CASE( queueWaitShouldWork ) +{ + bool CallbackFinished = false; + auto task = + [&CallbackFinished]() noexcept + { + std::this_thread::sleep_for(std::chrono::milliseconds(100u)); + CallbackFinished = true; + }; + TestType queue{0}; + Enqueue enqueue; + enqueue.enqueue( + queue, + task + ); + + currentThreadWaitFor(queue); + CHECK(CallbackFinished); +} + +//----------------------------------------------------------------------------- +TEMPLATE_LIST_TEST_CASE( queueShouldNotBeEmptyWhenLastTaskIsStillExecutingAndIsEmptyAfterProcessingFinished ) +{ + bool CallbackFinished = false; + TestType queue{0}; + auto task = [&queue, &CallbackFinished, &check]() noexcept + { + CHECK(!Empty::empty(queue)); + std::this_thread::sleep_for(std::chrono::milliseconds(100u)); + CallbackFinished = true; + }; + Enqueue enqueue; + enqueue.enqueue( + queue, + task + ); + // A non-blocking queue will always stay empty because the task has been executed immediately. + if(!TestType::isBlocking) + { + currentThreadWaitFor(queue); + } + + CHECK(Empty::empty(queue)); + CHECK(CallbackFinished); +} + +//----------------------------------------------------------------------------- +TEMPLATE_LIST_TEST_CASE( queueShouldNotExecuteTasksInParallel ) +{ + std::atomic taskIsExecuting(false); + std::promise firstTaskFinished; + std::future firstTaskFinishedFuture = firstTaskFinished.get_future(); + std::promise secondTaskFinished; + std::future secondTaskFinishedFuture = secondTaskFinished.get_future(); + + TestType queue{0}; + + std::thread thread1( + [&queue, &taskIsExecuting, &firstTaskFinished, &check]() + { + auto task1 = [&taskIsExecuting, &firstTaskFinished, &check]() noexcept + { + CHECK(!taskIsExecuting.exchange(true)); + std::this_thread::sleep_for(std::chrono::milliseconds(100u)); + CHECK(taskIsExecuting.exchange(false)); + firstTaskFinished.set_value(); + }; + Enqueue enqueue; + enqueue.enqueue( + queue, + task1 + ); + }); + + std::thread thread2( + [&queue, &taskIsExecuting, &secondTaskFinished, &check]() + { + auto task2 = [&taskIsExecuting, &secondTaskFinished, &check]() noexcept + { + CHECK(!taskIsExecuting.exchange(true)); + std::this_thread::sleep_for(std::chrono::milliseconds(100u)); + CHECK(taskIsExecuting.exchange(false)); + secondTaskFinished.set_value(); + }; + + Enqueue enqueue; + enqueue.enqueue( + queue, + task2 + ); + }); + + // Both tasks have to be enqueued + thread1.join(); + thread2.join(); + + currentThreadWaitFor(queue); + + firstTaskFinishedFuture.get(); + secondTaskFinishedFuture.get(); +} + +#define TESTER(name) do { \ + int result = name (); \ + fprintf(stderr, #name " %s\n", result?"Errors":"No Errors"); \ + if (result) { failed(#name " failed\n"); } \ +} while (false) + +int main() +{ + TESTER(queueIsInitiallyEmptyRunner); + TESTER(queueCallbackIsWorkingRunner); + TESTER(queueWaitShouldWorkRunner); + TESTER(queueShouldNotBeEmptyWhenLastTaskIsStillExecutingAndIsEmptyAfterProcessingFinishedRunner); + TESTER(queueShouldNotExecuteTasksInParallelRunner); + passed(); +}