libhsakmt/virtio: add non SVM mode in libhsakmt virtio driver and many fixes (#1756)

* libhsakmt/virtio: change shmem size to 80

Some DGPU props have a lot of information,
so it is necessary to increase the size of shmem.

Signed-off-by: Honglei Huang <honghuan@amd.com>

* libhsakmt/virtio: use BO handle instead of pointer in memory registration

Change vhsakmt_map_to_gpu() return type from void* to vhsakmt_bo_handle
to properly handle buffer object information. This allows access to
both the host address and resource ID needed for memory registration.

Signed-off-by: Honglei Huang <honghuan@amd.com>

* libhsakmt/virtio: Improve memory mapping logic

- Update vhsakmt_mappable() to check NoAddress flag and require HostAccess
- Remove mappable checks in cpu_map/unmap to allow all BOs to be mapped
- Set BO flags properly in vhsakmt_alloc_memory and scratch memory creation
- Ensure scratch memory is correctly flagged for proper handling

Signed-off-by: Honglei Huang <honghuan@amd.com>

* libhsakmt/virtio: add no svm mode for libhsakmt virtio

Add no svm mode for libhsakmt virtio driver, in no svm mode userptrs
need UMD to manage, so add interval tree to manage them.

New Features:
- Add augmented red-black tree based interval tree implementation
  * Implement RB-tree insertion, deletion, and color balancing
  * Provide interval query for fast overlapping range lookup
  * Based on Linux kernel's augmented rbtree implementation

- Improve userptr memory management
  * Use interval tree to efficiently track userptr memory regions
  * Support finding registered memory within given address ranges
  * Optimize memory mapping and unmapping performance

Signed-off-by: Honglei Huang <honghuan@amd.com>

---------

Signed-off-by: Honglei Huang <honghuan@amd.com>
This commit is contained in:
Honglei Huang
2025-11-28 09:20:43 +08:00
committed by GitHub
vanhempi 792ecc1a83
commit aaa06e1609
9 muutettua tiedostoa jossa 943 lisäystä ja 41 poistoa
@@ -29,6 +29,10 @@
#if defined(__linux__)
#include "hsakmt/linux/kfd_ioctl.h"
#endif
// Forward declaration for HsaKFDContext to avoid dependency issues
typedef struct _HsaKFDContext HsaKFDContext;
#include "hsakmt/hsakmt.h"
#include <libdrm/amdgpu.h>
@@ -53,7 +53,8 @@ set ( HSAKMT_VIRTIO_SRC "virtio_gpu.c"
"hsakmt_virtio_queues.c"
"hsakmt_virtio_topology.c"
"hsakmt_virtio_openclose.c"
"../rbtree.c" )
"../rbtree.c"
"hsakmt_interval_tree.c" )
add_library ( ${HSAKMT_VIRTIO_TARGET} STATIC ${HSAKMT_VIRTIO_SRC} )
@@ -0,0 +1,534 @@
/*
* HSAKMT Interval Tree Implementation
* Based on Linux kernel's augmented red-black tree implementation
*
* This implementation is derived from the Linux kernel source code
* (lib/rbtree.c, lib/rbtree_augmented.h, lib/interval_tree_generic.h)
* and simplified for use in this project.
*
* Copyright (C) 2025 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies
* of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice (including
* the next paragraph) shall be included in all copies or substantial
* portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#include "hsakmt_interval_tree.h"
#include <stddef.h>
static inline void rb_set_parent(struct rb_node* rb, struct rb_node* p) {
rb->__rb_parent_color = rb_color(rb) | (unsigned long)p;
}
static inline void rb_set_parent_color(struct rb_node* rb, struct rb_node* p, int color) {
rb->__rb_parent_color = (unsigned long)p | color;
}
static inline void rb_set_black(struct rb_node* rb) { rb->__rb_parent_color |= RB_BLACK; }
static inline struct rb_node* rb_red_parent(struct rb_node* red) {
return (struct rb_node*)red->__rb_parent_color;
}
static inline void __rb_change_child(struct rb_node* old, struct rb_node* new,
struct rb_node* parent, interval_tree_t* root) {
if (parent) {
if (parent->rb_left == old)
parent->rb_left = new;
else
parent->rb_right = new;
} else {
root->rb_node = new;
}
}
static inline void __rb_rotate_set_parents(struct rb_node* old, struct rb_node* new,
interval_tree_t* root, int color) {
struct rb_node* parent = rb_parent(old);
new->__rb_parent_color = old->__rb_parent_color;
rb_set_parent_color(old, new, color);
__rb_change_child(old, new, parent, root);
}
static inline unsigned long compute_subtree_last(interval_tree_node_t* node);
static void interval_tree_propagate(struct rb_node* rb, struct rb_node* stop);
static void interval_tree_copy(struct rb_node* rb_old, struct rb_node* rb_new);
static void interval_tree_rotate(struct rb_node* rb_old, struct rb_node* rb_new);
void __rb_insert_color(struct rb_node* node, interval_tree_t* root,
void (*augment_rotate)(struct rb_node* old, struct rb_node* new)) {
struct rb_node *parent = rb_red_parent(node), *gparent, *tmp;
while (true) {
if (!parent) {
rb_set_parent_color(node, NULL, RB_BLACK);
break;
} else if (rb_is_black(parent)) {
break;
}
gparent = rb_red_parent(parent);
tmp = gparent->rb_right;
if (parent != tmp) {
if (tmp && rb_is_red(tmp)) {
rb_set_parent_color(tmp, gparent, RB_BLACK);
rb_set_parent_color(parent, gparent, RB_BLACK);
node = gparent;
parent = rb_parent(node);
rb_set_parent_color(node, parent, RB_RED);
continue;
}
tmp = parent->rb_right;
if (node == tmp) {
parent->rb_right = tmp = node->rb_left;
node->rb_left = parent;
if (tmp) rb_set_parent_color(tmp, parent, RB_BLACK);
rb_set_parent_color(parent, node, RB_RED);
augment_rotate(parent, node);
parent = node;
tmp = node->rb_right;
}
gparent->rb_left = tmp;
parent->rb_right = gparent;
if (tmp) rb_set_parent_color(tmp, gparent, RB_BLACK);
__rb_rotate_set_parents(gparent, parent, root, RB_RED);
augment_rotate(gparent, parent);
break;
} else {
tmp = gparent->rb_left;
if (tmp && rb_is_red(tmp)) {
rb_set_parent_color(tmp, gparent, RB_BLACK);
rb_set_parent_color(parent, gparent, RB_BLACK);
node = gparent;
parent = rb_parent(node);
rb_set_parent_color(node, parent, RB_RED);
continue;
}
tmp = parent->rb_left;
if (node == tmp) {
parent->rb_left = tmp = node->rb_right;
node->rb_right = parent;
if (tmp) rb_set_parent_color(tmp, parent, RB_BLACK);
rb_set_parent_color(parent, node, RB_RED);
augment_rotate(parent, node);
parent = node;
tmp = node->rb_left;
}
gparent->rb_right = tmp;
parent->rb_left = gparent;
if (tmp) rb_set_parent_color(tmp, gparent, RB_BLACK);
__rb_rotate_set_parents(gparent, parent, root, RB_RED);
augment_rotate(gparent, parent);
break;
}
}
}
void __rb_erase_color(struct rb_node* parent, interval_tree_t* root,
void (*augment_rotate)(struct rb_node* old, struct rb_node* new)) {
struct rb_node *node = NULL, *sibling, *tmp1, *tmp2;
while (true) {
sibling = parent->rb_right;
if (node != sibling) {
if (rb_is_red(sibling)) {
parent->rb_right = tmp1 = sibling->rb_left;
sibling->rb_left = parent;
rb_set_parent_color(tmp1, parent, RB_BLACK);
__rb_rotate_set_parents(parent, sibling, root, RB_RED);
augment_rotate(parent, sibling);
sibling = tmp1;
}
tmp1 = sibling->rb_right;
if (!tmp1 || rb_is_black(tmp1)) {
tmp2 = sibling->rb_left;
if (!tmp2 || rb_is_black(tmp2)) {
rb_set_parent_color(sibling, parent, RB_RED);
if (rb_is_red(parent))
rb_set_black(parent);
else {
node = parent;
parent = rb_parent(node);
if (parent) continue;
}
break;
}
sibling->rb_left = tmp1 = tmp2->rb_right;
tmp2->rb_right = sibling;
parent->rb_right = tmp2;
if (tmp1) rb_set_parent_color(tmp1, sibling, RB_BLACK);
augment_rotate(sibling, tmp2);
tmp1 = sibling;
sibling = tmp2;
}
parent->rb_right = tmp2 = sibling->rb_left;
sibling->rb_left = parent;
rb_set_parent_color(tmp1, sibling, RB_BLACK);
if (tmp2) rb_set_parent(tmp2, parent);
__rb_rotate_set_parents(parent, sibling, root, RB_BLACK);
augment_rotate(parent, sibling);
break;
} else {
sibling = parent->rb_left;
if (rb_is_red(sibling)) {
parent->rb_left = tmp1 = sibling->rb_right;
sibling->rb_right = parent;
rb_set_parent_color(tmp1, parent, RB_BLACK);
__rb_rotate_set_parents(parent, sibling, root, RB_RED);
augment_rotate(parent, sibling);
sibling = tmp1;
}
tmp1 = sibling->rb_left;
if (!tmp1 || rb_is_black(tmp1)) {
tmp2 = sibling->rb_right;
if (!tmp2 || rb_is_black(tmp2)) {
rb_set_parent_color(sibling, parent, RB_RED);
if (rb_is_red(parent))
rb_set_black(parent);
else {
node = parent;
parent = rb_parent(node);
if (parent) continue;
}
break;
}
sibling->rb_right = tmp1 = tmp2->rb_left;
tmp2->rb_left = sibling;
parent->rb_left = tmp2;
if (tmp1) rb_set_parent_color(tmp1, sibling, RB_BLACK);
augment_rotate(sibling, tmp2);
tmp1 = sibling;
sibling = tmp2;
}
parent->rb_left = tmp2 = sibling->rb_right;
sibling->rb_right = parent;
rb_set_parent_color(tmp1, sibling, RB_BLACK);
if (tmp2) rb_set_parent(tmp2, parent);
__rb_rotate_set_parents(parent, sibling, root, RB_BLACK);
augment_rotate(parent, sibling);
break;
}
}
}
struct rb_node* __rb_erase_node(struct rb_node* node, interval_tree_t* root,
void (*augment_rotate)(struct rb_node* old, struct rb_node* new)) {
struct rb_node *child = node->rb_right, *tmp = node->rb_left;
struct rb_node *parent, *rebalance;
unsigned long pc;
if (!tmp) {
/*
* Case 1: node to erase has no more than 1 child (easy!)
*
* Note that if there is one child it must be red due to 5)
* and node must be black due to 4). We adjust colors locally
* so as to bypass __rb_erase_color() later on.
*/
pc = node->__rb_parent_color;
parent = __rb_parent(pc);
__rb_change_child(node, child, parent, root);
if (child) {
child->__rb_parent_color = pc;
rebalance = NULL;
} else {
rebalance = __rb_is_black(pc) ? parent : NULL;
}
tmp = parent;
} else if (!child) {
/* Still case 1, but this time the child is node->rb_left */
tmp->__rb_parent_color = pc = node->__rb_parent_color;
parent = __rb_parent(pc);
__rb_change_child(node, tmp, parent, root);
rebalance = NULL;
tmp = parent;
} else {
struct rb_node *successor = child, *child2;
tmp = child->rb_left;
if (!tmp) {
/*
* Case 2: node's successor is its right child
*
* (n) (s)
* / \ / \
* (x) (s) -> (x) (c)
* \
* (c)
*/
parent = successor;
child2 = successor->rb_right;
/* Copy augmented data: successor takes node's place */
interval_tree_copy(node, successor);
} else {
/*
* Case 3: node's successor is leftmost under
* node's right child subtree
*
* (n) (s)
* / \ / \
* (x) (y) -> (x) (y)
* / /
* (p) (p)
* / /
* (s) (c)
* \
* (c)
*/
do {
parent = successor;
successor = tmp;
tmp = tmp->rb_left;
} while (tmp);
parent->rb_left = child2 = successor->rb_right;
successor->rb_right = child;
rb_set_parent(child, successor);
/* Copy augmented data */
interval_tree_copy(node, successor);
/* Propagate changes up from parent */
interval_tree_propagate(parent, successor);
}
successor->rb_left = tmp = node->rb_left;
rb_set_parent(tmp, successor);
pc = node->__rb_parent_color;
tmp = __rb_parent(pc);
__rb_change_child(node, successor, tmp, root);
if (child2) {
successor->__rb_parent_color = pc;
rb_set_parent_color(child2, parent, RB_BLACK);
rebalance = NULL;
} else {
unsigned long pc2 = successor->__rb_parent_color;
successor->__rb_parent_color = pc;
rebalance = __rb_is_black(pc2) ? parent : NULL;
}
tmp = successor;
}
/* Propagate augmented data changes */
interval_tree_propagate(tmp, NULL);
return rebalance;
}
struct rb_node* rb_first(const interval_tree_t* root) {
struct rb_node* n;
n = root->rb_node;
if (!n) return NULL;
while (n->rb_left) n = n->rb_left;
return n;
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wcast-qual"
struct rb_node* rb_next(const struct rb_node* node) {
struct rb_node* parent;
if (RB_EMPTY_NODE(node)) return NULL;
if (node->rb_right) {
node = node->rb_right;
while (node->rb_left) node = node->rb_left;
return (struct rb_node*)node;
}
while ((parent = rb_parent(node)) && node == parent->rb_right) node = parent;
return parent;
}
#pragma GCC diagnostic pop
static inline unsigned long compute_subtree_last(interval_tree_node_t* node) {
unsigned long max = node->last;
unsigned long subtree_last;
if (node->rb.rb_left) {
interval_tree_node_t* left = rb_entry(node->rb.rb_left, interval_tree_node_t, rb);
subtree_last = left->__subtree_last;
if (max < subtree_last) max = subtree_last;
}
if (node->rb.rb_right) {
interval_tree_node_t* right = rb_entry(node->rb.rb_right, interval_tree_node_t, rb);
subtree_last = right->__subtree_last;
if (max < subtree_last) max = subtree_last;
}
return max;
}
static void interval_tree_propagate(struct rb_node* rb, struct rb_node* stop) {
while (rb != stop) {
interval_tree_node_t* node = rb_entry(rb, interval_tree_node_t, rb);
unsigned long augmented = compute_subtree_last(node);
if (node->__subtree_last == augmented) break;
node->__subtree_last = augmented;
rb = rb_parent(&node->rb);
}
}
static void interval_tree_copy(struct rb_node* rb_old, struct rb_node* rb_new) {
interval_tree_node_t* old = rb_entry(rb_old, interval_tree_node_t, rb);
interval_tree_node_t* new = rb_entry(rb_new, interval_tree_node_t, rb);
new->__subtree_last = old->__subtree_last;
}
static void interval_tree_rotate(struct rb_node* rb_old, struct rb_node* rb_new) {
interval_tree_node_t* old = rb_entry(rb_old, interval_tree_node_t, rb);
interval_tree_node_t* new = rb_entry(rb_new, interval_tree_node_t, rb);
new->__subtree_last = old->__subtree_last;
old->__subtree_last = compute_subtree_last(old);
}
static void interval_tree_augment_rotate_wrapper(struct rb_node* rb_old, struct rb_node* rb_new) {
if (rb_old && rb_new) {
interval_tree_rotate(rb_old, rb_new);
} else if (rb_old) {
/* Called from __rb_erase_node with rb_new == NULL */
interval_tree_propagate(rb_old, NULL);
}
}
void hsakmt_interval_tree_insert(interval_tree_t* root, interval_tree_node_t* node) {
struct rb_node** link = &root->rb_node;
struct rb_node* rb_parent = NULL;
unsigned long start = node->start;
unsigned long last = node->last;
interval_tree_node_t* parent;
while (*link) {
rb_parent = *link;
parent = rb_entry(rb_parent, interval_tree_node_t, rb);
if (parent->__subtree_last < last) parent->__subtree_last = last;
if (start < parent->start)
link = &parent->rb.rb_left;
else
link = &parent->rb.rb_right;
}
node->__subtree_last = last;
rb_link_node(&node->rb, rb_parent, link);
__rb_insert_color(&node->rb, root, interval_tree_augment_rotate_wrapper);
}
void hsakmt_interval_tree_remove(interval_tree_t* root, interval_tree_node_t* node) {
struct rb_node* rebalance;
rebalance = __rb_erase_node(&node->rb, root, interval_tree_augment_rotate_wrapper);
if (rebalance) __rb_erase_color(rebalance, root, interval_tree_augment_rotate_wrapper);
}
static interval_tree_node_t* interval_tree_subtree_search(interval_tree_node_t* node,
unsigned long start, unsigned long last) {
while (true) {
/*
* Loop invariant: start <= node->__subtree_last
* (Cond2 is satisfied by one of the subtree nodes)
*/
if (node->rb.rb_left) {
interval_tree_node_t* left = rb_entry(node->rb.rb_left, interval_tree_node_t, rb);
if (start <= left->__subtree_last) {
/*
* Some nodes in left subtree satisfy Cond2.
* Iterate to find the leftmost such node N.
* If it also satisfies Cond1, that's the match
* we are looking for. Otherwise, there is no
* matching interval as nodes to the right of N
* can't satisfy Cond1 either.
*/
node = left;
continue;
}
}
if (node->start <= last) { /* Cond1 */
if (start <= node->last) /* Cond2 */
return node; /* Match */
if (node->rb.rb_right) {
node = rb_entry(node->rb.rb_right, interval_tree_node_t, rb);
if (start <= node->__subtree_last) continue;
}
}
return NULL; /* No match */
}
}
interval_tree_node_t* hsakmt_interval_tree_iter_first(interval_tree_t* root, unsigned long start,
unsigned long last) {
interval_tree_node_t* node;
if (!root->rb_node) return NULL;
node = rb_entry(root->rb_node, interval_tree_node_t, rb);
if (node->__subtree_last < start) return NULL;
return interval_tree_subtree_search(node, start, last);
}
interval_tree_node_t* hsakmt_interval_tree_iter_next(interval_tree_t* root,
interval_tree_node_t* node,
unsigned long start, unsigned long last) {
struct rb_node* rb = node->rb.rb_right;
struct rb_node* prev;
/* Note: root parameter is unused but kept for API consistency */
(void)root;
while (true) {
/*
* Loop invariants:
* Cond1: node->start <= last
* rb == node->rb.rb_right
*
* First, search right subtree if suitable
*/
if (rb) {
interval_tree_node_t* right = rb_entry(rb, interval_tree_node_t, rb);
if (start <= right->__subtree_last) return interval_tree_subtree_search(right, start, last);
}
/* Move up the tree until we come from a node's left child */
do {
rb = rb_parent(&node->rb);
if (!rb) return NULL;
prev = &node->rb;
node = rb_entry(rb, interval_tree_node_t, rb);
rb = node->rb.rb_right;
} while (prev == rb);
/* Check if the node intersects [start;last] */
if (last < node->start) /* !Cond1 */
return NULL;
else if (start <= node->last) /* Cond2 */
return node;
}
}
@@ -0,0 +1,197 @@
/*
* HSAKMT Interval Tree Implementation
* Based on Linux kernel's augmented red-black tree implementation
*
* This implementation is derived from the Linux kernel source code
* (lib/rbtree.c, lib/rbtree_augmented.h, lib/interval_tree_generic.h)
* and simplified for use in this project.
*
* Copyright (C) 2025 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies
* of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice (including
* the next paragraph) shall be included in all copies or substantial
* portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#ifndef _HSAKMT_INTERVAL_TREE_H_
#define _HSAKMT_INTERVAL_TREE_H_
#include <stddef.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
struct rb_node {
unsigned long __rb_parent_color;
struct rb_node* rb_right;
struct rb_node* rb_left;
} __attribute__((aligned(sizeof(long))));
struct rb_root {
struct rb_node* rb_node;
};
#define RB_ROOT \
(struct rb_root) { NULL }
typedef struct interval_tree_node {
struct rb_node rb;
unsigned long start; /* Start of interval */
unsigned long last; /* Last location in interval */
unsigned long __subtree_last; /* Max 'last' in this subtree */
} interval_tree_node_t;
typedef struct rb_root interval_tree_t;
#define rb_parent(r) ((struct rb_node*)((r)->__rb_parent_color & ~3UL))
#define __rb_parent(pc) ((struct rb_node*)((pc) & ~3UL))
#define RB_RED 0
#define RB_BLACK 1
#define __rb_color(pc) ((pc)&1)
#define __rb_is_black(pc) __rb_color(pc)
#define __rb_is_red(pc) (!__rb_color(pc))
#define rb_color(rb) ((rb)->__rb_parent_color & 1)
#define rb_is_red(rb) (!rb_color(rb))
#define rb_is_black(rb) rb_color(rb)
#define rb_entry(ptr, type, member) ((type*)((char*)(ptr)-offsetof(type, member)))
#define rb_entry_safe(ptr, type, member) \
({ \
typeof(ptr) ____ptr = (ptr); \
____ptr ? rb_entry(____ptr, type, member) : NULL; \
})
#define RB_EMPTY_ROOT(root) ((root)->rb_node == NULL)
#define RB_EMPTY_NODE(node) (rb_parent(node) == (node))
#define RB_CLEAR_NODE(node) ((node)->__rb_parent_color = (unsigned long)(node))
#define interval_start(node) ((node)->start)
#define interval_last(node) ((node)->last)
#define interval_tree_init(itree) ((itree)->rb_node = NULL)
/**
* rb_link_node - Link a new node into the tree
* @node: new node to link
* @parent: parent node
* @rb_link: pointer to the parent's link (left or right child pointer)
*/
static inline void rb_link_node(struct rb_node* node, struct rb_node* parent,
struct rb_node** rb_link) {
node->__rb_parent_color = (unsigned long)parent;
node->rb_left = node->rb_right = NULL;
*rb_link = node;
}
void __rb_insert_color(struct rb_node* node, struct rb_root* root,
void (*augment_rotate)(struct rb_node* old, struct rb_node* new));
struct rb_node* __rb_erase_node(struct rb_node* node, struct rb_root* root,
void (*augment_rotate)(struct rb_node* old, struct rb_node* new));
void __rb_erase_color(struct rb_node* parent, struct rb_root* root,
void (*augment_rotate)(struct rb_node* old, struct rb_node* new));
struct rb_node* rb_first(const struct rb_root* root);
struct rb_node* rb_next(const struct rb_node* node);
/**
* hsakmt_interval_tree_insert - Insert an interval node into the tree
* @itree: root of the interval tree
* @node: interval node to insert
*/
void hsakmt_interval_tree_insert(interval_tree_t* itree, interval_tree_node_t* node);
/**
* hsakmt_interval_tree_remove - Remove an interval node from the tree
* @itree: root of the interval tree
* @node: interval node to remove
*/
void hsakmt_interval_tree_remove(interval_tree_t* itree, interval_tree_node_t* node);
/**
* hsakmt_interval_tree_iter_first - Find first interval overlapping [start, last]
* @itree: root of the interval tree
* @start: start of the query interval
* @last: last location of the query interval
*
* Returns: First overlapping interval node, or NULL if none found
*/
interval_tree_node_t* hsakmt_interval_tree_iter_first(interval_tree_t* itree, unsigned long start,
unsigned long last);
/**
* hsakmt_interval_tree_iter_next - Find next interval overlapping [start, last]
* @itree: root of the interval tree (unused, kept for API consistency)
* @node: previous interval node returned by iter_first or iter_next
* @start: start of the query interval
* @last: last location of the query interval
*
* Returns: Next overlapping interval node, or NULL if no more found
*/
interval_tree_node_t* hsakmt_interval_tree_iter_next(interval_tree_t* itree,
interval_tree_node_t* node,
unsigned long start, unsigned long last);
/**
* interval_tree_node_init - Initialize an interval tree node
* @node: node to initialize
* @start: start of interval
* @last: last location of interval
*/
static inline void interval_tree_node_init(interval_tree_node_t* node, unsigned long start,
unsigned long last) {
node->start = start;
node->last = last;
node->__subtree_last = last;
node->rb.rb_left = NULL;
node->rb.rb_right = NULL;
node->rb.__rb_parent_color = 0;
}
/**
* interval_tree_overlap - Check if two intervals overlap
* @start1: start of first interval
* @last1: last of first interval
* @start2: start of second interval
* @last2: last of second interval
*
* Returns: non-zero if intervals overlap, 0 otherwise
*/
static inline int interval_tree_overlap(unsigned long start1, unsigned long last1,
unsigned long start2, unsigned long last2) {
return start1 <= last2 && start2 <= last1;
}
#ifdef __cplusplus
}
#endif
#endif /* _HSAKMT_INTERVAL_TREE_H_ */
@@ -27,6 +27,7 @@
#include "hsakmt_virtio_proto.h"
#include "rbtree.h"
#include "hsakmt_interval_tree.h"
#include "virtio_gpu.h"
#include <stdatomic.h>
@@ -98,6 +99,7 @@ struct vhsakmt_device {
int refcount;
pthread_mutex_t bo_handles_mutex;
rbtree_t bo_rbt;
interval_tree_t userptr_tree;
struct vhsakmt_bo* shmem_bo;
@@ -110,10 +112,13 @@ struct vhsakmt_device {
pthread_mutex_t vhsakmt_mutex;
struct vhsakmt_node* vhsakmt_nodes;
HsaSystemProperties* sys_props;
bool use_svm;
};
struct vhsakmt_bo {
rbtree_node_t rbtn;
interval_tree_node_t itn;
struct vhsakmt_device* dev;
int refcount;
@@ -22,6 +22,8 @@
#include "hsakmt/hsakmt_virtio.h"
#include "hsakmt_virtio_device.h"
#include <unistd.h>
#include <xf86drm.h>
#define VHSA_GL_METADATA_MAX_SIZE (0x50)
@@ -29,9 +31,11 @@ vhsakmt_bo_handle vhsakmt_entry_to_bo_handle(bo_entry e) { return (vhsakmt_bo_ha
bo_entry vhsakmt_bo_handle_to_entry(vhsakmt_bo_handle bo) { return &bo->rbtn; }
static inline bool vhsakmt_is_mem_bo(vhsakmt_bo_handle bo) { return (!bo->queue_id && !bo->event); }
static bool vhsakmt_mappable(HsaMemFlags flags) { return (!flags.ui32.Scratch); }
static bool vhsakmt_mappable(HsaMemFlags flags) {
if (flags.ui32.Scratch || flags.ui32.NoAddress) return false;
static bool vhsakmt_bo_mappable(vhsakmt_bo_handle bo) { return vhsakmt_mappable(bo->flags); }
return flags.ui32.HostAccess;
}
void vhsakmt_insert_bo(vhsakmt_device_handle dev, vhsakmt_bo_handle bo, void* addr, uint64_t size) {
bo->rbtn.key.addr = (unsigned long)addr;
@@ -90,6 +94,51 @@ vhsakmt_bo_handle vhsakmt_find_bo_by_addr(vhsakmt_device_handle dev, void* addr)
return NULL;
}
static void vhsakmt_insert_userptr(vhsakmt_device_handle dev, vhsakmt_bo_handle userptr) {
if (!(userptr->bo_type & VHSA_BO_USERPTR)) return;
interval_tree_node_init(&userptr->itn, (unsigned long)userptr->cpu_addr,
(unsigned long)userptr->cpu_addr + userptr->size - 1UL);
pthread_mutex_lock(&dev->bo_handles_mutex);
hsakmt_interval_tree_insert(&dev->userptr_tree, &userptr->itn);
pthread_mutex_unlock(&dev->bo_handles_mutex);
}
static vhsakmt_bo_handle vhsakmt_find_userptr(vhsakmt_device_handle dev, unsigned long addr,
unsigned long last) {
interval_tree_node_t* n;
pthread_mutex_lock(&dev->bo_handles_mutex);
n = hsakmt_interval_tree_iter_first(&dev->userptr_tree, addr, last);
while (n) {
vhsakmt_bo_handle bo = (vhsakmt_bo_handle)((char*)n - offsetof(struct vhsakmt_bo, itn));
if ((unsigned long)bo->cpu_addr <= addr &&
((unsigned long)bo->cpu_addr + bo->size - 1UL) >= last) {
pthread_mutex_unlock(&dev->bo_handles_mutex);
return bo;
}
n = hsakmt_interval_tree_iter_next(&dev->userptr_tree, n, addr, last);
}
pthread_mutex_unlock(&dev->bo_handles_mutex);
return NULL;
}
static void vhsakmt_destroy_userptr(vhsakmt_device_handle dev, vhsakmt_bo_handle bo) {
hsakmt_interval_tree_remove(&dev->userptr_tree, &bo->itn);
pthread_mutex_destroy(&bo->map_mutex);
struct drm_gem_close drm_req = {
.handle = bo->real.handle,
};
drmIoctl(dev->vgdev->fd, DRM_IOCTL_GEM_CLOSE, &drm_req);
free(bo);
}
void* vhsakmt_gpu_va(vhsakmt_device_handle dev, void* va) {
if (!vhsakmt_is_userptr(dev, va)) return va;
@@ -103,8 +152,6 @@ void* vhsakmt_gpu_va(vhsakmt_device_handle dev, void* va) {
int vhsakmt_bo_cpu_map(vhsakmt_bo_handle bo, void** cpu, void* fixed_cpu) {
int r;
if (!vhsakmt_bo_mappable(bo)) return 0;
pthread_mutex_lock(&bo->map_mutex);
if (!bo->cpu_addr) {
@@ -124,8 +171,6 @@ int vhsakmt_bo_cpu_map(vhsakmt_bo_handle bo, void** cpu, void* fixed_cpu) {
int vhsakmt_bo_cpu_unmap(vhsakmt_bo_handle bo) {
int r = 0;
if (!vhsakmt_bo_mappable(bo)) return 0;
pthread_mutex_lock(&bo->map_mutex);
if (!bo->cpu_addr || bo->real.map_count == 0) {
@@ -280,6 +325,7 @@ HSAKMT_STATUS HSAKMTAPI vhsaKmtAllocMemory(HSAuint32 PreferredNode, HSAuint64 Si
vhsakmt_mappable(MemFlags) ? VIRTGPU_BLOB_FLAG_USE_MAPPABLE : 0,
req.blob_id, VHSA_BO_KFD_MEM, (void*)rsp->memory_handle, &bo);
if (r) return r;
bo->flags = MemFlags;
if (!vhsakmt_mappable(MemFlags)) {
bo->cpu_addr = bo->host_addr;
@@ -379,8 +425,18 @@ HSAKMT_STATUS HSAKMTAPI vhsaKmtMapMemoryToGPUNodes(void* MemoryAddress, HSAuint6
if (bo) {
req->map_to_GPU_nodes_args.MemoryAddress = (uint64_t)bo->host_addr;
if (bo->bo_type & VHSA_BO_USERPTR) vhsakmt_remove_userptr_bo(dev, bo);
} else
req->map_to_GPU_nodes_args.MemoryAddress = (uint64_t)MemoryAddress;
} else if (!dev->use_svm) {
bo = vhsakmt_find_userptr(dev, (uint64_t)MemoryAddress,
(uint64_t)MemoryAddress + MemorySizeInBytes - 1UL);
if (bo)
req->map_to_GPU_nodes_args.MemoryAddress =
(uint64_t)bo->host_addr + ((char*)MemoryAddress - (char*)bo->cpu_addr);
}
if (!bo) {
free(req);
return HSAKMT_STATUS_INVALID_HANDLE;
}
memcpy(req->payload, NodeArray, NumberOfNodes * sizeof(*NodeArray));
@@ -477,6 +533,7 @@ static int vhsakmt_create_scratch_map_memory(vhsakmt_device_handle dev, void* Me
// TODO: insert scratch bo into rbtree, or insert it in dev nodes.
out->flags.ui32.Scratch = 1;
out->cpu_addr = MemoryAddress;
out->host_addr = (void*)rsp->memory_handle;
*AlternateVAGPU = rsp->alternate_vagpu;
@@ -509,7 +566,17 @@ HSAKMT_STATUS HSAKMTAPI vhsaKmtMapMemoryToGPU(void* MemoryAddress, HSAuint64 Mem
},
};
if (bo && (bo->bo_type & VHSA_BO_USERPTR)) vhsakmt_remove_userptr_bo(dev, bo);
if (bo && (bo->bo_type & VHSA_BO_USERPTR)) {
vhsakmt_remove_userptr_bo(dev, bo);
} else if (!bo) {
bo = vhsakmt_find_userptr(dev, (uint64_t)MemoryAddress,
(uint64_t)MemoryAddress + MemorySizeInBytes - 1UL);
if (bo)
req.map_to_GPU_args.MemoryAddress =
(uint64_t)bo->host_addr + ((char*)MemoryAddress - (char*)bo->cpu_addr);
}
if (!bo) return HSAKMT_STATUS_INVALID_HANDLE;
rsp = vhsakmt_alloc_rsp(dev, &req.hdr, sizeof(struct vhsakmt_ccmd_memory_rsp));
if (!rsp) return -ENOMEM;
@@ -542,31 +609,46 @@ static int vhsakmt_map_userptr(vhsakmt_device_handle dev, void* addr, size_t siz
return rsp->ret;
}
static void* vhsakmt_map_to_gpu(void* addr, size_t size) {
static vhsakmt_bo_handle vhsakmt_map_to_gpu(void* addr, size_t size, bool use_svm) {
vhsakmt_device_handle dev = vhsakmt_dev();
size_t offset = (uint64_t)addr % getpagesize();
size_t map_size = (VHSA_ALIGN_UP(size + offset, getpagesize()) / getpagesize()) * getpagesize();
uint64_t userptr_offset, userptr_handle = 0;
size_t page_size = getpagesize();
size_t addr_offset = (uint64_t)addr % page_size;
void* blob_addr;
size_t blob_size;
uint64_t userptr_offset = 0, userptr_handle = 0;
vhsakmt_bo_handle userptr;
int r;
vhsa_debug("%s: addr: %p, size: 0x%lx, size + offset: 0x%lx, map_size: 0x%lx\n", __FUNCTION__,
addr, size, size + offset, map_size);
if (use_svm) {
blob_addr = addr;
blob_size = size;
} else {
blob_addr = (void*)((uint64_t)addr - addr_offset);
blob_size = VHSA_ALIGN_UP(size + addr_offset, page_size);
}
r = vhsakmt_init_userptr_blob(dev, addr, size, &userptr, &userptr_offset);
vhsa_debug("%s: addr: %p, size: 0x%lx, offset: 0x%lx, blob_addr: %p, blob_size: 0x%lx, svm: %d\n",
__FUNCTION__, addr, size, addr_offset, blob_addr, blob_size, use_svm);
r = vhsakmt_init_userptr_blob(dev, blob_addr, blob_size, &userptr, &userptr_offset);
if (r < 0) {
vhsa_debug("%s: userptr create failed at address: %p, ret = %d\n", __FUNCTION__, addr, r);
return NULL;
}
vhsakmt_map_userptr(dev, addr, size, userptr->real.res_id, &userptr_handle);
r = vhsakmt_map_userptr(dev, addr, size, userptr->real.res_id, &userptr_handle);
if (!userptr_handle) {
vhsa_debug("%s: map userptr failed at address: %p, ret = %d\n", __FUNCTION__, addr, r);
vhsakmt_destroy_handle(dev, userptr);
vhsakmt_remove_userptr_bo(dev, userptr);
return NULL;
}
userptr->host_addr = VHSA_UINT64_TO_VPTR(VHSA_VPTR_TO_UINT64(userptr_handle) + offset);
if (use_svm) {
userptr->host_addr = VHSA_UINT64_TO_VPTR(VHSA_VPTR_TO_UINT64(userptr_handle) + addr_offset);
} else {
userptr->host_addr = VHSA_UINT64_TO_VPTR(userptr_handle);
}
if (r > 0) {
vhsa_debug("%s: userptr: %p already registered, offset: %lx\n", __FUNCTION__, addr,
@@ -574,12 +656,17 @@ static void* vhsakmt_map_to_gpu(void* addr, size_t size) {
userptr->host_addr =
VHSA_UINT64_TO_VPTR(VHSA_VPTR_TO_UINT64(userptr->host_addr) + userptr_offset);
}
vhsakmt_insert_bo(dev, userptr, userptr->cpu_addr, userptr->size);
vhsa_debug("%s: real gva: %p, gva: %p, hva: %p, size: %lx, offset: %" PRIu64
", map_size: 0x%lx\n",
__FUNCTION__, addr, userptr->cpu_addr, userptr->host_addr, size, offset, map_size);
return userptr->host_addr;
if (use_svm) {
vhsakmt_insert_bo(dev, userptr, userptr->cpu_addr, userptr->size);
} else {
vhsakmt_insert_userptr(dev, userptr);
}
vhsa_debug("%s: gva: %p, cpu_addr: %p, hva: %p, size: %lx, offset: %lx, blob_size: 0x%lx\n",
__FUNCTION__, addr, userptr->cpu_addr, userptr->host_addr, size, addr_offset,
blob_size);
return userptr;
}
HSAKMT_STATUS HSAKMTAPI vhsaKmtRegisterMemoryWithFlags(void* MemoryAddress,
@@ -589,7 +676,7 @@ HSAKMT_STATUS HSAKMTAPI vhsaKmtRegisterMemoryWithFlags(void* MemoryAddress,
vhsakmt_device_handle dev = vhsakmt_dev();
struct vhsakmt_ccmd_memory_rsp* rsp;
void* addr;
vhsakmt_bo_handle userptr;
struct vhsakmt_ccmd_memory_req req = {
.hdr = VHSAKMT_CCMD(MEMORY, sizeof(struct vhsakmt_ccmd_memory_req)),
.type = VHSAKMT_CCMD_MEMORY_REG_MEM_WITH_FLAG,
@@ -603,14 +690,30 @@ HSAKMT_STATUS HSAKMTAPI vhsaKmtRegisterMemoryWithFlags(void* MemoryAddress,
/* no need to register memory from lihsakmt / not a userptr */
if (!vhsakmt_is_userptr(dev, MemoryAddress)) return HSAKMT_STATUS_SUCCESS;
addr = vhsakmt_map_to_gpu(MemoryAddress, MemorySizeInBytes);
if (!addr) {
vhsa_debug("%s: register memory failed, gva: %p, size: %lx\n", __FUNCTION__, MemoryAddress,
MemorySizeInBytes);
if (!dev->use_svm) {
vhsakmt_bo_handle bo = vhsakmt_find_userptr(dev, (uint64_t)MemoryAddress,
(uint64_t)MemoryAddress + MemorySizeInBytes - 1UL);
if (bo) {
vhsa_debug(
"%s: memory already registered, MemoryAddress:%p, bo address: %p, size: %x, "
"res_id: %d, count: %d\n",
__FUNCTION__, MemoryAddress, bo->cpu_addr, bo->size, bo->real.res_id, bo->refcount);
(void)vhsakmt_atomic_inc_return(&bo->refcount);
return HSAKMT_STATUS_SUCCESS;
}
}
userptr = vhsakmt_map_to_gpu(MemoryAddress, MemorySizeInBytes, dev->use_svm);
if (!userptr) {
vhsa_debug(
"%s: register memory failed at address: %p, size: %lx (vhsakmt_map_to_gpu returned %p)\n",
__FUNCTION__, MemoryAddress, MemorySizeInBytes, userptr);
return HSAKMT_STATUS_ERROR;
}
req.reg_mem_with_flag.MemoryAddress = (uint64_t)addr;
req.reg_mem_with_flag.MemoryAddress = (uint64_t)userptr->host_addr;
req.res_id = userptr->real.res_id;
rsp = vhsakmt_alloc_rsp(dev, &req.hdr, sizeof(struct vhsakmt_ccmd_memory_rsp));
if (!rsp) return -ENOMEM;
@@ -640,21 +743,62 @@ static int vhsakmt_remove_clgl_bo(vhsakmt_device_handle dev, vhsakmt_bo_handle b
return rsp->ret;
}
static int vhsakmt_deregister_userptr_non_svm(vhsakmt_device_handle dev, void* MemoryAddress) {
size_t page_size = getpagesize();
unsigned long aligned_addr = ((uint64_t)MemoryAddress / page_size) * page_size;
interval_tree_node_t* n;
pthread_mutex_lock(&dev->bo_handles_mutex);
/* First pass: Decrement refcounts and check if all can be freed */
bool can_free_all = true;
n = hsakmt_interval_tree_iter_first(&dev->userptr_tree, aligned_addr, aligned_addr);
while (n) {
vhsakmt_bo_handle bo = (vhsakmt_bo_handle)((char*)n - offsetof(struct vhsakmt_bo, itn));
if (bo->cpu_addr == (void*)aligned_addr) {
vhsa_debug("%s: found userptr: %p, size: %x, res_id: %d, count: %d\n", __FUNCTION__,
bo->cpu_addr, bo->size, bo->real.res_id, bo->refcount);
if (vhsakmt_atomic_dec_return(&bo->refcount) > 0) {
can_free_all = false;
}
}
n = hsakmt_interval_tree_iter_next(&dev->userptr_tree, n, aligned_addr, aligned_addr);
}
/* Second pass: Free all userptrs if all refcounts are <= 0 */
if (can_free_all) {
n = hsakmt_interval_tree_iter_first(&dev->userptr_tree, aligned_addr, aligned_addr);
while (n) {
vhsakmt_bo_handle bo = (vhsakmt_bo_handle)((char*)n - offsetof(struct vhsakmt_bo, itn));
interval_tree_node_t* next =
hsakmt_interval_tree_iter_next(&dev->userptr_tree, n, aligned_addr, aligned_addr);
if (bo->cpu_addr == (void*)aligned_addr) {
vhsa_debug("%s: destroying userptr: %p, size: %x, res_id: %d\n", __FUNCTION__, bo->cpu_addr,
bo->size, bo->real.res_id);
vhsakmt_destroy_userptr(dev, bo);
}
n = next;
}
}
pthread_mutex_unlock(&dev->bo_handles_mutex);
return 0;
}
HSAKMT_STATUS HSAKMTAPI vhsaKmtDeregisterMemory(void* MemoryAddress) {
CHECK_VIRTIO_KFD_OPEN();
vhsakmt_device_handle dev = vhsakmt_dev();
vhsakmt_bo_handle bo = vhsakmt_find_bo_by_addr(dev, MemoryAddress);
if (!bo) return HSAKMT_STATUS_SUCCESS;
vhsa_debug("%s: remove userptr %p size: 0x%lx, res id: %d\n", __FUNCTION__, MemoryAddress,
(size_t)bo->size, bo->real.res_id);
if (bo && (bo->bo_type & VHSA_BO_CLGL)) return vhsakmt_remove_clgl_bo(dev, bo);
if (bo->bo_type & VHSA_BO_CLGL)
return vhsakmt_remove_clgl_bo(dev, bo);
else {
vhsakmt_remove_bo(dev, bo);
free(bo);
if (!dev->use_svm) {
return vhsakmt_deregister_userptr_non_svm(dev, MemoryAddress);
}
return 0;
@@ -85,12 +85,15 @@ static vhsakmt_device_handle vhsakmt_device_init(void) {
if (!dev->vgdev) goto malloc_failed;
rbtree_init(&dev->bo_rbt);
interval_tree_init(&dev->userptr_tree);
atomic_store(&dev->next_blob_id, 1);
atomic_store(&dev->refcount, 1);
pthread_mutex_init(&dev->bo_handles_mutex, NULL);
pthread_mutex_init(&dev->vhsakmt_mutex, NULL);
dev_list = dev;
dev->use_svm = false;
pthread_mutex_unlock(&dev_mutex);
return dev;
@@ -102,14 +105,24 @@ open_failed:
return dev;
}
static void vhsakmt_init_vars_from_env(void) {
char* env_val = NULL;
env_val = getenv("VHSAKMT_USE_SVM");
if (env_val && atoi(env_val)) vhsakmt_dev()->use_svm = true;
env_val = getenv("VHSAKMT_DEBUG_LEVEL");
if (env_val) vhsakmt_debug_level = atoi(env_val);
}
HSAKMT_STATUS HSAKMTAPI vhsaKmtOpenKFD(void) {
vhsakmt_device_handle dev;
char* d = getenv("VHSAKMT_DEBUG_LEVEL");
if (d) vhsakmt_debug_level = atoi(d);
dev = vhsakmt_device_init();
if (!dev) return HSAKMT_STATUS_ERROR;
vhsakmt_init_vars_from_env();
return vhsakmt_openKFD_cmd(vhsakmt_dev());
}
@@ -24,6 +24,10 @@
#define VHSAKMT_VIRTIO_PROTO_H
#include "hsakmt/linux/kfd_ioctl.h"
// Forward declaration for HsaKFDContext to avoid dependency issues
typedef struct _HsaKFDContext HsaKFDContext;
#include "hsakmt/hsakmt.h"
#include <drm/amdgpu_drm.h>
@@ -30,7 +30,7 @@
#include "virtio_gpu.h"
#define SHMEM_SZ (25 * 0x1000)
#define SHMEM_SZ (80 * 0x1000)
static int set_context(int fd) {
struct drm_virtgpu_context_set_param params[] = {