From ba07492b2012405095e81ee124d3a354322c31d2 Mon Sep 17 00:00:00 2001 From: sdashmiz Date: Thu, 18 Jan 2024 10:52:32 -0500 Subject: [PATCH] SWDEV-421021 - Add APIs cuMemcpyNodeGet/Set params Signed-off-by: sdashmiz Change-Id: I24bc0da56aad34c9d5876a3d83b59515f11dc3ea [ROCm/clr commit: 57e79802cda7ed25e6ca23a57487c12d613a1c57] --- .../include/hip/amd_detail/hip_api_trace.hpp | 10 +++++++++- projects/clr/hipamd/src/amdhip.def | 2 ++ projects/clr/hipamd/src/hip_api_trace.cpp | 13 ++++++++++--- projects/clr/hipamd/src/hip_graph.cpp | 19 +++++++++++++++++++ projects/clr/hipamd/src/hip_hcc.map.in | 2 ++ .../clr/hipamd/src/hip_table_interface.cpp | 6 ++++++ 6 files changed, 48 insertions(+), 4 deletions(-) diff --git a/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp b/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp index fa6c6b121f..62443460c7 100644 --- a/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp +++ b/projects/clr/hipamd/include/hip/amd_detail/hip_api_trace.hpp @@ -61,7 +61,7 @@ // - Reset any of the *_STEP_VERSION defines to zero if the corresponding *_MAJOR_VERSION increases #define HIP_API_TABLE_STEP_VERSION 0 #define HIP_COMPILER_API_TABLE_STEP_VERSION 0 -#define HIP_RUNTIME_API_TABLE_STEP_VERSION 3 +#define HIP_RUNTIME_API_TABLE_STEP_VERSION 4 // HIP API interface typedef hipError_t (*t___hipPopCallConfiguration)(dim3* gridDim, dim3* blockDim, size_t* sharedMem, @@ -992,6 +992,12 @@ typedef hipError_t (*t_hipGraphExecNodeSetParams)(hipGraphExec_t graphExec, hipG typedef hipError_t (*t_hipExternalMemoryGetMappedMipmappedArray)( hipMipmappedArray_t* mipmap, hipExternalMemory_t extMem, const hipExternalMemoryMipmappedArrayDesc* mipmapDesc); +typedef hipError_t (*t_hipDrvGraphMemcpyNodeGetParams)(hipGraphNode_t hNode, + HIP_MEMCPY3D* nodeParams); + +typedef hipError_t (*t_hipDrvGraphMemcpyNodeSetParams)(hipGraphNode_t hNode, + const HIP_MEMCPY3D* nodeParams); + // HIP Compiler dispatch table struct HipCompilerDispatchTable { size_t size; @@ -1472,4 +1478,6 @@ struct HipDispatchTable { t_hipGraphNodeSetParams hipGraphNodeSetParams_fn; t_hipGraphExecNodeSetParams hipGraphExecNodeSetParams_fn; t_hipExternalMemoryGetMappedMipmappedArray hipExternalMemoryGetMappedMipmappedArray_fn; + t_hipDrvGraphMemcpyNodeGetParams hipDrvGraphMemcpyNodeGetParams_fn; + t_hipDrvGraphMemcpyNodeSetParams hipDrvGraphMemcpyNodeSetParams_fn; }; diff --git a/projects/clr/hipamd/src/amdhip.def b/projects/clr/hipamd/src/amdhip.def index 9a156a459f..3b4836484f 100644 --- a/projects/clr/hipamd/src/amdhip.def +++ b/projects/clr/hipamd/src/amdhip.def @@ -476,3 +476,5 @@ hipMemcpy2DArrayToArray hipGraphExecGetFlags hipGraphNodeSetParams hipGraphExecNodeSetParams +hipDrvGraphMemcpyNodeSetParams +hipDrvGraphMemcpyNodeGetParams diff --git a/projects/clr/hipamd/src/hip_api_trace.cpp b/projects/clr/hipamd/src/hip_api_trace.cpp index 66e2827365..6bd03cf8ac 100644 --- a/projects/clr/hipamd/src/hip_api_trace.cpp +++ b/projects/clr/hipamd/src/hip_api_trace.cpp @@ -762,7 +762,7 @@ hipError_t hipExtGetLastError(); hipError_t hipTexRefGetBorderColor(float* pBorderColor, const textureReference* texRef); hipError_t hipTexRefGetArray(hipArray_t* pArray, const textureReference* texRef); hipError_t hipGetProcAddress(const char* symbol, void** pfn, int hipVersion, uint64_t flags, - hipDriverProcAddressQueryResult* symbolStatus); + hipDriverProcAddressQueryResult* symbolStatus = NULL); hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph, const hipGraphNode_t* dependencies, const hipGraphEdgeData* dependencyData, @@ -796,6 +796,9 @@ hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t no hipError_t hipExternalMemoryGetMappedMipmappedArray( hipMipmappedArray_t* mipmap, hipExternalMemory_t extMem, const hipExternalMemoryMipmappedArrayDesc* mipmapDesc); +hipError_t hipDrvGraphMemcpyNodeGetParams(hipGraphNode_t hNode, HIP_MEMCPY3D* nodeParams); +hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY3D* nodeParams); + } // namespace hip namespace hip { @@ -1287,6 +1290,8 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) { ptrDispatchTable->hipGraphExecNodeSetParams_fn = hip::hipGraphExecNodeSetParams; ptrDispatchTable->hipExternalMemoryGetMappedMipmappedArray_fn = hip::hipExternalMemoryGetMappedMipmappedArray; + ptrDispatchTable->hipDrvGraphMemcpyNodeGetParams_fn = hip::hipDrvGraphMemcpyNodeGetParams; + ptrDispatchTable->hipDrvGraphMemcpyNodeSetParams_fn = hip::hipDrvGraphMemcpyNodeSetParams; } #if HIP_ROCPROFILER_REGISTER > 0 @@ -1863,6 +1868,8 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecGetFlags_fn, 455); HIP_ENFORCE_ABI(HipDispatchTable, hipGraphNodeSetParams_fn, 456); HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecNodeSetParams_fn, 457); HIP_ENFORCE_ABI(HipDispatchTable, hipExternalMemoryGetMappedMipmappedArray_fn, 458) +HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphMemcpyNodeGetParams_fn, 459) +HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphMemcpyNodeSetParams_fn, 460) // if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below // will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.: @@ -1870,9 +1877,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipExternalMemoryGetMappedMipmappedArray_fn, 4 // HIP_ENFORCE_ABI(, , 8) // // HIP_ENFORCE_ABI_VERSIONING(
, 9) <- 8 + 1 = 9 -HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 459) +HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 461) -static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3, +static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 4, "If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function " "pointers and then update this check so it is true"); #endif diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index ee5b94ce23..3139ae7644 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -3271,4 +3271,23 @@ hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t no } HIP_RETURN(ihipGraphNodeSetParams(clonedNode, nodeParams)); } +hipError_t hipDrvGraphMemcpyNodeGetParams(hipGraphNode_t hNode, HIP_MEMCPY3D* nodeParams) { + HIP_INIT_API(hipDrvGraphMemcpyNodeGetParams, hNode, nodeParams); + if (!hip::GraphNode::isNodeValid(reinterpret_cast(hNode)) || + nodeParams == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + reinterpret_cast(hNode)->GetParams(nodeParams); + HIP_RETURN(hipSuccess); +} + +hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY3D* nodeParams) { + HIP_INIT_API(hipDrvGraphMemcpyNodeSetParams, hNode, nodeParams); + if (!hip::GraphNode::isNodeValid(reinterpret_cast(hNode)) || + nodeParams == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(reinterpret_cast(hNode)->SetParams(nodeParams)); +} + } // namespace hip diff --git a/projects/clr/hipamd/src/hip_hcc.map.in b/projects/clr/hipamd/src/hip_hcc.map.in index 89b2c6520e..31ab71a7d7 100644 --- a/projects/clr/hipamd/src/hip_hcc.map.in +++ b/projects/clr/hipamd/src/hip_hcc.map.in @@ -544,6 +544,8 @@ global: hipChooseDeviceR0600; hipGetDevicePropertiesR0600; hipExtGetLastError; + hipDrvGraphMemcpyNodeSetParams; + hipDrvGraphMemcpyNodeGetParams; local: *; } hip_5.6; diff --git a/projects/clr/hipamd/src/hip_table_interface.cpp b/projects/clr/hipamd/src/hip_table_interface.cpp index 1b5331dc3e..2553038271 100644 --- a/projects/clr/hipamd/src/hip_table_interface.cpp +++ b/projects/clr/hipamd/src/hip_table_interface.cpp @@ -1799,3 +1799,9 @@ hipError_t hipMemcpy2DArrayToArray(hipArray_t dst, size_t wOffsetDst, size_t hOf return hip::GetHipDispatchTable()->hipMemcpy2DArrayToArray_fn( dst, wOffsetDst, hOffsetDst, src, wOffsetSrc, hOffsetSrc, width, height, kind); } +hipError_t hipDrvGraphMemcpyNodeGetParams(hipGraphNode_t hNode, HIP_MEMCPY3D* nodeParams) { + return hip::GetHipDispatchTable()->hipDrvGraphMemcpyNodeGetParams_fn(hNode, nodeParams); +} +hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY3D* nodeParams) { + return hip::GetHipDispatchTable()->hipDrvGraphMemcpyNodeSetParams_fn(hNode, nodeParams); +}