From 4bb028a49daafdcc3066b785bbe5f3caefed5967 Mon Sep 17 00:00:00 2001 From: German Andryeyev Date: Mon, 4 Mar 2024 17:09:52 -0500 Subject: [PATCH] SWDEV-311271 - Add dependency tracking for streams Mempool has capability to track dependency between streams for faster memory reuse. Enable that capability. Change-Id: I28266a7e38d0fc4c5d027b9542d3719653840821 [ROCm/clr commit: 17d0c166d2e04420c98d763e04b5789c285d134f] --- projects/clr/hipamd/src/hip_device.cpp | 9 +++++++ projects/clr/hipamd/src/hip_internal.hpp | 3 +++ projects/clr/hipamd/src/hip_mempool_impl.hpp | 28 +++++++++++++++++--- projects/clr/hipamd/src/hip_stream.cpp | 12 ++++++--- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/projects/clr/hipamd/src/hip_device.cpp b/projects/clr/hipamd/src/hip_device.cpp index cc8cfdacf4..cd2ecb18e8 100644 --- a/projects/clr/hipamd/src/hip_device.cpp +++ b/projects/clr/hipamd/src/hip_device.cpp @@ -127,6 +127,15 @@ void Device::RemoveStreamFromPools(Stream* stream) { } } +// ================================================================================================ +void Device::AddSafeStream(Stream* event_stream, Stream* wait_stream) { + amd::ScopedLock lock(lock_); + // Update all pools with the safe streams + for (auto it : mem_pools_) { + it->AddSafeStream(event_stream, wait_stream); + } +} + // ================================================================================================ void Device::Reset() { { diff --git a/projects/clr/hipamd/src/hip_internal.hpp b/projects/clr/hipamd/src/hip_internal.hpp index ec20e47873..46a03e0b4d 100644 --- a/projects/clr/hipamd/src/hip_internal.hpp +++ b/projects/clr/hipamd/src/hip_internal.hpp @@ -535,6 +535,9 @@ public: /// Removes a destroyed stream from the safe list of memory pools void RemoveStreamFromPools(Stream* stream); + /// Add safe streams into the memppools for reuse + void AddSafeStream(Stream* event_stream, Stream* wait_stream); + /// Returns true if memory pool is valid on this device bool IsMemoryPoolValid(MemoryPool* pool); }; diff --git a/projects/clr/hipamd/src/hip_mempool_impl.hpp b/projects/clr/hipamd/src/hip_mempool_impl.hpp index fd33f00bf6..c4dd51f8ca 100644 --- a/projects/clr/hipamd/src/hip_mempool_impl.hpp +++ b/projects/clr/hipamd/src/hip_mempool_impl.hpp @@ -46,9 +46,15 @@ struct MemoryTimestamp { MemoryTimestamp(): event_(nullptr) {} /// Adds a safe stream to the list of stream for possible reuse - void AddSafeStream(hip::Stream* stream) { - if (safe_streams_.find(stream) != safe_streams_.end()) { - safe_streams_.insert(stream); + void AddSafeStream(Stream* event_stream, Stream* wait_stream = nullptr) { + if (wait_stream == nullptr) { + if (safe_streams_.find(event_stream) == safe_streams_.end()) { + safe_streams_.insert(event_stream); + } + } else { + if (safe_streams_.find(event_stream) != safe_streams_.end()) { + safe_streams_.insert(wait_stream); + } } } /// Changes last known valid event asociated with memory @@ -144,6 +150,14 @@ public: /// Erases single allocation form the heap's map SortedMap::iterator EraseAllocaton(SortedMap::iterator& it); + /// Add a safe stream for quick looks-ups in all allocations + void AddSafeStream(Stream* event_stream, Stream* wait_stream) { + for (auto& it : allocations_) { + it.second.AddSafeStream(event_stream, wait_stream); + } + } + + /// Checks if memory belongs to this heap bool IsActiveMemory(amd::Memory* memory) const { return (allocations_.find({memory->getSize(), memory}) != allocations_.end()); @@ -233,6 +247,14 @@ class MemoryPool : public amd::ReferenceCountedObject { void AddBusyMemory(amd::Memory* memory) { busy_heap_.AddMemory(memory, nullptr); } + + /// Add a safe stream for quick looks-ups if event dependencies option is enabled + void AddSafeStream(Stream* event_stream, Stream* wait_stream) { + if (EventDependencies()) { + free_heap_.AddSafeStream(event_stream, wait_stream); + } + } + /// Trims the pool until it has only min_bytes_to_hold void TrimTo(size_t min_bytes_to_hold); diff --git a/projects/clr/hipamd/src/hip_stream.cpp b/projects/clr/hipamd/src/hip_stream.cpp index e0395ba9e7..9315082883 100644 --- a/projects/clr/hipamd/src/hip_stream.cpp +++ b/projects/clr/hipamd/src/hip_stream.cpp @@ -553,10 +553,14 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig if (flags != 0) { return hipErrorInvalidValue; } - if ((eventStream != nullptr) && - (eventStream->GetCaptureStatus() == hipStreamCaptureStatusActive)) { - // If stream is capturing but event is not recorded on event's stream. - return hipErrorStreamCaptureIsolation; + if (eventStream != nullptr) { + if (eventStream->GetCaptureStatus() == hipStreamCaptureStatusActive) { + // If stream is capturing but event is not recorded on event's stream. + return hipErrorStreamCaptureIsolation; + } + if (eventStream->DeviceId() == waitStream->DeviceId()) { + eventStream->GetDevice()->AddSafeStream(eventStream, waitStream); + } } status = e->streamWait(stream, flags); }