SWDEV-313856, SWDEV-313907, SWDEV-313915 - Export StreamCapture APIs, handle few graph negitive senarios, set event, parallelstreams for endcapture

Change-Id: I3c6008e1a1195cd2e1a14ef24c943ef6b54033ab
This commit is contained in:
anusha GodavarthySurya
2021-12-06 01:54:07 -08:00
committato da Maneesh Gupta
parent d1a491b30b
commit e3585209d7
4 ha cambiato i file con 27 aggiunte e 6 eliminazioni
+21 -6
Vedi File
@@ -64,7 +64,8 @@ hipError_t ihipGraphAddKernelNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
const hipGraphNode_t* pDependencies, size_t numDependencies,
const hipKernelNodeParams* pNodeParams) {
if (pGraphNode == nullptr || graph == nullptr ||
(numDependencies > 0 && pDependencies == nullptr) || pNodeParams == nullptr) {
(numDependencies > 0 && pDependencies == nullptr) || pNodeParams == nullptr ||
pNodeParams->func == nullptr) {
return hipErrorInvalidValue;
}
hipError_t status = ihipValidateKernelParams(pNodeParams);
@@ -629,6 +630,7 @@ hipError_t capturehipEventRecord(hipStream_t& stream, hipEvent_t& event) {
hip::Event* e = reinterpret_cast<hip::Event*>(event);
e->StartCapture(stream);
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
s->SetCaptureEvent(event);
std::vector<hipGraphNode_t> lastCapturedNodes = s->GetLastCapturedNodes();
if (!lastCapturedNodes.empty()) {
e->SetNodesPrevToRecorded(lastCapturedNodes);
@@ -653,6 +655,7 @@ hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, uns
s->SetCaptureGraph(reinterpret_cast<hip::Stream*>(e->GetCaptureStream())->GetCaptureGraph());
s->SetCaptureMode(reinterpret_cast<hip::Stream*>(e->GetCaptureStream())->GetCaptureMode());
s->SetParentStream(e->GetCaptureStream());
s->SetParallelCaptureStream(stream);
}
s->AddCrossCapturedNode(e->GetNodesPrevToRecorded());
g_captureStreams.push_back(stream);
@@ -709,7 +712,7 @@ hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode)
hipError_t hipStreamEndCapture(hipStream_t stream, hipGraph_t* pGraph) {
HIP_INIT_API(hipStreamEndCapture, stream, pGraph);
if (!hip::isValid(stream)) {
if (pGraph == nullptr || stream == nullptr || !hip::isValid(stream)) {
HIP_RETURN(hipErrorInvalidValue);
}
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
@@ -972,7 +975,7 @@ hipError_t hipGraphKernelNodeGetParams(hipGraphNode_t node, hipKernelNodeParams*
hipError_t hipGraphKernelNodeSetParams(hipGraphNode_t node,
const hipKernelNodeParams* pNodeParams) {
HIP_INIT_API(hipGraphKernelNodeSetParams, node, pNodeParams);
if (node == nullptr || pNodeParams == nullptr) {
if (node == nullptr || pNodeParams == nullptr || pNodeParams->func == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
HIP_RETURN(reinterpret_cast<hipGraphKernelNode*>(node)->SetParams(pNodeParams));
@@ -1047,6 +1050,17 @@ hipError_t hipGraphAddDependencies(hipGraph_t graph, const hipGraphNode_t* from,
HIP_RETURN(hipErrorInvalidValue);
}
for (size_t i = 0; i < numDependencies; i++) {
// When the same node is specified for both from and to
if (from[i] == to[i]) {
HIP_RETURN(hipErrorInvalidValue);
}
// When the same edge added from->to return invalid value
const std::vector<Node>& edges = from[i]->GetEdges();
for (auto edge : edges) {
if (edge == to[i]) {
HIP_RETURN(hipErrorInvalidValue);
}
}
from[i]->AddEdge(to[i]);
}
HIP_RETURN(hipSuccess);
@@ -1187,7 +1201,8 @@ hipError_t hipGraphGetEdges(hipGraph_t graph, hipGraphNode_t* from, hipGraphNode
from[i] = edges[i].first;
to[i] = edges[i].second;
}
// If numEdges > actual number of edges, the remaining entries in from and to will be set to NULL
// If numEdges > actual number of edges, the remaining entries in from and to will be set to
// NULL
for (int i = edges.size(); i < *numEdges; i++) {
from[i] = nullptr;
to[i] = nullptr;
@@ -1244,8 +1259,8 @@ hipError_t hipGraphNodeGetDependentNodes(hipGraphNode_t node, hipGraphNode_t* pD
for (int i = 0; i < dependents.size(); i++) {
pDependentNodes[i] = dependents[i];
}
// pNumDependentNodes > actual number of dependents, the remaining entries in pDependentNodes will
// be set to NULL
// pNumDependentNodes > actual number of dependents, the remaining entries in pDependentNodes
// will be set to NULL
for (int i = dependents.size(); i < *pNumDependentNodes; i++) {
pDependentNodes[i] = nullptr;
}
+2
Vedi File
@@ -348,3 +348,5 @@ hipGraphExecMemsetNodeSetParams
amd_dbgapi_get_build_name
amd_dbgapi_get_git_hash
amd_dbgapi_get_build_id
hipStreamGetCaptureInfo
hipStreamGetCaptureInfo_v2
+2
Vedi File
@@ -380,6 +380,8 @@ global:
amd_dbgapi_get_build_name;
amd_dbgapi_get_git_hash;
amd_dbgapi_get_build_id;
hipStreamGetCaptureInfo;
hipStreamGetCaptureInfo_v2;
local:
*;
} hip_4.4;
+2
Vedi File
@@ -308,6 +308,8 @@ namespace hip {
}
/// Get Capture ID
int GetCaptureID() { return captureID_; }
void SetCaptureEvent(hipEvent_t e) { captureEvents_.push_back(e); }
void SetParallelCaptureStream(hipStream_t s) { parallelCaptureStreams_.push_back(s); }
};
/// HIP Device class