From f1c05e902656c15d7a6c229e0b400a1656dbcb5f Mon Sep 17 00:00:00 2001 From: Sourabh Betigeri Date: Thu, 12 Sep 2024 15:08:43 -0700 Subject: [PATCH] SWDEV-421020 - Adds hipGraphAddBatchMemOp, SetGetParams and execSetParams APIs Change-Id: Ieccecfe6173cc68fd3c01f86c99f7cc09fe194a3 --- .../include/hip/amd_detail/hip_api_trace.hpp | 22 +++- hipamd/include/hip/amd_detail/hip_prof_str.h | 121 +++++++++++++++++- hipamd/src/amdhip.def | 4 + hipamd/src/hip_api_trace.cpp | 25 +++- hipamd/src/hip_graph.cpp | 54 ++++++++ hipamd/src/hip_graph_internal.hpp | 45 +++++++ hipamd/src/hip_hcc.map.in | 11 +- hipamd/src/hip_table_interface.cpp | 20 +++ 8 files changed, 294 insertions(+), 8 deletions(-) diff --git a/hipamd/include/hip/amd_detail/hip_api_trace.hpp b/hipamd/include/hip/amd_detail/hip_api_trace.hpp index acfbe1a2e0..efab82a610 100644 --- a/hipamd/include/hip/amd_detail/hip_api_trace.hpp +++ b/hipamd/include/hip/amd_detail/hip_api_trace.hpp @@ -61,7 +61,7 @@ // - Reset any of the *_STEP_VERSION defines to zero if the corresponding *_MAJOR_VERSION increases #define HIP_API_TABLE_STEP_VERSION 0 #define HIP_COMPILER_API_TABLE_STEP_VERSION 0 -#define HIP_RUNTIME_API_TABLE_STEP_VERSION 7 +#define HIP_RUNTIME_API_TABLE_STEP_VERSION 8 // HIP API interface typedef hipError_t (*t___hipPopCallConfiguration)(dim3* gridDim, dim3* blockDim, size_t* sharedMem, @@ -723,7 +723,8 @@ typedef hipError_t (*t_hipStreamWriteValue32)(hipStream_t stream, void* ptr, uin typedef hipError_t (*t_hipStreamWriteValue64)(hipStream_t stream, void* ptr, uint64_t value, unsigned int flags); typedef hipError_t (*t_hipStreamBatchMemOp)(hipStream_t stream, unsigned int count, - hipStreamBatchMemOpParams* paramArray, unsigned int flags); + hipStreamBatchMemOpParams* paramArray, + unsigned int flags); typedef hipError_t (*t_hipTexObjectCreate)(hipTextureObject_t* pTexObject, const HIP_RESOURCE_DESC* pResDesc, const HIP_TEXTURE_DESC* pTexDesc, @@ -1006,6 +1007,17 @@ typedef hipError_t (*t_hipExtHostAlloc)(void **ptr, size_t size, typedef hipError_t (*t_hipDeviceGetTexture1DLinearMaxWidth)(size_t *maxWidthInElements, const hipChannelFormatDesc *fmtDesc, int device); + +typedef hipError_t (*t_hipGraphAddBatchMemOpNode)(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, + size_t numDependencies, + const hipBatchMemOpNodeParams* nodeParams); +typedef hipError_t (*t_hipGraphBatchMemOpNodeGetParams)(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams_out); +typedef hipError_t (*t_hipGraphBatchMemOpNodeSetParams)(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams); +typedef hipError_t (*t_hipGraphExecBatchMemOpNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, const hipBatchMemOpNodeParams* nodeParams); // HIP Compiler dispatch table struct HipCompilerDispatchTable { // HIP_COMPILER_API_TABLE_STEP_VERSION == 0 @@ -1524,6 +1536,12 @@ struct HipDispatchTable { // HIP_RUNTIME_API_TABLE_STEP_VERSION == 7 t_hipStreamBatchMemOp hipStreamBatchMemOp_fn; + // HIP_RUNTIME_API_TABLE_STEP_VERSION == 8 + t_hipGraphAddBatchMemOpNode hipGraphAddBatchMemOpNode_fn; + t_hipGraphBatchMemOpNodeGetParams hipGraphBatchMemOpNodeGetParams_fn; + t_hipGraphBatchMemOpNodeSetParams hipGraphBatchMemOpNodeSetParams_fn; + t_hipGraphExecBatchMemOpNodeSetParams hipGraphExecBatchMemOpNodeSetParams_fn; + // DO NOT EDIT ABOVE! // HIP_RUNTIME_API_TABLE_STEP_VERSION == 7 diff --git a/hipamd/include/hip/amd_detail/hip_prof_str.h b/hipamd/include/hip/amd_detail/hip_prof_str.h index 9cb1e275e6..a7658cfc38 100644 --- a/hipamd/include/hip/amd_detail/hip_prof_str.h +++ b/hipamd/include/hip/amd_detail/hip_prof_str.h @@ -426,7 +426,11 @@ enum hip_api_id_t { HIP_API_ID_hipSetValidDevices = 406, HIP_API_ID_hipExtHostAlloc = 407, HIP_API_ID_hipStreamBatchMemOp = 408, - HIP_API_ID_LAST = 408, + HIP_API_ID_hipGraphAddBatchMemOpNode = 409, + HIP_API_ID_hipGraphBatchMemOpNodeGetParams = 410, + HIP_API_ID_hipGraphBatchMemOpNodeSetParams = 411, + HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams = 412, + HIP_API_ID_LAST = 412, HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_,hipChooseDevice), HIP_API_ID_hipGetDeviceProperties = HIP_API_ID_CONCAT(HIP_API_ID_,hipGetDeviceProperties), @@ -861,6 +865,10 @@ static inline const char* hip_api_name(const uint32_t id) { case HIP_API_ID_hipWaitExternalSemaphoresAsync: return "hipWaitExternalSemaphoresAsync"; case HIP_API_ID_hipExtGetLastError: return "hipExtGetLastError"; case HIP_API_ID_hipStreamBatchMemOp: return "hipStreamBatchMemOp"; + case HIP_API_ID_hipGraphAddBatchMemOpNode: return "hipGraphAddBatchMemOpNode"; + case HIP_API_ID_hipGraphBatchMemOpNodeGetParams: return "hipGraphBatchMemOpNodeGetParams"; + case HIP_API_ID_hipGraphBatchMemOpNodeSetParams: return "hipGraphBatchMemOpNodeSetParams"; + case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams: return "hipGraphExecBatchMemOpNodeSetParams"; }; return "unknown"; }; @@ -1265,6 +1273,10 @@ static inline uint32_t hipApiIdByName(const char* name) { if (strcmp("hipUserObjectRetain", name) == 0) return HIP_API_ID_hipUserObjectRetain; if (strcmp("hipWaitExternalSemaphoresAsync", name) == 0) return HIP_API_ID_hipWaitExternalSemaphoresAsync; if (strcmp("hipStreamBatchMemOp", name) == 0) return HIP_API_ID_hipStreamBatchMemOp; + if (strcmp("hipGraphAddBatchMemOpNode", name) == 0) return HIP_API_ID_hipGraphAddBatchMemOpNode; + if (strcmp("hipGraphBatchMemOpNodeGetParams", name) == 0) return HIP_API_ID_hipGraphBatchMemOpNodeGetParams; + if (strcmp("hipGraphBatchMemOpNodeSetParams", name) == 0) return HIP_API_ID_hipGraphBatchMemOpNodeSetParams; + if (strcmp("hipGraphExecBatchMemOpNodeSetParams", name) == 0) return HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams; return HIP_API_ID_NONE; } @@ -3633,6 +3645,32 @@ typedef struct hip_api_data_s { hipStreamBatchMemOpParams paramArray__val; unsigned int flags; } hipStreamBatchMemOp; + struct { + hipGraphNode_t* phGraphNode; + hipGraphNode_t phGraphNode__val; + hipGraph_t hGraph; + const hipGraphNode_t* dependencies; + hipGraphNode_t dependencies__val; + size_t numDependencies; + const hipBatchMemOpNodeParams* nodeParams; + hipBatchMemOpNodeParams nodeParams__val; + } hipGraphAddBatchMemOpNode; + struct { + hipGraphNode_t hNode; + hipBatchMemOpNodeParams* nodeParams_out; + hipBatchMemOpNodeParams nodeParams_out__val; + } hipGraphBatchMemOpNodeGetParams; + struct { + hipGraphNode_t hNode; + hipBatchMemOpNodeParams* nodeParams; + hipBatchMemOpNodeParams nodeParams__val; + } hipGraphBatchMemOpNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + const hipBatchMemOpNodeParams* nodeParams; + hipBatchMemOpNodeParams nodeParams__val; + } hipGraphExecBatchMemOpNodeSetParams; } args; uint64_t *phase_data; } hip_api_data_t; @@ -6045,6 +6083,36 @@ typedef struct hip_api_data_s { cb_data.args.hipWaitExternalSemaphoresAsync.numExtSems = (unsigned int)numExtSems; \ cb_data.args.hipWaitExternalSemaphoresAsync.stream = (hipStream_t)stream; \ }; + +// hipGraphAddBatchMemOpNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'), +// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'), +// ('hipBatchMemOpNodeParams*'), 'nodeParams')] +#define INIT_hipGraphAddBatchMemOpNode_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphAddBatchMemOpNode.phGraphNode = (hipGraphNode_t*)phGraphNode; \ + cb_data.args.hipGraphAddBatchMemOpNode.hGraph = (hipGraph_t)hGraph; \ + cb_data.args.hipGraphAddBatchMemOpNode.dependencies= (hipGraphNode_t*)dependencies; \ + cb_data.args.hipGraphAddBatchMemOpNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipGraphAddBatchMemOpNode.nodeParams = (hipBatchMemOpNodeParams*)nodeParams; \ +}; +// hipGraphBatchMemOpNodeGetParams[('hipGraphNode_t', hNode), +// ('hipBatchMemOpNodeParams*', 'nodeParams_out')] +#define INIT_hipGraphBatchMemOpNodeGetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphBatchMemOpNodeGetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphBatchMemOpNodeGetParams.nodeParams_out = (hipBatchMemOpNodeParams*)nodeParams_out; \ +}; +// hipGraphBatchMemOpNodeSetParams[('hipGraphNode_t', hNode), +// ('hipBatchMemOpNodeParams*', 'nodeParams')] +#define INIT_hipGraphBatchMemOpNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphBatchMemOpNodeSetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphBatchMemOpNodeSetParams.nodeParams = (hipBatchMemOpNodeParams*)nodeParams; \ +}; +// hipGraphExecBatchMemOpNodeSetParams[('hipGraphExec_t'. hGraphExec), +// ('hipGraphNode_t'. hNode), ('hipBatchMemOpNodeParams*', 'nodeParams')] +#define INIT_hipGraphExecBatchMemOpNodeSetParams_CB_ARGS_DATA(cb_data) { \ + cb_data.args.hipGraphExecBatchMemOpNodeSetParams.hGraphExec = (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecBatchMemOpNodeSetParams.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecBatchMemOpNodeSetParams.nodeParams= (hipBatchMemOpNodeParams*)nodeParams; \ +}; #define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data) // Macros for non-public API primitives @@ -7551,6 +7619,29 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) { case HIP_API_ID_hipStreamBatchMemOp: if (data->args.hipStreamBatchMemOp.paramArray) data->args.hipStreamBatchMemOp.paramArray__val = *(data->args.hipStreamBatchMemOp.paramArray); break; +// hipGraphAddBatchMemOpNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', 'hGraph'), +// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'), +// ('hipBatchMemOpNodeParams*'), 'nodeParams')] + case HIP_API_ID_hipGraphAddBatchMemOpNode: + if (data->args.hipGraphAddBatchMemOpNode.phGraphNode) data->args.hipGraphAddBatchMemOpNode.phGraphNode__val = *(data->args.hipGraphAddBatchMemOpNode.phGraphNode); + if (data->args.hipGraphAddBatchMemOpNode.dependencies) data->args.hipGraphAddBatchMemOpNode.dependencies__val = *(data->args.hipGraphAddBatchMemOpNode.dependencies); + if (data->args.hipGraphAddBatchMemOpNode.nodeParams) data->args.hipGraphAddBatchMemOpNode.nodeParams__val = *(data->args.hipGraphAddBatchMemOpNode.nodeParams); + break; +// hipGraphBatchMemOpNodeGetParams[('hipGraphNode_t', hNode), +// ('hipBatchMemOpNodeParams*', 'nodeParams_out')] + case HIP_API_ID_hipGraphBatchMemOpNodeGetParams: + if (data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out) data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out__val = *(data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out); + break; +// hipGraphBatchMemOpNodeSetParams[('hipGraphNode_t', hNode), +// ('hipBatchMemOpNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphBatchMemOpNodeSetParams: + if (data->args.hipGraphBatchMemOpNodeSetParams.nodeParams) data->args.hipGraphBatchMemOpNodeSetParams.nodeParams__val = *(data->args.hipGraphBatchMemOpNodeSetParams.nodeParams); + break; +// hipGraphExecBatchMemOpNodeSetParams[('hipGraphExec_t'. hGraphExec), +// ('hipGraphNode_t'. hNode), ('hipBatchMemOpNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams: + if (data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams) data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams__val = *(data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams); + break; // hipTexRefGetAddress[('hipDeviceptr_t*', 'dev_ptr'), ('const textureReference*', 'texRef')] case HIP_API_ID_hipTexRefGetAddress: if (data->args.hipTexRefGetAddress.dev_ptr) data->args.hipTexRefGetAddress.dev_ptr__val = *(data->args.hipTexRefGetAddress.dev_ptr); @@ -10843,6 +10934,34 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da oss << ", stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipWaitExternalSemaphoresAsync.stream); oss << ")"; break; + case HIP_API_ID_hipGraphAddBatchMemOpNode: + oss << "hipGraphAddBatchMemOpNode("; + oss << "phGraphNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.phGraphNode); + oss << ", hGraph="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.hGraph); + oss << ", dependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.dependencies); + oss << ", numDependencies="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.numDependencies); + oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphAddBatchMemOpNode.nodeParams); + oss << ")"; + break; + case HIP_API_ID_hipGraphBatchMemOpNodeGetParams: + oss << "hipGraphBatchMemOpNodeGetParams("; + oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeGetParams.hNode); + oss << ", nodeParams_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeGetParams.nodeParams_out); + oss << ")"; + break; + case HIP_API_ID_hipGraphBatchMemOpNodeSetParams: + oss << "hipGraphBatchMemOpNodeSetParams("; + oss << "hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeSetParams.hNode); + oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphBatchMemOpNodeSetParams.nodeParams); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecBatchMemOpNodeSetParams: + oss << "hipGraphExecBatchMemOpNodeSetParams("; + oss << "hGraphExec="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.hGraphExec); + oss << ", hNode="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.hNode); + oss << ", nodeParams="; roctracer::hip_support::detail::operator<<(oss, data->args.hipGraphExecBatchMemOpNodeSetParams.nodeParams); + oss << ")"; + break; default: oss << "unknown"; }; return strdup(oss.str().c_str()); diff --git a/hipamd/src/amdhip.def b/hipamd/src/amdhip.def index 355f180d7b..b8dbd49323 100644 --- a/hipamd/src/amdhip.def +++ b/hipamd/src/amdhip.def @@ -481,3 +481,7 @@ hipDrvGraphMemcpyNodeSetParams hipDrvGraphMemcpyNodeGetParams hipExtHostAlloc hipStreamBatchMemOp +hipGraphAddBatchMemOpNode +hipGraphBatchMemOpNodeGetParams +hipGraphBatchMemOpNodeSetParams +hipGraphExecBatchMemOpNodeSetParams diff --git a/hipamd/src/hip_api_trace.cpp b/hipamd/src/hip_api_trace.cpp index b2178d78e8..6d0e27d7ad 100644 --- a/hipamd/src/hip_api_trace.cpp +++ b/hipamd/src/hip_api_trace.cpp @@ -804,7 +804,15 @@ hipError_t hipExternalMemoryGetMappedMipmappedArray( const hipExternalMemoryMipmappedArrayDesc* mipmapDesc); hipError_t hipDrvGraphMemcpyNodeGetParams(hipGraphNode_t hNode, HIP_MEMCPY3D* nodeParams); hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY3D* nodeParams); - +hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, size_t numDependencies, + const hipBatchMemOpNodeParams* nodeParams); +hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams_out); +hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams); +hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const hipBatchMemOpNodeParams* nodeParams); } // namespace hip namespace hip { @@ -1301,6 +1309,11 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) { hip::hipExternalMemoryGetMappedMipmappedArray; ptrDispatchTable->hipDrvGraphMemcpyNodeGetParams_fn = hip::hipDrvGraphMemcpyNodeGetParams; ptrDispatchTable->hipDrvGraphMemcpyNodeSetParams_fn = hip::hipDrvGraphMemcpyNodeSetParams; + ptrDispatchTable->hipGraphAddBatchMemOpNode_fn = hip::hipGraphAddBatchMemOpNode; + ptrDispatchTable->hipGraphBatchMemOpNodeGetParams_fn = hip::hipGraphBatchMemOpNodeGetParams; + ptrDispatchTable->hipGraphBatchMemOpNodeSetParams_fn = hip::hipGraphBatchMemOpNodeSetParams; + ptrDispatchTable->hipGraphExecBatchMemOpNodeSetParams_fn = + hip::hipGraphExecBatchMemOpNodeSetParams; } #if HIP_ROCPROFILER_REGISTER > 0 @@ -1892,7 +1905,11 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipExtHostAlloc_fn, 461) HIP_ENFORCE_ABI(HipDispatchTable, hipDeviceGetTexture1DLinearMaxWidth_fn, 462) // HIP_RUNTIME_API_TABLE_STEP_VERSION == 7 HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBatchMemOp_fn, 463); - +// HIP_RUNTIME_API_TABLE_STEP_VERSION == 8 +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphAddBatchMemOpNode_fn, 464); +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphBatchMemOpNodeGetParams_fn, 465); +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphBatchMemOpNodeSetParams_fn, 466); +HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecBatchMemOpNodeSetParams_fn, 467); // if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below // will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.: @@ -1900,9 +1917,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipStreamBatchMemOp_fn, 463); // HIP_ENFORCE_ABI(, , 8) // // HIP_ENFORCE_ABI_VERSIONING(
, 9) <- 8 + 1 = 9 -HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 464) +HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 468) -static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 7, +static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 8, "If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function " "pointers and then update this check so it is true"); #endif diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index b608c03d7c..c229670845 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -3427,4 +3427,58 @@ hipError_t hipDrvGraphMemcpyNodeSetParams(hipGraphNode_t hNode, const HIP_MEMCPY HIP_RETURN(reinterpret_cast(hNode)->SetParams(nodeParams)); } +hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t* dependencies, size_t numDependencies, + const hipBatchMemOpNodeParams* nodeParams) { + HIP_INIT_API(hipGraphAddBatchMemOpNode, phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); + if (phGraphNode == nullptr || hGraph == nullptr || + (numDependencies > 0 && dependencies == nullptr) || nodeParams == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + hip::GraphNode* node = new hip::hipGraphBatchMemOpNode(nodeParams); + hipError_t status = + ihipGraphAddNode(node, reinterpret_cast(hGraph), + reinterpret_cast(dependencies), numDependencies); + *phGraphNode = reinterpret_cast(node); + HIP_RETURN(status); +} + +hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams_out) { + HIP_INIT_API(hipGraphBatchMemOpNodeGetParams, hNode, nodeParams_out); + hip::GraphNode* n = reinterpret_cast(hNode); + if (!hip::GraphNode::isNodeValid(n) || nodeParams_out == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + reinterpret_cast(n)->GetParams(nodeParams_out); + HIP_RETURN(hipSuccess); +} + +hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams) { + HIP_INIT_API(hipGraphBatchMemOpNodeSetParams, hNode, nodeParams); + hip::GraphNode* n = reinterpret_cast(hNode); + if (!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(reinterpret_cast(n)->SetParams(nodeParams)); +} + +hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + const hipBatchMemOpNodeParams* nodeParams) { + HIP_INIT_API(hipGraphExecBatchMemOpNodeSetParams, hGraphExec, hNode, nodeParams); + hip::GraphNode* n = reinterpret_cast(hNode); + hip::GraphExec* graphExec = reinterpret_cast(hGraphExec); + if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) || + !hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + hip::GraphNode* clonedNode = reinterpret_cast(graphExec)->GetClonedNode(n); + if (clonedNode == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + HIP_RETURN(reinterpret_cast(clonedNode)->SetParams(nodeParams)); +} } // namespace hip diff --git a/hipamd/src/hip_graph_internal.hpp b/hipamd/src/hip_graph_internal.hpp index 46695cd555..2b1fbc4453 100644 --- a/hipamd/src/hip_graph_internal.hpp +++ b/hipamd/src/hip_graph_internal.hpp @@ -2736,4 +2736,49 @@ class hipGraphExternalSemWaitNode : public GraphNode { } }; +class hipGraphBatchMemOpNode : public GraphNode { + hipBatchMemOpNodeParams batchMemOpNodeParam_; + + public: + hipGraphBatchMemOpNode(const hipBatchMemOpNodeParams* pNodeParams) + : GraphNode(hipGraphNodeTypeBatchMemOp, "solid", "rectangle", "BATCH_MEM_OP_NODE") { + batchMemOpNodeParam_ = *pNodeParams; + } + + hipGraphBatchMemOpNode(const hipGraphBatchMemOpNode& rhs) : GraphNode(rhs) { + batchMemOpNodeParam_ = rhs.batchMemOpNodeParam_; + } + ~hipGraphBatchMemOpNode() {} + + GraphNode* clone() const { + return new hipGraphBatchMemOpNode(static_cast(*this)); + } + + hipError_t CreateCommand(hip::Stream* stream) { + hipError_t status = GraphNode::CreateCommand(stream); + if (status != hipSuccess) { + return status; + } + amd::Command::EventWaitList waitList; + amd::BatchMemoryOperationCommand* command = new amd::BatchMemoryOperationCommand( + *stream, ROCCLR_COMMAND_BATCH_STREAM, batchMemOpNodeParam_.count, + batchMemOpNodeParam_.flags, waitList, batchMemOpNodeParam_.paramArray, + sizeof(hipStreamBatchMemOpParams)); + if (command == nullptr) { + return hipErrorOutOfMemory; + } + commands_.emplace_back(command); + return hipSuccess; + } + + void GetParams(hipBatchMemOpNodeParams* pNodeParams) const { + std::memcpy(pNodeParams, &batchMemOpNodeParam_, sizeof(hipBatchMemOpNodeParams)); + } + + hipError_t SetParams(const hipBatchMemOpNodeParams* pNodeParams) { + std::memcpy(&batchMemOpNodeParam_, pNodeParams, sizeof(hipBatchMemOpNodeParams)); + return hipSuccess; + } +}; + } // namespace hip diff --git a/hipamd/src/hip_hcc.map.in b/hipamd/src/hip_hcc.map.in index 1b7c323563..8cf57d6b53 100644 --- a/hipamd/src/hip_hcc.map.in +++ b/hipamd/src/hip_hcc.map.in @@ -583,7 +583,16 @@ local: hip_6.3 { global: hipExtHostAlloc; - hipStreamBatchMemOp; local: *; } hip_6.2; + +hip_6.4 { +global: + hipGraphAddBatchMemOpNode; + hipGraphBatchMemOpNodeGetParams; + hipGraphBatchMemOpNodeSetParams; + hipGraphExecBatchMemOpNodeSetParams; +local: + *; +} hip_6.3; diff --git a/hipamd/src/hip_table_interface.cpp b/hipamd/src/hip_table_interface.cpp index 237f0d8805..1ee51926f4 100644 --- a/hipamd/src/hip_table_interface.cpp +++ b/hipamd/src/hip_table_interface.cpp @@ -1823,3 +1823,23 @@ hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodePa hipError_t hipExtHostAlloc(void** ptr, size_t size, unsigned int flags) { return hip::GetHipDispatchTable()->hipExtHostAlloc_fn(ptr, size, flags); } +hipError_t hipGraphAddBatchMemOpNode(hipGraphNode_t* pGraphNode, hipGraph_t graph, + const hipGraphNode_t* dependencies, size_t numDependencies, + const hipBatchMemOpNodeParams* nodeParams) { + return hip::GetHipDispatchTable()->hipGraphAddBatchMemOpNode_fn(pGraphNode, graph, dependencies, + numDependencies, nodeParams); +} +hipError_t hipGraphBatchMemOpNodeGetParams(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams_out) { + return hip::GetHipDispatchTable()->hipGraphBatchMemOpNodeGetParams_fn(hNode, nodeParams_out); +} +hipError_t hipGraphBatchMemOpNodeSetParams(hipGraphNode_t hNode, + hipBatchMemOpNodeParams* nodeParams) { + return hip::GetHipDispatchTable()->hipGraphBatchMemOpNodeSetParams_fn(hNode, nodeParams); +} +hipError_t hipGraphExecBatchMemOpNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + const hipBatchMemOpNodeParams* nodeParams) { + return hip::GetHipDispatchTable()->hipGraphExecBatchMemOpNodeSetParams_fn(hGraphExec, hNode, + nodeParams); +} \ No newline at end of file