SWDEV-420016 - Add more driver side graph APIs
Signed-off-by: shadi <shadi.dashmiz@amd.com> Change-Id: Iff3ee7dcbcd24836f227fdc9bd5ff4b554ac914f
Этот коммит содержится в:
коммит произвёл
Saleel Kudchadker
родитель
9fdddb7c5d
Коммит
f2b01782ac
@@ -954,6 +954,18 @@ typedef hipError_t (*t_hipStreamBeginCaptureToGraph)(hipStream_t stream, hipGrap
|
||||
hipStreamCaptureMode mode);
|
||||
typedef hipError_t (*t_hipGetFuncBySymbol)(hipFunction_t* functionPtr, const void* symbolPtr);
|
||||
|
||||
typedef hipError_t (*t_hipDrvGraphAddMemFreeNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
|
||||
const hipGraphNode_t* dependencies, size_t numDependencies,
|
||||
hipDeviceptr_t dptr);
|
||||
|
||||
typedef hipError_t (*t_hipDrvGraphExecMemcpyNodeSetParams)(hipGraphExec_t hGraphExec,
|
||||
hipGraphNode_t hNode, const HIP_MEMCPY3D* copyParams,
|
||||
hipCtx_t ctx);
|
||||
|
||||
typedef hipError_t (*t_hipDrvGraphExecMemsetNodeSetParams)(hipGraphExec_t hGraphExec,
|
||||
hipGraphNode_t hNode, const HIP_MEMSET_NODE_PARAMS* memsetParams,
|
||||
hipCtx_t ctx);
|
||||
|
||||
// HIP Compiler dispatch table
|
||||
struct HipCompilerDispatchTable {
|
||||
size_t size;
|
||||
@@ -1420,4 +1432,7 @@ struct HipDispatchTable {
|
||||
t_hipGetProcAddress hipGetProcAddress_fn;
|
||||
t_hipStreamBeginCaptureToGraph hipStreamBeginCaptureToGraph_fn;
|
||||
t_hipGetFuncBySymbol hipGetFuncBySymbol_fn;
|
||||
t_hipDrvGraphAddMemFreeNode hipDrvGraphAddMemFreeNode_fn;
|
||||
t_hipDrvGraphExecMemcpyNodeSetParams hipDrvGraphExecMemcpyNodeSetParams_fn;
|
||||
t_hipDrvGraphExecMemsetNodeSetParams hipDrvGraphExecMemsetNodeSetParams_fn;
|
||||
};
|
||||
|
||||
@@ -463,3 +463,6 @@ hipGraphAddNode
|
||||
hipGraphInstantiateWithParams
|
||||
hipStreamBeginCaptureToGraph
|
||||
hipGetFuncBySymbol
|
||||
hipDrvGraphAddMemFreeNode
|
||||
hipDrvGraphExecMemcpyNodeSetParams
|
||||
hipDrvGraphExecMemsetNodeSetParams
|
||||
|
||||
@@ -768,6 +768,13 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph,
|
||||
const hipGraphEdgeData* dependencyData,
|
||||
size_t numDependencies, hipStreamCaptureMode mode);
|
||||
hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr);
|
||||
hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
|
||||
const hipGraphNode_t* dependencies, size_t numDependencies,
|
||||
hipDeviceptr_t dptr);
|
||||
hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
|
||||
const HIP_MEMCPY3D* copyParams, hipCtx_t ctx);
|
||||
hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
|
||||
const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx);
|
||||
} // namespace hip
|
||||
|
||||
namespace hip {
|
||||
@@ -1244,6 +1251,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
|
||||
ptrDispatchTable->hipGetProcAddress_fn = hip::hipGetProcAddress;
|
||||
ptrDispatchTable->hipStreamBeginCaptureToGraph_fn = hip::hipStreamBeginCaptureToGraph;
|
||||
ptrDispatchTable->hipGetFuncBySymbol_fn = hip::hipGetFuncBySymbol;
|
||||
ptrDispatchTable->hipDrvGraphAddMemFreeNode_fn = hip::hipDrvGraphAddMemFreeNode;
|
||||
ptrDispatchTable->hipDrvGraphExecMemcpyNodeSetParams_fn = hip::hipDrvGraphExecMemcpyNodeSetParams;
|
||||
ptrDispatchTable->hipDrvGraphExecMemsetNodeSetParams_fn = hip::hipDrvGraphExecMemsetNodeSetParams;
|
||||
}
|
||||
|
||||
#if HIP_ROCPROFILER_REGISTER > 0
|
||||
@@ -1806,7 +1816,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipTexRefGetArray_fn, 441)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipGetProcAddress_fn, 442)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBeginCaptureToGraph_fn, 443);
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444);
|
||||
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphAddMemFreeNode_fn, 445)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemcpyNodeSetParams_fn, 446)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemsetNodeSetParams_fn, 447)
|
||||
|
||||
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
|
||||
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
|
||||
@@ -1814,7 +1826,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444);
|
||||
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
|
||||
//
|
||||
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
|
||||
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 445)
|
||||
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 448)
|
||||
|
||||
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3,
|
||||
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
|
||||
|
||||
+109
-19
@@ -2501,6 +2501,28 @@ hipError_t hipGraphMemAllocNodeGetParams(hipGraphNode_t node, hipMemAllocNodePar
|
||||
HIP_RETURN(hipSuccess);
|
||||
}
|
||||
|
||||
hipError_t ihipGraphAddMemFreeNode(hip::GraphNode** graphNode, hip::Graph* graph,
|
||||
hip::GraphNode* const* pDependencies, size_t numDependencies,
|
||||
void* dptr) {
|
||||
// Is memory passed to be free'd valid
|
||||
size_t offset = 0;
|
||||
auto memory = getMemoryObject(dptr, offset);
|
||||
if (memory == nullptr) {
|
||||
if (HIP_MEM_POOL_USE_VM) {
|
||||
// When VM is on the address must be valid and may point to a VA object
|
||||
memory = amd::MemObjMap::FindVirtualMemObj(dptr);
|
||||
}
|
||||
if (memory == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
}
|
||||
|
||||
auto mem_free_node = new hip::GraphMemFreeNode(dptr);
|
||||
*graphNode = mem_free_node;
|
||||
auto status =
|
||||
ihipGraphAddNode(*graphNode, graph, pDependencies, numDependencies);
|
||||
HIP_RETURN(status);
|
||||
}
|
||||
// ================================================================================================
|
||||
hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
|
||||
const hipGraphNode_t* pDependencies, size_t numDependencies,
|
||||
@@ -2512,26 +2534,12 @@ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
|
||||
dev_ptr == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
// Is memory passed to be free'd valid
|
||||
size_t offset = 0;
|
||||
auto memory = getMemoryObject(dev_ptr, offset);
|
||||
if (memory == nullptr) {
|
||||
if (HIP_MEM_POOL_USE_VM) {
|
||||
// When VM is on the address must be valid and may point to a VA object
|
||||
memory = amd::MemObjMap::FindVirtualMemObj(dev_ptr);
|
||||
}
|
||||
if (memory == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
}
|
||||
|
||||
auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr);
|
||||
hip::GraphNode* node = mem_free_node;
|
||||
hip::GraphNode* pNode;
|
||||
auto status =
|
||||
ihipGraphAddNode(node, reinterpret_cast<hip::Graph*>(graph),
|
||||
reinterpret_cast<hip::GraphNode* const*>(pDependencies), numDependencies);
|
||||
*pGraphNode = reinterpret_cast<hipGraphNode_t>(node);
|
||||
ihipGraphAddMemFreeNode(&pNode,
|
||||
reinterpret_cast<hip::Graph*>(graph),
|
||||
reinterpret_cast<hip::GraphNode* const*>(pDependencies), numDependencies, dev_ptr);
|
||||
*pGraphNode = reinterpret_cast<hipGraphNode_t>(pNode);
|
||||
HIP_RETURN(status);
|
||||
}
|
||||
|
||||
@@ -3036,4 +3044,86 @@ hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams(hipGraphExec_t hGraph
|
||||
nodeParams));
|
||||
}
|
||||
|
||||
hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
|
||||
const hipGraphNode_t* dependencies, size_t numDependencies,
|
||||
hipDeviceptr_t dptr) {
|
||||
HIP_INIT_API(hipDrvGraphAddMemFreeNode, phGraphNode, hGraph, dependencies, numDependencies, dptr);
|
||||
if (phGraphNode == nullptr || hGraph == nullptr ||
|
||||
((numDependencies > 0 && dependencies == nullptr) ||
|
||||
(dependencies != nullptr && numDependencies == 0)) ||
|
||||
dptr == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
// Is memory passed to be free'd valid
|
||||
size_t offset = 0;
|
||||
auto memory = getMemoryObject(dptr, offset);
|
||||
if (memory == nullptr) {
|
||||
if (HIP_MEM_POOL_USE_VM) {
|
||||
// When VM is on the address must be valid and may point to a VA object
|
||||
memory = amd::MemObjMap::FindVirtualMemObj(dptr);
|
||||
}
|
||||
if (memory == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
}
|
||||
hip::GraphNode* pNode;
|
||||
auto status =
|
||||
ihipGraphAddMemFreeNode(&pNode,
|
||||
reinterpret_cast<hip::Graph*>(hGraph),
|
||||
reinterpret_cast<hip::GraphNode* const*>(dependencies), numDependencies, dptr);
|
||||
*phGraphNode = reinterpret_cast<hipGraphNode_t>(pNode);
|
||||
HIP_RETURN(status);
|
||||
}
|
||||
|
||||
hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
|
||||
const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) {
|
||||
HIP_INIT_API(hipDrvGraphExecMemcpyNodeSetParams, hGraphExec, hNode, copyParams);
|
||||
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
|
||||
if (hGraphExec == nullptr ||
|
||||
!hip::GraphNode::isNodeValid(reinterpret_cast<hip::GraphNode*>(n))) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
if (ihipDrvMemcpy3D_validate(copyParams) != hipSuccess) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
// Check if pNodeParams passed is a empty struct
|
||||
if (((copyParams->srcArray == 0) && (copyParams->srcHost == nullptr)
|
||||
&& (copyParams->srcDevice == nullptr)) ||
|
||||
((copyParams->dstArray == 0) && (copyParams->dstHost == nullptr)
|
||||
&& (copyParams->dstDevice == nullptr))) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphDrvMemcpyNode*>(clonedNode)->SetParams(copyParams));
|
||||
}
|
||||
|
||||
hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
|
||||
const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) {
|
||||
HIP_INIT_API(hipDrvGraphExecMemsetNodeSetParams, hGraphExec, hNode, memsetParams);
|
||||
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
|
||||
|
||||
if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || memsetParams == nullptr ||
|
||||
memsetParams->dst == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hipMemsetParams pmemsetParams;
|
||||
pmemsetParams.dst = reinterpret_cast<void*>(memsetParams->dst);
|
||||
pmemsetParams.elementSize = memsetParams->elementSize;
|
||||
pmemsetParams.height = memsetParams->height;
|
||||
pmemsetParams.pitch = memsetParams->pitch;
|
||||
pmemsetParams.value = memsetParams->value;
|
||||
pmemsetParams.width = memsetParams->width;
|
||||
if (ihipGraphMemsetParams_validate(&pmemsetParams) != hipSuccess) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphMemsetNode*>(clonedNode)->SetParams(memsetParams, true));
|
||||
}
|
||||
|
||||
} // namespace hip
|
||||
|
||||
@@ -560,6 +560,9 @@ local:
|
||||
hip_6.2 {
|
||||
global:
|
||||
hipGetFuncBySymbol;
|
||||
hipDrvGraphExecMemcpyNodeSetParams;
|
||||
hipDrvGraphExecMemsetNodeSetParams;
|
||||
hipDrvGraphAddMemFreeNode;
|
||||
local:
|
||||
*;
|
||||
} hip_6.1;
|
||||
|
||||
@@ -1743,3 +1743,20 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph,
|
||||
hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr) {
|
||||
return hip::GetHipDispatchTable()->hipGetFuncBySymbol_fn(functionPtr, symbolPtr);
|
||||
}
|
||||
hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
|
||||
const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) {
|
||||
return hip::GetHipDispatchTable()->hipDrvGraphExecMemsetNodeSetParams_fn(hGraphExec, hNode,
|
||||
memsetParams, ctx);
|
||||
}
|
||||
hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
|
||||
const hipGraphNode_t* dependencies, size_t numDependencies,
|
||||
hipDeviceptr_t dptr) {
|
||||
return hip::GetHipDispatchTable()->hipDrvGraphAddMemFreeNode_fn(phGraphNode, hGraph,
|
||||
dependencies, numDependencies,
|
||||
dptr);
|
||||
}
|
||||
hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
|
||||
const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) {
|
||||
return hip::GetHipDispatchTable()->hipDrvGraphExecMemcpyNodeSetParams_fn(hGraphExec, hNode,
|
||||
copyParams, ctx);
|
||||
}
|
||||
|
||||
Ссылка в новой задаче
Block a user