From 15b8dc404c081a90dc571cc9149d52415e804bb9 Mon Sep 17 00:00:00 2001 From: shadi Date: Mon, 15 Jan 2024 14:34:56 -0500 Subject: [PATCH] SWDEV-420016 - Add more driver side graph APIs Signed-off-by: shadi Change-Id: Iff3ee7dcbcd24836f227fdc9bd5ff4b554ac914f [ROCm/clr commit: f2b01782ace57e8ab9819df369bac91dea92ae4c] --- .../include/hip/amd_detail/hip_api_trace.hpp | 15 ++ projects/clr/hipamd/src/amdhip.def | 3 + projects/clr/hipamd/src/hip_api_trace.cpp | 16 ++- projects/clr/hipamd/src/hip_graph.cpp | 128 +++++++++++++++--- projects/clr/hipamd/src/hip_hcc.map.in | 3 + .../clr/hipamd/src/hip_table_interface.cpp | 17 +++ 6 files changed, 161 insertions(+), 21 deletions(-) diff --git a/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp b/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp index 96031d3b48..957ea75627 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp +++ b/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp @@ -954,6 +954,18 @@ typedef hipError_t (*t_hipStreamBeginCaptureToGraph)(hipStream_t stream, hipGrap hipStreamCaptureMode mode); typedef hipError_t (*t_hipGetFuncBySymbol)(hipFunction_t* functionPtr, const void* symbolPtr); +typedef hipError_t (*t_hipDrvGraphAddMemFreeNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, size_t numDependencies, + hipDeviceptr_t dptr); + +typedef hipError_t (*t_hipDrvGraphExecMemcpyNodeSetParams)(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, const HIP_MEMCPY3D* copyParams, + hipCtx_t ctx); + +typedef hipError_t (*t_hipDrvGraphExecMemsetNodeSetParams)(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, const HIP_MEMSET_NODE_PARAMS* memsetParams, + hipCtx_t ctx); + // HIP Compiler dispatch table struct HipCompilerDispatchTable { size_t size; @@ -1420,4 +1432,7 @@ struct HipDispatchTable { t_hipGetProcAddress hipGetProcAddress_fn; t_hipStreamBeginCaptureToGraph hipStreamBeginCaptureToGraph_fn; t_hipGetFuncBySymbol hipGetFuncBySymbol_fn; + t_hipDrvGraphAddMemFreeNode hipDrvGraphAddMemFreeNode_fn; + t_hipDrvGraphExecMemcpyNodeSetParams hipDrvGraphExecMemcpyNodeSetParams_fn; + t_hipDrvGraphExecMemsetNodeSetParams hipDrvGraphExecMemsetNodeSetParams_fn; }; diff --git a/projects/clr/hipamd/src/amdhip.def b/projects/clr/hipamd/src/amdhip.def index cfc3ae6b5d..c1017af7af 100644 --- a/projects/clr/hipamd/src/amdhip.def +++ b/projects/clr/hipamd/src/amdhip.def @@ -463,3 +463,6 @@ hipGraphAddNode hipGraphInstantiateWithParams hipStreamBeginCaptureToGraph hipGetFuncBySymbol +hipDrvGraphAddMemFreeNode +hipDrvGraphExecMemcpyNodeSetParams +hipDrvGraphExecMemsetNodeSetParams diff --git a/projects/clr/hipamd/src/hip_api_trace.cpp b/projects/clr/hipamd/src/hip_api_trace.cpp index 1479ae1507..05598dd7f1 100644 --- a/projects/clr/hipamd/src/hip_api_trace.cpp +++ b/projects/clr/hipamd/src/hip_api_trace.cpp @@ -768,6 +768,13 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph, const hipGraphEdgeData* dependencyData, size_t numDependencies, hipStreamCaptureMode mode); hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr); +hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, size_t numDependencies, + hipDeviceptr_t dptr); +hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const HIP_MEMCPY3D* copyParams, hipCtx_t ctx); +hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx); } // namespace hip namespace hip { @@ -1244,6 +1251,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) { ptrDispatchTable->hipGetProcAddress_fn = hip::hipGetProcAddress; ptrDispatchTable->hipStreamBeginCaptureToGraph_fn = hip::hipStreamBeginCaptureToGraph; ptrDispatchTable->hipGetFuncBySymbol_fn = hip::hipGetFuncBySymbol; + ptrDispatchTable->hipDrvGraphAddMemFreeNode_fn = hip::hipDrvGraphAddMemFreeNode; + ptrDispatchTable->hipDrvGraphExecMemcpyNodeSetParams_fn = hip::hipDrvGraphExecMemcpyNodeSetParams; + ptrDispatchTable->hipDrvGraphExecMemsetNodeSetParams_fn = hip::hipDrvGraphExecMemsetNodeSetParams; } #if HIP_ROCPROFILER_REGISTER > 0 @@ -1806,7 +1816,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipTexRefGetArray_fn, 441) HIP_ENFORCE_ABI(HipDispatchTable, hipGetProcAddress_fn, 442) HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBeginCaptureToGraph_fn, 443); HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444); - +HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphAddMemFreeNode_fn, 445) +HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemcpyNodeSetParams_fn, 446) +HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemsetNodeSetParams_fn, 447) // if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below // will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.: @@ -1814,7 +1826,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444); // HIP_ENFORCE_ABI(, , 8) // // HIP_ENFORCE_ABI_VERSIONING(
, 9) <- 8 + 1 = 9 -HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 445) +HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 448) static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3, "If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function " diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index bf13f3cafd..8440be6f93 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -2501,6 +2501,28 @@ hipError_t hipGraphMemAllocNodeGetParams(hipGraphNode_t node, hipMemAllocNodePar HIP_RETURN(hipSuccess); } +hipError_t ihipGraphAddMemFreeNode(hip::GraphNode** graphNode, hip::Graph* graph, + hip::GraphNode* const* pDependencies, size_t numDependencies, + void* dptr) { + // Is memory passed to be free'd valid + size_t offset = 0; + auto memory = getMemoryObject(dptr, offset); + if (memory == nullptr) { + if (HIP_MEM_POOL_USE_VM) { + // When VM is on the address must be valid and may point to a VA object + memory = amd::MemObjMap::FindVirtualMemObj(dptr); + } + if (memory == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + } + + auto mem_free_node = new hip::GraphMemFreeNode(dptr); + *graphNode = mem_free_node; + auto status = + ihipGraphAddNode(*graphNode, graph, pDependencies, numDependencies); + HIP_RETURN(status); +} // ================================================================================================ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, const hipGraphNode_t* pDependencies, size_t numDependencies, @@ -2512,26 +2534,12 @@ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, dev_ptr == nullptr) { HIP_RETURN(hipErrorInvalidValue); } - - // Is memory passed to be free'd valid - size_t offset = 0; - auto memory = getMemoryObject(dev_ptr, offset); - if (memory == nullptr) { - if (HIP_MEM_POOL_USE_VM) { - // When VM is on the address must be valid and may point to a VA object - memory = amd::MemObjMap::FindVirtualMemObj(dev_ptr); - } - if (memory == nullptr) { - HIP_RETURN(hipErrorInvalidValue); - } - } - - auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr); - hip::GraphNode* node = mem_free_node; + hip::GraphNode* pNode; auto status = - ihipGraphAddNode(node, reinterpret_cast(graph), - reinterpret_cast(pDependencies), numDependencies); - *pGraphNode = reinterpret_cast(node); + ihipGraphAddMemFreeNode(&pNode, + reinterpret_cast(graph), + reinterpret_cast(pDependencies), numDependencies, dev_ptr); + *pGraphNode = reinterpret_cast(pNode); HIP_RETURN(status); } @@ -3036,4 +3044,86 @@ hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams(hipGraphExec_t hGraph nodeParams)); } +hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, size_t numDependencies, + hipDeviceptr_t dptr) { + HIP_INIT_API(hipDrvGraphAddMemFreeNode, phGraphNode, hGraph, dependencies, numDependencies, dptr); + if (phGraphNode == nullptr || hGraph == nullptr || + ((numDependencies > 0 && dependencies == nullptr) || + (dependencies != nullptr && numDependencies == 0)) || + dptr == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + // Is memory passed to be free'd valid + size_t offset = 0; + auto memory = getMemoryObject(dptr, offset); + if (memory == nullptr) { + if (HIP_MEM_POOL_USE_VM) { + // When VM is on the address must be valid and may point to a VA object + memory = amd::MemObjMap::FindVirtualMemObj(dptr); + } + if (memory == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + } + hip::GraphNode* pNode; + auto status = + ihipGraphAddMemFreeNode(&pNode, + reinterpret_cast(hGraph), + reinterpret_cast(dependencies), numDependencies, dptr); + *phGraphNode = reinterpret_cast(pNode); + HIP_RETURN(status); +} + +hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) { + HIP_INIT_API(hipDrvGraphExecMemcpyNodeSetParams, hGraphExec, hNode, copyParams); + hip::GraphNode* n = reinterpret_cast(hNode); + if (hGraphExec == nullptr || + !hip::GraphNode::isNodeValid(reinterpret_cast(n))) { + HIP_RETURN(hipErrorInvalidValue); + } + if (ihipDrvMemcpy3D_validate(copyParams) != hipSuccess) { + HIP_RETURN(hipErrorInvalidValue); + } + // Check if pNodeParams passed is a empty struct + if (((copyParams->srcArray == 0) && (copyParams->srcHost == nullptr) + && (copyParams->srcDevice == nullptr)) || + ((copyParams->dstArray == 0) && (copyParams->dstHost == nullptr) + && (copyParams->dstDevice == nullptr))) { + HIP_RETURN(hipErrorInvalidValue); + } + hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n); + if (clonedNode == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(copyParams)); +} + +hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) { + HIP_INIT_API(hipDrvGraphExecMemsetNodeSetParams, hGraphExec, hNode, memsetParams); + hip::GraphNode* n = reinterpret_cast(hNode); + + if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || memsetParams == nullptr || + memsetParams->dst == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + hipMemsetParams pmemsetParams; + pmemsetParams.dst = reinterpret_cast(memsetParams->dst); + pmemsetParams.elementSize = memsetParams->elementSize; + pmemsetParams.height = memsetParams->height; + pmemsetParams.pitch = memsetParams->pitch; + pmemsetParams.value = memsetParams->value; + pmemsetParams.width = memsetParams->width; + if (ihipGraphMemsetParams_validate(&pmemsetParams) != hipSuccess) { + HIP_RETURN(hipErrorInvalidValue); + } + hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n); + if (clonedNode == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(memsetParams, true)); +} + } // namespace hip diff --git a/projects/clr/hipamd/src/hip_hcc.map.in b/projects/clr/hipamd/src/hip_hcc.map.in index 9bfa41d49a..2ec315cba0 100644 --- a/projects/clr/hipamd/src/hip_hcc.map.in +++ b/projects/clr/hipamd/src/hip_hcc.map.in @@ -560,6 +560,9 @@ local: hip_6.2 { global: hipGetFuncBySymbol; + hipDrvGraphExecMemcpyNodeSetParams; + hipDrvGraphExecMemsetNodeSetParams; + hipDrvGraphAddMemFreeNode; local: *; } hip_6.1; diff --git a/projects/clr/hipamd/src/hip_table_interface.cpp b/projects/clr/hipamd/src/hip_table_interface.cpp index 7173691e44..0b92c8696c 100644 --- a/projects/clr/hipamd/src/hip_table_interface.cpp +++ b/projects/clr/hipamd/src/hip_table_interface.cpp @@ -1743,3 +1743,20 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph, hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr) { return hip::GetHipDispatchTable()->hipGetFuncBySymbol_fn(functionPtr, symbolPtr); } +hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) { + return hip::GetHipDispatchTable()->hipDrvGraphExecMemsetNodeSetParams_fn(hGraphExec, hNode, + memsetParams, ctx); +} +hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, size_t numDependencies, + hipDeviceptr_t dptr) { + return hip::GetHipDispatchTable()->hipDrvGraphAddMemFreeNode_fn(phGraphNode, hGraph, + dependencies, numDependencies, + dptr); +} +hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) { + return hip::GetHipDispatchTable()->hipDrvGraphExecMemcpyNodeSetParams_fn(hGraphExec, hNode, + copyParams, ctx); +}