SWDEV-353281 - Add support for MemPool in graphs

Implement hipDeviceGetGraphMemAttribute, hipDeviceSetGraphMemAttribute
and hipDeviceGraphMemTrim

Change-Id: I4f8fc1250ce1e8b7636d43d59ba7343158e45088


[ROCm/clr commit: 314bdba632]
Esse commit está contido em:
German Andryeyev
2023-02-23 11:27:47 -05:00
commit 2e0fd95d74
5 arquivos alterados com 58 adições e 33 exclusões
+20 -4
Ver Arquivo
@@ -44,7 +44,7 @@ hip::Stream* Device::GetNullStream() {
if (null_stream_ == nullptr) {
null_stream_ = new Stream(this, Stream::Priority::Normal, 0, true);
}
if (null_stream_ == nullptr) {
return nullptr;
}
@@ -60,6 +60,18 @@ bool Device::Create() {
if (default_mem_pool_ == nullptr) {
return false;
}
// Create graph memory pool
graph_mem_pool_ = new MemoryPool(this);
if (graph_mem_pool_ == nullptr) {
return false;
}
uint64_t max_size = std::numeric_limits<uint64_t>::max();
// Use maximum value to hold memory, because current implementation doesn't support VM
// Note: the call for the threshold is always successful
auto error = graph_mem_pool_->SetAttribute(hipMemPoolAttrReleaseThreshold, &max_size);
// Current is default pool after device creation
current_mem_pool_ = default_mem_pool_;
return true;
@@ -85,7 +97,7 @@ void Device::RemoveMemoryPool(MemoryPool* pool) {
bool Device::FreeMemory(amd::Memory* memory, Stream* stream) {
amd::ScopedLock lock(lock_);
// Search for memory in the entire list of pools
for (auto& it : mem_pools_) {
for (auto it : mem_pools_) {
if (it->FreeMemory(memory, stream)) {
return true;
}
@@ -97,7 +109,7 @@ bool Device::FreeMemory(amd::Memory* memory, Stream* stream) {
void Device::ReleaseFreedMemory(Stream* stream) {
amd::ScopedLock lock(lock_);
// Search for memory in the entire list of pools
for (auto& it : mem_pools_) {
for (auto it : mem_pools_) {
it->ReleaseFreedMemory(stream);
}
}
@@ -106,7 +118,7 @@ void Device::ReleaseFreedMemory(Stream* stream) {
void Device::RemoveStreamFromPools(Stream* stream) {
amd::ScopedLock lock(lock_);
// Update all pools with the destroyed stream
for (auto& it : mem_pools_) {
for (auto it : mem_pools_) {
it->RemoveStream(stream);
}
}
@@ -135,6 +147,10 @@ Device::~Device() {
default_mem_pool_->release();
}
if (graph_mem_pool_ != nullptr) {
graph_mem_pool_->release();
}
if (null_stream_!= nullptr) {
delete null_stream_;
}
+22 -13
Ver Arquivo
@@ -2216,54 +2216,63 @@ hipError_t hipDeviceGetGraphMemAttribute(int device, hipGraphMemAttributeType at
if ((static_cast<size_t>(device) >= g_devices.size()) || device < 0 || value == nullptr) {
HIP_RETURN(hipErrorInvalidDevice);
}
// later use this to access memory pool
auto* deviceHandle = g_devices[device]->devices()[0];
hipError_t result = hipErrorInvalidValue;
switch (attr) {
case hipGraphMemAttrUsedMemCurrent:
*reinterpret_cast<int32_t*>(value) = 0;
result = g_devices[device]->GetGraphMemoryPool()->GetAttribute(
hipMemPoolAttrUsedMemCurrent, value);
break;
case hipGraphMemAttrUsedMemHigh:
*reinterpret_cast<int32_t*>(value) = 0;
result = g_devices[device]->GetGraphMemoryPool()->GetAttribute(
hipMemPoolAttrUsedMemHigh, value);
break;
case hipGraphMemAttrReservedMemCurrent:
*reinterpret_cast<int32_t*>(value) = 0;
result = g_devices[device]->GetGraphMemoryPool()->GetAttribute(
hipMemPoolAttrReservedMemCurrent, value);
break;
case hipGraphMemAttrReservedMemHigh:
*reinterpret_cast<int32_t*>(value) = 0;
result = g_devices[device]->GetGraphMemoryPool()->GetAttribute(
hipMemPoolAttrReservedMemHigh, value);
break;
default:
return HIP_RETURN(hipErrorInvalidValue);
break;
}
return HIP_RETURN(hipSuccess);
return HIP_RETURN(result);
}
// ================================================================================================
hipError_t hipDeviceSetGraphMemAttribute(int device, hipGraphMemAttributeType attr, void* value) {
HIP_INIT_API(hipDeviceSetGraphMemAttribute, device, attr, value);
if ((static_cast<size_t>(device) >= g_devices.size()) || device < 0 || value == nullptr) {
HIP_RETURN(hipErrorInvalidDevice);
}
// later use this to access memory pool
auto* deviceHandle = g_devices[device]->devices()[0];
hipError_t result = hipErrorInvalidValue;
switch (attr) {
case hipGraphMemAttrUsedMemHigh:
result = g_devices[device]->GetGraphMemoryPool()->SetAttribute(
hipMemPoolAttrUsedMemHigh, value);
break;
case hipGraphMemAttrReservedMemHigh:
result = g_devices[device]->GetGraphMemoryPool()->SetAttribute(
hipMemPoolAttrReservedMemHigh, value);
break;
default:
return HIP_RETURN(hipErrorInvalidValue);
break;
}
return HIP_RETURN(hipSuccess);
return HIP_RETURN(result);
}
// ================================================================================================
hipError_t hipDeviceGraphMemTrim(int device) {
HIP_INIT_API(hipDeviceGraphMemTrim, device);
if ((static_cast<size_t>(device) >= g_devices.size()) || device < 0) {
HIP_RETURN(hipErrorInvalidDevice);
}
// not implemented yet
g_devices[device]->GetGraphMemoryPool()->TrimTo(0);
return HIP_RETURN(hipSuccess);
}
// ================================================================================================
hipError_t hipUserObjectCreate(hipUserObject_t* object_out, void* ptr, hipHostFn_t destroy,
unsigned int initialRefcount, unsigned int flags) {
HIP_INIT_API(hipUserObjectCreate, object_out, ptr, destroy, initialRefcount, flags);
+4 -12
Ver Arquivo
@@ -405,17 +405,9 @@ struct ihipGraph {
, device_(device) {
amd::ScopedLock lock(graphSetLock_);
graphSet_.insert(this);
if (original == nullptr) {
// Create memory pool, associated with the graph
mem_pool_ = new hip::MemoryPool(device);
uint64_t max_size = std::numeric_limits<uint64_t>::max();
// Note: the call for the threshold is always successful
auto error = mem_pool_->SetAttribute(hipMemPoolAttrReleaseThreshold, &max_size);
} else {
mem_pool_ = original->mem_pool_;
mem_pool_->retain();
}
};
mem_pool_ = device->GetGraphMemoryPool();
mem_pool_->retain();
}
~ihipGraph() {
for (auto node : vertices_) {
@@ -430,7 +422,7 @@ struct ihipGraph {
mem_pool_->release();
}
};
}
void AddManualNodeDuringCapture(hipGraphNode* node) { capturedNodes_.insert(node); }
+8 -2
Ver Arquivo
@@ -401,8 +401,9 @@ namespace hip {
bool isActive_;
MemoryPool* default_mem_pool_;
MemoryPool* default_mem_pool_; //!< Default memory pool for this device
MemoryPool* current_mem_pool_;
MemoryPool* graph_mem_pool_; //!< Memory pool, associated with graphs for this device
std::set<MemoryPool*> mem_pools_;
@@ -412,7 +413,8 @@ namespace hip {
flags_(hipDeviceScheduleSpin),
isActive_(false),
default_mem_pool_(nullptr),
current_mem_pool_(nullptr)
current_mem_pool_(nullptr),
graph_mem_pool_(nullptr)
{ assert(ctx != nullptr); }
~Device();
@@ -470,6 +472,9 @@ namespace hip {
/// Get the default memory pool on the device
MemoryPool* GetDefaultMemoryPool() const { return default_mem_pool_; }
/// Get the graph memory pool on the device
MemoryPool* GetGraphMemoryPool() const { return graph_mem_pool_; }
/// Add memory pool to the device
void AddMemoryPool(MemoryPool* pool);
@@ -484,6 +489,7 @@ namespace hip {
/// Removes a destroyed stream from the safe list of memory pools
void RemoveStreamFromPools(Stream* stream);
};
/// Thread Local Storage Variables Aggregator Class
+4 -2
Ver Arquivo
@@ -213,17 +213,19 @@ public:
/// Set memory pool access by different devices
void GetAccess(hip::Device* device, hipMemAccessFlags* flags);
/// Frees all busy memory
void FreeAllMemory(hip::Stream* stream = nullptr);
/// Accessors for the pool state
bool EventDependencies() const { return (state_.event_dependencies_) ? true : false; }
bool Opportunistic() const { return (state_.opportunistic_) ? true : false; }
bool InternalDependencies() const { return (state_.internal_dependencies_) ? true : false; }
void FreeAllMemory(hip::Stream* stream = nullptr);
private:
MemoryPool() = delete;
MemoryPool(const MemoryPool&) = delete;
MemoryPool& operator=(const MemoryPool&) = delete;
Heap busy_heap_; //!< Heap of busy allocations
Heap free_heap_; //!< Heap of freed allocations
struct {