From 9a3d3aef8c2a5fa855333fb77235bd82c97a75bc Mon Sep 17 00:00:00 2001 From: sdashmiz Date: Thu, 4 May 2023 10:25:50 -0400 Subject: [PATCH] SWDEV-360031 - Event record clean up Signed-off-by: sdashmiz Change-Id: Ic7b5edd501c5f61b2bce87252ac040cbc4083006 [ROCm/clr commit: 1e9dfdfe99945d6bacac5bff785f6c301b1765be] --- projects/clr/hipamd/src/hip_event.cpp | 26 ++++++++--- projects/clr/hipamd/src/hip_event.hpp | 16 ------- projects/clr/hipamd/src/hip_graph.cpp | 44 ------------------- projects/clr/hipamd/src/hip_graph_capture.hpp | 4 -- projects/clr/hipamd/src/hip_internal.hpp | 19 +++++--- projects/clr/hipamd/src/hip_stream.cpp | 42 +++++++++++++----- 6 files changed, 62 insertions(+), 89 deletions(-) diff --git a/projects/clr/hipamd/src/hip_event.cpp b/projects/clr/hipamd/src/hip_event.cpp index 601a1f5bca..c0923adf0d 100644 --- a/projects/clr/hipamd/src/hip_event.cpp +++ b/projects/clr/hipamd/src/hip_event.cpp @@ -383,18 +383,32 @@ hipError_t hipEventElapsedTime(float* ms, hipEvent_t start, hipEvent_t stop) { } hipError_t hipEventRecord_common(hipEvent_t event, hipStream_t stream) { - STREAM_CAPTURE(hipEventRecord, stream, event); - + ClPrint(amd::LOG_INFO, amd::LOG_API, + "[hipGraph] current capture node EventRecord on stream : %p, Event %p", stream, event); + hipError_t status = hipSuccess; if (event == nullptr) { return hipErrorInvalidHandle; } + if (!hip::isValid(stream)) { + return hipErrorContextIsDestroyed; + } hip::Event* e = reinterpret_cast(event); + hip::Stream* s = reinterpret_cast(stream); hip::Stream* hip_stream = hip::getStream(stream); e->SetCaptureStream(stream); - if (g_devices[e->deviceId()]->devices()[0] != &hip_stream->device()) { - return hipErrorInvalidHandle; + if ((s != nullptr) && (s->GetCaptureStatus() == hipStreamCaptureStatusActive)) { + s->SetCaptureEvent(event); + std::vector lastCapturedNodes = s->GetLastCapturedNodes(); + if (!lastCapturedNodes.empty()) { + e->SetNodesPrevToRecorded(lastCapturedNodes); + } + } else { + if (g_devices[e->deviceId()]->devices()[0] != &hip_stream->device()) { + return hipErrorInvalidHandle; + } + status = e->addMarker(stream, nullptr, true); } - return e->addMarker(stream, nullptr, true); + return status; } hipError_t hipEventRecord(hipEvent_t event, hipStream_t stream) { @@ -417,7 +431,7 @@ hipError_t hipEventSynchronize(hipEvent_t event) { hip::Event* e = reinterpret_cast(event); hip::Stream* s = reinterpret_cast(e->GetCaptureStream()); if ((s != nullptr) && (s->GetCaptureStatus() == hipStreamCaptureStatusActive)) { - if (e->GetCaptureStatus() == false) { + if (s->IsEventCaptured(event) == false) { return HIP_RETURN(hipErrorStreamCaptureUnsupported); } } diff --git a/projects/clr/hipamd/src/hip_event.hpp b/projects/clr/hipamd/src/hip_event.hpp index 74f57ca163..bb6f6a239c 100644 --- a/projects/clr/hipamd/src/hip_event.hpp +++ b/projects/clr/hipamd/src/hip_event.hpp @@ -91,8 +91,6 @@ class EventMarker : public amd::Marker { enum eventType { Query, StreamWait, ElapsedTime }; class Event { - /// event recorded on stream where capture is active - bool onCapture_; /// capture stream where event is recorded hipStream_t captureStream_ = nullptr; /// Previous captured nodes before event record @@ -112,7 +110,6 @@ class Event { Event(unsigned int flags) : flags(flags), lock_("hipEvent_t", true), event_(nullptr), unrecorded_(false), stream_(nullptr) { // No need to init event_ here as addMarker does that - onCapture_ = false; device_id_ = hip::getCurrentDevice()->deviceId(); // Created in current device ctx } @@ -151,19 +148,6 @@ class Event { const int deviceId() const { return device_id_; } void setDeviceId(int id) { device_id_ = id; } amd::Event* event() { return event_; } - - /// End capture on this event - void EndCapture() { - onCapture_ = false; - captureStream_ = nullptr; - } - /// Start capture when waited on this event - void StartCapture(hipStream_t stream) { - onCapture_ = true; - captureStream_ = stream; - } - /// Get capture status of the graph - bool GetCaptureStatus() const { return onCapture_; } /// Get capture stream where event is recorded hipStream_t GetCaptureStream() const { return captureStream_; } /// Set capture stream where event is recorded diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 9e695287a9..dd7a4982f0 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -786,50 +786,6 @@ hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDe return hipSuccess; } -hipError_t capturehipEventRecord(hipStream_t& stream, hipEvent_t& event) { - ClPrint(amd::LOG_INFO, amd::LOG_API, - "[hipGraph] current capture node EventRecord on stream : %p, Event %p", stream, event); - if (event == nullptr) { - return hipErrorInvalidHandle; - } - if (!hip::isValid(stream)) { - return hipErrorContextIsDestroyed; - } - hip::Event* e = reinterpret_cast(event); - e->StartCapture(stream); - hip::Stream* s = reinterpret_cast(stream); - s->SetCaptureEvent(event); - std::vector lastCapturedNodes = s->GetLastCapturedNodes(); - if (!lastCapturedNodes.empty()) { - e->SetNodesPrevToRecorded(lastCapturedNodes); - } - return hipSuccess; -} - -hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, unsigned int& flags) { - ClPrint(amd::LOG_INFO, amd::LOG_API, - "[hipGraph] current capture node StreamWaitEvent on stream : %p, Event %p", stream, - event); - if (!hip::isValid(stream)) { - return hipErrorContextIsDestroyed; - } - hip::Stream* s = reinterpret_cast(stream); - hip::Event* e = reinterpret_cast(event); - - if (event == nullptr || stream == nullptr) { - return hipErrorInvalidValue; - } - if (!s->IsOriginStream()) { - s->SetCaptureGraph(reinterpret_cast(e->GetCaptureStream())->GetCaptureGraph()); - s->SetCaptureId(reinterpret_cast(e->GetCaptureStream())->GetCaptureID()); - s->SetCaptureMode(reinterpret_cast(e->GetCaptureStream())->GetCaptureMode()); - s->SetParentStream(e->GetCaptureStream()); - reinterpret_cast(s->GetParentStream())->SetParallelCaptureStream(stream); - } - s->AddCrossCapturedNode(e->GetNodesPrevToRecorded()); - return hipSuccess; -} - hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*& userData) { ClPrint(amd::LOG_INFO, amd::LOG_API, "[hipGraph] current capture node host on stream : %p", stream); diff --git a/projects/clr/hipamd/src/hip_graph_capture.hpp b/projects/clr/hipamd/src/hip_graph_capture.hpp index 51226ac13e..443b19147d 100644 --- a/projects/clr/hipamd/src/hip_graph_capture.hpp +++ b/projects/clr/hipamd/src/hip_graph_capture.hpp @@ -100,10 +100,6 @@ hipError_t capturehipMemset2DAsync(hipStream_t& stream, void*& dst, size_t& pitc hipError_t capturehipMemset3DAsync(hipStream_t& stream, hipPitchedPtr& pitchedDevPtr, int& value, hipExtent& extent); -hipError_t capturehipEventRecord(hipStream_t& stream, hipEvent_t& event); - -hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, unsigned int& flags); - hipError_t capturehipLaunchHostFunc(hipStream_t& stream, hipHostFn_t& fn, void*& userData); hipError_t capturehipMallocAsync(hipStream_t stream, hipMemPool_t mem_pool, size_t size, void** dev_ptr); diff --git a/projects/clr/hipamd/src/hip_internal.hpp b/projects/clr/hipamd/src/hip_internal.hpp index c7d51b95d4..1fe9d1d17e 100644 --- a/projects/clr/hipamd/src/hip_internal.hpp +++ b/projects/clr/hipamd/src/hip_internal.hpp @@ -190,12 +190,6 @@ extern amd::Monitor g_hipInitlock; return status; \ } -#define EVENT_CAPTURE(name, event, ...) \ - if (event != nullptr && reinterpret_cast(event)->GetCaptureStatus() == true) { \ - hipError_t status = capture##name(event, ##__VA_ARGS__); \ - HIP_RETURN(status); \ - } - #define PER_THREAD_DEFAULT_STREAM(stream) \ if (stream == nullptr) { \ stream = getPerThreadDefaultStream(); \ @@ -369,8 +363,19 @@ namespace hip { } /// Get Capture ID unsigned long long GetCaptureID() { return captureID_; } - void SetCaptureEvent(hipEvent_t e) { captureEvents_.emplace(e); } + void SetCaptureEvent(hipEvent_t e) { + amd::ScopedLock lock(lock_); + captureEvents_.emplace(e); } + bool IsEventCaptured(hipEvent_t e) { + amd::ScopedLock lock(lock_); + auto it = captureEvents_.find(e); + if (it != captureEvents_.end()) { + return true; + } + return false; + } void EraseCaptureEvent(hipEvent_t e) { + amd::ScopedLock lock(lock_); auto it = captureEvents_.find(e); if (it != captureEvents_.end()) { captureEvents_.erase(it); diff --git a/projects/clr/hipamd/src/hip_stream.cpp b/projects/clr/hipamd/src/hip_stream.cpp index 7f832af8aa..c71afbcf89 100644 --- a/projects/clr/hipamd/src/hip_stream.cpp +++ b/projects/clr/hipamd/src/hip_stream.cpp @@ -51,7 +51,7 @@ Stream::Stream(hip::Device* dev, Priority p, unsigned int f, bool null_stream, hipError_t Stream::EndCapture() { for (auto event : captureEvents_) { hip::Event* e = reinterpret_cast(event); - e->EndCapture(); + e->SetCaptureStream(nullptr); } for (auto stream : parallelCaptureStreams_) { hip::Stream* s = reinterpret_cast(stream); @@ -511,26 +511,44 @@ void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_d // ================================================================================================ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsigned int flags) { - EVENT_CAPTURE(hipStreamWaitEvent, event, stream, flags); - + ClPrint(amd::LOG_INFO, amd::LOG_API, + "[hipGraph] current capture node StreamWaitEvent on stream : %p, Event %p", stream, + event); + hipError_t status = hipSuccess; if (event == nullptr) { return hipErrorInvalidHandle; } - - if (flags != 0 || !hip::isValid(stream)) { + if (stream == nullptr) { return hipErrorInvalidValue; } - + if (!hip::isValid(stream)) { + return hipErrorContextIsDestroyed; + } + hip::Stream* waitStream = reinterpret_cast(stream); hip::Event* e = reinterpret_cast(event); - if ((e->GetCaptureStream() != nullptr) && - (reinterpret_cast(e->GetCaptureStream())->GetCaptureStatus() - == hipStreamCaptureStatusActive)) { - // If stream is capturing but event is not recorded on event's stream. - if (e->GetCaptureStatus() == false) { + hip::Stream* eventStream = reinterpret_cast(e->GetCaptureStream()); + + if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) { + if (!waitStream->IsOriginStream()) { + waitStream->SetCaptureGraph((eventStream)->GetCaptureGraph()); + waitStream->SetCaptureId((eventStream)->GetCaptureID()); + waitStream->SetCaptureMode((eventStream)->GetCaptureMode()); + waitStream->SetParentStream(reinterpret_cast(eventStream)); + eventStream->SetParallelCaptureStream(stream); + } + waitStream->AddCrossCapturedNode(e->GetNodesPrevToRecorded()); + } else { + if (flags != 0) { + return hipErrorInvalidValue; + } + if ((eventStream != nullptr) && + (eventStream->GetCaptureStatus() == hipStreamCaptureStatusActive)) { + // If stream is capturing but event is not recorded on event's stream. return hipErrorStreamCaptureIsolation; } + status = e->streamWait(stream, flags); } - return e->streamWait(stream, flags); + return status; } // ================================================================================================