diff --git a/inc/wddm/gpu_memory.h b/inc/wddm/gpu_memory.h index 50a7f7d3d8..4835a718f8 100644 --- a/inc/wddm/gpu_memory.h +++ b/inc/wddm/gpu_memory.h @@ -117,6 +117,7 @@ struct GpuMemoryDesc { uint32_t is_external : 1; uint32_t is_physical_only : 1; uint32_t is_locked : 1; + uint32_t is_queue_referenced : 1; uint32_t unused : 27; }; @@ -167,6 +168,9 @@ public: return (desc_.adapter_luid.HighPart == luid.HighPart && desc_.adapter_luid.LowPart == luid.LowPart); } + inline void GetQueueReference() { desc_.flags.is_queue_referenced = 1; } + inline void PutQueueReference() { desc_.flags.is_queue_referenced = 0; } + inline bool IsQueueReferenced() const { return desc_.flags.is_queue_referenced; } WinAllocationHandle GetAllocationHandle(size_t index) const { return alloc_handles_ptr_[index]; } size_t NumChunks() const { return num_allocations_; } diff --git a/inc/wddm/queue.h b/inc/wddm/queue.h index d08e654b09..7e078eb123 100644 --- a/inc/wddm/queue.h +++ b/inc/wddm/queue.h @@ -87,6 +87,7 @@ public: virtual hsa_status_t Init(void) { return HSA_STATUS_SUCCESS; } virtual hsa_status_t Fini(void) { return HSA_STATUS_SUCCESS; } virtual void RingDoorbell() { } + virtual void* GetHsaQueueAddr(void) const { return reinterpret_cast(GetCmdbufAddr()); } hsa_status_t SwsInit(void); hsa_status_t SwsFini(void); @@ -162,6 +163,7 @@ public: uint64_t GetAqlWriteIndex(void) const { return cmdbuf_aql_frame_write_index; } uint32_t GetAqlFrameSize(void) const { return cmdbuf_aql_frame_size; } + void* GetHsaQueueAddr(void) const { return ring; } bool IsInvalidPacket(void) const { uint16_t *packet = (uint16_t *)((char *)ring + @@ -278,6 +280,7 @@ public: uint64_t * GetRingRptr(void) { return WDDMQueue::GetSyncAddr(); } uint64_t * GetDoorbellPtr() { return &doorbell_; } void RingDoorbell(); + void* GetHsaQueueAddr(void) const { return reinterpret_cast(GetCmdbufAddr()); } private: uint64_t wptr_next_; diff --git a/libhsakmt.h b/libhsakmt.h index 136180ad98..0a2627ced3 100644 --- a/libhsakmt.h +++ b/libhsakmt.h @@ -177,6 +177,8 @@ bool is_forked_child(void); void clear_allocation_map(void); +bool queue_acquire_buffer(void *MemoryAddress); +bool queue_release_buffer(void *MemoryAddress); /* Calculate VGPR and SGPR register file size per CU */ uint32_t get_vgpr_size_per_cu(HSA_ENGINE_ID id); #define SGPR_SIZE_PER_CU 0x4000 diff --git a/memory.cpp b/memory.cpp index 9d97935031..04f73ef0fd 100644 --- a/memory.cpp +++ b/memory.cpp @@ -213,11 +213,55 @@ HSAKMT_STATUS HSAKMTAPI hsaKmtFreeMemory(void *MemoryAddress, allocation_map_.erase(it); } - delete gpu_mem; + if (gpu_mem->IsQueueReferenced()) + return HSAKMT_STATUS_ERROR; + delete gpu_mem; return HSAKMT_STATUS_SUCCESS; } +bool queue_acquire_buffer(void *MemoryAddress) { + if (!MemoryAddress) + return false; + + wsl::thunk::GpuMemory *gpu_mem = nullptr; + { + std::lock_guard gard(*allocation_map_lock_); + auto it = allocation_map_.find(MemoryAddress); + if (it == allocation_map_.end()) { + return HSAKMT_STATUS_ERROR; + } + + gpu_mem = wsl::thunk::GpuMemory::Convert(it->second.handle); + gpu_mem->GetQueueReference(); + } + if (gpu_mem == nullptr) + return false; + + return true; +} + +bool queue_release_buffer(void *MemoryAddress) { + if (!MemoryAddress) + return false; + + wsl::thunk::GpuMemory *gpu_mem = nullptr; + { + std::lock_guard gard(*allocation_map_lock_); + auto it = allocation_map_.find(MemoryAddress); + if (it == allocation_map_.end()) { + return HSAKMT_STATUS_ERROR; + } + + gpu_mem = wsl::thunk::GpuMemory::Convert(it->second.handle); + gpu_mem->PutQueueReference(); + } + if (gpu_mem == nullptr) + return false; + + return true; +} + HSAKMT_STATUS HSAKMTAPI hsaKmtAvailableMemory(HSAuint32 Node, HSAuint64 *AvailableBytes) { CHECK_DXG_OPEN(); @@ -531,8 +575,10 @@ HSAKMT_STATUS HSAKMTAPI hsaKmtUnmapMemoryToGPU(void *MemoryAddress) { allocation_map_.erase(it); } auto gpu_mem = wsl::thunk::GpuMemory::Convert(handle); - delete gpu_mem; + if (gpu_mem->IsQueueReferenced()) + return HSAKMT_STATUS_ERROR; + delete gpu_mem; return HSAKMT_STATUS_SUCCESS; } diff --git a/queues.cpp b/queues.cpp index 58a3ede5d3..e7b89e3529 100644 --- a/queues.cpp +++ b/queues.cpp @@ -76,6 +76,9 @@ HSAKMT_STATUS HSAKMTAPI hsaKmtCreateQueueExt(HSAuint32 NodeId, wsl::thunk::WDDMDevice *device_ = get_wddmdev(NodeId); assert(device_); + if (queue_acquire_buffer(QueueAddress) == false) + return HSAKMT_STATUS_INVALID_PARAMETER; + switch (Type) { case HSA_QUEUE_COMPUTE_AQL: { assert(QueueResource->ErrorReason == nullptr); @@ -138,11 +141,13 @@ HSAKMT_STATUS HSAKMTAPI hsaKmtDestroyQueue(HSA_QUEUEID QueueId) { CHECK_DXG_OPEN(); auto queue_ = reinterpret_cast(QueueId); + void *QueueAddress = queue_->GetHsaQueueAddr(); if (!queue_) return HSAKMT_STATUS_INVALID_PARAMETER; delete queue_; + queue_release_buffer(QueueAddress); return HSAKMT_STATUS_SUCCESS; }