SWDEV-240807 - Added Implementation for few more graph APIs

Change-Id: I76336a22233a208a3f54ff9e90f0c5bf4a1bddb4
This commit is contained in:
Anusha GodavarthySurya
2021-07-07 03:58:37 -07:00
förälder 4aab5c29ff
incheckning e5cbfa8ca9
7 ändrade filer med 224 tillägg och 3 borttagningar
+97
Visa fil
@@ -413,3 +413,100 @@ hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream) {
HIP_INIT_API(hipGraphLaunch, graphExec, stream);
HIP_RETURN_DURATION(ihipGraphlaunch(graphExec, stream));
}
hipError_t hipGraphGetNodes(hipGraph_t graph, hipGraphNode_t* nodes, size_t* numNodes) {
HIP_INIT_API(hipGraphGetNodes, graph, nodes, numNodes);
if (graph == nullptr || numNodes == nullptr) {
*numNodes = graph->GetNodeCount();
}
if (*numNodes > 0) {
nodes = graph->GetNodes().data();
}
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphGetRootNodes(hipGraph_t graph, hipGraphNode_t* pRootNodes,
size_t* pNumRootNodes) {
HIP_INIT_API(hipGraphGetRootNodes, graph, pRootNodes, pNumRootNodes);
if (graph == nullptr || pNumRootNodes == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
std::vector<Node> rootNodes = graph->GetRootNodes();
pRootNodes = rootNodes.data();
*pNumRootNodes = rootNodes.size();
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphKernelNodeGetParams(hipGraphNode_t node, hipKernelNodeParams* pNodeParams) {
HIP_INIT_API(hipGraphKernelNodeGetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hipGraphKernelNode*>(node)->GetParams(pNodeParams);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphKernelNodeSetParams(hipGraphNode_t node,
const hipKernelNodeParams* pNodeParams) {
HIP_INIT_API(hipGraphKernelNodeSetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hipGraphKernelNode*>(node)->SetParams(pNodeParams);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphMemcpyNodeGetParams(hipGraphNode_t node, hipMemcpy3DParms* pNodeParams) {
HIP_INIT_API(hipGraphMemcpyNodeGetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hipGraphMemcpyNode*>(node)->GetParams(pNodeParams);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphMemcpyNodeSetParams(hipGraphNode_t node, const hipMemcpy3DParms* pNodeParams) {
HIP_INIT_API(hipGraphMemcpyNodeSetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hipGraphMemcpyNode*>(node)->SetParams(pNodeParams);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphMemsetNodeGetParams(hipGraphNode_t node, hipMemsetParams* pNodeParams) {
HIP_INIT_API(hipGraphMemsetNodeGetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hipGraphMemsetNode*>(node)->GetParams(pNodeParams);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphMemsetNodeSetParams(hipGraphNode_t node, const hipMemsetParams* pNodeParams) {
HIP_INIT_API(hipGraphMemsetNodeSetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
reinterpret_cast<hipGraphMemsetNode*>(node)->SetParams(pNodeParams);
HIP_RETURN(hipSuccess);
}
hipError_t hipGraphAddDependencies(hipGraph_t graph, const hipGraphNode_t* from,
const hipGraphNode_t* to, size_t numDependencies) {
HIP_INIT_API(hipGraphAddDependencies, graph, from, to, numDependencies);
if (graph == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
if (numDependencies == 0) {
HIP_RETURN(hipSuccess);
} else if (from == nullptr || to == nullptr) {
return HIP_RETURN(hipErrorInvalidValue);
}
for (size_t i = 0; i < numDependencies; i++) {
if (graph->AddEdge(from[i], to[i]) != hipSuccess) {
HIP_RETURN(hipErrorInvalidValue);
}
}
HIP_RETURN(hipSuccess);
}