SWDEV-568260 - Validate sub-buffer coverage in hipMemSetAccess (#2451)

This commit is contained in:
marandje
2026-01-26 23:09:46 +01:00
committato da GitHub
parent 1255ba2bcc
commit 5cda2a496e
3 ha cambiato i file con 181 aggiunte e 63 eliminazioni
+44 -51
Vedi File
@@ -354,6 +354,38 @@ hipError_t hipMemRetainAllocationHandle(hipMemGenericAllocationHandle_t* handle,
HIP_RETURN(hipSuccess);
}
static inline address NextSubBufferPtr(const amd::Memory* mem) {
return reinterpret_cast<address>(mem->getSvmPtr()) + mem->getSize();
}
static hipError_t ValidateSubBufferCoverage(amd::Memory* vaddr_sub_buffer_obj, size_t range_size) {
// Validate that the requested range size is within the parent sub-buffer bounds.
if (vaddr_sub_buffer_obj == nullptr || (vaddr_sub_buffer_obj->parent() != nullptr &&
range_size > (vaddr_sub_buffer_obj->parent()->getSize() -
vaddr_sub_buffer_obj->getOrigin()))) {
return hipErrorInvalidValue;
}
address range_end_address =
reinterpret_cast<address>(vaddr_sub_buffer_obj->getSvmPtr()) + range_size;
size_t covered_size = 0;
amd::Memory* current_sub_buffer_obj = vaddr_sub_buffer_obj;
// Validate that the size matches the sum of sub-buffer sizes
while (current_sub_buffer_obj && NextSubBufferPtr(current_sub_buffer_obj) <= range_end_address) {
if (range_size > covered_size &&
range_size < covered_size + current_sub_buffer_obj->getSize()) {
return hipErrorInvalidValue;
}
covered_size += current_sub_buffer_obj->getSize();
current_sub_buffer_obj = amd::MemObjMap::FindMemObj(NextSubBufferPtr(current_sub_buffer_obj));
}
if (covered_size != range_size) {
return hipErrorInvalidValue;
}
return hipSuccess;
}
hipError_t hipMemSetAccess(void* ptr, size_t size, const hipMemAccessDesc* desc, size_t count) {
HIP_INIT_API(hipMemSetAccess, ptr, size, desc, count);
@@ -361,30 +393,12 @@ hipError_t hipMemSetAccess(void* ptr, size_t size, const hipMemAccessDesc* desc,
HIP_RETURN(hipErrorInvalidValue);
}
// Ensure that the specified size parameter matches the total size of a complete set of
// sub-buffers, disallowing partial sub-buffer coverage
auto mem_object = amd::MemObjMap::FindMemObj(ptr);
hipMemLocationType memLocationType = hipMemLocationTypeNone;
if (mem_object) {
memLocationType = static_cast<hipMemLocationType>(mem_object->getUserData().locationType);
if (mem_object->parent()) {
size_t accumulated_buffer_size = 0;
for (auto sub_buffer : mem_object->parent()->subBuffers()) {
accumulated_buffer_size += sub_buffer->getSize();
if (accumulated_buffer_size > size) {
HIP_RETURN(hipErrorInvalidValue);
} else if (accumulated_buffer_size == size) {
break;
}
}
if (accumulated_buffer_size != size) {
HIP_RETURN(hipErrorInvalidValue);
}
}
} else {
HIP_RETURN(hipErrorInvalidValue);
// Ensure that the specified size parameter matches the sum of a complete set of
// sub-buffers in the range, disallowing partial sub-buffer coverage.
amd::Memory* vaddr_sub_obj = amd::MemObjMap::FindMemObj(ptr);
hipError_t status = ValidateSubBufferCoverage(vaddr_sub_obj, size);
if (status != hipSuccess) {
HIP_RETURN(status);
}
for (size_t desc_idx = 0; desc_idx < count; ++desc_idx) {
@@ -421,36 +435,15 @@ hipError_t hipMemUnmap(void* ptr, size_t size) {
HIP_RETURN(hipErrorInvalidValue);
}
// Helper lambda to get the next sub-buffer pointer
auto next_subbuffer_ptr = [](const amd::Memory* mem) -> address {
return reinterpret_cast<address>(mem->getSvmPtr()) + mem->getSize();
};
amd::Memory* vaddr_sub_obj = amd::MemObjMap::FindMemObj(ptr);
// Validate that the size is within range
if (vaddr_sub_obj == nullptr ||
(vaddr_sub_obj->parent() != nullptr &&
size > (vaddr_sub_obj->parent()->getSize() - vaddr_sub_obj->getOrigin()))) {
HIP_RETURN(hipErrorInvalidValue);
}
address end_address = reinterpret_cast<address>(vaddr_sub_obj->getSvmPtr()) + size;
size_t total_processed_size = 0;
amd::Memory* check_obj = vaddr_sub_obj;
// Validate that the size matches the sum of sub-buffer sizes
while (check_obj && next_subbuffer_ptr(check_obj) <= end_address) {
if (size > total_processed_size && size < total_processed_size + check_obj->getSize()) {
HIP_RETURN(hipErrorInvalidValue);
}
total_processed_size += check_obj->getSize();
check_obj = amd::MemObjMap::FindMemObj(next_subbuffer_ptr(check_obj));
}
if (total_processed_size != size) {
HIP_RETURN(hipErrorInvalidValue);
hipError_t status = ValidateSubBufferCoverage(vaddr_sub_obj, size);
if (status != hipSuccess) {
HIP_RETURN(status);
}
// Unmap all sub-buffers in the range
while (vaddr_sub_obj && next_subbuffer_ptr(vaddr_sub_obj) <= end_address) {
address end_address = reinterpret_cast<address>(vaddr_sub_obj->getSvmPtr()) + size;
while (vaddr_sub_obj && NextSubBufferPtr(vaddr_sub_obj) <= end_address) {
amd::Memory* phys_mem_obj = vaddr_sub_obj->getUserData().phys_mem_obj;
if (phys_mem_obj == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
@@ -467,7 +460,7 @@ hipError_t hipMemUnmap(void* ptr, size_t size) {
reinterpret_cast<hip::GenericAllocation*>(phys_mem_obj->getUserData().data);
ga->release();
address next_ptr = next_subbuffer_ptr(vaddr_sub_obj);
address next_ptr = NextSubBufferPtr(vaddr_sub_obj);
vaddr_sub_obj->release();
vaddr_sub_obj = amd::MemObjMap::FindMemObj(next_ptr);
}
@@ -2490,6 +2490,10 @@ bool Device::virtualFree(void* addr) {
return true;
}
static inline address NextSubBufferPtr(const amd::Memory* mem) {
return reinterpret_cast<address>(mem->getSvmPtr()) + mem->getSize();
}
// ================================================================================================
bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags,
VmmLocationType access_location) {
@@ -2505,16 +2509,13 @@ bool Device::SetMemAccess(void* va_addr, size_t va_size, VmmAccess access_flags,
LogPrintfError("Virtual address present, but not mapped yet: 0x%x \n", va_addr);
}
// Check for valid size.
if (va_size > amd_mem_obj->getSize()) {
LogPrintfError("Given size: %u cannot be greater than mem_size: %u \n", va_size,
amd_mem_obj->getSize());
return false;
address range_end_address = reinterpret_cast<address>(amd_mem_obj->getSvmPtr()) + va_size;
while (amd_mem_obj && NextSubBufferPtr(amd_mem_obj) <= range_end_address) {
device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this);
dev_mem_obj->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
amd_mem_obj = amd::MemObjMap::FindMemObj(NextSubBufferPtr(amd_mem_obj));
}
device::Memory* dev_mem_obj = amd_mem_obj->getDeviceMemory(*this);
dev_mem_obj->SetAccess(static_cast<device::Memory::MemAccess>(access_flags));
return true;
}
@@ -61,8 +61,8 @@ static __global__ void square_kernel(int* Buff) {
// Simple HIP kernel: read from host-backed memory and write to a device buffer
__global__ void copyFromHostMem(const int* hostMem, int* devOut, int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) devOut[i] = hostMem[i];
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) devOut[i] = hostMem[i];
}
/**
@@ -490,6 +490,130 @@ TEST_CASE("Unit_hipMemSetAccess_ChangeAccessProp") {
CTX_DESTROY();
}
/**
* Test Description
* ------------------------
* - Create a VA range split into 3 segments. Map all of them.
* - Verify hipMemSetAccess() works when called on:
* - a single segment (3 calls: segment 0, segment 1, segment 2)
* - two segments (2 calls: segments 0-1, then segments 1-2)
* - the full range (1 call: segments 0-2)
* ------------------------
*/
TEST_CASE("Unit_hipMemSetAccess_SegmentsAccess") {
size_t granularity = 0;
int deviceId = 0;
hipDevice_t device;
CTX_CREATE();
HIP_CHECK(hipDeviceGet(&device, deviceId));
checkVMMSupported(device);
hipMemAllocationProp prop{};
prop.type = hipMemAllocationTypePinned;
prop.location.type = hipMemLocationTypeDevice;
prop.location.id = device;
HIP_CHECK(
hipMemGetAllocationGranularity(&granularity, &prop, hipMemAllocationGranularityMinimum));
REQUIRE(granularity > 0);
const size_t segment0_size = granularity;
const size_t segment1_size = granularity * 2;
const size_t segment2_size = granularity * 3;
const size_t total_size = segment0_size + segment1_size + segment2_size;
void* base = nullptr;
HIP_CHECK(hipMemAddressReserve(&base, total_size, 0, 0, 0));
auto* base_c = reinterpret_cast<char*>(base);
void* segment_0 = base_c;
void* segment_1 = base_c + segment0_size;
void* segment_2 = base_c + segment0_size + segment1_size;
hipMemGenericAllocationHandle_t handle_0{};
hipMemGenericAllocationHandle_t handle_1{};
hipMemGenericAllocationHandle_t handle_2{};
HIP_CHECK(hipMemCreate(&handle_0, segment0_size, &prop, 0));
HIP_CHECK(hipMemCreate(&handle_1, segment1_size, &prop, 0));
HIP_CHECK(hipMemCreate(&handle_2, segment2_size, &prop, 0));
HIP_CHECK(hipMemMap(segment_0, segment0_size, 0, handle_0, 0));
HIP_CHECK(hipMemMap(segment_1, segment1_size, 0, handle_1, 0));
HIP_CHECK(hipMemMap(segment_2, segment2_size, 0, handle_2, 0));
HIP_CHECK(hipMemRelease(handle_0));
HIP_CHECK(hipMemRelease(handle_1));
HIP_CHECK(hipMemRelease(handle_2));
hipMemAccessDesc rw{};
rw.location.type = hipMemLocationTypeDevice;
rw.location.id = device;
rw.flags = hipMemAccessFlagsProtReadWrite;
hipMemLocation location{};
location.type = hipMemLocationTypeDevice;
location.id = device;
unsigned long long flags = 0;
SECTION("Single segment access") {
HIP_CHECK(hipMemSetAccess(segment_0, segment0_size, &rw, 1));
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemSetAccess(segment_1, segment1_size, &rw, 1));
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemSetAccess(segment_2, segment2_size, &rw, 1));
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_2));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
}
SECTION("Two segments access") {
// First call targets segments 0 and 1.
HIP_CHECK(hipMemSetAccess(segment_0, segment0_size + segment1_size, &rw, 1));
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
// Second call targets segments 1 and 2.
HIP_CHECK(hipMemSetAccess(segment_1, segment1_size + segment2_size, &rw, 1));
flags = 0;
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_2));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
}
SECTION("All three segments access") {
HIP_CHECK(hipMemSetAccess(base, total_size, &rw, 1));
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_0));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_1));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
flags = 0;
HIP_CHECK(hipMemGetAccess(&flags, &location, segment_2));
REQUIRE(flags == hipMemAccessFlagsProtReadWrite);
}
HIP_CHECK(hipMemUnmap(segment_0, segment0_size));
HIP_CHECK(hipMemUnmap(segment_1, segment1_size));
HIP_CHECK(hipMemUnmap(segment_2, segment2_size));
HIP_CHECK(hipMemAddressFree(base, total_size));
CTX_DESTROY();
}
/**
* Test Description
* ------------------------
@@ -548,7 +672,7 @@ TEST_CASE("Unit_hipMemSetAccess_Vmm2UnifiedMemCpy") {
HIP_CHECK(hipMemcpyHtoD(reinterpret_cast<hipDeviceptr_t>(ptrA), ptrA_h, buffer_size));
HIP_CHECK(hipMalloc(reinterpret_cast<void**>(&ptrB), buffer_size));
HIP_CHECK(hipMemcpyDtoD(reinterpret_cast<hipDeviceptr_t>(ptrB),
reinterpret_cast<hipDeviceptr_t>(ptrA), buffer_size));
reinterpret_cast<hipDeviceptr_t>(ptrA), buffer_size));
HIP_CHECK(hipMemcpyDtoH(ptrB_h, reinterpret_cast<hipDeviceptr_t>(ptrB), buffer_size));
bool bPassed = true;
for (int idx = 0; idx < N; idx++) {
@@ -1474,7 +1598,7 @@ TEST_CASE("Unit_hipMemSetAccessHost_devicealloc") {
constexpr size_t N = 1024;
constexpr size_t bytes = N * sizeof(int);
//get minimum granularity
// get minimum granularity
size_t gran = 0;
HIP_CHECK(hipMemGetAllocationGranularity(&gran, &prop, hipMemAllocationGranularityMinimum));
size_t mapSize = ((bytes + gran - 1) / gran) * gran;