From eecbcddaf331a259cdb1bd1478f1e81231367ad2 Mon Sep 17 00:00:00 2001 From: kjayapra-amd Date: Fri, 19 Jul 2024 10:16:18 -0400 Subject: [PATCH] SWDEV-439234 - Access check before memcpy and kernel operations. Change-Id: I7057125c03460db205409e19980145298c190fe2 [ROCm/clr commit: 6211037f63378994a0412626737f41c77f58dee4] --- projects/clr/hipamd/src/hip_memory.cpp | 52 ++++++++++++++----- projects/clr/rocclr/device/device.hpp | 11 +++- projects/clr/rocclr/device/pal/paldevice.cpp | 40 +++++++++----- projects/clr/rocclr/device/pal/paldevice.hpp | 9 +++- projects/clr/rocclr/device/pal/palvirtual.cpp | 6 +++ projects/clr/rocclr/device/rocm/rocdevice.cpp | 5 +- projects/clr/rocclr/device/rocm/rocdevice.hpp | 10 +++- projects/clr/rocclr/platform/memory.cpp | 8 +++ projects/clr/rocclr/platform/memory.hpp | 3 ++ 9 files changed, 110 insertions(+), 34 deletions(-) diff --git a/projects/clr/hipamd/src/hip_memory.cpp b/projects/clr/hipamd/src/hip_memory.cpp index c3d8770bd1..72b934a56c 100644 --- a/projects/clr/hipamd/src/hip_memory.cpp +++ b/projects/clr/hipamd/src/hip_memory.cpp @@ -419,22 +419,40 @@ hipError_t ihipMemcpy_validate(void* dst, const void* src, size_t sizeBytes, size_t dOffset = 0; amd::Memory* dstMemory = getMemoryObject(dst, dOffset); - // If the mem object is a VMM sub buffer (subbuffer has parent set), - // then use parent's size for validation. - if (srcMemory && srcMemory->parent() && (srcMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) { + if (srcMemory != nullptr) { + // Validate Mem Access in case of VMM Memory + if (!srcMemory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], false)) { + return hipErrorUnknown; + } + + // Size validation + if (sizeBytes > (srcMemory->getSize() - sOffset)) { + return hipErrorInvalidValue; + } + + // If the mem object is a VMM sub buffer (subbuffer has parent set), + // then use parent's size for validation. + if (srcMemory->parent() && (srcMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) { srcMemory = srcMemory->parent(); + } } - // If the mem object is a VMM sub buffer (subbuffer has parent set), - // then use parent's size for validation. - if (dstMemory && dstMemory->parent() && (dstMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) { + if (dstMemory != nullptr) { + // Validate Mem Access in case of VMM Memory + if (!dstMemory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) { + return hipErrorUnknown; + } + + // Size validation + if (sizeBytes > (dstMemory->getSize() - dOffset)) { + return hipErrorInvalidValue; + } + + // If the mem object is a VMM sub buffer (subbuffer has parent set), + // then use parent's size for validation. + if (dstMemory->parent() && (dstMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) { dstMemory = dstMemory->parent(); - } - - // Return error if sizeBytes passed to memcpy is more than the actual size allocated - if ((dstMemory && sizeBytes > (dstMemory->getSize() - dOffset)) || - (srcMemory && sizeBytes > (srcMemory->getSize() - sOffset))) { - return hipErrorInvalidValue; + } } //If src and dst ptr are null then kind must be either h2h or def. @@ -3185,6 +3203,11 @@ hipError_t ihipMemset_validate(void* dst, int64_t value, size_t valueSize, memory = memory->parent(); } + // Validate Mem Access in case of VMM Memory + if (!memory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) { + return hipErrorUnknown; + } + // Return error if sizeBytes passed to memcpy is more than the actual size allocated if (sizeBytes > (memory->getSize() - offset)){ return hipErrorInvalidValue; @@ -3212,6 +3235,11 @@ hipError_t ihipGraphMemsetParams_validate(const hipMemsetParams* pNodeParams) { size_t discardOffset = 0; amd::Memory *memObj = getMemoryObject(pNodeParams->dst, discardOffset); if (memObj != nullptr) { + // Validate Mem Access in case of VMM Memory + if (!memObj->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) { + return hipErrorUnknown; + } + if ((pNodeParams->pitch * pNodeParams->height) > memObj->getSize()) { return hipErrorInvalidValue; } diff --git a/projects/clr/rocclr/device/device.hpp b/projects/clr/rocclr/device/device.hpp index 005c498f0f..a654823455 100644 --- a/projects/clr/rocclr/device/device.hpp +++ b/projects/clr/rocclr/device/device.hpp @@ -1630,7 +1630,6 @@ class Device : public RuntimeObject { enum class VmmAccess { kNone = 0x0, kReadOnly = 0x1, - kWriteOnly = 0x2, kReadWrite = 0x3 }; @@ -1883,7 +1882,15 @@ class Device : public RuntimeObject { * @param va_addr Virtual Address ptr * @param access_flags_ptr Access permissions to be filled */ - virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) = 0; + virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const = 0; + + /** + * Validate Access permisions for a virtual memory object. + * + * @param va_addr Virtual Address ptr + * @param access_flags_ptr Access permissions to be filled + */ + virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const = 0; /** * Free a VA range diff --git a/projects/clr/rocclr/device/pal/paldevice.cpp b/projects/clr/rocclr/device/pal/paldevice.cpp index 97a63ec14f..42c1f25cbf 100644 --- a/projects/clr/rocclr/device/pal/paldevice.cpp +++ b/projects/clr/rocclr/device/pal/paldevice.cpp @@ -2511,34 +2511,50 @@ void Device::virtualFree(void* addr) { // ================================================================================================ bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags) { - amd::Memory* phys_mem_obj = amd::MemObjMap::FindMemObj(va_addr); - if (phys_mem_obj == nullptr) { - // If the phys_mem_obj is null, the check if this is a valid va_addr, but not-mapped, + amd::Memory* amd_mem_obj = amd::MemObjMap::FindMemObj(va_addr); + if (amd_mem_obj == nullptr) { + // If the amd_mem_obj is null, the check if this is a valid va_addr, but not-mapped, // if not-mapped then print a different error message. (No functional change due to this check). - amd::Memory* vaddr_mem_obj = amd::MemObjMap::FindVirtualMemObj(va_addr); - if (vaddr_mem_obj == nullptr) { + amd_mem_obj = amd::MemObjMap::FindVirtualMemObj(va_addr); + if (amd_mem_obj == nullptr) { LogPrintfError("Cannot find virtual address: 0x%x \n", va_addr); return false; } LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr); - return false; } // Check for valid size. - if (va_size > phys_mem_obj->getSize()) { + if (va_size > amd_mem_obj->getSize()) { LogPrintfError("Given size: %u cannot be greater than mem_size: %u \n", va_size, - phys_mem_obj->getSize()); + amd_mem_obj->getSize()); return false; } - device::Memory* phys_dev_mem = phys_mem_obj->getDeviceMemory(*this); - phys_dev_mem->SetAccess(static_cast(access_flags)); + device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this); + dev_mem_obj->SetAccess(static_cast(access_flags)); return true; } // ================================================================================================ -bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) { +bool Device::ValidateMemAccess(amd::Memory& amd_mem_obj, bool read_write) const { + + device::Memory* dev_mem = amd_mem_obj.getDeviceMemory(*this); + device::Memory::MemAccess mem_access = dev_mem->GetAccess(); + + // If read_write flag is set, then only read_write is valid, else it could be a read or write. + if (read_write && mem_access != device::Memory::MemAccess::kMemAccessReadWrite) { + return false; + } else if ((mem_access != device::Memory::MemAccess::kMemAccessRead) + && (mem_access != device::Memory::MemAccess::kMemAccessReadWrite)) { + return false; + } + + return true; +} + +// ================================================================================================ +bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const { amd::Memory* phys_mem_obj = amd::MemObjMap::FindMemObj(va_addr); if (phys_mem_obj == nullptr) { @@ -2549,7 +2565,7 @@ bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) { LogPrintfError("Cannot find virtual address: 0x%x \n", va_addr); return false; } - LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr); + LogPrintfInfo("Virtual address present, but not mapped yet: 0x%x \n", va_addr); return false; } diff --git a/projects/clr/rocclr/device/pal/paldevice.hpp b/projects/clr/rocclr/device/pal/paldevice.hpp index d6f0c31506..28a45d61c8 100644 --- a/projects/clr/rocclr/device/pal/paldevice.hpp +++ b/projects/clr/rocclr/device/pal/paldevice.hpp @@ -151,7 +151,11 @@ class NullDevice : public amd::Device { return true; } - virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) { + virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const { + return true; + } + + virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const { return true; } @@ -561,7 +565,8 @@ class Device : public NullDevice { //! Set/Get memory access set by the app virtual bool SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags); - virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr); + virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const; + virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const; virtual bool ExportShareableVMMHandle(amd::Memory& amd_mem_obj, int flags, void* shareableHandle); diff --git a/projects/clr/rocclr/device/pal/palvirtual.cpp b/projects/clr/rocclr/device/pal/palvirtual.cpp index 14fb8ce3a1..ba3f4fc727 100644 --- a/projects/clr/rocclr/device/pal/palvirtual.cpp +++ b/projects/clr/rocclr/device/pal/palvirtual.cpp @@ -3517,6 +3517,12 @@ bool VirtualGPU::processMemObjectsHSA(const amd::Kernel& kernel, const_address p continue; } } else { + + // Validate Mem Access in case of VMM Memory + if (!memory->ValidateMemAccess(dev(), true)) { + return false; + } + Memory* gpuMemory = dev().getGpuMemory(memory); if (nullptr != gpuMemory) { // Synchronize data with other memory instances if necessary diff --git a/projects/clr/rocclr/device/rocm/rocdevice.cpp b/projects/clr/rocclr/device/rocm/rocdevice.cpp index 78c55eeb9b..a80d8c6580 100644 --- a/projects/clr/rocclr/device/rocm/rocdevice.cpp +++ b/projects/clr/rocclr/device/rocm/rocdevice.cpp @@ -71,9 +71,6 @@ static_assert(static_cast(amd::Device::VmmAccess::kNone) static_assert(static_cast(amd::Device::VmmAccess::kReadOnly) == static_cast(HSA_ACCESS_PERMISSION_RO), "Vmm Access Flag Read mismatch with ROCr-runtime!"); -static_assert(static_cast(amd::Device::VmmAccess::kWriteOnly) - == static_cast(HSA_ACCESS_PERMISSION_WO), - "Vmm Access Flag Write mismatch with ROC-runtime!"); static_assert(static_cast(amd::Device::VmmAccess::kReadWrite) == static_cast(HSA_ACCESS_PERMISSION_RW), "Vmm Access Flag Read Write mismatch with ROC-runtime!"); @@ -2511,7 +2508,7 @@ bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags) return true; } -bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) { +bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const { hsa_status_t hsa_status = HSA_STATUS_SUCCESS; hsa_access_permission_t perms; diff --git a/projects/clr/rocclr/device/rocm/rocdevice.hpp b/projects/clr/rocclr/device/rocm/rocdevice.hpp index dc968b2ee4..bc1f5a5383 100644 --- a/projects/clr/rocclr/device/rocm/rocdevice.hpp +++ b/projects/clr/rocclr/device/rocm/rocdevice.hpp @@ -240,11 +240,16 @@ class NullDevice : public amd::Device { return false; } - virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) override { + virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const { ShouldNotReachHere(); return false; } + virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const { + ShouldNotReachHere(); + return true; + } + //! Determine if we can use device memory for SVM const bool forceFineGrain(amd::Memory* memory) const { return (memory->getContext().devices().size() > 1); @@ -478,7 +483,8 @@ class Device : public NullDevice { virtual void virtualFree(void* addr); virtual bool SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags); - virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr); + virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const; + virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) { return true; } virtual bool ExportShareableVMMHandle(amd::Memory& amd_mem_obj, int flags, void* shareableHandle); diff --git a/projects/clr/rocclr/platform/memory.cpp b/projects/clr/rocclr/platform/memory.cpp index d26b6c802e..948067dcbb 100644 --- a/projects/clr/rocclr/platform/memory.cpp +++ b/projects/clr/rocclr/platform/memory.cpp @@ -545,6 +545,14 @@ Device* Memory::GetDeviceById() { return getContext().devices()[device_idx]; } +// ================================================================================================= +bool Memory::ValidateMemAccess(const Device& dev, bool read_write) { + if (flags_ & CL_MEM_VA_RANGE_AMD) { + return dev.ValidateMemAccess(*this, read_write); + } + return true; +} + void Buffer::initDeviceMemory() { deviceMemories_ = reinterpret_cast(reinterpret_cast(this) + sizeof(Buffer)); memset(deviceMemories_, 0, NumDevicesWithP2P() * sizeof(DeviceMemory)); diff --git a/projects/clr/rocclr/platform/memory.hpp b/projects/clr/rocclr/platform/memory.hpp index 6dfc480e34..95f37cfc45 100644 --- a/projects/clr/rocclr/platform/memory.hpp +++ b/projects/clr/rocclr/platform/memory.hpp @@ -417,6 +417,9 @@ class Memory : public amd::RuntimeObject { //! get device by id when glb ctx is used. Device* GetDeviceById(); + + //! Validate memory access for vmm memory + bool ValidateMemAccess(const Device& dev, bool read_write); }; //! Buffers are a specialization of memory. Just a wrapper, really,