diff --git a/memory.cpp b/memory.cpp index 598834ea90..b6ef48cf29 100644 --- a/memory.cpp +++ b/memory.cpp @@ -39,7 +39,7 @@ struct Allocation { Allocation() : handle(0), cpu_addr(0), gpu_addr(0), size(0), userptr(false), user_data(nullptr), size_requested(0), node_id(0), mem_flags_value(0), - dmabuf_fd(-1) {} + dmabuf_fd(-1), rocr_userdata(nullptr) {} Allocation(wsl::thunk::GpuMemoryHandle handle_arg, void *cpu_addr_arg, uint64_t gpu_addr_arg, size_t size_arg, bool userptr_arg = false, void *user_data_arg = nullptr, size_t user_size_arg = 0, @@ -47,7 +47,7 @@ struct Allocation { : handle(handle_arg), cpu_addr(cpu_addr_arg), gpu_addr(gpu_addr_arg), size(size_arg), userptr(userptr_arg), user_data(user_data_arg), size_requested(user_size_arg), node_id(node_id_arg), - mem_flags_value(mem_flags_value_arg), dmabuf_fd(-1) {} + mem_flags_value(mem_flags_value_arg), dmabuf_fd(-1), rocr_userdata(nullptr) {} wsl::thunk::GpuMemoryHandle handle; void *cpu_addr; @@ -59,6 +59,7 @@ struct Allocation { HSAuint32 node_id; HSAuint32 mem_flags_value; int dmabuf_fd; + void *rocr_userdata; }; static std::map* allocation_map_ = new std::map(); @@ -938,7 +939,7 @@ HSAKMT_STATUS HSAKMTAPI hsaKmtQueryPointerInfo(const void *Pointer, PointerInfo->MemFlags.Value = allocation_info.mem_flags_value; PointerInfo->CPUAddress = allocation_info.cpu_addr; PointerInfo->GPUAddress = allocation_info.gpu_addr; - PointerInfo->UserData = allocation_info.user_data; + PointerInfo->UserData = allocation_info.rocr_userdata; return HSAKMT_STATUS_SUCCESS; } @@ -946,9 +947,17 @@ HSAKMT_STATUS HSAKMTAPI hsaKmtQueryPointerInfo(const void *Pointer, HSAKMT_STATUS HSAKMTAPI hsaKmtSetMemoryUserData(const void *Pointer, void *UserData) { CHECK_DXG_OPEN(); - pr_warn_once("not implemented\n"); - assert(false); - return HSAKMT_STATUS_SUCCESS; + + uint64_t aligned_ptr = wsl::AlignDown((uint64_t)Pointer, 4096); + + std::lock_guard gard(*allocation_map_lock_); + auto it = allocation_map_->find((void *)aligned_ptr); + if (it != allocation_map_->end()) { + it->second.rocr_userdata = UserData; + return HSAKMT_STATUS_SUCCESS; + } + + return HSAKMT_STATUS_ERROR; } HSAKMT_STATUS HSAKMTAPI hsaKmtReplaceAsanHeaderPage(void *addr) {