SWDEV-421027 - Add more Graph APIs

Signed-off-by: shadi <shadi.dashmiz@amd.com>
Change-Id: I0a1fc284e48317a49ca88d4ed4e3a10e752efd58


[ROCm/clr commit: e705e5e0d9]
This commit is contained in:
shadi
2024-01-08 15:39:43 -05:00
zatwierdzone przez Maneesh Gupta
rodzic 4ebecc5b98
commit 19ce99d104
6 zmienionych plików z 116 dodań i 4 usunięć
@@ -982,6 +982,13 @@ typedef hipError_t (*t_hipMemcpy2DArrayToArray)(hipArray_t dst, size_t wOffsetDs
size_t height, hipMemcpyKind kind);
typedef hipError_t (*t_hipGraphExecGetFlags)(hipGraphExec_t graphExec, unsigned long long* flags);
typedef hipError_t (*t_hipGraphNodeSetParams)(hipGraphNode_t node, hipGraphNodeParams *nodeParams);
typedef hipError_t (*t_hipGraphExecNodeSetParams)(hipGraphExec_t graphExec, hipGraphNode_t node,
hipGraphNodeParams* nodeParams);
// HIP Compiler dispatch table
struct HipCompilerDispatchTable {
size_t size;
@@ -1458,4 +1465,7 @@ struct HipDispatchTable {
t_hipMemcpyAtoHAsync hipMemcpyAtoHAsync_fn;
t_hipMemcpyHtoAAsync hipMemcpyHtoAAsync_fn;
t_hipMemcpy2DArrayToArray hipMemcpy2DArrayToArray_fn;
t_hipGraphExecGetFlags hipGraphExecGetFlags_fn;
t_hipGraphNodeSetParams hipGraphNodeSetParams_fn;
t_hipGraphExecNodeSetParams hipGraphExecNodeSetParams_fn;
};
+3
Wyświetl plik
@@ -473,3 +473,6 @@ hipMemcpyAtoA
hipMemcpyAtoHAsync
hipMemcpyHtoAAsync
hipMemcpy2DArrayToArray
hipGraphExecGetFlags
hipGraphNodeSetParams
hipGraphExecNodeSetParams
+11 -2
Wyświetl plik
@@ -789,6 +789,10 @@ hipError_t hipMemcpyHtoAAsync(hipArray_t dstArray, size_t dstOffset, const void*
hipError_t hipMemcpy2DArrayToArray(hipArray_t dst, size_t wOffsetDst, size_t hOffsetDst,
hipArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc,
size_t width, size_t height, hipMemcpyKind kind);
hipError_t hipGraphExecGetFlags(hipGraphExec_t graphExec, unsigned long long* flags);
hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodeParams);
hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t node,
hipGraphNodeParams* nodeParams);
} // namespace hip
namespace hip {
@@ -1275,6 +1279,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
ptrDispatchTable->hipMemcpyAtoHAsync_fn = hip::hipMemcpyAtoHAsync;
ptrDispatchTable->hipMemcpyHtoAAsync_fn = hip::hipMemcpyHtoAAsync;
ptrDispatchTable->hipMemcpy2DArrayToArray_fn = hip::hipMemcpy2DArrayToArray;
ptrDispatchTable->hipGraphExecGetFlags_fn = hip::hipGraphExecGetFlags;
ptrDispatchTable->hipGraphNodeSetParams_fn = hip::hipGraphNodeSetParams;
ptrDispatchTable->hipGraphExecNodeSetParams_fn = hip::hipGraphExecNodeSetParams;
}
#if HIP_ROCPROFILER_REGISTER > 0
@@ -1847,7 +1854,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyAtoA_fn, 451)
HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyAtoHAsync_fn, 452)
HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyHtoAAsync_fn, 453)
HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpy2DArrayToArray_fn, 454)
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecGetFlags_fn, 455);
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphNodeSetParams_fn, 456);
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecNodeSetParams_fn, 457);
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
@@ -1855,7 +1864,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpy2DArrayToArray_fn, 454)
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
//
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 455)
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 458)
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3,
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
+88 -2
Wyświetl plik
@@ -2891,11 +2891,11 @@ hipError_t hipGraphAddNode(hipGraphNode_t *pGraphNode, hipGraph_t graph,
numDependencies, false);
break;
case hipGraphNodeTypeExtSemaphoreSignal:
status = hipSuccess;
status = hipErrorNotSupported;
// to be added.
break;
case hipGraphNodeTypeExtSemaphoreWait:
status = hipSuccess;
status = hipErrorNotSupported;
// to be added.
break;
case hipGraphNodeTypeMemAlloc:
@@ -3140,4 +3140,90 @@ hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGrap
HIP_RETURN(reinterpret_cast<hip::GraphMemsetNode*>(clonedNode)->SetParams(memsetParams, true));
}
hipError_t hipGraphExecGetFlags(hipGraphExec_t graphExec, unsigned long long* flags) {
HIP_INIT_API(hipGraphExecGetFlags, graphExec, flags);
if (graphExec == nullptr || flags == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphExec* pgraphExec = reinterpret_cast<hip::GraphExec*>(graphExec);
*flags = pgraphExec->GetFlags();
HIP_RETURN(hipSuccess);
}
hipError_t ihipGraphNodeSetParams(hip::GraphNode* n, hipGraphNodeParams *nodeParams) {
hipGraphNodeType nodeType = nodeParams->type;
hipError_t status = hipSuccess;
switch(nodeType) {
case hipGraphNodeTypeKernel:
status = reinterpret_cast<hip::GraphKernelNode*>(n)->SetParams(&nodeParams->kernel);
break;
case hipGraphNodeTypeMemcpy:
status = reinterpret_cast<hip::GraphMemcpyNode*>(n)->SetParams(
&nodeParams->memcpy.copyParams);
break;
case hipGraphNodeTypeMemset:
status =
reinterpret_cast<hip::GraphMemsetNode*>(n)->SetParams(&nodeParams->memset);
break;
case hipGraphNodeTypeHost:
status =
reinterpret_cast<hip::GraphHostNode*>(n)->SetParams(&nodeParams->host);
break;
case hipGraphNodeTypeGraph:
status = reinterpret_cast<hip::ChildGraphNode*>(n)->SetParams(
reinterpret_cast<hip::Graph*>(nodeParams->graph.graph));
break;
case hipGraphNodeTypeWaitEvent:
status = reinterpret_cast<hip::GraphEventWaitNode*>(n)->SetParams(
nodeParams->eventWait.event);
break;
case hipGraphNodeTypeEventRecord:
status = reinterpret_cast<hip::GraphEventRecordNode*>(n)->SetParams(
nodeParams->eventRecord.event);
break;
case hipGraphNodeTypeExtSemaphoreSignal:
status = hipErrorNotSupported;
// to be added.
break;
case hipGraphNodeTypeExtSemaphoreWait:
status = hipErrorNotSupported;
// to be added.
break;
case hipGraphNodeTypeMemAlloc:
status = hipErrorNotSupported;
break;
case hipGraphNodeTypeMemFree:
status = hipErrorNotSupported;
break;
default:
status = hipErrorInvalidValue;
break;
}
HIP_RETURN(status);
}
hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodeParams) {
HIP_INIT_API(hipGraphNodeSetParams, node, nodeParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (node == nullptr || nodeParams == nullptr || !hip::GraphNode::isNodeValid(n)) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(ihipGraphNodeSetParams(n, nodeParams));
}
hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t node,
hipGraphNodeParams* nodeParams) {
HIP_INIT_API(hipGraphNodeSetParams, graphExec, node, nodeParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (node == nullptr || nodeParams == nullptr || graphExec == nullptr
|| !hip::GraphNode::isNodeValid(n)) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphNode*>(
reinterpret_cast<hip::GraphExec*>(graphExec)->GetClonedNode(n));
if (clonedNode == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(ihipGraphNodeSetParams(clonedNode, nodeParams));
}
} // namespace hip
@@ -643,6 +643,7 @@ struct GraphExec : public amd::ReferenceCountedObject {
}
void ResetQueueIndex() { currentQueueIndex_ = 0; }
uint64_t GetFlags() const { return flags_; }
hipError_t Init();
hipError_t CreateStreams(uint32_t num_streams);
hipError_t Run(hipStream_t stream);
@@ -570,6 +570,9 @@ global:
hipMemcpyAtoHAsync;
hipMemcpyHtoAAsync;
hipMemcpy2DArrayToArray;
hipGraphExecGetFlags;
hipGraphNodeSetParams;
hipGraphExecNodeSetParams;
local:
*;
} hip_6.1;