SWDEV-439234 - Access check before memcpy and kernel operations.
Change-Id: I7057125c03460db205409e19980145298c190fe2
[ROCm/clr commit: 6211037f63]
This commit is contained in:
committed by
Karthik Jayaprakash
parent
29e9bed35d
commit
eecbcddaf3
@@ -419,22 +419,40 @@ hipError_t ihipMemcpy_validate(void* dst, const void* src, size_t sizeBytes,
|
||||
size_t dOffset = 0;
|
||||
amd::Memory* dstMemory = getMemoryObject(dst, dOffset);
|
||||
|
||||
// If the mem object is a VMM sub buffer (subbuffer has parent set),
|
||||
// then use parent's size for validation.
|
||||
if (srcMemory && srcMemory->parent() && (srcMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
|
||||
if (srcMemory != nullptr) {
|
||||
// Validate Mem Access in case of VMM Memory
|
||||
if (!srcMemory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], false)) {
|
||||
return hipErrorUnknown;
|
||||
}
|
||||
|
||||
// Size validation
|
||||
if (sizeBytes > (srcMemory->getSize() - sOffset)) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
// If the mem object is a VMM sub buffer (subbuffer has parent set),
|
||||
// then use parent's size for validation.
|
||||
if (srcMemory->parent() && (srcMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
|
||||
srcMemory = srcMemory->parent();
|
||||
}
|
||||
}
|
||||
|
||||
// If the mem object is a VMM sub buffer (subbuffer has parent set),
|
||||
// then use parent's size for validation.
|
||||
if (dstMemory && dstMemory->parent() && (dstMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
|
||||
if (dstMemory != nullptr) {
|
||||
// Validate Mem Access in case of VMM Memory
|
||||
if (!dstMemory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) {
|
||||
return hipErrorUnknown;
|
||||
}
|
||||
|
||||
// Size validation
|
||||
if (sizeBytes > (dstMemory->getSize() - dOffset)) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
// If the mem object is a VMM sub buffer (subbuffer has parent set),
|
||||
// then use parent's size for validation.
|
||||
if (dstMemory->parent() && (dstMemory->getMemFlags() & CL_MEM_VA_RANGE_AMD)) {
|
||||
dstMemory = dstMemory->parent();
|
||||
}
|
||||
|
||||
// Return error if sizeBytes passed to memcpy is more than the actual size allocated
|
||||
if ((dstMemory && sizeBytes > (dstMemory->getSize() - dOffset)) ||
|
||||
(srcMemory && sizeBytes > (srcMemory->getSize() - sOffset))) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
}
|
||||
|
||||
//If src and dst ptr are null then kind must be either h2h or def.
|
||||
@@ -3185,6 +3203,11 @@ hipError_t ihipMemset_validate(void* dst, int64_t value, size_t valueSize,
|
||||
memory = memory->parent();
|
||||
}
|
||||
|
||||
// Validate Mem Access in case of VMM Memory
|
||||
if (!memory->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) {
|
||||
return hipErrorUnknown;
|
||||
}
|
||||
|
||||
// Return error if sizeBytes passed to memcpy is more than the actual size allocated
|
||||
if (sizeBytes > (memory->getSize() - offset)){
|
||||
return hipErrorInvalidValue;
|
||||
@@ -3212,6 +3235,11 @@ hipError_t ihipGraphMemsetParams_validate(const hipMemsetParams* pNodeParams) {
|
||||
size_t discardOffset = 0;
|
||||
amd::Memory *memObj = getMemoryObject(pNodeParams->dst, discardOffset);
|
||||
if (memObj != nullptr) {
|
||||
// Validate Mem Access in case of VMM Memory
|
||||
if (!memObj->ValidateMemAccess(*hip::getCurrentDevice()->devices()[0], true)) {
|
||||
return hipErrorUnknown;
|
||||
}
|
||||
|
||||
if ((pNodeParams->pitch * pNodeParams->height) > memObj->getSize()) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
@@ -1630,7 +1630,6 @@ class Device : public RuntimeObject {
|
||||
enum class VmmAccess {
|
||||
kNone = 0x0,
|
||||
kReadOnly = 0x1,
|
||||
kWriteOnly = 0x2,
|
||||
kReadWrite = 0x3
|
||||
};
|
||||
|
||||
@@ -1883,7 +1882,15 @@ class Device : public RuntimeObject {
|
||||
* @param va_addr Virtual Address ptr
|
||||
* @param access_flags_ptr Access permissions to be filled
|
||||
*/
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) = 0;
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const = 0;
|
||||
|
||||
/**
|
||||
* Validate Access permisions for a virtual memory object.
|
||||
*
|
||||
* @param va_addr Virtual Address ptr
|
||||
* @param access_flags_ptr Access permissions to be filled
|
||||
*/
|
||||
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const = 0;
|
||||
|
||||
/**
|
||||
* Free a VA range
|
||||
|
||||
@@ -2511,34 +2511,50 @@ void Device::virtualFree(void* addr) {
|
||||
// ================================================================================================
|
||||
bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags) {
|
||||
|
||||
amd::Memory* phys_mem_obj = amd::MemObjMap::FindMemObj(va_addr);
|
||||
if (phys_mem_obj == nullptr) {
|
||||
// If the phys_mem_obj is null, the check if this is a valid va_addr, but not-mapped,
|
||||
amd::Memory* amd_mem_obj = amd::MemObjMap::FindMemObj(va_addr);
|
||||
if (amd_mem_obj == nullptr) {
|
||||
// If the amd_mem_obj is null, the check if this is a valid va_addr, but not-mapped,
|
||||
// if not-mapped then print a different error message. (No functional change due to this check).
|
||||
amd::Memory* vaddr_mem_obj = amd::MemObjMap::FindVirtualMemObj(va_addr);
|
||||
if (vaddr_mem_obj == nullptr) {
|
||||
amd_mem_obj = amd::MemObjMap::FindVirtualMemObj(va_addr);
|
||||
if (amd_mem_obj == nullptr) {
|
||||
LogPrintfError("Cannot find virtual address: 0x%x \n", va_addr);
|
||||
return false;
|
||||
}
|
||||
LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for valid size.
|
||||
if (va_size > phys_mem_obj->getSize()) {
|
||||
if (va_size > amd_mem_obj->getSize()) {
|
||||
LogPrintfError("Given size: %u cannot be greater than mem_size: %u \n", va_size,
|
||||
phys_mem_obj->getSize());
|
||||
amd_mem_obj->getSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
device::Memory* phys_dev_mem = phys_mem_obj->getDeviceMemory(*this);
|
||||
phys_dev_mem->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
|
||||
device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this);
|
||||
dev_mem_obj->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ================================================================================================
|
||||
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
|
||||
bool Device::ValidateMemAccess(amd::Memory& amd_mem_obj, bool read_write) const {
|
||||
|
||||
device::Memory* dev_mem = amd_mem_obj.getDeviceMemory(*this);
|
||||
device::Memory::MemAccess mem_access = dev_mem->GetAccess();
|
||||
|
||||
// If read_write flag is set, then only read_write is valid, else it could be a read or write.
|
||||
if (read_write && mem_access != device::Memory::MemAccess::kMemAccessReadWrite) {
|
||||
return false;
|
||||
} else if ((mem_access != device::Memory::MemAccess::kMemAccessRead)
|
||||
&& (mem_access != device::Memory::MemAccess::kMemAccessReadWrite)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ================================================================================================
|
||||
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
|
||||
|
||||
amd::Memory* phys_mem_obj = amd::MemObjMap::FindMemObj(va_addr);
|
||||
if (phys_mem_obj == nullptr) {
|
||||
@@ -2549,7 +2565,7 @@ bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
|
||||
LogPrintfError("Cannot find virtual address: 0x%x \n", va_addr);
|
||||
return false;
|
||||
}
|
||||
LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
|
||||
LogPrintfInfo("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -151,7 +151,11 @@ class NullDevice : public amd::Device {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -561,7 +565,8 @@ class Device : public NullDevice {
|
||||
|
||||
//! Set/Get memory access set by the app
|
||||
virtual bool SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags);
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr);
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const;
|
||||
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const;
|
||||
|
||||
virtual bool ExportShareableVMMHandle(amd::Memory& amd_mem_obj, int flags, void* shareableHandle);
|
||||
|
||||
|
||||
@@ -3517,6 +3517,12 @@ bool VirtualGPU::processMemObjectsHSA(const amd::Kernel& kernel, const_address p
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
|
||||
// Validate Mem Access in case of VMM Memory
|
||||
if (!memory->ValidateMemAccess(dev(), true)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Memory* gpuMemory = dev().getGpuMemory(memory);
|
||||
if (nullptr != gpuMemory) {
|
||||
// Synchronize data with other memory instances if necessary
|
||||
|
||||
@@ -71,9 +71,6 @@ static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kNone)
|
||||
static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kReadOnly)
|
||||
== static_cast<uint32_t>(HSA_ACCESS_PERMISSION_RO),
|
||||
"Vmm Access Flag Read mismatch with ROCr-runtime!");
|
||||
static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kWriteOnly)
|
||||
== static_cast<uint32_t>(HSA_ACCESS_PERMISSION_WO),
|
||||
"Vmm Access Flag Write mismatch with ROC-runtime!");
|
||||
static_assert(static_cast<uint32_t>(amd::Device::VmmAccess::kReadWrite)
|
||||
== static_cast<uint32_t>(HSA_ACCESS_PERMISSION_RW),
|
||||
"Vmm Access Flag Read Write mismatch with ROC-runtime!");
|
||||
@@ -2511,7 +2508,7 @@ bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags)
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) {
|
||||
bool Device::GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
|
||||
hsa_status_t hsa_status = HSA_STATUS_SUCCESS;
|
||||
hsa_access_permission_t perms;
|
||||
|
||||
|
||||
@@ -240,11 +240,16 @@ class NullDevice : public amd::Device {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) override {
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const {
|
||||
ShouldNotReachHere();
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) const {
|
||||
ShouldNotReachHere();
|
||||
return true;
|
||||
}
|
||||
|
||||
//! Determine if we can use device memory for SVM
|
||||
const bool forceFineGrain(amd::Memory* memory) const {
|
||||
return (memory->getContext().devices().size() > 1);
|
||||
@@ -478,7 +483,8 @@ class Device : public NullDevice {
|
||||
virtual void virtualFree(void* addr);
|
||||
|
||||
virtual bool SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags);
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr);
|
||||
virtual bool GetMemAccess(void* va_addr, VmmAccess* access_flags_ptr) const;
|
||||
virtual bool ValidateMemAccess(amd::Memory& mem, bool read_write) { return true; }
|
||||
|
||||
virtual bool ExportShareableVMMHandle(amd::Memory& amd_mem_obj, int flags, void* shareableHandle);
|
||||
|
||||
|
||||
@@ -545,6 +545,14 @@ Device* Memory::GetDeviceById() {
|
||||
return getContext().devices()[device_idx];
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
bool Memory::ValidateMemAccess(const Device& dev, bool read_write) {
|
||||
if (flags_ & CL_MEM_VA_RANGE_AMD) {
|
||||
return dev.ValidateMemAccess(*this, read_write);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Buffer::initDeviceMemory() {
|
||||
deviceMemories_ = reinterpret_cast<DeviceMemory*>(reinterpret_cast<char*>(this) + sizeof(Buffer));
|
||||
memset(deviceMemories_, 0, NumDevicesWithP2P() * sizeof(DeviceMemory));
|
||||
|
||||
@@ -417,6 +417,9 @@ class Memory : public amd::RuntimeObject {
|
||||
|
||||
//! get device by id when glb ctx is used.
|
||||
Device* GetDeviceById();
|
||||
|
||||
//! Validate memory access for vmm memory
|
||||
bool ValidateMemAccess(const Device& dev, bool read_write);
|
||||
};
|
||||
|
||||
//! Buffers are a specialization of memory. Just a wrapper, really,
|
||||
|
||||
Reference in New Issue
Block a user