SWDEV-497288 - Enable hipGraphExecSetParams for Ext SemWait and SemSignal Nodes

Change-Id: I7184a3a04ac17d3d841222ae1559db66d73a429c
This commit is contained in:
Rahul Manocha
2024-11-21 15:50:47 -08:00
zatwierdzone przez Rahul Manocha
rodzic 9820480cbd
commit e0c11624e5
+16 -13
Wyświetl plik
@@ -1231,7 +1231,7 @@ hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, hipGraph
HIP_INIT_API(hipGraphExecMemcpyNodeSetParams1D, hGraphExec, node, dst, src, count, kind);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || dst == nullptr ||
src == nullptr || count == 0 || src == dst) {
src == nullptr || count == 0 || src == dst || n->GetType() != hipGraphNodeTypeMemcpy) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphNode*>(
@@ -1636,7 +1636,8 @@ hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo
HIP_INIT_API(hipGraphExecMemcpyNodeSetParams, hGraphExec, node, pNodeParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (hGraphExec == nullptr ||
!hip::GraphNode::isNodeValid(reinterpret_cast<hip::GraphNode*>(n))) {
!hip::GraphNode::isNodeValid(reinterpret_cast<hip::GraphNode*>(n)) ||
n->GetType() != hipGraphNodeTypeMemcpy) {
HIP_RETURN(hipErrorInvalidValue);
}
if (ihipMemcpy3D_validate(pNodeParams) != hipSuccess) {
@@ -1697,7 +1698,7 @@ hipError_t hipGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || pNodeParams == nullptr ||
pNodeParams->dst == nullptr) {
pNodeParams->dst == nullptr || n->GetType() != hipGraphNodeTypeMemset) {
HIP_RETURN(hipErrorInvalidValue);
}
if (ihipGraphMemsetParams_validate(pNodeParams) != hipSuccess) {
@@ -1762,7 +1763,7 @@ hipError_t hipGraphExecKernelNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNo
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (hGraphExec == nullptr ||
!hip::GraphNode::isNodeValid(n) ||
pNodeParams == nullptr || pNodeParams->func == nullptr) {
pNodeParams == nullptr || pNodeParams->func == nullptr || n->GetType() != hipGraphNodeTypeKernel) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
@@ -1799,7 +1800,7 @@ hipError_t hipGraphExecChildGraphNodeSetParams(hipGraphExec_t hGraphExec, hipGra
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
hip::Graph* cg = reinterpret_cast<hip::Graph*>(childGraph);
if (hGraphExec == nullptr || !hip::GraphNode::isNodeValid(n) || childGraph == nullptr ||
!hip::Graph::isGraphValid(cg)) {
!hip::Graph::isGraphValid(cg) || n->GetType() != hipGraphNodeTypeGraph) {
HIP_RETURN(hipErrorInvalidValue);
}
@@ -2428,7 +2429,7 @@ hipError_t hipGraphExecEventWaitNodeSetEvent(hipGraphExec_t hGraphExec, hipGraph
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
if (hGraphExec == nullptr || hNode == nullptr || event == nullptr ||
(n->GetType() != hipGraphNodeTypeWaitEvent)) {
(n->GetType() != hipGraphNodeTypeWaitEvent) || n->GetType() != hipGraphNodeTypeWaitEvent) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
@@ -2479,7 +2480,7 @@ hipError_t hipGraphExecHostNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode
HIP_INIT_API(hipGraphExecHostNodeSetParams, hGraphExec, node, pNodeParams);
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(node);
if (hGraphExec == nullptr || pNodeParams == nullptr || pNodeParams->fn == nullptr ||
!hip::GraphNode::isNodeValid(n)) {
!hip::GraphNode::isNodeValid(n) || n->GetType() != hipGraphNodeTypeHost) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = reinterpret_cast<hip::GraphExec*>(hGraphExec)->GetClonedNode(n);
@@ -3165,7 +3166,8 @@ hipError_t hipGraphExecExternalSemaphoresSignalNodeSetParams(hipGraphExec_t hGra
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
hip::GraphExec* graphExec = reinterpret_cast<hip::GraphExec*>(hGraphExec);
if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) ||
!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) {
!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr ||
n->GetType() != hipGraphNodeTypeExtSemaphoreSignal) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = graphExec->GetClonedNode(n);
@@ -3183,7 +3185,8 @@ hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams(hipGraphExec_t hGraph
hip::GraphNode* n = reinterpret_cast<hip::GraphNode*>(hNode);
hip::GraphExec* graphExec = reinterpret_cast<hip::GraphExec*>(hGraphExec);
if (hGraphExec == nullptr || hNode == nullptr || !hip::GraphExec::isGraphExecValid(graphExec) ||
!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr) {
!hip::GraphNode::isNodeValid(n) || nodeParams == nullptr ||
n->GetType() != hipGraphNodeTypeExtSemaphoreWait) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::GraphNode* clonedNode = graphExec->GetClonedNode(n);
@@ -3326,12 +3329,12 @@ hipError_t ihipGraphNodeSetParams(hip::GraphNode* n, hipGraphNodeParams *nodePar
nodeParams->eventRecord.event);
break;
case hipGraphNodeTypeExtSemaphoreSignal:
status = hipErrorNotSupported;
// to be added.
status = reinterpret_cast<hip::hipGraphExternalSemSignalNode*>(n)->SetParams(
&nodeParams->extSemSignal);
break;
case hipGraphNodeTypeExtSemaphoreWait:
status = hipErrorNotSupported;
// to be added.
status = reinterpret_cast<hip::hipGraphExternalSemWaitNode*>(n)->SetParams(
&nodeParams->extSemWait);
break;
case hipGraphNodeTypeMemAlloc:
status = hipErrorNotSupported;