From 2e0fd95d74f6968157ced82d55f8a5bec0b2156e Mon Sep 17 00:00:00 2001 From: German Andryeyev Date: Thu, 23 Feb 2023 11:27:47 -0500 Subject: [PATCH] SWDEV-353281 - Add support for MemPool in graphs Implement hipDeviceGetGraphMemAttribute, hipDeviceSetGraphMemAttribute and hipDeviceGraphMemTrim Change-Id: I4f8fc1250ce1e8b7636d43d59ba7343158e45088 [ROCm/clr commit: 314bdba632dc8e5f3b020c062b74af1c96a89612] --- projects/clr/hipamd/src/hip_device.cpp | 24 ++++++++++--- projects/clr/hipamd/src/hip_graph.cpp | 35 ++++++++++++------- .../clr/hipamd/src/hip_graph_internal.hpp | 16 +++------ projects/clr/hipamd/src/hip_internal.hpp | 10 ++++-- projects/clr/hipamd/src/hip_mempool_impl.hpp | 6 ++-- 5 files changed, 58 insertions(+), 33 deletions(-) diff --git a/projects/clr/hipamd/src/hip_device.cpp b/projects/clr/hipamd/src/hip_device.cpp index 2b83616d77..092652d098 100644 --- a/projects/clr/hipamd/src/hip_device.cpp +++ b/projects/clr/hipamd/src/hip_device.cpp @@ -44,7 +44,7 @@ hip::Stream* Device::GetNullStream() { if (null_stream_ == nullptr) { null_stream_ = new Stream(this, Stream::Priority::Normal, 0, true); } - + if (null_stream_ == nullptr) { return nullptr; } @@ -60,6 +60,18 @@ bool Device::Create() { if (default_mem_pool_ == nullptr) { return false; } + + // Create graph memory pool + graph_mem_pool_ = new MemoryPool(this); + if (graph_mem_pool_ == nullptr) { + return false; + } + + uint64_t max_size = std::numeric_limits::max(); + // Use maximum value to hold memory, because current implementation doesn't support VM + // Note: the call for the threshold is always successful + auto error = graph_mem_pool_->SetAttribute(hipMemPoolAttrReleaseThreshold, &max_size); + // Current is default pool after device creation current_mem_pool_ = default_mem_pool_; return true; @@ -85,7 +97,7 @@ void Device::RemoveMemoryPool(MemoryPool* pool) { bool Device::FreeMemory(amd::Memory* memory, Stream* stream) { amd::ScopedLock lock(lock_); // Search for memory in the entire list of pools - for (auto& it : mem_pools_) { + for (auto it : mem_pools_) { if (it->FreeMemory(memory, stream)) { return true; } @@ -97,7 +109,7 @@ bool Device::FreeMemory(amd::Memory* memory, Stream* stream) { void Device::ReleaseFreedMemory(Stream* stream) { amd::ScopedLock lock(lock_); // Search for memory in the entire list of pools - for (auto& it : mem_pools_) { + for (auto it : mem_pools_) { it->ReleaseFreedMemory(stream); } } @@ -106,7 +118,7 @@ void Device::ReleaseFreedMemory(Stream* stream) { void Device::RemoveStreamFromPools(Stream* stream) { amd::ScopedLock lock(lock_); // Update all pools with the destroyed stream - for (auto& it : mem_pools_) { + for (auto it : mem_pools_) { it->RemoveStream(stream); } } @@ -135,6 +147,10 @@ Device::~Device() { default_mem_pool_->release(); } + if (graph_mem_pool_ != nullptr) { + graph_mem_pool_->release(); + } + if (null_stream_!= nullptr) { delete null_stream_; } diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 642a4bd76a..f3a93253cf 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -2216,54 +2216,63 @@ hipError_t hipDeviceGetGraphMemAttribute(int device, hipGraphMemAttributeType at if ((static_cast(device) >= g_devices.size()) || device < 0 || value == nullptr) { HIP_RETURN(hipErrorInvalidDevice); } - // later use this to access memory pool - auto* deviceHandle = g_devices[device]->devices()[0]; + hipError_t result = hipErrorInvalidValue; switch (attr) { case hipGraphMemAttrUsedMemCurrent: - *reinterpret_cast(value) = 0; + result = g_devices[device]->GetGraphMemoryPool()->GetAttribute( + hipMemPoolAttrUsedMemCurrent, value); break; case hipGraphMemAttrUsedMemHigh: - *reinterpret_cast(value) = 0; + result = g_devices[device]->GetGraphMemoryPool()->GetAttribute( + hipMemPoolAttrUsedMemHigh, value); break; case hipGraphMemAttrReservedMemCurrent: - *reinterpret_cast(value) = 0; + result = g_devices[device]->GetGraphMemoryPool()->GetAttribute( + hipMemPoolAttrReservedMemCurrent, value); break; case hipGraphMemAttrReservedMemHigh: - *reinterpret_cast(value) = 0; + result = g_devices[device]->GetGraphMemoryPool()->GetAttribute( + hipMemPoolAttrReservedMemHigh, value); break; default: - return HIP_RETURN(hipErrorInvalidValue); + break; } - return HIP_RETURN(hipSuccess); + return HIP_RETURN(result); } +// ================================================================================================ hipError_t hipDeviceSetGraphMemAttribute(int device, hipGraphMemAttributeType attr, void* value) { HIP_INIT_API(hipDeviceSetGraphMemAttribute, device, attr, value); if ((static_cast(device) >= g_devices.size()) || device < 0 || value == nullptr) { HIP_RETURN(hipErrorInvalidDevice); } - // later use this to access memory pool - auto* deviceHandle = g_devices[device]->devices()[0]; + hipError_t result = hipErrorInvalidValue; switch (attr) { case hipGraphMemAttrUsedMemHigh: + result = g_devices[device]->GetGraphMemoryPool()->SetAttribute( + hipMemPoolAttrUsedMemHigh, value); break; case hipGraphMemAttrReservedMemHigh: + result = g_devices[device]->GetGraphMemoryPool()->SetAttribute( + hipMemPoolAttrReservedMemHigh, value); break; default: - return HIP_RETURN(hipErrorInvalidValue); + break; } - return HIP_RETURN(hipSuccess); + return HIP_RETURN(result); } +// ================================================================================================ hipError_t hipDeviceGraphMemTrim(int device) { HIP_INIT_API(hipDeviceGraphMemTrim, device); if ((static_cast(device) >= g_devices.size()) || device < 0) { HIP_RETURN(hipErrorInvalidDevice); } - // not implemented yet + g_devices[device]->GetGraphMemoryPool()->TrimTo(0); return HIP_RETURN(hipSuccess); } +// ================================================================================================ hipError_t hipUserObjectCreate(hipUserObject_t* object_out, void* ptr, hipHostFn_t destroy, unsigned int initialRefcount, unsigned int flags) { HIP_INIT_API(hipUserObjectCreate, object_out, ptr, destroy, initialRefcount, flags); diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index 4f0b5dd31e..38f72581c0 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -405,17 +405,9 @@ struct ihipGraph { , device_(device) { amd::ScopedLock lock(graphSetLock_); graphSet_.insert(this); - if (original == nullptr) { - // Create memory pool, associated with the graph - mem_pool_ = new hip::MemoryPool(device); - uint64_t max_size = std::numeric_limits::max(); - // Note: the call for the threshold is always successful - auto error = mem_pool_->SetAttribute(hipMemPoolAttrReleaseThreshold, &max_size); - } else { - mem_pool_ = original->mem_pool_; - mem_pool_->retain(); - } - }; + mem_pool_ = device->GetGraphMemoryPool(); + mem_pool_->retain(); + } ~ihipGraph() { for (auto node : vertices_) { @@ -430,7 +422,7 @@ struct ihipGraph { mem_pool_->release(); } - }; + } void AddManualNodeDuringCapture(hipGraphNode* node) { capturedNodes_.insert(node); } diff --git a/projects/clr/hipamd/src/hip_internal.hpp b/projects/clr/hipamd/src/hip_internal.hpp index 84782cb6ef..0ebb4b5816 100644 --- a/projects/clr/hipamd/src/hip_internal.hpp +++ b/projects/clr/hipamd/src/hip_internal.hpp @@ -401,8 +401,9 @@ namespace hip { bool isActive_; - MemoryPool* default_mem_pool_; + MemoryPool* default_mem_pool_; //!< Default memory pool for this device MemoryPool* current_mem_pool_; + MemoryPool* graph_mem_pool_; //!< Memory pool, associated with graphs for this device std::set mem_pools_; @@ -412,7 +413,8 @@ namespace hip { flags_(hipDeviceScheduleSpin), isActive_(false), default_mem_pool_(nullptr), - current_mem_pool_(nullptr) + current_mem_pool_(nullptr), + graph_mem_pool_(nullptr) { assert(ctx != nullptr); } ~Device(); @@ -470,6 +472,9 @@ namespace hip { /// Get the default memory pool on the device MemoryPool* GetDefaultMemoryPool() const { return default_mem_pool_; } + /// Get the graph memory pool on the device + MemoryPool* GetGraphMemoryPool() const { return graph_mem_pool_; } + /// Add memory pool to the device void AddMemoryPool(MemoryPool* pool); @@ -484,6 +489,7 @@ namespace hip { /// Removes a destroyed stream from the safe list of memory pools void RemoveStreamFromPools(Stream* stream); + }; /// Thread Local Storage Variables Aggregator Class diff --git a/projects/clr/hipamd/src/hip_mempool_impl.hpp b/projects/clr/hipamd/src/hip_mempool_impl.hpp index 9d176b1710..5e18cb3599 100644 --- a/projects/clr/hipamd/src/hip_mempool_impl.hpp +++ b/projects/clr/hipamd/src/hip_mempool_impl.hpp @@ -213,17 +213,19 @@ public: /// Set memory pool access by different devices void GetAccess(hip::Device* device, hipMemAccessFlags* flags); + /// Frees all busy memory + void FreeAllMemory(hip::Stream* stream = nullptr); + /// Accessors for the pool state bool EventDependencies() const { return (state_.event_dependencies_) ? true : false; } bool Opportunistic() const { return (state_.opportunistic_) ? true : false; } bool InternalDependencies() const { return (state_.internal_dependencies_) ? true : false; } - void FreeAllMemory(hip::Stream* stream = nullptr); + private: MemoryPool() = delete; MemoryPool(const MemoryPool&) = delete; MemoryPool& operator=(const MemoryPool&) = delete; - Heap busy_heap_; //!< Heap of busy allocations Heap free_heap_; //!< Heap of freed allocations struct {