SWDEV-313856, SWDEV-313907, SWDEV-313915 - Export StreamCapture APIs, handle few graph negitive senarios, set event, parallelstreams for endcapture
Change-Id: I3c6008e1a1195cd2e1a14ef24c943ef6b54033ab
This commit is contained in:
committato da
Maneesh Gupta
parent
d1a491b30b
commit
e3585209d7
@@ -64,7 +64,8 @@ hipError_t ihipGraphAddKernelNode(hipGraphNode_t* pGraphNode, hipGraph_t graph,
|
||||
const hipGraphNode_t* pDependencies, size_t numDependencies,
|
||||
const hipKernelNodeParams* pNodeParams) {
|
||||
if (pGraphNode == nullptr || graph == nullptr ||
|
||||
(numDependencies > 0 && pDependencies == nullptr) || pNodeParams == nullptr) {
|
||||
(numDependencies > 0 && pDependencies == nullptr) || pNodeParams == nullptr ||
|
||||
pNodeParams->func == nullptr) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
hipError_t status = ihipValidateKernelParams(pNodeParams);
|
||||
@@ -629,6 +630,7 @@ hipError_t capturehipEventRecord(hipStream_t& stream, hipEvent_t& event) {
|
||||
hip::Event* e = reinterpret_cast<hip::Event*>(event);
|
||||
e->StartCapture(stream);
|
||||
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
|
||||
s->SetCaptureEvent(event);
|
||||
std::vector<hipGraphNode_t> lastCapturedNodes = s->GetLastCapturedNodes();
|
||||
if (!lastCapturedNodes.empty()) {
|
||||
e->SetNodesPrevToRecorded(lastCapturedNodes);
|
||||
@@ -653,6 +655,7 @@ hipError_t capturehipStreamWaitEvent(hipEvent_t& event, hipStream_t& stream, uns
|
||||
s->SetCaptureGraph(reinterpret_cast<hip::Stream*>(e->GetCaptureStream())->GetCaptureGraph());
|
||||
s->SetCaptureMode(reinterpret_cast<hip::Stream*>(e->GetCaptureStream())->GetCaptureMode());
|
||||
s->SetParentStream(e->GetCaptureStream());
|
||||
s->SetParallelCaptureStream(stream);
|
||||
}
|
||||
s->AddCrossCapturedNode(e->GetNodesPrevToRecorded());
|
||||
g_captureStreams.push_back(stream);
|
||||
@@ -709,7 +712,7 @@ hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode)
|
||||
|
||||
hipError_t hipStreamEndCapture(hipStream_t stream, hipGraph_t* pGraph) {
|
||||
HIP_INIT_API(hipStreamEndCapture, stream, pGraph);
|
||||
if (!hip::isValid(stream)) {
|
||||
if (pGraph == nullptr || stream == nullptr || !hip::isValid(stream)) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
|
||||
@@ -972,7 +975,7 @@ hipError_t hipGraphKernelNodeGetParams(hipGraphNode_t node, hipKernelNodeParams*
|
||||
hipError_t hipGraphKernelNodeSetParams(hipGraphNode_t node,
|
||||
const hipKernelNodeParams* pNodeParams) {
|
||||
HIP_INIT_API(hipGraphKernelNodeSetParams, node, pNodeParams);
|
||||
if (node == nullptr || pNodeParams == nullptr) {
|
||||
if (node == nullptr || pNodeParams == nullptr || pNodeParams->func == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
HIP_RETURN(reinterpret_cast<hipGraphKernelNode*>(node)->SetParams(pNodeParams));
|
||||
@@ -1047,6 +1050,17 @@ hipError_t hipGraphAddDependencies(hipGraph_t graph, const hipGraphNode_t* from,
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
for (size_t i = 0; i < numDependencies; i++) {
|
||||
// When the same node is specified for both from and to
|
||||
if (from[i] == to[i]) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
// When the same edge added from->to return invalid value
|
||||
const std::vector<Node>& edges = from[i]->GetEdges();
|
||||
for (auto edge : edges) {
|
||||
if (edge == to[i]) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
}
|
||||
from[i]->AddEdge(to[i]);
|
||||
}
|
||||
HIP_RETURN(hipSuccess);
|
||||
@@ -1187,7 +1201,8 @@ hipError_t hipGraphGetEdges(hipGraph_t graph, hipGraphNode_t* from, hipGraphNode
|
||||
from[i] = edges[i].first;
|
||||
to[i] = edges[i].second;
|
||||
}
|
||||
// If numEdges > actual number of edges, the remaining entries in from and to will be set to NULL
|
||||
// If numEdges > actual number of edges, the remaining entries in from and to will be set to
|
||||
// NULL
|
||||
for (int i = edges.size(); i < *numEdges; i++) {
|
||||
from[i] = nullptr;
|
||||
to[i] = nullptr;
|
||||
@@ -1244,8 +1259,8 @@ hipError_t hipGraphNodeGetDependentNodes(hipGraphNode_t node, hipGraphNode_t* pD
|
||||
for (int i = 0; i < dependents.size(); i++) {
|
||||
pDependentNodes[i] = dependents[i];
|
||||
}
|
||||
// pNumDependentNodes > actual number of dependents, the remaining entries in pDependentNodes will
|
||||
// be set to NULL
|
||||
// pNumDependentNodes > actual number of dependents, the remaining entries in pDependentNodes
|
||||
// will be set to NULL
|
||||
for (int i = dependents.size(); i < *pNumDependentNodes; i++) {
|
||||
pDependentNodes[i] = nullptr;
|
||||
}
|
||||
|
||||
@@ -348,3 +348,5 @@ hipGraphExecMemsetNodeSetParams
|
||||
amd_dbgapi_get_build_name
|
||||
amd_dbgapi_get_git_hash
|
||||
amd_dbgapi_get_build_id
|
||||
hipStreamGetCaptureInfo
|
||||
hipStreamGetCaptureInfo_v2
|
||||
|
||||
@@ -380,6 +380,8 @@ global:
|
||||
amd_dbgapi_get_build_name;
|
||||
amd_dbgapi_get_git_hash;
|
||||
amd_dbgapi_get_build_id;
|
||||
hipStreamGetCaptureInfo;
|
||||
hipStreamGetCaptureInfo_v2;
|
||||
local:
|
||||
*;
|
||||
} hip_4.4;
|
||||
|
||||
@@ -308,6 +308,8 @@ namespace hip {
|
||||
}
|
||||
/// Get Capture ID
|
||||
int GetCaptureID() { return captureID_; }
|
||||
void SetCaptureEvent(hipEvent_t e) { captureEvents_.push_back(e); }
|
||||
void SetParallelCaptureStream(hipStream_t s) { parallelCaptureStreams_.push_back(s); }
|
||||
};
|
||||
|
||||
/// HIP Device class
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user