From c85200fc42a8a83a648b57db0316aad9f8ba9a51 Mon Sep 17 00:00:00 2001 From: "systems-assistant[bot]" <221163467+systems-assistant[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:50:55 -0700 Subject: [PATCH] SWDEV-541096 - add hipEventWaitDefault and hipEventWaitExternal flags (#507) Co-authored-by: Li, Todd tiantuo --- .../clr/hipamd/src/hip_graph_internal.hpp | 104 ++++++++++-------- projects/clr/hipamd/src/hip_stream.cpp | 17 ++- 2 files changed, 70 insertions(+), 51 deletions(-) diff --git a/projects/clr/hipamd/src/hip_graph_internal.hpp b/projects/clr/hipamd/src/hip_graph_internal.hpp index c58c810432..babcb5f87c 100644 --- a/projects/clr/hipamd/src/hip_graph_internal.hpp +++ b/projects/clr/hipamd/src/hip_graph_internal.hpp @@ -478,6 +478,53 @@ class GraphNode : public hipGraphNodeDOTAttribute { //!< when explicitly added) }; +class GraphEventWaitNode : public GraphNode { + hipEvent_t event_; + + public: + GraphEventWaitNode(hipEvent_t event) + : GraphNode(hipGraphNodeTypeWaitEvent, "solid", "rectangle", "EVENT_WAIT"), event_(event) {} + + ~GraphEventWaitNode() {} + + GraphEventWaitNode(const GraphEventWaitNode& rhs) : GraphNode(rhs) { event_ = rhs.event_; } + + GraphNode* clone() const override { return new GraphEventWaitNode(*this); } + + hipError_t CreateCommand(hip::Stream* stream) override { + hipError_t status = GraphNode::CreateCommand(stream); + if (status != hipSuccess) { + return status; + } + hip::Event* e = reinterpret_cast(event_); + commands_.reserve(1); + amd::Command* command; + status = e->streamWaitCommand(command, stream); + commands_.emplace_back(command); + return status; + } + + void EnqueueCommands(hip::Stream* stream) override { + if (!commands_.empty()) { + hip::Event* e = reinterpret_cast(event_); + commands_[0]->enqueue(); + commands_[0]->release(); + } + } + + void GetParams(hipEvent_t* event) const { *event = event_; } + + hipError_t SetParams(hipEvent_t event) { + event_ = event; + return hipSuccess; + } + + hipError_t SetParams(GraphNode* node) override { + const GraphEventWaitNode* eventWaitNode = static_cast(node); + return SetParams(eventWaitNode->event_); + } +}; + class Graph { public: //!< Contains mem alloc dptrs whose corresponding free node is not added to the graph. @@ -533,6 +580,16 @@ class Graph { std::unordered_set GetManualNodesDuringCapture() { return capturedNodes_; } + GraphNode* AddExternalEventWaitNode(hip::GraphNode* pDependencies, size_t numDependencies, + hipEvent_t event) { + GraphNode* node = new GraphEventWaitNode(event); + for (size_t i = 0; i < numDependencies; i++) { + pDependencies[i].AddEdgeDep(node); + } + AddNode(node); + return node; + } + void RemoveManualNodesDuringCapture() { capturedNodes_.erase(capturedNodes_.begin(), capturedNodes_.end()); } @@ -2167,53 +2224,6 @@ class GraphEventRecordNode : public GraphNode { } }; -class GraphEventWaitNode : public GraphNode { - hipEvent_t event_; - - public: - GraphEventWaitNode(hipEvent_t event) - : GraphNode(hipGraphNodeTypeWaitEvent, "solid", "rectangle", "EVENT_WAIT"), event_(event) {} - - ~GraphEventWaitNode() {} - - GraphEventWaitNode(const GraphEventWaitNode& rhs) : GraphNode(rhs) { event_ = rhs.event_; } - - GraphNode* clone() const override { return new GraphEventWaitNode(*this); } - - hipError_t CreateCommand(hip::Stream* stream) override { - hipError_t status = GraphNode::CreateCommand(stream); - if (status != hipSuccess) { - return status; - } - hip::Event* e = reinterpret_cast(event_); - commands_.reserve(1); - amd::Command* command; - status = e->streamWaitCommand(command, stream); - commands_.emplace_back(command); - return status; - } - - void EnqueueCommands(hip::Stream* stream) override { - if (!commands_.empty()) { - hip::Event* e = reinterpret_cast(event_); - commands_[0]->enqueue(); - commands_[0]->release(); - } - } - - void GetParams(hipEvent_t* event) const { *event = event_; } - - hipError_t SetParams(hipEvent_t event) { - event_ = event; - return hipSuccess; - } - - hipError_t SetParams(GraphNode* node) override { - const GraphEventWaitNode* eventWaitNode = static_cast(node); - return SetParams(eventWaitNode->event_); - } -}; - class GraphHostNode : public GraphNode { hipHostNodeParams NodeParams_; diff --git a/projects/clr/hipamd/src/hip_stream.cpp b/projects/clr/hipamd/src/hip_stream.cpp index 4c99c0f340..f2081b97f8 100644 --- a/projects/clr/hipamd/src/hip_stream.cpp +++ b/projects/clr/hipamd/src/hip_stream.cpp @@ -467,6 +467,9 @@ void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_d // ================================================================================================ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsigned int flags) { + if (flags != hipEventWaitDefault && flags != hipEventWaitExternal) { + return hipErrorInvalidValue; + } hipError_t status = hipSuccess; if (event == nullptr) { return hipErrorInvalidHandle; @@ -483,7 +486,16 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig } hip::Stream* eventStream = reinterpret_cast(eventStreamHandle); - if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) { + if (flags == hipEventWaitExternal) { + auto lastCapturedNodes = waitStream->GetLastCapturedNodes(); + hip::GraphNode* pGraphNode = waitStream->GetCaptureGraph()->AddExternalEventWaitNode( + reinterpret_cast(lastCapturedNodes.data()), + lastCapturedNodes.size(), + event); + waitStream->SetLastCapturedNode(pGraphNode); + return hipSuccess; + } + else if (eventStream != nullptr && eventStream->IsEventCaptured(event) == true) { ClPrint(amd::LOG_DETAIL_DEBUG, amd::LOG_API, "[hipGraph] Current capture node StreamWaitEvent on stream : %p, Event %p", stream, event); @@ -501,9 +513,6 @@ hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsig } waitStream->AddCrossCapturedNode(e->GetNodesPrevToRecorded()); } else { - if (flags != 0) { - return hipErrorInvalidValue; - } if (eventStream != nullptr) { if (eventStream->GetCaptureStatus() == hipStreamCaptureStatusActive) { // If stream is capturing but event is not recorded on event's stream.