SWDEV-541096 - add hipEventWaitDefault and hipEventWaitExternal flags (#507)
Co-authored-by: Li, Todd tiantuo <Toddtiantuo.Li@amd.com>
这个提交包含在:
@@ -478,6 +478,53 @@ class GraphNode : public hipGraphNodeDOTAttribute {
|
||||
//!< when explicitly added)
|
||||
};
|
||||
|
||||
class GraphEventWaitNode : public GraphNode {
|
||||
hipEvent_t event_;
|
||||
|
||||
public:
|
||||
GraphEventWaitNode(hipEvent_t event)
|
||||
: GraphNode(hipGraphNodeTypeWaitEvent, "solid", "rectangle", "EVENT_WAIT"), event_(event) {}
|
||||
|
||||
~GraphEventWaitNode() {}
|
||||
|
||||
GraphEventWaitNode(const GraphEventWaitNode& rhs) : GraphNode(rhs) { event_ = rhs.event_; }
|
||||
|
||||
GraphNode* clone() const override { return new GraphEventWaitNode(*this); }
|
||||
|
||||
hipError_t CreateCommand(hip::Stream* stream) override {
|
||||
hipError_t status = GraphNode::CreateCommand(stream);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
|
||||
commands_.reserve(1);
|
||||
amd::Command* command;
|
||||
status = e->streamWaitCommand(command, stream);
|
||||
commands_.emplace_back(command);
|
||||
return status;
|
||||
}
|
||||
|
||||
void EnqueueCommands(hip::Stream* stream) override {
|
||||
if (!commands_.empty()) {
|
||||
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
|
||||
commands_[0]->enqueue();
|
||||
commands_[0]->release();
|
||||
}
|
||||
}
|
||||
|
||||
void GetParams(hipEvent_t* event) const { *event = event_; }
|
||||
|
||||
hipError_t SetParams(hipEvent_t event) {
|
||||
event_ = event;
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t SetParams(GraphNode* node) override {
|
||||
const GraphEventWaitNode* eventWaitNode = static_cast<GraphEventWaitNode const*>(node);
|
||||
return SetParams(eventWaitNode->event_);
|
||||
}
|
||||
};
|
||||
|
||||
class Graph {
|
||||
public:
|
||||
//!< Contains mem alloc dptrs whose corresponding free node is not added to the graph.
|
||||
@@ -533,6 +580,16 @@ class Graph {
|
||||
|
||||
std::unordered_set<GraphNode*> GetManualNodesDuringCapture() { return capturedNodes_; }
|
||||
|
||||
GraphNode* AddExternalEventWaitNode(hip::GraphNode* pDependencies, size_t numDependencies,
|
||||
hipEvent_t event) {
|
||||
GraphNode* node = new GraphEventWaitNode(event);
|
||||
for (size_t i = 0; i < numDependencies; i++) {
|
||||
pDependencies[i].AddEdgeDep(node);
|
||||
}
|
||||
AddNode(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RemoveManualNodesDuringCapture() {
|
||||
capturedNodes_.erase(capturedNodes_.begin(), capturedNodes_.end());
|
||||
}
|
||||
@@ -2167,53 +2224,6 @@ class GraphEventRecordNode : public GraphNode {
|
||||
}
|
||||
};
|
||||
|
||||
class GraphEventWaitNode : public GraphNode {
|
||||
hipEvent_t event_;
|
||||
|
||||
public:
|
||||
GraphEventWaitNode(hipEvent_t event)
|
||||
: GraphNode(hipGraphNodeTypeWaitEvent, "solid", "rectangle", "EVENT_WAIT"), event_(event) {}
|
||||
|
||||
~GraphEventWaitNode() {}
|
||||
|
||||
GraphEventWaitNode(const GraphEventWaitNode& rhs) : GraphNode(rhs) { event_ = rhs.event_; }
|
||||
|
||||
GraphNode* clone() const override { return new GraphEventWaitNode(*this); }
|
||||
|
||||
hipError_t CreateCommand(hip::Stream* stream) override {
|
||||
hipError_t status = GraphNode::CreateCommand(stream);
|
||||
if (status != hipSuccess) {
|
||||
return status;
|
||||
}
|
||||
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
|
||||
commands_.reserve(1);
|
||||
amd::Command* command;
|
||||
status = e->streamWaitCommand(command, stream);
|
||||
commands_.emplace_back(command);
|
||||
return status;
|
||||
}
|
||||
|
||||
void EnqueueCommands(hip::Stream* stream) override {
|
||||
if (!commands_.empty()) {
|
||||
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
|
||||
commands_[0]->enqueue();
|
||||
commands_[0]->release();
|
||||
}
|
||||
}
|
||||
|
||||
void GetParams(hipEvent_t* event) const { *event = event_; }
|
||||
|
||||
hipError_t SetParams(hipEvent_t event) {
|
||||
event_ = event;
|
||||
return hipSuccess;
|
||||
}
|
||||
|
||||
hipError_t SetParams(GraphNode* node) override {
|
||||
const GraphEventWaitNode* eventWaitNode = static_cast<GraphEventWaitNode const*>(node);
|
||||
return SetParams(eventWaitNode->event_);
|
||||
}
|
||||
};
|
||||
|
||||
class GraphHostNode : public GraphNode {
|
||||
hipHostNodeParams NodeParams_;
|
||||
|
||||
|
||||
@@ -467,6 +467,9 @@ void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_d
|
||||
|
||||
// ================================================================================================
|
||||
hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsigned int flags) {
|
||||
if (flags != hipEventWaitDefault && flags != hipEventWaitExternal) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
hipError_t status = hipSuccess;
|
||||
if (event == nullptr) {
|
||||
return hipErrorInvalidHandle;
|
||||
@@ -483,7 +486,16 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig
|
||||
}
|
||||
|
||||
hip::Stream* eventStream = reinterpret_cast<hip::Stream*>(eventStreamHandle);
|
||||
if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) {
|
||||
if (flags == hipEventWaitExternal) {
|
||||
auto lastCapturedNodes = waitStream->GetLastCapturedNodes();
|
||||
hip::GraphNode* pGraphNode = waitStream->GetCaptureGraph()->AddExternalEventWaitNode(
|
||||
reinterpret_cast<hip::GraphNode*>(lastCapturedNodes.data()),
|
||||
lastCapturedNodes.size(),
|
||||
event);
|
||||
waitStream->SetLastCapturedNode(pGraphNode);
|
||||
return hipSuccess;
|
||||
}
|
||||
else if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) {
|
||||
ClPrint(amd::LOG_DETAIL_DEBUG, amd::LOG_API,
|
||||
"[hipGraph] Current capture node StreamWaitEvent on stream : %p, Event %p", stream,
|
||||
event);
|
||||
@@ -501,9 +513,6 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig
|
||||
}
|
||||
waitStream->AddCrossCapturedNode(e->GetNodesPrevToRecorded());
|
||||
} else {
|
||||
if (flags != 0) {
|
||||
return hipErrorInvalidValue;
|
||||
}
|
||||
if (eventStream != nullptr) {
|
||||
if (eventStream->GetCaptureStatus() == hipStreamCaptureStatusActive) {
|
||||
// If stream is capturing but event is not recorded on event's stream.
|
||||
|
||||
在新工单中引用
屏蔽一个用户