From b3c4e94e70e6b4ee65ca703b6e8758343e5e8cec Mon Sep 17 00:00:00 2001 From: hongkzha-amd Date: Wed, 14 Jan 2026 02:08:20 +0800 Subject: [PATCH] rocr: Improve memory protection and WSL compatibility (#2274) * rocr: Add ProtectMemory API and use it in RemoveAccess Replace munmap + mmap with mprotect when removing memory access. This improves performance by 5-10x, ensures atomicity (no race condition window), and prepares for WSL/DXG compatibility fixes. Suggested-by: David Yat Sin Signed-off-by: Flora Cui Signed-off-by: Horatio Zhang * rocr: Skip CPU mapping operations on WSL On WSL, CPU cannot access GPU VRAM due to platform restrictions. CPU access would fault-in system RAM instead, causing data corruption and memory leaks. Return HSA_STATUS_ERROR to fail fast rather than silently creating broken mappings. GPU-to-GPU mappings remain functional. Signed-off-by: Flora Cui Signed-off-by: Horatio Zhang * rocr: reduce ifdef linux v2: Fix IsDXG check logic Signed-off-by: David Yat Sin Signed-off-by: Horatio Zhang --------- Signed-off-by: Horatio Zhang Signed-off-by: David Yat Sin Signed-off-by: Flora Cui --- .../hsa-runtime/core/runtime/runtime.cpp | 31 ++++++++++--------- .../hsa-runtime/core/util/lnx/os_linux.cpp | 4 +++ .../runtime/hsa-runtime/core/util/os.h | 2 ++ .../hsa-runtime/core/util/win/os_win.cpp | 8 +++++ 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/projects/rocr-runtime/runtime/hsa-runtime/core/runtime/runtime.cpp b/projects/rocr-runtime/runtime/hsa-runtime/core/runtime/runtime.cpp index 965a0d5c63..d23f889637 100644 --- a/projects/rocr-runtime/runtime/hsa-runtime/core/runtime/runtime.cpp +++ b/projects/rocr-runtime/runtime/hsa-runtime/core/runtime/runtime.cpp @@ -3810,11 +3810,13 @@ Runtime::MappedHandleAllowedAgent::~MappedHandleAllowedAgent() { hsa_status_t Runtime::MappedHandleAllowedAgent::EnableAccess(hsa_access_permission_t perms) { if (targetAgent->device_type() == core::Agent::DeviceType::kAmdCpuDevice) { - if (!core::Runtime::runtime_singleton_->thunkLoader()->IsDXG()) { - if (!rocr::os::MapMemory(va, size, PermissionsToMemProt(perms), mappedHandle->drm_fd, - reinterpret_cast(mappedHandle->drm_cpu_addr))) { - return HSA_STATUS_ERROR; - } +#if defined(__linux__) + if (core::Runtime::runtime_singleton_->thunkLoader()->IsDXG()) return HSA_STATUS_ERROR; +#endif + + if (!rocr::os::MapMemory(va, size, PermissionsToMemProt(perms), mappedHandle->drm_fd, + reinterpret_cast(mappedHandle->drm_cpu_addr))) { + return HSA_STATUS_ERROR; } } else { hsa_status_t status = targetAgent->driver().Map( @@ -3829,12 +3831,11 @@ hsa_status_t Runtime::MappedHandleAllowedAgent::EnableAccess(hsa_access_permissi hsa_status_t Runtime::MappedHandleAllowedAgent::RemoveAccess() { if (targetAgent->device_type() == core::Agent::DeviceType::kAmdCpuDevice) { if (permissions != HSA_ACCESS_PERMISSION_NONE) { +#if defined(__linux__) + if (core::Runtime::runtime_singleton_->thunkLoader()->IsDXG()) return HSA_STATUS_ERROR; +#endif hsa_access_permission_t perms = HSA_ACCESS_PERMISSION_NONE; - if (!rocr::os::UnmapMemory(va, size)) { - return HSA_STATUS_ERROR; - } - if (!rocr::os::MapMemory(va, size, PermissionsToMemProt(perms), mappedHandle->drm_fd, - reinterpret_cast(mappedHandle->drm_cpu_addr))) { + if (!rocr::os::ProtectMemory(va, size, PermissionsToMemProt(perms))) { return HSA_STATUS_ERROR; } permissions = perms; @@ -3855,17 +3856,19 @@ Runtime::MappedHandle::MappedHandle(MemoryHandle *mem_handle, AddressHandle *add { /* Create a CPU mapping with PROT_NONE */ #if defined(__linux__) + if (core::Runtime::runtime_singleton_->thunkLoader()->IsDXG()) return; + #endif + auto cpu_agent = static_cast(agentOwner())->GetNearestCpuAgent(); auto agentPermsIt = allowed_agents.emplace(std::piecewise_construct, - std::forward_as_tuple(cpu_agent), - std::forward_as_tuple(this, cpu_agent, va, - size, HSA_ACCESS_PERMISSION_NONE)) + std::forward_as_tuple(cpu_agent), + std::forward_as_tuple(this, cpu_agent, va, + size, HSA_ACCESS_PERMISSION_NONE)) .first; auto ret = agentPermsIt->second.EnableAccess(HSA_ACCESS_PERMISSION_NONE); if (ret != HSA_STATUS_SUCCESS) throw AMD::hsa_exception(ret, "Failed to create default CPU mapping"); - #endif } // Note: VMemorySetAccessPerHandle should be called with &memory_lock_ held diff --git a/projects/rocr-runtime/runtime/hsa-runtime/core/util/lnx/os_linux.cpp b/projects/rocr-runtime/runtime/hsa-runtime/core/util/lnx/os_linux.cpp index 608676a773..f73fd8c928 100644 --- a/projects/rocr-runtime/runtime/hsa-runtime/core/util/lnx/os_linux.cpp +++ b/projects/rocr-runtime/runtime/hsa-runtime/core/util/lnx/os_linux.cpp @@ -930,6 +930,10 @@ bool UncommitMemory(void* addr, size_t size) { 0) != MAP_FAILED; } +bool ProtectMemory(void* va, size_t size, MemProt perms) { + return ::mprotect(va, size, MemProtToOsProt(perms)) == 0; +} + uint64_t HostTotalPhysicalMemory() { static uint64_t totalPhys = 0; diff --git a/projects/rocr-runtime/runtime/hsa-runtime/core/util/os.h b/projects/rocr-runtime/runtime/hsa-runtime/core/util/os.h index b7835df567..3c07850e83 100644 --- a/projects/rocr-runtime/runtime/hsa-runtime/core/util/os.h +++ b/projects/rocr-runtime/runtime/hsa-runtime/core/util/os.h @@ -355,6 +355,8 @@ bool UncommitMemory(void* addr, size_t size); bool UnmapMemory(void* addr, size_t size); bool MapMemory(void* addr, size_t size, MemProt prot, int fd, uint64_t cpu_addr); +bool ProtectMemory(void* va, size_t size, MemProt perms); + uint64_t HostTotalPhysicalMemory(); /// Find First Set for any OS diff --git a/projects/rocr-runtime/runtime/hsa-runtime/core/util/win/os_win.cpp b/projects/rocr-runtime/runtime/hsa-runtime/core/util/win/os_win.cpp index cf42be7cf2..d7b2567945 100644 --- a/projects/rocr-runtime/runtime/hsa-runtime/core/util/win/os_win.cpp +++ b/projects/rocr-runtime/runtime/hsa-runtime/core/util/win/os_win.cpp @@ -470,6 +470,14 @@ bool MapMemory(void* addr, size_t size, MemProt perms, int fd [[maybe_unused]], return VirtualProtect(addr, size, memProtToOsProt(perms), &OldProtect) != 0; } +bool ProtectMemory(void* va, size_t size, MemProt perms) { + if (perms == MEM_PROT_NONE) { + return UncommitMemory(addr, size); + } + DWORD oldProt; + return VirtualProtect(va, size, memProtToOsProt(perms), &oldProt) != 0; +} + int Ffs(int i) { int res = 0; unsigned long index;