SWDEV-568260 - Validate sub-buffer coverage in hipMemSetAccess (#2451)
Tá an tiomantas seo le fáil i:
@@ -354,6 +354,38 @@ hipError_t hipMemRetainAllocationHandle(hipMemGenericAllocationHandle_t* handle,
|
||||
HIP_RETURN(hipSuccess);
|
||||
}
|
||||
|
||||
static inline address NextSubBufferPtr(const amd::Memory* mem) {
|
||||
return reinterpret_cast<address>(mem->getSvmPtr()) + mem->getSize();
|
||||
}
|
||||
|
||||
static hipError_t ValidateSubBufferCoverage(amd::Memory* vaddr_sub_buffer_obj, size_t range_size) {
|
||||
// Validate that the requested range size is within the parent sub-buffer bounds.
|
||||
if (vaddr_sub_buffer_obj == nullptr || (vaddr_sub_buffer_obj->parent() != nullptr &&
|
||||
range_size > (vaddr_sub_buffer_obj->parent()->getSize() -
|
||||
vaddr_sub_buffer_obj->getOrigin()))) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
address range_end_address =
|
||||
reinterpret_cast<address>(vaddr_sub_buffer_obj->getSvmPtr()) + range_size;
|
||||
size_t covered_size = 0;
|
||||
amd::Memory* current_sub_buffer_obj = vaddr_sub_buffer_obj;
|
||||
// Validate that the size matches the sum of sub-buffer sizes
|
||||
while (current_sub_buffer_obj && NextSubBufferPtr(current_sub_buffer_obj) <= range_end_address) {
|
||||
if (range_size > covered_size &&
|
||||
range_size < covered_size + current_sub_buffer_obj->getSize()) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
covered_size += current_sub_buffer_obj->getSize();
|
||||
current_sub_buffer_obj = amd::MemObjMap::FindMemObj(NextSubBufferPtr(current_sub_buffer_obj));
|
||||
}
|
||||
if (covered_size != range_size) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t hipMemSetAccess(void* ptr, size_t size, const hipMemAccessDesc* desc, size_t count) {
|
||||
HIP_INIT_API(hipMemSetAccess, ptr, size, desc, count);
|
||||
|
||||
@@ -361,30 +393,12 @@ hipError_t hipMemSetAccess(void* ptr, size_t size, const hipMemAccessDesc* desc,
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
// Ensure that the specified size parameter matches the total size of a complete set of
|
||||
// sub-buffers, disallowing partial sub-buffer coverage
|
||||
auto mem_object = amd::MemObjMap::FindMemObj(ptr);
|
||||
hipMemLocationType memLocationType = hipMemLocationTypeNone;
|
||||
|
||||
if (mem_object) {
|
||||
memLocationType = static_cast<hipMemLocationType>(mem_object->getUserData().locationType);
|
||||
if (mem_object->parent()) {
|
||||
size_t accumulated_buffer_size = 0;
|
||||
for (auto sub_buffer : mem_object->parent()->subBuffers()) {
|
||||
accumulated_buffer_size += sub_buffer->getSize();
|
||||
if (accumulated_buffer_size > size) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
} else if (accumulated_buffer_size == size) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulated_buffer_size != size) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
// Ensure that the specified size parameter matches the sum of a complete set of
|
||||
// sub-buffers in the range, disallowing partial sub-buffer coverage.
|
||||
amd::Memory* vaddr_sub_obj = amd::MemObjMap::FindMemObj(ptr);
|
||||
hipError_t status = ValidateSubBufferCoverage(vaddr_sub_obj, size);
|
||||
if (status != hipSuccess) {
|
||||
HIP_RETURN(status);
|
||||
}
|
||||
|
||||
for (size_t desc_idx = 0; desc_idx < count; ++desc_idx) {
|
||||
@@ -421,36 +435,15 @@ hipError_t hipMemUnmap(void* ptr, size_t size) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
// Helper lambda to get the next sub-buffer pointer
|
||||
auto next_subbuffer_ptr = [](const amd::Memory* mem) -> address {
|
||||
return reinterpret_cast<address>(mem->getSvmPtr()) + mem->getSize();
|
||||
};
|
||||
|
||||
amd::Memory* vaddr_sub_obj = amd::MemObjMap::FindMemObj(ptr);
|
||||
// Validate that the size is within range
|
||||
if (vaddr_sub_obj == nullptr ||
|
||||
(vaddr_sub_obj->parent() != nullptr &&
|
||||
size > (vaddr_sub_obj->parent()->getSize() - vaddr_sub_obj->getOrigin()))) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
address end_address = reinterpret_cast<address>(vaddr_sub_obj->getSvmPtr()) + size;
|
||||
size_t total_processed_size = 0;
|
||||
amd::Memory* check_obj = vaddr_sub_obj;
|
||||
// Validate that the size matches the sum of sub-buffer sizes
|
||||
while (check_obj && next_subbuffer_ptr(check_obj) <= end_address) {
|
||||
if (size > total_processed_size && size < total_processed_size + check_obj->getSize()) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
total_processed_size += check_obj->getSize();
|
||||
check_obj = amd::MemObjMap::FindMemObj(next_subbuffer_ptr(check_obj));
|
||||
}
|
||||
if (total_processed_size != size) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
hipError_t status = ValidateSubBufferCoverage(vaddr_sub_obj, size);
|
||||
if (status != hipSuccess) {
|
||||
HIP_RETURN(status);
|
||||
}
|
||||
|
||||
// Unmap all sub-buffers in the range
|
||||
while (vaddr_sub_obj && next_subbuffer_ptr(vaddr_sub_obj) <= end_address) {
|
||||
address end_address = reinterpret_cast<address>(vaddr_sub_obj->getSvmPtr()) + size;
|
||||
while (vaddr_sub_obj && NextSubBufferPtr(vaddr_sub_obj) <= end_address) {
|
||||
amd::Memory* phys_mem_obj = vaddr_sub_obj->getUserData().phys_mem_obj;
|
||||
if (phys_mem_obj == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
@@ -467,7 +460,7 @@ hipError_t hipMemUnmap(void* ptr, size_t size) {
|
||||
reinterpret_cast<hip::GenericAllocation*>(phys_mem_obj->getUserData().data);
|
||||
ga->release();
|
||||
|
||||
address next_ptr = next_subbuffer_ptr(vaddr_sub_obj);
|
||||
address next_ptr = NextSubBufferPtr(vaddr_sub_obj);
|
||||
vaddr_sub_obj->release();
|
||||
vaddr_sub_obj = amd::MemObjMap::FindMemObj(next_ptr);
|
||||
}
|
||||
|
||||
@@ -2490,6 +2490,10 @@ bool Device::virtualFree(void* addr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline address NextSubBufferPtr(const amd::Memory* mem) {
|
||||
return reinterpret_cast<address>(mem->getSvmPtr()) + mem->getSize();
|
||||
}
|
||||
|
||||
// ================================================================================================
|
||||
bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags,
|
||||
VmmLocationType access_location) {
|
||||
@@ -2505,16 +2509,13 @@ bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags,
|
||||
LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
|
||||
}
|
||||
|
||||
// Check for valid size.
|
||||
if (va_size > amd_mem_obj->getSize()) {
|
||||
LogPrintfError("Given size: %u cannot be greater than mem_size: %u \n", va_size,
|
||||
amd_mem_obj->getSize());
|
||||
return false;
|
||||
address range_end_address = reinterpret_cast<address>(amd_mem_obj->getSvmPtr()) + va_size;
|
||||
while (amd_mem_obj && NextSubBufferPtr(amd_mem_obj) <= range_end_address) {
|
||||
device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this);
|
||||
dev_mem_obj->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
|
||||
amd_mem_obj = amd::MemObjMap::FindMemObj(NextSubBufferPtr(amd_mem_obj));
|
||||
}
|
||||
|
||||
device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this);
|
||||
dev_mem_obj->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
+128
-4
@@ -61,8 +61,8 @@ static __global__ void square_kernel(int* Buff) {
|
||||
|
||||
// Simple HIP kernel: read from host-backed memory and write to a device buffer
|
||||
__global__ void copyFromHostMem(const int* hostMem, int* devOut, int N) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < N) devOut[i] = hostMem[i];
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < N) devOut[i] = hostMem[i];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -490,6 +490,130 @@ TEST_CASE("Unit_hipMemSetAccess_ChangeAccessProp") {
|
||||
CTX_DESTROY();
|
||||
}
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
* - Create a VA range split into 3 segments. Map all of them.
|
||||
* - Verify hipMemSetAccess() works when called on:
|
||||
* - a single segment (3 calls: segment 0, segment 1, segment 2)
|
||||
* - two segments (2 calls: segments 0-1, then segments 1-2)
|
||||
* - the full range (1 call: segments 0-2)
|
||||
* ------------------------
|
||||
*/
|
||||
TEST_CASE("Unit_hipMemSetAccess_SegmentsAccess") {
|
||||
size_t granularity = 0;
|
||||
int deviceId = 0;
|
||||
hipDevice_t device;
|
||||
CTX_CREATE();
|
||||
HIP_CHECK(hipDeviceGet(&device, deviceId));
|
||||
checkVMMSupported(device);
|
||||
|
||||
hipMemAllocationProp prop{};
|
||||
prop.type = hipMemAllocationTypePinned;
|
||||
prop.location.type = hipMemLocationTypeDevice;
|
||||
prop.location.id = device;
|
||||
HIP_CHECK(
|
||||
hipMemGetAllocationGranularity(&granularity, &prop, hipMemAllocationGranularityMinimum));
|
||||
REQUIRE(granularity > 0);
|
||||
|
||||
const size_t segment0_size = granularity;
|
||||
const size_t segment1_size = granularity * 2;
|
||||
const size_t segment2_size = granularity * 3;
|
||||
const size_t total_size = segment0_size + segment1_size + segment2_size;
|
||||
|
||||
void* base = nullptr;
|
||||
HIP_CHECK(hipMemAddressReserve(&base, total_size, 0, 0, 0));
|
||||
|
||||
auto* base_c = reinterpret_cast<char*>(base);
|
||||
void* segment_0 = base_c;
|
||||
void* segment_1 = base_c + segment0_size;
|
||||
void* segment_2 = base_c + segment0_size + segment1_size;
|
||||
|
||||
hipMemGenericAllocationHandle_t handle_0{};
|
||||
hipMemGenericAllocationHandle_t handle_1{};
|
||||
hipMemGenericAllocationHandle_t handle_2{};
|
||||
HIP_CHECK(hipMemCreate(&handle_0, segment0_size, &prop, 0));
|
||||
HIP_CHECK(hipMemCreate(&handle_1, segment1_size, &prop, 0));
|
||||
HIP_CHECK(hipMemCreate(&handle_2, segment2_size, &prop, 0));
|
||||
|
||||
HIP_CHECK(hipMemMap(segment_0, segment0_size, 0, handle_0, 0));
|
||||
HIP_CHECK(hipMemMap(segment_1, segment1_size, 0, handle_1, 0));
|
||||
HIP_CHECK(hipMemMap(segment_2, segment2_size, 0, handle_2, 0));
|
||||
|
||||
HIP_CHECK(hipMemRelease(handle_0));
|
||||
HIP_CHECK(hipMemRelease(handle_1));
|
||||
HIP_CHECK(hipMemRelease(handle_2));
|
||||
|
||||
hipMemAccessDesc rw{};
|
||||
rw.location.type = hipMemLocationTypeDevice;
|
||||
rw.location.id = device;
|
||||
rw.flags = hipMemAccessFlagsProtReadWrite;
|
||||
|
||||
hipMemLocation location{};
|
||||
location.type = hipMemLocationTypeDevice;
|
||||
location.id = device;
|
||||
|
||||
unsigned long long flags = 0;
|
||||
SECTION("Single segment access") {
|
||||
HIP_CHECK(hipMemSetAccess(segment_0, segment0_size, &rw, 1));
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemSetAccess(segment_1, segment1_size, &rw, 1));
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemSetAccess(segment_2, segment2_size, &rw, 1));
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_2));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
}
|
||||
|
||||
SECTION("Two segments access") {
|
||||
// First call targets segments 0 and 1.
|
||||
HIP_CHECK(hipMemSetAccess(segment_0, segment0_size + segment1_size, &rw, 1));
|
||||
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
|
||||
// Second call targets segments 1 and 2.
|
||||
HIP_CHECK(hipMemSetAccess(segment_1, segment1_size + segment2_size, &rw, 1));
|
||||
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_2));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
}
|
||||
|
||||
SECTION("All three segments access") {
|
||||
HIP_CHECK(hipMemSetAccess(base, total_size, &rw, 1));
|
||||
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
flags = 0;
|
||||
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_2));
|
||||
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemUnmap(segment_0, segment0_size));
|
||||
HIP_CHECK(hipMemUnmap(segment_1, segment1_size));
|
||||
HIP_CHECK(hipMemUnmap(segment_2, segment2_size));
|
||||
HIP_CHECK(hipMemAddressFree(base, total_size));
|
||||
CTX_DESTROY();
|
||||
}
|
||||
|
||||
/**
|
||||
* Test Description
|
||||
* ------------------------
|
||||
@@ -548,7 +672,7 @@ TEST_CASE("Unit_hipMemSetAccess_Vmm2UnifiedMemCpy") {
|
||||
HIP_CHECK(hipMemcpyHtoD(reinterpret_cast<hipDeviceptr_t>(ptrA), ptrA_h, buffer_size));
|
||||
HIP_CHECK(hipMalloc(reinterpret_cast<void**>(&ptrB), buffer_size));
|
||||
HIP_CHECK(hipMemcpyDtoD(reinterpret_cast<hipDeviceptr_t>(ptrB),
|
||||
reinterpret_cast<hipDeviceptr_t>(ptrA), buffer_size));
|
||||
reinterpret_cast<hipDeviceptr_t>(ptrA), buffer_size));
|
||||
HIP_CHECK(hipMemcpyDtoH(ptrB_h, reinterpret_cast<hipDeviceptr_t>(ptrB), buffer_size));
|
||||
bool bPassed = true;
|
||||
for (int idx = 0; idx < N; idx++) {
|
||||
@@ -1474,7 +1598,7 @@ TEST_CASE("Unit_hipMemSetAccessHost_devicealloc") {
|
||||
constexpr size_t N = 1024;
|
||||
constexpr size_t bytes = N * sizeof(int);
|
||||
|
||||
//get minimum granularity
|
||||
// get minimum granularity
|
||||
size_t gran = 0;
|
||||
HIP_CHECK(hipMemGetAllocationGranularity(&gran, &prop, hipMemAllocationGranularityMinimum));
|
||||
size_t mapSize = ((bytes + gran - 1) / gran) * gran;
|
||||
|
||||
Tagairt in Eagrán Nua
Cuir bac ar úsáideoir