From 19ce99d104a66d80e29b7744b6c8e9bfb1b00243 Mon Sep 17 00:00:00 2001 From: shadi Date: Mon, 8 Jan 2024 15:39:43 -0500 Subject: [PATCH] SWDEV-421027 - Add more Graph APIs Signed-off-by: shadi Change-Id: I0a1fc284e48317a49ca88d4ed4e3a10e752efd58 [ROCm/clr commit: e705e5e0d99dac7e7ff0d0dc351604ae70976bf3] --- .../include/hip/amd_detail/hip_api_trace.hpp | 10 +++ projects/clr/hipamd/src/amdhip.def | 3 + projects/clr/hipamd/src/hip_api_trace.cpp | 13 ++- projects/clr/hipamd/src/hip_graph.cpp | 90 ++++++++++++++++++- .../clr/hipamd/src/hip_graph_internal.hpp | 1 + projects/clr/hipamd/src/hip_hcc.map.in | 3 + 6 files changed, 116 insertions(+), 4 deletions(-) diff --git a/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp b/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp index 08889fa5f3..413c3cb79b 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp +++ b/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp @@ -982,6 +982,13 @@ typedef hipError_t (*t_hipMemcpy2DArrayToArray)(hipArray_t dst, size_t wOffsetDs size_t height, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphExecGetFlags)(hipGraphExec_t graphExec, unsigned long long* flags); +typedef hipError_t (*t_hipGraphNodeSetParams)(hipGraphNode_t node, hipGraphNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphExecNodeSetParams)(hipGraphExec_t graphExec, hipGraphNode_t node, + hipGraphNodeParams* nodeParams); + + + // HIP Compiler dispatch table struct HipCompilerDispatchTable { size_t size; @@ -1458,4 +1465,7 @@ struct HipDispatchTable { t_hipMemcpyAtoHAsync hipMemcpyAtoHAsync_fn; t_hipMemcpyHtoAAsync hipMemcpyHtoAAsync_fn; t_hipMemcpy2DArrayToArray hipMemcpy2DArrayToArray_fn; + t_hipGraphExecGetFlags hipGraphExecGetFlags_fn; + t_hipGraphNodeSetParams hipGraphNodeSetParams_fn; + t_hipGraphExecNodeSetParams hipGraphExecNodeSetParams_fn; }; diff --git a/projects/clr/hipamd/src/amdhip.def b/projects/clr/hipamd/src/amdhip.def index 5ad7942013..9a156a459f 100644 --- a/projects/clr/hipamd/src/amdhip.def +++ b/projects/clr/hipamd/src/amdhip.def @@ -473,3 +473,6 @@ hipMemcpyAtoA hipMemcpyAtoHAsync hipMemcpyHtoAAsync hipMemcpy2DArrayToArray +hipGraphExecGetFlags +hipGraphNodeSetParams +hipGraphExecNodeSetParams diff --git a/projects/clr/hipamd/src/hip_api_trace.cpp b/projects/clr/hipamd/src/hip_api_trace.cpp index 551660041a..f58cd87774 100644 --- a/projects/clr/hipamd/src/hip_api_trace.cpp +++ b/projects/clr/hipamd/src/hip_api_trace.cpp @@ -789,6 +789,10 @@ hipError_t hipMemcpyHtoAAsync(hipArray_t dstArray, size_t dstOffset, const void* hipError_t hipMemcpy2DArrayToArray(hipArray_t dst, size_t wOffsetDst, size_t hOffsetDst, hipArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc, size_t width, size_t height, hipMemcpyKind kind); +hipError_t hipGraphExecGetFlags(hipGraphExec_t graphExec, unsigned long long* flags); +hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodeParams); +hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t node, + hipGraphNodeParams* nodeParams); } // namespace hip namespace hip { @@ -1275,6 +1279,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) { ptrDispatchTable->hipMemcpyAtoHAsync_fn = hip::hipMemcpyAtoHAsync; ptrDispatchTable->hipMemcpyHtoAAsync_fn = hip::hipMemcpyHtoAAsync; ptrDispatchTable->hipMemcpy2DArrayToArray_fn = hip::hipMemcpy2DArrayToArray; + ptrDispatchTable->hipGraphExecGetFlags_fn = hip::hipGraphExecGetFlags; + ptrDispatchTable->hipGraphNodeSetParams_fn = hip::hipGraphNodeSetParams; + ptrDispatchTable->hipGraphExecNodeSetParams_fn = hip::hipGraphExecNodeSetParams; } #if HIP_ROCPROFILER_REGISTER > 0 @@ -1847,7 +1854,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyAtoA_fn, 451) HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyAtoHAsync_fn, 452) HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyHtoAAsync_fn, 453) HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpy2DArrayToArray_fn, 454) - +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecGetFlags_fn, 455); +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphNodeSetParams_fn, 456); +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecNodeSetParams_fn, 457); // if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below // will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.: @@ -1855,7 +1864,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpy2DArrayToArray_fn, 454) // HIP_ENFORCE_ABI(, , 8) // // HIP_ENFORCE_ABI_VERSIONING(
, 9) <- 8 + 1 = 9 -HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 455) +HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 458) static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3, "If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function " diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 66b567cb45..807e883fa7 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -2891,11 +2891,11 @@ hipError_t hipGraphAddNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, numDependencies, false); break; case hipGraphNodeTypeExtSemaphoreSignal: - status = hipSuccess; + status = hipErrorNotSupported; // to be added. break; case hipGraphNodeTypeExtSemaphoreWait: - status = hipSuccess; + status = hipErrorNotSupported; // to be added. break; case hipGraphNodeTypeMemAlloc: @@ -3140,4 +3140,90 @@ hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGrap HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(memsetParams, true)); } +hipError_t hipGraphExecGetFlags(hipGraphExec_t graphExec, unsigned long long* flags) { + HIP_INIT_API(hipGraphExecGetFlags, graphExec, flags); + if (graphExec == nullptr || flags == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + hip::GraphExec* pgraphExec = reinterpret_cast(graphExec); + *flags = pgraphExec->GetFlags(); + HIP_RETURN(hipSuccess); +} + +hipError_t ihipGraphNodeSetParams(hip::GraphNode* n, hipGraphNodeParams *nodeParams) { + hipGraphNodeType nodeType = nodeParams->type; + hipError_t status = hipSuccess; + switch(nodeType) { + case hipGraphNodeTypeKernel: + status = reinterpret_cast(n)->SetParams(&nodeParams->kernel); + break; + case hipGraphNodeTypeMemcpy: + status = reinterpret_cast(n)->SetParams( + &nodeParams->memcpy.copyParams); + break; + case hipGraphNodeTypeMemset: + status = + reinterpret_cast(n)->SetParams(&nodeParams->memset); + break; + case hipGraphNodeTypeHost: + status = + reinterpret_cast(n)->SetParams(&nodeParams->host); + break; + case hipGraphNodeTypeGraph: + status = reinterpret_cast(n)->SetParams( + reinterpret_cast(nodeParams->graph.graph)); + break; + case hipGraphNodeTypeWaitEvent: + status = reinterpret_cast(n)->SetParams( + nodeParams->eventWait.event); + break; + case hipGraphNodeTypeEventRecord: + status = reinterpret_cast(n)->SetParams( + nodeParams->eventRecord.event); + break; + case hipGraphNodeTypeExtSemaphoreSignal: + status = hipErrorNotSupported; + // to be added. + break; + case hipGraphNodeTypeExtSemaphoreWait: + status = hipErrorNotSupported; + // to be added. + break; + case hipGraphNodeTypeMemAlloc: + status = hipErrorNotSupported; + break; + case hipGraphNodeTypeMemFree: + status = hipErrorNotSupported; + break; + default: + status = hipErrorInvalidValue; + break; + } + HIP_RETURN(status); +} + +hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodeParams) { + HIP_INIT_API(hipGraphNodeSetParams, node, nodeParams); + hip::GraphNode* n = reinterpret_cast(node); + if (node == nullptr || nodeParams == nullptr || !hip::GraphNode::isNodeValid(n)) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(ihipGraphNodeSetParams(n, nodeParams)); +} + +hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t node, + hipGraphNodeParams* nodeParams) { + HIP_INIT_API(hipGraphNodeSetParams, graphExec, node, nodeParams); + hip::GraphNode* n = reinterpret_cast(node); + if (node == nullptr || nodeParams == nullptr || graphExec == nullptr + || !hip::GraphNode::isNodeValid(n)) { + HIP_RETURN(hipErrorInvalidValue); + } + hip::GraphNode* clonedNode = reinterpret_cast( + reinterpret_cast(graphExec)->GetClonedNode(n)); + if (clonedNode == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(ihipGraphNodeSetParams(clonedNode, nodeParams)); +} } // namespace hip diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index d69d986b3d..d4faee7d95 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -643,6 +643,7 @@ struct GraphExec : public amd::ReferenceCountedObject { } void ResetQueueIndex() { currentQueueIndex_ = 0; } + uint64_t GetFlags() const { return flags_; } hipError_t Init(); hipError_t CreateStreams(uint32_t num_streams); hipError_t Run(hipStream_t stream); diff --git a/projects/clr/hipamd/src/hip_hcc.map.in b/projects/clr/hipamd/src/hip_hcc.map.in index 4efc55b23f..89b2c6520e 100644 --- a/projects/clr/hipamd/src/hip_hcc.map.in +++ b/projects/clr/hipamd/src/hip_hcc.map.in @@ -570,6 +570,9 @@ global: hipMemcpyAtoHAsync; hipMemcpyHtoAAsync; hipMemcpy2DArrayToArray; + hipGraphExecGetFlags; + hipGraphNodeSetParams; + hipGraphExecNodeSetParams; local: *; } hip_6.1;