SWDEV-421027 - Add more Graph APIs
Signed-off-by: shadi <shadi.dashmiz@amd.com>
Change-Id: I0a1fc284e48317a49ca88d4ed4e3a10e752efd58
[ROCm/clr commit: e705e5e0d9]
This commit is contained in:
zatwierdzone przez
Maneesh Gupta
rodzic
4ebecc5b98
commit
19ce99d104
@@ -982,6 +982,13 @@ typedef hipError_t (*t_hipMemcpy2DArrayToArray)(hipArray_t dst, size_t wOffsetDs
|
||||
size_t height, hipMemcpyKind kind);
|
||||
|
||||
|
||||
typedef hipError_t (*t_hipGraphExecGetFlags)(hipGraphExec_t graphExec, unsigned long long* flags);
|
||||
typedef hipError_t (*t_hipGraphNodeSetParams)(hipGraphNode_t node, hipGraphNodeParams *nodeParams);
|
||||
typedef hipError_t (*t_hipGraphExecNodeSetParams)(hipGraphExec_t graphExec, hipGraphNode_t node,
|
||||
hipGraphNodeParams* nodeParams);
|
||||
|
||||
|
||||
|
||||
// HIP Compiler dispatch table
|
||||
struct HipCompilerDispatchTable {
|
||||
size_t size;
|
||||
@@ -1458,4 +1465,7 @@ struct HipDispatchTable {
|
||||
t_hipMemcpyAtoHAsync hipMemcpyAtoHAsync_fn;
|
||||
t_hipMemcpyHtoAAsync hipMemcpyHtoAAsync_fn;
|
||||
t_hipMemcpy2DArrayToArray hipMemcpy2DArrayToArray_fn;
|
||||
t_hipGraphExecGetFlags hipGraphExecGetFlags_fn;
|
||||
t_hipGraphNodeSetParams hipGraphNodeSetParams_fn;
|
||||
t_hipGraphExecNodeSetParams hipGraphExecNodeSetParams_fn;
|
||||
};
|
||||
|
||||
@@ -473,3 +473,6 @@ hipMemcpyAtoA
|
||||
hipMemcpyAtoHAsync
|
||||
hipMemcpyHtoAAsync
|
||||
hipMemcpy2DArrayToArray
|
||||
hipGraphExecGetFlags
|
||||
hipGraphNodeSetParams
|
||||
hipGraphExecNodeSetParams
|
||||
|
||||
@@ -789,6 +789,10 @@ hipError_t hipMemcpyHtoAAsync(hipArray_t dstArray, size_t dstOffset, const void*
|
||||
hipError_t hipMemcpy2DArrayToArray(hipArray_t dst, size_t wOffsetDst, size_t hOffsetDst,
|
||||
hipArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc,
|
||||
size_t width, size_t height, hipMemcpyKind kind);
|
||||
hipError_t hipGraphExecGetFlags(hipGraphExec_t graphExec, unsigned long long* flags);
|
||||
hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodeParams);
|
||||
hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t node,
|
||||
hipGraphNodeParams* nodeParams);
|
||||
} // namespace hip
|
||||
|
||||
namespace hip {
|
||||
@@ -1275,6 +1279,9 @@ void UpdateDispatchTable(HipDispatchTable* ptrDispatchTable) {
|
||||
ptrDispatchTable->hipMemcpyAtoHAsync_fn = hip::hipMemcpyAtoHAsync;
|
||||
ptrDispatchTable->hipMemcpyHtoAAsync_fn = hip::hipMemcpyHtoAAsync;
|
||||
ptrDispatchTable->hipMemcpy2DArrayToArray_fn = hip::hipMemcpy2DArrayToArray;
|
||||
ptrDispatchTable->hipGraphExecGetFlags_fn = hip::hipGraphExecGetFlags;
|
||||
ptrDispatchTable->hipGraphNodeSetParams_fn = hip::hipGraphNodeSetParams;
|
||||
ptrDispatchTable->hipGraphExecNodeSetParams_fn = hip::hipGraphExecNodeSetParams;
|
||||
}
|
||||
|
||||
#if HIP_ROCPROFILER_REGISTER > 0
|
||||
@@ -1847,7 +1854,9 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyAtoA_fn, 451)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyAtoHAsync_fn, 452)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpyHtoAAsync_fn, 453)
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpy2DArrayToArray_fn, 454)
|
||||
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecGetFlags_fn, 455);
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphNodeSetParams_fn, 456);
|
||||
HIP_ENFORCE_ABI(HipDispatchTable, hipGraphExecNodeSetParams_fn, 457);
|
||||
|
||||
// if HIP_ENFORCE_ABI entries are added for each new function pointer in the table, the number below
|
||||
// will be +1 of the number in the last HIP_ENFORCE_ABI line. E.g.:
|
||||
@@ -1855,7 +1864,7 @@ HIP_ENFORCE_ABI(HipDispatchTable, hipMemcpy2DArrayToArray_fn, 454)
|
||||
// HIP_ENFORCE_ABI(<table>, <functor>, 8)
|
||||
//
|
||||
// HIP_ENFORCE_ABI_VERSIONING(<table>, 9) <- 8 + 1 = 9
|
||||
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 455)
|
||||
HIP_ENFORCE_ABI_VERSIONING(HipDispatchTable, 458)
|
||||
|
||||
static_assert(HIP_RUNTIME_API_TABLE_MAJOR_VERSION == 0 && HIP_RUNTIME_API_TABLE_STEP_VERSION == 3,
|
||||
"If you get this error, add new HIP_ENFORCE_ABI(...) code for the new function "
|
||||
|
||||
@@ -2891,11 +2891,11 @@ hipError_t hipGraphAddNode(hipGraphNode_t *pGraphNode, hipGraph_t graph,
|
||||
numDependencies, false);
|
||||
break;
|
||||
case hipGraphNodeTypeExtSemaphoreSignal:
|
||||
status = hipSuccess;
|
||||
status = hipErrorNotSupported;
|
||||
// to be added.
|
||||
break;
|
||||
case hipGraphNodeTypeExtSemaphoreWait:
|
||||
status = hipSuccess;
|
||||
status = hipErrorNotSupported;
|
||||
// to be added.
|
||||
break;
|
||||
case hipGraphNodeTypeMemAlloc:
|
||||
@@ -3140,4 +3140,90 @@ hipError_t hipDrvGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGrap
|
||||
HIP_RETURN(reinterpret_cast<hip::GraphMemsetNode*>(clonedNode)->SetParams(memsetParams, true));
|
||||
}
|
||||
|
||||
hipError_t hipGraphExecGetFlags(hipGraphExec_t graphExec, unsigned long long* flags) {
|
||||
HIP_INIT_API(hipGraphExecGetFlags, graphExec, flags);
|
||||
if (graphExec == nullptr || flags == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hip::GraphExec* pgraphExec = reinterpret_cast<hip::GraphExec*>(graphExec);
|
||||
*flags = pgraphExec->GetFlags();
|
||||
HIP_RETURN(hipSuccess);
|
||||
}
|
||||
|
||||
hipError_t ihipGraphNodeSetParams(hip::GraphNode* n, hipGraphNodeParams *nodeParams) {
|
||||
hipGraphNodeType nodeType = nodeParams->type;
|
||||
hipError_t status = hipSuccess;
|
||||
switch(nodeType) {
|
||||
case hipGraphNodeTypeKernel:
|
||||
status = reinterpret_cast<hip::GraphKernelNode*>(n)->SetParams(&nodeParams->kernel);
|
||||
break;
|
||||
case hipGraphNodeTypeMemcpy:
|
||||
status = reinterpret_cast<hip::GraphMemcpyNode*>(n)->SetParams(
|
||||
&nodeParams->memcpy.copyParams);
|
||||
break;
|
||||
case hipGraphNodeTypeMemset:
|
||||
status =
|
||||
reinterpret_cast<hip::GraphMemsetNode*>(n)->SetParams(&nodeParams->memset);
|
||||
break;
|
||||
case hipGraphNodeTypeHost:
|
||||
status =
|
||||
reinterpret_cast<hip::GraphHostNode*>(n)->SetParams(&nodeParams->host);
|
||||
break;
|
||||
case hipGraphNodeTypeGraph:
|
||||
status = reinterpret_cast<hip::ChildGraphNode*>(n)->SetParams(
|
||||
reinterpret_cast<hip::Graph*>(nodeParams->graph.graph));
|
||||
break;
|
||||
case hipGraphNodeTypeWaitEvent:
|
||||
status = reinterpret_cast<hip::GraphEventWaitNode*>(n)->SetParams(
|
||||
nodeParams->eventWait.event);
|
||||
break;
|
||||
case hipGraphNodeTypeEventRecord:
|
||||
status = reinterpret_cast<hip::GraphEventRecordNode*>(n)->SetParams(
|
||||
nodeParams->eventRecord.event);
|
||||
break;
|
||||
case hipGraphNodeTypeExtSemaphoreSignal:
|
||||
status = hipErrorNotSupported;
|
||||
// to be added.
|
||||
break;
|
||||
case hipGraphNodeTypeExtSemaphoreWait:
|
||||
status = hipErrorNotSupported;
|
||||
// to be added.
|
||||
break;
|
||||
case hipGraphNodeTypeMemAlloc:
|
||||
status = hipErrorNotSupported;
|
||||
break;
|
||||
case hipGraphNodeTypeMemFree:
|
||||
status = hipErrorNotSupported;
|
||||
break;
|
||||
default:
|
||||
status = hipErrorInvalidValue;
|
||||
break;
|
||||
}
|
||||
HIP_RETURN(status);
|
||||
}
|
||||
|
||||
hipError_t hipGraphNodeSetParams(hipGraphNode_t node, hipGraphNodeParams *nodeParams) {
|
||||
HIP_INIT_API(hipGraphNodeSetParams, node, nodeParams);
|
||||
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
|
||||
if (node == nullptr || nodeParams == nullptr || !hip::GraphNode::isNodeValid(n)) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(ihipGraphNodeSetParams(n, nodeParams));
|
||||
}
|
||||
|
||||
hipError_t hipGraphExecNodeSetParams(hipGraphExec_t graphExec, hipGraphNode_t node,
|
||||
hipGraphNodeParams* nodeParams) {
|
||||
HIP_INIT_API(hipGraphNodeSetParams, graphExec, node, nodeParams);
|
||||
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
|
||||
if (node == nullptr || nodeParams == nullptr || graphExec == nullptr
|
||||
|| !hip::GraphNode::isNodeValid(n)) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphNode*>(
|
||||
reinterpret_cast<hip::GraphExec*>(graphExec)->GetClonedNode(n));
|
||||
if (clonedNode == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(ihipGraphNodeSetParams(clonedNode, nodeParams));
|
||||
}
|
||||
} // namespace hip
|
||||
|
||||
@@ -643,6 +643,7 @@ struct GraphExec : public amd::ReferenceCountedObject {
|
||||
}
|
||||
|
||||
void ResetQueueIndex() { currentQueueIndex_ = 0; }
|
||||
uint64_t GetFlags() const { return flags_; }
|
||||
hipError_t Init();
|
||||
hipError_t CreateStreams(uint32_t num_streams);
|
||||
hipError_t Run(hipStream_t stream);
|
||||
|
||||
@@ -570,6 +570,9 @@ global:
|
||||
hipMemcpyAtoHAsync;
|
||||
hipMemcpyHtoAAsync;
|
||||
hipMemcpy2DArrayToArray;
|
||||
hipGraphExecGetFlags;
|
||||
hipGraphNodeSetParams;
|
||||
hipGraphExecNodeSetParams;
|
||||
local:
|
||||
*;
|
||||
} hip_6.1;
|
||||
|
||||
Reference in New Issue
Block a user