diff --git a/projects/clr/hipamd/src/hip_device_runtime.cpp b/projects/clr/hipamd/src/hip_device_runtime.cpp index 5b861a4d51..6d18d9f4a7 100644 --- a/projects/clr/hipamd/src/hip_device_runtime.cpp +++ b/projects/clr/hipamd/src/hip_device_runtime.cpp @@ -509,7 +509,7 @@ hipError_t hipDeviceSetLimit ( hipLimit_t limit, size_t value ) { hipError_t hipDeviceSetSharedMemConfig ( hipSharedMemConfig config ) { HIP_INIT_API(hipDeviceSetSharedMemConfig, config); if (config != hipSharedMemBankSizeDefault && - config != hipSharedMemBankSizeFourByte && + config != hipSharedMemBankSizeFourByte && config != hipSharedMemBankSizeEightByte) { HIP_RETURN(hipErrorInvalidValue); } @@ -520,7 +520,8 @@ hipError_t hipDeviceSetSharedMemConfig ( hipSharedMemConfig config ) { hipError_t hipDeviceSynchronize() { HIP_INIT_API(hipDeviceSynchronize); - hip::Stream::SyncAllStreams(hip::getCurrentDevice()->deviceId()); + constexpr bool kDontWaitForCpu = false; + hip::Stream::SyncAllStreams(hip::getCurrentDevice()->deviceId(), kDontWaitForCpu); HIP_RETURN(hipSuccess); } diff --git a/projects/clr/hipamd/src/hip_internal.hpp b/projects/clr/hipamd/src/hip_internal.hpp index 2cef6902ed..08be24d7ae 100644 --- a/projects/clr/hipamd/src/hip_internal.hpp +++ b/projects/clr/hipamd/src/hip_internal.hpp @@ -293,7 +293,7 @@ namespace hip { const std::vector GetCUMask() const { return cuMask_; } /// Sync all streams - static void SyncAllStreams(int deviceId); + static void SyncAllStreams(int deviceId, bool cpu_wait = true); /// Check whether any blocking stream running static bool StreamCaptureBlocking(); diff --git a/projects/clr/hipamd/src/hip_stream.cpp b/projects/clr/hipamd/src/hip_stream.cpp index 138fa018fc..7eab6d8d83 100644 --- a/projects/clr/hipamd/src/hip_stream.cpp +++ b/projects/clr/hipamd/src/hip_stream.cpp @@ -122,7 +122,7 @@ int Stream::DeviceId(const hipStream_t hStream) { } // ================================================================================================ -void Stream::SyncAllStreams(int deviceId) { +void Stream::SyncAllStreams(int deviceId, bool cpu_wait) { // Make a local copy to avoid stalls for GPU finish with multiple threads std::vector streams; streams.reserve(streamSet.size()); @@ -136,7 +136,7 @@ void Stream::SyncAllStreams(int deviceId) { } } for (auto it : streams) { - it->finish(); + it->finish(cpu_wait); it->release(); } } @@ -442,8 +442,9 @@ hipError_t hipStreamSynchronize_common(hipStream_t stream) { } } bool wait = (stream == nullptr) ? true : false; + constexpr bool kDontWaitForCpu = false; // Wait for the current host queue - hip::getStream(stream, wait)->finish(); + hip::getStream(stream, wait)->finish(kDontWaitForCpu); return hipSuccess; } diff --git a/projects/clr/rocclr/platform/commandqueue.cpp b/projects/clr/rocclr/platform/commandqueue.cpp index 96829e4a4c..1dc91ad05f 100644 --- a/projects/clr/rocclr/platform/commandqueue.cpp +++ b/projects/clr/rocclr/platform/commandqueue.cpp @@ -113,7 +113,7 @@ bool HostQueue::terminate() { return true; } -void HostQueue::finish() { +void HostQueue::finish(bool cpu_wait) { Command* command = nullptr; if (IS_HIP) { command = getLastQueuedCommand(true); @@ -121,7 +121,8 @@ void HostQueue::finish() { return; } } - if (nullptr == command || vdev()->isHandlerPending() || vdev()->isFenceDirty()) { + if (nullptr == command || command->type() != CL_COMMAND_MARKER || + vdev()->isHandlerPending() || vdev()->isFenceDirty()) { if (nullptr != command) { command->release(); } @@ -135,7 +136,7 @@ void HostQueue::finish() { } // Check HW status of the ROCcrl event. Note: not all ROCclr modes support HW status static constexpr bool kWaitCompletion = true; - if (!device().IsHwEventReady(command->event(), kWaitCompletion)) { + if (cpu_wait || !device().IsHwEventReady(command->event(), kWaitCompletion)) { ClPrint(LOG_DEBUG, LOG_CMD, "HW Event not ready, awaiting completion instead"); command->awaitCompletion(); } diff --git a/projects/clr/rocclr/platform/commandqueue.hpp b/projects/clr/rocclr/platform/commandqueue.hpp index f98332bf33..3f3d4f5e1b 100644 --- a/projects/clr/rocclr/platform/commandqueue.hpp +++ b/projects/clr/rocclr/platform/commandqueue.hpp @@ -231,7 +231,7 @@ class HostQueue : public CommandQueue { } //! Finish all queued commands - void finish(); + void finish(bool cpu_wait = false); //! Check if hostQueue empty snapshot bool isEmpty();