SWDEV-541096 - add hipEventWaitDefault and hipEventWaitExternal flags (#507)

Co-authored-by: Li, Todd tiantuo <Toddtiantuo.Li@amd.com>
这个提交包含在:
systems-assistant[bot]
2025-09-11 14:50:55 -07:00
提交者 GitHub
父节点 3742814d82
当前提交 c85200fc42
修改 2 个文件,包含 70 行新增51 行删除
+57 -47
查看文件
@@ -478,6 +478,53 @@ class GraphNode : public hipGraphNodeDOTAttribute {
//!< when explicitly added)
};
class GraphEventWaitNode : public GraphNode {
hipEvent_t event_;
public:
GraphEventWaitNode(hipEvent_t event)
: GraphNode(hipGraphNodeTypeWaitEvent, "solid", "rectangle", "EVENT_WAIT"), event_(event) {}
~GraphEventWaitNode() {}
GraphEventWaitNode(const GraphEventWaitNode& rhs) : GraphNode(rhs) { event_ = rhs.event_; }
GraphNode* clone() const override { return new GraphEventWaitNode(*this); }
hipError_t CreateCommand(hip::Stream* stream) override {
hipError_t status = GraphNode::CreateCommand(stream);
if (status != hipSuccess) {
return status;
}
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
commands_.reserve(1);
amd::Command* command;
status = e->streamWaitCommand(command, stream);
commands_.emplace_back(command);
return status;
}
void EnqueueCommands(hip::Stream* stream) override {
if (!commands_.empty()) {
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
commands_[0]->enqueue();
commands_[0]->release();
}
}
void GetParams(hipEvent_t* event) const { *event = event_; }
hipError_t SetParams(hipEvent_t event) {
event_ = event;
return hipSuccess;
}
hipError_t SetParams(GraphNode* node) override {
const GraphEventWaitNode* eventWaitNode = static_cast<GraphEventWaitNode const*>(node);
return SetParams(eventWaitNode->event_);
}
};
class Graph {
public:
//!< Contains mem alloc dptrs whose corresponding free node is not added to the graph.
@@ -533,6 +580,16 @@ class Graph {
std::unordered_set<GraphNode*> GetManualNodesDuringCapture() { return capturedNodes_; }
GraphNode* AddExternalEventWaitNode(hip::GraphNode* pDependencies, size_t numDependencies,
hipEvent_t event) {
GraphNode* node = new GraphEventWaitNode(event);
for (size_t i = 0; i < numDependencies; i++) {
pDependencies[i].AddEdgeDep(node);
}
AddNode(node);
return node;
}
void RemoveManualNodesDuringCapture() {
capturedNodes_.erase(capturedNodes_.begin(), capturedNodes_.end());
}
@@ -2167,53 +2224,6 @@ class GraphEventRecordNode : public GraphNode {
}
};
class GraphEventWaitNode : public GraphNode {
hipEvent_t event_;
public:
GraphEventWaitNode(hipEvent_t event)
: GraphNode(hipGraphNodeTypeWaitEvent, "solid", "rectangle", "EVENT_WAIT"), event_(event) {}
~GraphEventWaitNode() {}
GraphEventWaitNode(const GraphEventWaitNode& rhs) : GraphNode(rhs) { event_ = rhs.event_; }
GraphNode* clone() const override { return new GraphEventWaitNode(*this); }
hipError_t CreateCommand(hip::Stream* stream) override {
hipError_t status = GraphNode::CreateCommand(stream);
if (status != hipSuccess) {
return status;
}
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
commands_.reserve(1);
amd::Command* command;
status = e->streamWaitCommand(command, stream);
commands_.emplace_back(command);
return status;
}
void EnqueueCommands(hip::Stream* stream) override {
if (!commands_.empty()) {
hip::Event* e = reinterpret_cast<hip::Event*>(event_);
commands_[0]->enqueue();
commands_[0]->release();
}
}
void GetParams(hipEvent_t* event) const { *event = event_; }
hipError_t SetParams(hipEvent_t event) {
event_ = event;
return hipSuccess;
}
hipError_t SetParams(GraphNode* node) override {
const GraphEventWaitNode* eventWaitNode = static_cast<GraphEventWaitNode const*>(node);
return SetParams(eventWaitNode->event_);
}
};
class GraphHostNode : public GraphNode {
hipHostNodeParams NodeParams_;
+13 -4
查看文件
@@ -467,6 +467,9 @@ void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_d
// ================================================================================================
hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsigned int flags) {
if (flags != hipEventWaitDefault && flags != hipEventWaitExternal) {
return hipErrorInvalidValue;
}
hipError_t status = hipSuccess;
if (event == nullptr) {
return hipErrorInvalidHandle;
@@ -483,7 +486,16 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig
}
hip::Stream* eventStream = reinterpret_cast<hip::Stream*>(eventStreamHandle);
if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) {
if (flags == hipEventWaitExternal) {
auto lastCapturedNodes = waitStream->GetLastCapturedNodes();
hip::GraphNode* pGraphNode = waitStream->GetCaptureGraph()->AddExternalEventWaitNode(
reinterpret_cast<hip::GraphNode*>(lastCapturedNodes.data()),
lastCapturedNodes.size(),
event);
waitStream->SetLastCapturedNode(pGraphNode);
return hipSuccess;
}
else if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) {
ClPrint(amd::LOG_DETAIL_DEBUG, amd::LOG_API,
"[hipGraph] Current capture node StreamWaitEvent on stream : %p, Event %p", stream,
event);
@@ -501,9 +513,6 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig
}
waitStream->AddCrossCapturedNode(e->GetNodesPrevToRecorded());
} else {
if (flags != 0) {
return hipErrorInvalidValue;
}
if (eventStream != nullptr) {
if (eventStream->GetCaptureStatus() == hipStreamCaptureStatusActive) {
// If stream is capturing but event is not recorded on event's stream.