diff --git a/projects/rocr-runtime/src/fmm.c b/projects/rocr-runtime/src/fmm.c index 411aa8ff1c..87894c21fc 100644 --- a/projects/rocr-runtime/src/fmm.c +++ b/projects/rocr-runtime/src/fmm.c @@ -322,21 +322,6 @@ static void vm_split_area(manageble_aperture_t *app, vm_area_t *area, } static vm_object_t *vm_find_object_by_address(manageble_aperture_t *app, - const void *address) -{ - vm_object_t *cur = app->vm_objects; - - while (cur) { - if (address >= cur->start && - (uint64_t)address < ((uint64_t)cur->start + cur->size)) - break; - cur = cur->next; - } - - return cur; /* NULL if not found */ -} - -static vm_object_t *vm_find_object_by_start_address(manageble_aperture_t *app, const void *address, uint64_t size) { vm_object_t *cur = app->vm_objects; @@ -353,6 +338,21 @@ static vm_object_t *vm_find_object_by_start_address(manageble_aperture_t *app, return cur; /* NULL if not found */ } +static vm_object_t *vm_find_object_by_address_range(manageble_aperture_t *app, + const void *address) +{ + vm_object_t *cur = app->vm_objects; + + while (cur) { + if (address >= cur->start && + (uint64_t)address < ((uint64_t)cur->start + cur->size)) + break; + cur = cur->next; + } + + return cur; /* NULL if not found */ +} + static vm_object_t *vm_find_object_by_userptr(manageble_aperture_t *app, const void *address) { @@ -363,7 +363,23 @@ static vm_object_t *vm_find_object_by_userptr(manageble_aperture_t *app, if (cur->userptr == address) break; cur = cur->next; - }; + } + + return cur; /* NULL if not found */ +} + +static vm_object_t *vm_find_object_by_userptr_range(manageble_aperture_t *app, + const void *address) +{ + vm_object_t *cur = app->vm_objects; + + /* Look up the appropriate address range containing the given address */ + while (cur) { + if (address >= cur->userptr && + (uint64_t)address < (uint64_t)cur->userptr + cur->userptr_size) + break; + cur = cur->next; + } return cur; /* NULL if not found */ } @@ -1060,7 +1076,7 @@ static void __fmm_release(void *address, manageble_aperture_t *aperture) pthread_mutex_lock(&aperture->fmm_mutex); /* Find the object to retrieve the handle */ - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) { pthread_mutex_unlock(&aperture->fmm_mutex); return; @@ -1127,7 +1143,7 @@ void fmm_release(void *address) if (!found) { /* Release the vm object in CPUVM */ pthread_mutex_lock(&cpuvm_aperture.fmm_mutex); - object = vm_find_object_by_start_address(&cpuvm_aperture, address, 0); + object = vm_find_object_by_address(&cpuvm_aperture, address, 0); if (object) vm_remove_object(&cpuvm_aperture, object); pthread_mutex_unlock(&cpuvm_aperture.fmm_mutex); @@ -1407,7 +1423,7 @@ static int _fmm_map_to_gpu_gtt(manageble_aperture_t *aperture, object = obj; if (!object) { /* Find the object to retrieve the handle */ - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) goto err_object_not_found; } @@ -1535,7 +1551,7 @@ static int _fmm_map_to_gpu(uint32_t gpu_id, manageble_aperture_t *aperture, pthread_mutex_lock(&aperture->fmm_mutex); /* Find the object to retrieve the handle */ - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) goto err_object_not_found; @@ -1702,7 +1718,7 @@ static int _fmm_unmap_from_gpu(manageble_aperture_t *aperture, void *address, /* Find the object to retrieve the handle */ object = obj; if (!object) { - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) { ret = -1; goto err; @@ -1774,7 +1790,7 @@ static int _fmm_unmap_from_gpu_scratch(uint32_t gpu_id, pthread_mutex_lock(&aperture->fmm_mutex); /* Find the object to retrieve the handle and size */ - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) goto err; @@ -2062,7 +2078,7 @@ bool fmm_get_handle(void *address, uint64_t *handle) pthread_mutex_lock(&aperture->fmm_mutex); /* Find the object to retrieve the handle */ - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (object && handle) { *handle = object->handle; found = true; @@ -2151,7 +2167,7 @@ HSAKMT_STATUS fmm_register_memory(void *address, uint64_t size_in_bytes, if (!object) { pthread_mutex_lock(&aperture->fmm_mutex); - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); pthread_mutex_unlock(&aperture->fmm_mutex); } @@ -2331,7 +2347,7 @@ HSAKMT_STATUS fmm_deregister_memory(void *address) pthread_mutex_lock(&aperture->fmm_mutex); - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) { pthread_mutex_unlock(&aperture->fmm_mutex); return HSAKMT_STATUS_MEMORY_NOT_REGISTERED; @@ -2404,7 +2420,7 @@ HSAKMT_STATUS fmm_map_to_gpu_nodes(void *address, uint64_t size, if (userptr && is_dgpu) object = vm_find_object_by_userptr(aperture, address); else - object = vm_find_object_by_start_address(aperture, address, 0); + object = vm_find_object_by_address(aperture, address, 0); if (!object) { pthread_mutex_unlock(&aperture->fmm_mutex); @@ -2505,9 +2521,9 @@ HSAKMT_STATUS fmm_get_mem_info(const void *address, HsaPointerInfo *info) aperture = fmm_find_aperture(address); - vm_obj = vm_find_object_by_address(aperture, address); + vm_obj = vm_find_object_by_address_range(aperture, address); if (!vm_obj) - vm_obj = vm_find_object_by_userptr(aperture, address); + vm_obj = vm_find_object_by_userptr_range(aperture, address); if (!vm_obj) { info->Type = HSA_POINTER_UNKNOWN; @@ -2577,7 +2593,7 @@ HSAKMT_STATUS fmm_set_mem_user_data(const void *mem, void *usr_data) aperture = fmm_find_aperture(mem); - vm_obj = vm_find_object_by_start_address(aperture, mem, 0); + vm_obj = vm_find_object_by_address(aperture, mem, 0); if (!vm_obj) vm_obj = vm_find_object_by_userptr(aperture, mem); if (!vm_obj)