diff --git a/src/hip_hcc.cpp b/src/hip_hcc.cpp index f15a0eb1d8..0ca170152b 100644 --- a/src/hip_hcc.cpp +++ b/src/hip_hcc.cpp @@ -37,6 +37,7 @@ THE SOFTWARE. #include #include #include +#include #include #include @@ -1409,9 +1410,38 @@ void ihipInit() tprintf(DB_SYNC, "pid=%u %-30s g_numLogicalThreads=%u\n", getpid(), "", g_numLogicalThreads); } +hipError_t ihipStreamSynchronize(hipStream_t stream) +{ + hipError_t e = hipSuccess; + if (stream == hipStreamNull) { + ihipCtx_t *ctx = ihipGetTlsDefaultCtx(); + ctx->locked_syncDefaultStream(true/*waitOnSelf*/, true/*syncToHost*/); + } else { + // note this does not synchornize with the NULL stream: + stream->locked_wait(); + e = hipSuccess; + } + return e; +} +void ihipStreamCallbackHandler(ihipStreamCallback_t *cb) +{ + hipError_t e = hipSuccess; + + // Notify hipStreamAddCallback that callback handler thread is active + std::lock_guard guard(cb->_mtx); + cb->_ready = true; + + // Synchronize stream + tprintf(DB_SYNC, "ihipStreamCallbackHandler wait on stream %s\n", ToString(cb->_stream).c_str()); + e = ihipStreamSynchronize(cb->_stream); + + // Call registered callback function + cb->_callback(cb->_stream, e, cb->_userData); + delete cb; +} //--- // Get the stream to use for a command submission. diff --git a/src/hip_hcc_internal.h b/src/hip_hcc_internal.h index 4891f54fee..601b66f343 100644 --- a/src/hip_hcc_internal.h +++ b/src/hip_hcc_internal.h @@ -622,6 +622,24 @@ private: // Data }; +//---- +// Internal structure for stream callback handler +class ihipStreamCallback_t { +public: + ihipStreamCallback_t(hipStream_t stream, hipStreamCallback_t callback, void *userData) : + _stream(stream), + _callback(callback), + _userData(userData) + { + _ready = false; + }; + hipStream_t _stream; + hipStreamCallback_t _callback; + void* _userData; + std::mutex _mtx; + bool _ready; +}; + //---- // Internal event structure: @@ -931,6 +949,8 @@ ihipCtx_t * ihipGetPrimaryCtx(unsigned deviceIndex); hipStream_t ihipSyncAndResolveStream(hipStream_t); +hipError_t ihipStreamSynchronize(hipStream_t stream); +void ihipStreamCallbackHandler(ihipStreamCallback_t *cb); // Stream printf functions: inline std::ostream& operator<<(std::ostream& os, const ihipStream_t& s) diff --git a/src/hip_stream.cpp b/src/hip_stream.cpp index dab31dad62..94fc436b75 100644 --- a/src/hip_stream.cpp +++ b/src/hip_stream.cpp @@ -20,6 +20,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +#include +#include #include "hip/hip_runtime.h" #include "hip_hcc_internal.h" #include "trace_helper.h" @@ -147,20 +149,8 @@ hipError_t hipStreamSynchronize(hipStream_t stream) { HIP_INIT_SPECIAL_API(TRACE_SYNC, stream); - hipError_t e = hipSuccess; - - if (stream == hipStreamNull) { - ihipCtx_t *ctx = ihipGetTlsDefaultCtx(); - ctx->locked_syncDefaultStream(true/*waitOnSelf*/, true/*syncToHost*/); - } else { - // note this does not synchornize with the NULL stream: - stream->locked_wait(); - e = hipSuccess; - } - - - return ihipLogStatus(e); -}; + return ihipLogStatus(ihipStreamSynchronize(stream)); +} //--- @@ -216,8 +206,20 @@ hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallback_t callback { HIP_INIT_API(stream, callback, userData, flags); hipError_t e = hipSuccess; - //--- explicitly synchronize stream to add callback routines - hipStreamSynchronize(stream); - callback(stream, e, userData); + + // Create a thread in detached mode to handle callback + ihipStreamCallback_t *cb = new ihipStreamCallback_t(stream, callback, userData); + std::thread (ihipStreamCallbackHandler, cb).detach(); + + // Wait for thread to be ready + cb->_mtx.lock(); + while(cb->_ready != true) + { + cb->_mtx.unlock(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + cb->_mtx.lock(); + } + cb->_mtx.unlock(); + return ihipLogStatus(e); } diff --git a/tests/src/runtimeApi/stream/hipStreamAddCallback.cpp b/tests/src/runtimeApi/stream/hipStreamAddCallback.cpp index 32a2793479..692d090509 100644 --- a/tests/src/runtimeApi/stream/hipStreamAddCallback.cpp +++ b/tests/src/runtimeApi/stream/hipStreamAddCallback.cpp @@ -23,8 +23,7 @@ THE SOFTWARE. * HIT_END */ -// Test under-development. Call hipStreamAddCallback function and see if it works as expected. - +#include #include "hip/hip_runtime.h" #include "test_common.h" @@ -32,32 +31,57 @@ THE SOFTWARE. #define HIPRT_CB #endif -class CallbackClass +__global__ void vector_square(float *C_d, float *A_d, size_t N) { -public: - static void HIPRT_CB Callback(hipStream_t stream, hipError_t status, void *userData); + size_t offset = (blockIdx.x * blockDim.x + threadIdx.x); + size_t stride = blockDim.x * gridDim.x ; -private: - void callbackFunc(hipError_t status); -}; - -void HIPRT_CB CallbackClass::Callback(hipStream_t stream, hipError_t status, void *userData) -{ - CallbackClass* obj = (CallbackClass*) userData; - obj->callbackFunc(status); + for (size_t i=offset; i