SWDEV-439234 - Access check before memcpy and kernel operations.

Change-Id: I7057125c03460db205409e19980145298c190fe2


[ROCm/clr commit: 6211037f63]
This commit is contained in:
kjayapra-amd
2024-07-19 10:16:18 -04:00
committed by Karthik Jayaprakash
parent 29e9bed35d
commit eecbcddaf3
9 changed files with 110 additions and 34 deletions
+40 -12
View File
@@ -419,22 +419,40 @@ hipError_t ihipMemcpy_validate(void* dst, const void* src, size_t sizeBytes,
size_t dOffset = 0;
amd::Memory* dstMemory = getMemoryObject(dst, dOffset);
// If the mem object is a VMM sub buffer (subbuffer has parent set),
// then use parent's size for validation.
if (srcMemory && srcMemory->parent() && (srcMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
if (srcMemory != nullptr) {
// Validate Mem Access in case of VMM Memory
if (!srcMemory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], false)) {
return hipErrorUnknown;
}
// Size validation
if (sizeBytes > (srcMemory->getSize() - sOffset)) {
return hipErrorInvalidValue;
}
// If the mem object is a VMM sub buffer (subbuffer has parent set),
// then use parent's size for validation.
if (srcMemory->parent() && (srcMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
srcMemory = srcMemory->parent();
}
}
// If the mem object is a VMM sub buffer (subbuffer has parent set),
// then use parent's size for validation.
if (dstMemory && dstMemory->parent() && (dstMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
if (dstMemory != nullptr) {
// Validate Mem Access in case of VMM Memory
if (!dstMemory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) {
return hipErrorUnknown;
}
// Size validation
if (sizeBytes > (dstMemory->getSize() - dOffset)) {
return hipErrorInvalidValue;
}
// If the mem object is a VMM sub buffer (subbuffer has parent set),
// then use parent's size for validation.
if (dstMemory->parent() && (dstMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
dstMemory = dstMemory->parent();
}
// Return error if sizeBytes passed to memcpy is more than the actual size allocated
if ((dstMemory && sizeBytes > (dstMemory->getSize() - dOffset)) ||
(srcMemory && sizeBytes > (srcMemory->getSize() - sOffset))) {
return hipErrorInvalidValue;
}
}
//If src and dst ptr are null then kind must be either h2h or def.
@@ -3185,6 +3203,11 @@ hipError_t ihipMemset_validate(void* dst, int64_t value, size_t valueSize,
memory = memory->parent();
}
// Validate Mem Access in case of VMM Memory
if (!memory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) {
return hipErrorUnknown;
}
// Return error if sizeBytes passed to memcpy is more than the actual size allocated
if (sizeBytes > (memory->getSize() - offset)){
return hipErrorInvalidValue;
@@ -3212,6 +3235,11 @@ hipError_t ihipGraphMemsetParams_validate(const hipMemsetParams* pNodeParams) {
size_t discardOffset = 0;
amd::Memory *memObj = getMemoryObject(pNodeParams->dst, discardOffset);
if (memObj != nullptr) {
// Validate Mem Access in case of VMM Memory
if (!memObj->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) {
return hipErrorUnknown;
}
if ((pNodeParams->pitch * pNodeParams->height) > memObj->getSize()) {
return hipErrorInvalidValue;
}
+9 -2
View File
@@ -1630,7 +1630,6 @@ class Device : public RuntimeObject {
enum class VmmAccess {
kNone = 0x0,
kReadOnly = 0x1,
kWriteOnly = 0x2,
kReadWrite = 0x3
};
@@ -1883,7 +1882,15 @@ class Device : public RuntimeObject {
* @param va_addr Virtual Address ptr
* @param access_flags_ptr Access permissions to be filled
*/
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) = 0;
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const = 0;
/**
* Validate Access permisions for a virtual memory object.
*
* @param va_addr Virtual Address ptr
* @param access_flags_ptr Access permissions to be filled
*/
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const = 0;
/**
* Free a VA range
+28 -12
View File
@@ -2511,34 +2511,50 @@ void Device::virtualFree(void* addr) {
// ================================================================================================
bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags) {
amd::Memory* phys_mem_obj = amd::MemObjMap::FindMemObj(va_addr);
if (phys_mem_obj == nullptr) {
// If the phys_mem_obj is null, the check if this is a valid va_addr, but not-mapped,
amd::Memory* amd_mem_obj = amd::MemObjMap::FindMemObj(va_addr);
if (amd_mem_obj == nullptr) {
// If the amd_mem_obj is null, the check if this is a valid va_addr, but not-mapped,
// if not-mapped then print a different error message. (No functional change due to this check).
amd::Memory* vaddr_mem_obj = amd::MemObjMap::FindVirtualMemObj(va_addr);
if (vaddr_mem_obj == nullptr) {
amd_mem_obj = amd::MemObjMap::FindVirtualMemObj(va_addr);
if (amd_mem_obj == nullptr) {
LogPrintfError("Cannot find virtual address: 0x%x \n", va_addr);
return false;
}
LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
return false;
}
// Check for valid size.
if (va_size > phys_mem_obj->getSize()) {
if (va_size > amd_mem_obj->getSize()) {
LogPrintfError("Given size: %u cannot be greater than mem_size: %u \n", va_size,
phys_mem_obj->getSize());
amd_mem_obj->getSize());
return false;
}
device::Memory* phys_dev_mem = phys_mem_obj->getDeviceMemory(*this);
phys_dev_mem->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this);
dev_mem_obj->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
return true;
}
// ================================================================================================
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
bool Device::ValidateMemAccess(amd::Memory& amd_mem_obj, bool read_write) const {
device::Memory* dev_mem = amd_mem_obj.getDeviceMemory(*this);
device::Memory::MemAccess mem_access = dev_mem->GetAccess();
// If read_write flag is set, then only read_write is valid, else it could be a read or write.
if (read_write && mem_access != device::Memory::MemAccess::kMemAccessReadWrite) {
return false;
} else if ((mem_access != device::Memory::MemAccess::kMemAccessRead)
&& (mem_access != device::Memory::MemAccess::kMemAccessReadWrite)) {
return false;
}
return true;
}
// ================================================================================================
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
amd::Memory* phys_mem_obj = amd::MemObjMap::FindMemObj(va_addr);
if (phys_mem_obj == nullptr) {
@@ -2549,7 +2565,7 @@ bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
LogPrintfError("Cannot find virtual address: 0x%x \n", va_addr);
return false;
}
LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
LogPrintfInfo("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
return false;
}
+7 -2
View File
@@ -151,7 +151,11 @@ class NullDevice : public amd::Device {
return true;
}
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
return true;
}
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const {
return true;
}
@@ -561,7 +565,8 @@ class Device : public NullDevice {
//! Set/Get memory access set by the app
virtual bool SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags);
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr);
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const;
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const;
virtual bool ExportShareableVMMHandle(amd::Memory& amd_mem_obj, int flags, void* shareableHandle);
@@ -3517,6 +3517,12 @@ bool VirtualGPU::processMemObjectsHSA(const amd::Kernel& kernel, const_address p
continue;
}
} else {
// Validate Mem Access in case of VMM Memory
if (!memory->ValidateMemAccess(dev(), true)) {
return false;
}
Memory* gpuMemory = dev().getGpuMemory(memory);
if (nullptr != gpuMemory) {
// Synchronize data with other memory instances if necessary
@@ -71,9 +71,6 @@ static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kNone)
static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kReadOnly)
== static_cast<uint32_t>(HSA_ACCESS_PERMISSION_RO),
"Vmm Access Flag Read mismatch with ROCr-runtime!");
static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kWriteOnly)
== static_cast<uint32_t>(HSA_ACCESS_PERMISSION_WO),
"Vmm Access Flag Write mismatch with ROC-runtime!");
static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kReadWrite)
== static_cast<uint32_t>(HSA_ACCESS_PERMISSION_RW),
"Vmm Access Flag Read Write mismatch with ROC-runtime!");
@@ -2511,7 +2508,7 @@ bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags)
return true;
}
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
hsa_status_t hsa_status = HSA_STATUS_SUCCESS;
hsa_access_permission_t perms;
@@ -240,11 +240,16 @@ class NullDevice : public amd::Device {
return false;
}
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) override {
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
ShouldNotReachHere();
return false;
}
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const {
ShouldNotReachHere();
return true;
}
//! Determine if we can use device memory for SVM
const bool forceFineGrain(amd::Memory* memory) const {
return (memory->getContext().devices().size() > 1);
@@ -478,7 +483,8 @@ class Device : public NullDevice {
virtual void virtualFree(void* addr);
virtual bool SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags);
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr);
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const;
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) { return true; }
virtual bool ExportShareableVMMHandle(amd::Memory& amd_mem_obj, int flags, void* shareableHandle);
+8
View File
@@ -545,6 +545,14 @@ Device* Memory::GetDeviceById() {
return getContext().devices()[device_idx];
}
// =================================================================================================
bool Memory::ValidateMemAccess(const Device& dev, bool read_write) {
if (flags_ & CL_MEM_VA_RANGE_AMD) {
return dev.ValidateMemAccess(*this, read_write);
}
return true;
}
void Buffer::initDeviceMemory() {
deviceMemories_ = reinterpret_cast<DeviceMemory*>(reinterpret_cast<char*>(this) + sizeof(Buffer));
memset(deviceMemories_, 0, NumDevicesWithP2P() * sizeof(DeviceMemory));
+3
View File
@@ -417,6 +417,9 @@ class Memory : public amd::RuntimeObject {
//! get device by id when glb ctx is used.
Device* GetDeviceById();
//! Validate memory access for vmm memory
bool ValidateMemAccess(const Device& dev, bool read_write);
};
//! Buffers are a specialization of memory. Just a wrapper, really,