SWDEV-420016 - Add more driver side graph APIs

Signed-off-by: shadi <shadi.dashmiz@amd.com>
Change-Id: Iff3ee7dcbcd24836f227fdc9bd5ff4b554ac914f


[ROCm/clr commit: f2b01782ac]
This commit is contained in:
shadi
2024-01-15 14:34:56 -05:00
committed by Saleel Kudchadker
parent 08e7942cf8
commit 15b8dc404c
6 changed files with 161 additions and 21 deletions
@@ -954,6 +954,18 @@ typedef hipError_t (*t_hipStreamBeginCaptureToGraph)(hipStream_t stream, hipGrap
hipStreamCaptureMode mode);
typedef hipError_t (*t_hipGetFuncBySymbol)(hipFunction_t* functionPtr, const void* symbolPtr);
typedef hipError_t (*t_hipDrvGraphAddMemFreeNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies, size_t numDependencies,
hipDeviceptr_t dptr);
typedef hipError_t (*t_hipDrvGraphExecMemcpyNodeSetParams)(hipGraphExec_t hGraphExec,
hipGraphNode_t hNode, const HIP_MEMCPY3D* copyParams,
hipCtx_t ctx);
typedef hipError_t (*t_hipDrvGraphExecMemsetNodeSetParams)(hipGraphExec_t hGraphExec,
hipGraphNode_t hNode, const HIP_MEMSET_NODE_PARAMS* memsetParams,
hipCtx_t ctx);
// HIP Compiler dispatch table
struct HipCompilerDispatchTable {
size_t size;
@@ -1420,4 +1432,7 @@ struct HipDispatchTable {
t_hipGetProcAddress hipGetProcAddress_fn;
t_hipStreamBeginCaptureToGraph hipStreamBeginCaptureToGraph_fn;
t_hipGetFuncBySymbol hipGetFuncBySymbol_fn;
t_hipDrvGraphAddMemFreeNode hipDrvGraphAddMemFreeNode_fn;
t_hipDrvGraphExecMemcpyNodeSetParams hipDrvGraphExecMemcpyNodeSetParams_fn;
t_hipDrvGraphExecMemsetNodeSetParams hipDrvGraphExecMemsetNodeSetParams_fn;
};
+3
View File
@@ -463,3 +463,6 @@ hipGraphAddNode
hipGraphInstantiateWithParams
hipStreamBeginCaptureToGraph
hipGetFuncBySymbol
hipDrvGraphAddMemFreeNode
hipDrvGraphExecMemcpyNodeSetParams
hipDrvGraphExecMemsetNodeSetParams
+14 -2
View File
@@ -768,6 +768,13 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph,
const hipGraphEdgeData* dependencyData,
size_t numDependencies, hipStreamCaptureMode mode);
hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr);
hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies, size_t numDependencies,
hipDeviceptr_t dptr);
hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const HIP_MEMCPY3D* copyParams, hipCtx_t ctx);
hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx);
} // namespace hip
namespace hip {
@@ -1244,6 +1251,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
ptrDispatchTable->hipGetProcAddress_fn = hip::hipGetProcAddress;
ptrDispatchTable->hipStreamBeginCaptureToGraph_fn = hip::hipStreamBeginCaptureToGraph;
ptrDispatchTable->hipGetFuncBySymbol_fn = hip::hipGetFuncBySymbol;
ptrDispatchTable->hipDrvGraphAddMemFreeNode_fn = hip::hipDrvGraphAddMemFreeNode;
ptrDispatchTable->hipDrvGraphExecMemcpyNodeSetParams_fn = hip::hipDrvGraphExecMemcpyNodeSetParams;
ptrDispatchTable->hipDrvGraphExecMemsetNodeSetParams_fn = hip::hipDrvGraphExecMemsetNodeSetParams;
}
#if HIP_ROCPROFILER_REGISTER > 0
@@ -1806,7 +1816,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipTexRefGetArray_fn, 441)
HIP_ENFORCE_ABI(HipDispatchTable, hipGetProcAddress_fn, 442)
HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBeginCaptureToGraph_fn, 443);
HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444);
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphAddMemFreeNode_fn, 445)
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemcpyNodeSetParams_fn, 446)
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemsetNodeSetParams_fn, 447)
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
@@ -1814,7 +1826,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444);
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
//
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 445)
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 448)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
+109 -19
View File
@@ -2501,6 +2501,28 @@ hipError_t hipGraphMemAllocNodeGetParams(hipGraphNode_t node, hipMemAllocNodePar
HIP_RETURN(hipSuccess);
}
hipError_t ihipGraphAddMemFreeNode(hip::GraphNode** graphNode, hip::Graph* graph,
hip::GraphNode* const* pDependencies, size_t numDependencies,
void* dptr) {
// Is memory passed to be free'd valid
size_t offset = 0;
auto memory = getMemoryObject(dptr, offset);
if (memory == nullptr) {
if (HIP_MEM_POOL_USE_VM) {
// When VM is on the address must be valid and may point to a VA object
memory = amd::MemObjMap::FindVirtualMemObj(dptr);
}
if (memory == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
}
auto mem_free_node = new hip::GraphMemFreeNode(dptr);
*graphNode = mem_free_node;
auto status =
ihipGraphAddNode(*graphNode, graph, pDependencies, numDependencies);
HIP_RETURN(status);
}
// ================================================================================================
hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
const hipGraphNode_t* pDependencies, size_t numDependencies,
@@ -2512,26 +2534,12 @@ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
dev_ptr == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
// Is memory passed to be free'd valid
size_t offset = 0;
auto memory = getMemoryObject(dev_ptr, offset);
if (memory == nullptr) {
if (HIP_MEM_POOL_USE_VM) {
// When VM is on the address must be valid and may point to a VA object
memory = amd::MemObjMap::FindVirtualMemObj(dev_ptr);
}
if (memory == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
}
auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr);
hip::GraphNode* node = mem_free_node;
hip::GraphNode* pNode;
auto status =
ihipGraphAddNode(node, reinterpret_cast<hip::Graph*>(graph),
reinterpret_cast<hip::GraphNode* const*>(pDependencies), numDependencies);
*pGraphNode = reinterpret_cast<hipGraphNode_t>(node);
ihipGraphAddMemFreeNode(&pNode,
reinterpret_cast<hip::Graph*>(graph),
reinterpret_cast<hip::GraphNode* const*>(pDependencies), numDependencies, dev_ptr);
*pGraphNode = reinterpret_cast<hipGraphNode_t>(pNode);
HIP_RETURN(status);
}
@@ -3036,4 +3044,86 @@ hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams(hipGraphExec_t hGraph
nodeParams));
}
hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies, size_t numDependencies,
hipDeviceptr_t dptr) {
HIP_INIT_API(hipDrvGraphAddMemFreeNode, phGraphNode, hGraph, dependencies, numDependencies, dptr);
if (phGraphNode == nullptr || hGraph == nullptr ||
((numDependencies > 0 && dependencies == nullptr) ||
(dependencies != nullptr && numDependencies == 0)) ||
dptr == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
// Is memory passed to be free'd valid
size_t offset = 0;
auto memory = getMemoryObject(dptr, offset);
if (memory == nullptr) {
if (HIP_MEM_POOL_USE_VM) {
// When VM is on the address must be valid and may point to a VA object
memory = amd::MemObjMap::FindVirtualMemObj(dptr);
}
if (memory == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
}
hip::GraphNode* pNode;
auto status =
ihipGraphAddMemFreeNode(&pNode,
reinterpret_cast<hip::Graph*>(hGraph),
reinterpret_cast<hip::GraphNode* const*>(dependencies), numDependencies, dptr);
*phGraphNode = reinterpret_cast<hipGraphNode_t>(pNode);
HIP_RETURN(status);
}
hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) {
HIP_INIT_API(hipDrvGraphExecMemcpyNodeSetParams, hGraphExec, hNode, copyParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
if (hGraphExec == nullptr ||
!hip::GraphNode::isNodeValid(reinterpret_cast<hip::GraphNode*>(n))) {
HIP_RETURN(hipErrorInvalidValue);
}
if (ihipDrvMemcpy3D_validate(copyParams) != hipSuccess) {
HIP_RETURN(hipErrorInvalidValue);
}
// Check if pNodeParams passed is a empty struct
if (((copyParams->srcArray == 0) && (copyParams->srcHost == nullptr)
&& (copyParams->srcDevice == nullptr)) ||
((copyParams->dstArray == 0) && (copyParams->dstHost == nullptr)
&& (copyParams->dstDevice == nullptr))) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hip::GraphDrvMemcpyNode*>(clonedNode)->SetParams(copyParams));
}
hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) {
HIP_INIT_API(hipDrvGraphExecMemsetNodeSetParams, hGraphExec, hNode, memsetParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || memsetParams == nullptr ||
memsetParams->dst == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hipMemsetParams pmemsetParams;
pmemsetParams.dst = reinterpret_cast<void*>(memsetParams->dst);
pmemsetParams.elementSize = memsetParams->elementSize;
pmemsetParams.height = memsetParams->height;
pmemsetParams.pitch = memsetParams->pitch;
pmemsetParams.value = memsetParams->value;
pmemsetParams.width = memsetParams->width;
if (ihipGraphMemsetParams_validate(&pmemsetParams) != hipSuccess) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hip::GraphMemsetNode*>(clonedNode)->SetParams(memsetParams, true));
}
} // namespace hip
+3
View File
@@ -560,6 +560,9 @@ local:
hip_6.2 {
global:
hipGetFuncBySymbol;
hipDrvGraphExecMemcpyNodeSetParams;
hipDrvGraphExecMemsetNodeSetParams;
hipDrvGraphAddMemFreeNode;
local:
*;
} hip_6.1;
@@ -1743,3 +1743,20 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph,
hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr) {
return hip::GetHipDispatchTable()->hipGetFuncBySymbol_fn(functionPtr, symbolPtr);
}
hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) {
return hip::GetHipDispatchTable()->hipDrvGraphExecMemsetNodeSetParams_fn(hGraphExec, hNode,
memsetParams, ctx);
}
hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
const hipGraphNode_t* dependencies, size_t numDependencies,
hipDeviceptr_t dptr) {
return hip::GetHipDispatchTable()->hipDrvGraphAddMemFreeNode_fn(phGraphNode, hGraph,
dependencies, numDependencies,
dptr);
}
hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) {
return hip::GetHipDispatchTable()->hipDrvGraphExecMemcpyNodeSetParams_fn(hGraphExec, hNode,
copyParams, ctx);
}