From 6c174f79fab0848bac19859a2e7904177e672cd9 Mon Sep 17 00:00:00 2001 From: sdashmiz Date: Wed, 21 Jun 2023 14:43:57 -0400 Subject: [PATCH] SWDEV-405485 - move the param validation Signed-off-by: sdashmiz Change-Id: Ic3a27c47a88954da866a91494bcfb8721f33ad2b [ROCm/clr commit: 2cdaf7e80f93910ccd31fea58b95609a26cd531e] --- projects/clr/hipamd/src/hip_event.cpp | 1 + projects/clr/hipamd/src/hip_stream.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/projects/clr/hipamd/src/hip_event.cpp b/projects/clr/hipamd/src/hip_event.cpp index c0923adf0d..50e5c091b8 100644 --- a/projects/clr/hipamd/src/hip_event.cpp +++ b/projects/clr/hipamd/src/hip_event.cpp @@ -389,6 +389,7 @@ hipError_t hipEventRecord_common(hipEvent_t event, hipStream_t stream) { if (event == nullptr) { return hipErrorInvalidHandle; } + getStreamPerThread(stream); if (!hip::isValid(stream)) { return hipErrorContextIsDestroyed; } diff --git a/projects/clr/hipamd/src/hip_stream.cpp b/projects/clr/hipamd/src/hip_stream.cpp index c71afbcf89..f3dceeae80 100644 --- a/projects/clr/hipamd/src/hip_stream.cpp +++ b/projects/clr/hipamd/src/hip_stream.cpp @@ -31,7 +31,7 @@ namespace hip { // ================================================================================================ Stream::Stream(hip::Device* dev, Priority p, unsigned int f, bool null_stream, const std::vector& cuMask, hipStreamCaptureStatus captureStatus) - : amd::HostQueue(*dev->asContext(), *dev->devices()[0], 0, amd::CommandQueue::RealTimeDisabled, + : amd::HostQueue(*dev->asContext(), *dev->devices()[0], 0, amd::CommandQueue::RealTimeDisabled, convertToQueuePriority(p), cuMask), lock_("Stream Callback lock"), device_(dev), @@ -518,9 +518,6 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig if (event == nullptr) { return hipErrorInvalidHandle; } - if (stream == nullptr) { - return hipErrorInvalidValue; - } if (!hip::isValid(stream)) { return hipErrorContextIsDestroyed; } @@ -529,6 +526,9 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig hip::Stream* eventStream = reinterpret_cast(e->GetCaptureStream()); if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) { + if (waitStream == nullptr) { + return hipErrorInvalidHandle; + } if (!waitStream->IsOriginStream()) { waitStream->SetCaptureGraph((eventStream)->GetCaptureGraph()); waitStream->SetCaptureId((eventStream)->GetCaptureID());