diff --git a/hipamd/include/hip/amd_detail/hip_api_trace.hpp b/hipamd/include/hip/amd_detail/hip_api_trace.hpp
index 96031d3b48..957ea75627 100644
--- a/hipamd/include/hip/amd_detail/hip_api_trace.hpp
+++ b/hipamd/include/hip/amd_detail/hip_api_trace.hpp
@@ -954,6 +954,18 @@ typedef hipError_t (*t_hipStreamBeginCaptureToGraph)(hipStream_t stream, hipGrap
hipStreamCaptureMode mode);
typedef hipError_t (*t_hipGetFuncBySymbol)(hipFunction_t* functionPtr, const void* symbolPtr);
+typedef hipError_t (*t_hipDrvGraphAddMemFreeNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
+ const hipGraphNode_t* dependencies, size_t numDependencies,
+ hipDeviceptr_t dptr);
+
+typedef hipError_t (*t_hipDrvGraphExecMemcpyNodeSetParams)(hipGraphExec_t hGraphExec,
+ hipGraphNode_t hNode, const HIP_MEMCPY3D* copyParams,
+ hipCtx_t ctx);
+
+typedef hipError_t (*t_hipDrvGraphExecMemsetNodeSetParams)(hipGraphExec_t hGraphExec,
+ hipGraphNode_t hNode, const HIP_MEMSET_NODE_PARAMS* memsetParams,
+ hipCtx_t ctx);
+
// HIP Compiler dispatch table
struct HipCompilerDispatchTable {
size_t size;
@@ -1420,4 +1432,7 @@ struct HipDispatchTable {
t_hipGetProcAddress hipGetProcAddress_fn;
t_hipStreamBeginCaptureToGraph hipStreamBeginCaptureToGraph_fn;
t_hipGetFuncBySymbol hipGetFuncBySymbol_fn;
+ t_hipDrvGraphAddMemFreeNode hipDrvGraphAddMemFreeNode_fn;
+ t_hipDrvGraphExecMemcpyNodeSetParams hipDrvGraphExecMemcpyNodeSetParams_fn;
+ t_hipDrvGraphExecMemsetNodeSetParams hipDrvGraphExecMemsetNodeSetParams_fn;
};
diff --git a/hipamd/src/amdhip.def b/hipamd/src/amdhip.def
index cfc3ae6b5d..c1017af7af 100644
--- a/hipamd/src/amdhip.def
+++ b/hipamd/src/amdhip.def
@@ -463,3 +463,6 @@ hipGraphAddNode
hipGraphInstantiateWithParams
hipStreamBeginCaptureToGraph
hipGetFuncBySymbol
+hipDrvGraphAddMemFreeNode
+hipDrvGraphExecMemcpyNodeSetParams
+hipDrvGraphExecMemsetNodeSetParams
diff --git a/hipamd/src/hip_api_trace.cpp b/hipamd/src/hip_api_trace.cpp
index 1479ae1507..05598dd7f1 100644
--- a/hipamd/src/hip_api_trace.cpp
+++ b/hipamd/src/hip_api_trace.cpp
@@ -768,6 +768,13 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph,
const hipGraphEdgeData* dependencyData,
size_t numDependencies, hipStreamCaptureMode mode);
hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr);
+hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
+ const hipGraphNode_t* dependencies, size_t numDependencies,
+ hipDeviceptr_t dptr);
+hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
+ const HIP_MEMCPY3D* copyParams, hipCtx_t ctx);
+hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
+ const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx);
} // namespace hip
namespace hip {
@@ -1244,6 +1251,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
ptrDispatchTable->hipGetProcAddress_fn = hip::hipGetProcAddress;
ptrDispatchTable->hipStreamBeginCaptureToGraph_fn = hip::hipStreamBeginCaptureToGraph;
ptrDispatchTable->hipGetFuncBySymbol_fn = hip::hipGetFuncBySymbol;
+ ptrDispatchTable->hipDrvGraphAddMemFreeNode_fn = hip::hipDrvGraphAddMemFreeNode;
+ ptrDispatchTable->hipDrvGraphExecMemcpyNodeSetParams_fn = hip::hipDrvGraphExecMemcpyNodeSetParams;
+ ptrDispatchTable->hipDrvGraphExecMemsetNodeSetParams_fn = hip::hipDrvGraphExecMemsetNodeSetParams;
}
#if HIP_ROCPROFILER_REGISTER > 0
@@ -1806,7 +1816,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipTexRefGetArray_fn, 441)
HIP_ENFORCE_ABI(HipDispatchTable, hipGetProcAddress_fn, 442)
HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBeginCaptureToGraph_fn, 443);
HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444);
-
+HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphAddMemFreeNode_fn, 445)
+HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemcpyNodeSetParams_fn, 446)
+HIP_ENFORCE_ABI(HipDispatchTable, hipDrvGraphExecMemsetNodeSetParams_fn, 447)
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
@@ -1814,7 +1826,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipGetFuncBySymbol_fn, 444);
// HIP_ENFORCE_ABI(
, , 8)
//
// HIP_ENFORCE_ABI_VERSIONING(, 9) <- 8 + 1 = 9
-HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 445)
+HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 448)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp
index bf13f3cafd..8440be6f93 100644
--- a/hipamd/src/hip_graph.cpp
+++ b/hipamd/src/hip_graph.cpp
@@ -2501,6 +2501,28 @@ hipError_t hipGraphMemAllocNodeGetParams(hipGraphNode_t node, hipMemAllocNodePar
HIP_RETURN(hipSuccess);
}
+hipError_t ihipGraphAddMemFreeNode(hip::GraphNode** graphNode, hip::Graph* graph,
+ hip::GraphNode* const* pDependencies, size_t numDependencies,
+ void* dptr) {
+ // Is memory passed to be free'd valid
+ size_t offset = 0;
+ auto memory = getMemoryObject(dptr, offset);
+ if (memory == nullptr) {
+ if (HIP_MEM_POOL_USE_VM) {
+ // When VM is on the address must be valid and may point to a VA object
+ memory = amd::MemObjMap::FindVirtualMemObj(dptr);
+ }
+ if (memory == nullptr) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ }
+
+ auto mem_free_node = new hip::GraphMemFreeNode(dptr);
+ *graphNode = mem_free_node;
+ auto status =
+ ihipGraphAddNode(*graphNode, graph, pDependencies, numDependencies);
+ HIP_RETURN(status);
+}
// ================================================================================================
hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
const hipGraphNode_t* pDependencies, size_t numDependencies,
@@ -2512,26 +2534,12 @@ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
dev_ptr == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
-
- // Is memory passed to be free'd valid
- size_t offset = 0;
- auto memory = getMemoryObject(dev_ptr, offset);
- if (memory == nullptr) {
- if (HIP_MEM_POOL_USE_VM) {
- // When VM is on the address must be valid and may point to a VA object
- memory = amd::MemObjMap::FindVirtualMemObj(dev_ptr);
- }
- if (memory == nullptr) {
- HIP_RETURN(hipErrorInvalidValue);
- }
- }
-
- auto mem_free_node = new hip::GraphMemFreeNode(dev_ptr);
- hip::GraphNode* node = mem_free_node;
+ hip::GraphNode* pNode;
auto status =
- ihipGraphAddNode(node, reinterpret_cast(graph),
- reinterpret_cast(pDependencies), numDependencies);
- *pGraphNode = reinterpret_cast(node);
+ ihipGraphAddMemFreeNode(&pNode,
+ reinterpret_cast(graph),
+ reinterpret_cast(pDependencies), numDependencies, dev_ptr);
+ *pGraphNode = reinterpret_cast(pNode);
HIP_RETURN(status);
}
@@ -3036,4 +3044,86 @@ hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams(hipGraphExec_t hGraph
nodeParams));
}
+hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
+ const hipGraphNode_t* dependencies, size_t numDependencies,
+ hipDeviceptr_t dptr) {
+ HIP_INIT_API(hipDrvGraphAddMemFreeNode, phGraphNode, hGraph, dependencies, numDependencies, dptr);
+ if (phGraphNode == nullptr || hGraph == nullptr ||
+ ((numDependencies > 0 && dependencies == nullptr) ||
+ (dependencies != nullptr && numDependencies == 0)) ||
+ dptr == nullptr) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ // Is memory passed to be free'd valid
+ size_t offset = 0;
+ auto memory = getMemoryObject(dptr, offset);
+ if (memory == nullptr) {
+ if (HIP_MEM_POOL_USE_VM) {
+ // When VM is on the address must be valid and may point to a VA object
+ memory = amd::MemObjMap::FindVirtualMemObj(dptr);
+ }
+ if (memory == nullptr) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ }
+ hip::GraphNode* pNode;
+ auto status =
+ ihipGraphAddMemFreeNode(&pNode,
+ reinterpret_cast(hGraph),
+ reinterpret_cast(dependencies), numDependencies, dptr);
+ *phGraphNode = reinterpret_cast(pNode);
+ HIP_RETURN(status);
+}
+
+hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
+ const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) {
+ HIP_INIT_API(hipDrvGraphExecMemcpyNodeSetParams, hGraphExec, hNode, copyParams);
+ hip::GraphNode* n = reinterpret_cast(hNode);
+ if (hGraphExec == nullptr ||
+ !hip::GraphNode::isNodeValid(reinterpret_cast(n))) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ if (ihipDrvMemcpy3D_validate(copyParams) != hipSuccess) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ // Check if pNodeParams passed is a empty struct
+ if (((copyParams->srcArray == 0) && (copyParams->srcHost == nullptr)
+ && (copyParams->srcDevice == nullptr)) ||
+ ((copyParams->dstArray == 0) && (copyParams->dstHost == nullptr)
+ && (copyParams->dstDevice == nullptr))) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n);
+ if (clonedNode == nullptr) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(copyParams));
+}
+
+hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
+ const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) {
+ HIP_INIT_API(hipDrvGraphExecMemsetNodeSetParams, hGraphExec, hNode, memsetParams);
+ hip::GraphNode* n = reinterpret_cast(hNode);
+
+ if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || memsetParams == nullptr ||
+ memsetParams->dst == nullptr) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ hipMemsetParams pmemsetParams;
+ pmemsetParams.dst = reinterpret_cast(memsetParams->dst);
+ pmemsetParams.elementSize = memsetParams->elementSize;
+ pmemsetParams.height = memsetParams->height;
+ pmemsetParams.pitch = memsetParams->pitch;
+ pmemsetParams.value = memsetParams->value;
+ pmemsetParams.width = memsetParams->width;
+ if (ihipGraphMemsetParams_validate(&pmemsetParams) != hipSuccess) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ hip::GraphNode* clonedNode = reinterpret_cast(hGraphExec)->GetClonedNode(n);
+ if (clonedNode == nullptr) {
+ HIP_RETURN(hipErrorInvalidValue);
+ }
+ HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(memsetParams, true));
+}
+
} // namespace hip
diff --git a/hipamd/src/hip_hcc.map.in b/hipamd/src/hip_hcc.map.in
index 9bfa41d49a..2ec315cba0 100644
--- a/hipamd/src/hip_hcc.map.in
+++ b/hipamd/src/hip_hcc.map.in
@@ -560,6 +560,9 @@ local:
hip_6.2 {
global:
hipGetFuncBySymbol;
+ hipDrvGraphExecMemcpyNodeSetParams;
+ hipDrvGraphExecMemsetNodeSetParams;
+ hipDrvGraphAddMemFreeNode;
local:
*;
} hip_6.1;
diff --git a/hipamd/src/hip_table_interface.cpp b/hipamd/src/hip_table_interface.cpp
index 7173691e44..0b92c8696c 100644
--- a/hipamd/src/hip_table_interface.cpp
+++ b/hipamd/src/hip_table_interface.cpp
@@ -1743,3 +1743,20 @@ hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph,
hipError_t hipGetFuncBySymbol(hipFunction_t* functionPtr, const void* symbolPtr) {
return hip::GetHipDispatchTable()->hipGetFuncBySymbol_fn(functionPtr, symbolPtr);
}
+hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
+ const HIP_MEMSET_NODE_PARAMS* memsetParams, hipCtx_t ctx) {
+ return hip::GetHipDispatchTable()->hipDrvGraphExecMemsetNodeSetParams_fn(hGraphExec, hNode,
+ memsetParams, ctx);
+}
+hipError_t hipDrvGraphAddMemFreeNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph,
+ const hipGraphNode_t* dependencies, size_t numDependencies,
+ hipDeviceptr_t dptr) {
+ return hip::GetHipDispatchTable()->hipDrvGraphAddMemFreeNode_fn(phGraphNode, hGraph,
+ dependencies, numDependencies,
+ dptr);
+}
+hipError_t hipDrvGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode,
+ const HIP_MEMCPY3D* copyParams, hipCtx_t ctx) {
+ return hip::GetHipDispatchTable()->hipDrvGraphExecMemcpyNodeSetParams_fn(hGraphExec, hNode,
+ copyParams, ctx);
+}