From 13e2e797c0b217cbfa91b7bc67ce563822b63d76 Mon Sep 17 00:00:00 2001 From: Anusha GodavarthySurya Date: Thu, 12 Dec 2024 09:22:29 +0000 Subject: [PATCH] SWDEV-469422 - Derive GraphExec from Graph and ChildGraphNode from GraphExec Change-Id: I54d67a1665355579bc249d8ff4f9806e9ee14588 --- hipamd/src/hip_graph.cpp | 54 ++++++++++------------- hipamd/src/hip_graph_internal.cpp | 51 +++++++++++----------- hipamd/src/hip_graph_internal.hpp | 71 +++++++++++-------------------- 3 files changed, 74 insertions(+), 102 deletions(-) diff --git a/hipamd/src/hip_graph.cpp b/hipamd/src/hip_graph.cpp index c229670845..e68e348483 100644 --- a/hipamd/src/hip_graph.cpp +++ b/hipamd/src/hip_graph.cpp @@ -1374,37 +1374,30 @@ hipError_t ihipGraphInstantiate(hip::GraphExec** pGraphExec, hip::Graph* graph, } } } - std::unordered_map clonedNodes; - hip::Graph* clonedGraph = graph->clone(clonedNodes); - clonedGraph->memAllocNodePtrs_ = graph->memAllocNodePtrs_; - if (clonedGraph == nullptr) { - return hipErrorInvalidValue; - } - std::vector graphNodes; - clonedGraph->ScheduleNodes(); - if (false == clonedGraph->TopologicalOrder(graphNodes)) { - return hipErrorInvalidValue; - } - *pGraphExec = new hip::GraphExec(graphNodes, clonedGraph, clonedNodes, flags); - if (*pGraphExec != nullptr) { - graph->SetGraphInstantiated(true); - if (DEBUG_HIP_GRAPH_DOT_PRINT) { - static int i = 1; - std::string filename = - "graph_" + std::to_string(amd::Os::getProcessId()) + "_dot_print_" + std::to_string(i++); - hipError_t status = - ihipGraphDebugDotPrint(reinterpret_cast(clonedGraph), filename.c_str(), 0); - if (status == hipSuccess) { - LogPrintfInfo("[hipGraph] graph dump:%s", filename.c_str()); - } - } - if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { - (*pGraphExec)->SetKernelArgManager(new hip::GraphKernelArgManager()); - } - return (*pGraphExec)->Init(); - } else { + *pGraphExec = new hip::GraphExec(flags); + if (*pGraphExec == nullptr) { return hipErrorOutOfMemory; } + graph->clone(*pGraphExec, true); + (*pGraphExec)->ScheduleNodes(); + if (false == (*pGraphExec)->TopologicalOrder()) { + return hipErrorInvalidValue; + } + graph->SetGraphInstantiated(true); + if (DEBUG_HIP_GRAPH_DOT_PRINT) { + static int i = 1; + std::string filename = + "graph_" + std::to_string(amd::Os::getProcessId()) + "_dot_print_" + std::to_string(i++); + hipError_t status = + ihipGraphDebugDotPrint(reinterpret_cast(*pGraphExec), filename.c_str(), 0); + if (status == hipSuccess) { + LogPrintfInfo("[hipGraph] graph dump:%s", filename.c_str()); + } + } + if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { + (*pGraphExec)->SetKernelArgManager(new hip::GraphKernelArgManager()); + } + return (*pGraphExec)->Init(); } hipError_t hipGraphInstantiate(hipGraphExec_t* pGraphExec, hipGraph_t graph, @@ -1865,8 +1858,7 @@ hipError_t hipGraphExecChildGraphNodeSetParams(hipGraphExec_t hGraphExec, hipGra for (std::vector::size_type i = 0; i != childGraphNodes.size(); i++) { if (childGraphNodes[i]->GraphCaptureEnabled()) { status = reinterpret_cast(clonedNode) - ->graphExec_.UpdateAQLPacket( - reinterpret_cast(childGraphNodes[i])); + ->UpdateAQLPacket(reinterpret_cast(childGraphNodes[i])); if (status != hipSuccess) { return status; } diff --git a/hipamd/src/hip_graph_internal.cpp b/hipamd/src/hip_graph_internal.cpp index 4fa163655e..2503486897 100644 --- a/hipamd/src/hip_graph_internal.cpp +++ b/hipamd/src/hip_graph_internal.cpp @@ -193,7 +193,7 @@ void Graph::ScheduleOneNode(Node node, int stream_id) { child->ScheduleNodes(); max_streams_ = std::max(max_streams_, child->max_streams_); if (child->max_streams_ == 1) { - reinterpret_cast(node)->TopologicalOrder(); + reinterpret_cast(node)->GraphExec::TopologicalOrder(); } } for (auto edge: node->GetEdges()) { @@ -269,13 +269,13 @@ bool Graph::TopologicalOrder(std::vector& TopoOrder) { } // ================================================================================================ -Graph* Graph::clone(std::unordered_map& clonedNodes) const { - Graph* newGraph = new Graph(device_, this); - for (auto entry : vertices_) { +void Graph::clone(Graph* newGraph, bool cloneNodes) const { + newGraph->pOriginalGraph_ = this; + for (hip::GraphNode* entry : vertices_) { GraphNode* node = entry->clone(); node->SetParentGraph(newGraph); newGraph->vertices_.push_back(node); - clonedNodes[entry] = node; + newGraph->clonedNodes_[entry] = node; } std::vector clonedEdges; @@ -284,17 +284,17 @@ Graph* Graph::clone(std::unordered_map& clonedNodes) const { const std::vector& edges = node->GetEdges(); clonedEdges.clear(); for (auto edge : edges) { - clonedEdges.push_back(clonedNodes[edge]); + clonedEdges.push_back(newGraph->clonedNodes_[edge]); } - clonedNodes[node]->SetEdges(clonedEdges); + newGraph->clonedNodes_[node]->SetEdges(clonedEdges); } for (auto node : vertices_) { const std::vector& dependencies = node->GetDependencies(); clonedDependencies.clear(); for (auto dep : dependencies) { - clonedDependencies.push_back(clonedNodes[dep]); + clonedDependencies.push_back(newGraph->clonedNodes_[dep]); } - clonedNodes[node]->SetDependencies(clonedDependencies); + newGraph->clonedNodes_[node]->SetDependencies(clonedDependencies); } for (auto& userObj : graphUserObj_) { userObj.first->retain(); @@ -307,13 +307,17 @@ Graph* Graph::clone(std::unordered_map& clonedNodes) const { if (roots_.size() > 0) { memcpy(&newGraph->roots_[0], &roots_[0], sizeof(Node) * roots_.size()); } - return newGraph; + newGraph->memAllocNodePtrs_ = memAllocNodePtrs_; + if(!cloneNodes) { + newGraph->clonedNodes_.clear(); + } } // ================================================================================================ Graph* Graph::clone() const { - std::unordered_map clonedNodes; - return clone(clonedNodes); + Graph* newGraph = new Graph(device_); + clone(newGraph); + return newGraph; } // ================================================================================================ @@ -350,7 +354,7 @@ hipError_t GraphExec::CreateStreams(uint32_t num_streams) { hipError_t GraphExec::Init() { hipError_t status = hipSuccess; // create extra stream to avoid queue collision with the default execution stream - status = CreateStreams(clonedGraph_->max_streams_); + status = CreateStreams(max_streams_); if (status != hipSuccess) { return status; } @@ -376,11 +380,11 @@ void GraphExec::GetKernelArgSizeForGraph(size_t& kernArgSizeForGraph) { // Child graph shares same kernel arg manager GraphKernelArgManager* KernelArgManager = GetKernelArgManager(); KernelArgManager->retain(); - childNode->graphExec_.SetKernelArgManager(KernelArgManager); + childNode->SetKernelArgManager(KernelArgManager); // Set capture stream for child graph - childNode->graphExec_.capture_stream_ = capture_stream_; + childNode->capture_stream_ = capture_stream_; if (childNode->GetChildGraph()->max_streams_ == 1) { - childNode->graphExec_.GetKernelArgSizeForGraph(kernArgSizeForGraph); + childNode->GetKernelArgSizeForGraph(kernArgSizeForGraph); } } } @@ -404,7 +408,7 @@ hipError_t GraphExec::AllocKernelArgForGraphNode() { auto childNode = reinterpret_cast(node); if (childNode->GetChildGraph()->max_streams_ == 1) { childNode->SetGraphCaptureStatus(true); - status = childNode->graphExec_.AllocKernelArgForGraphNode(); + status = childNode->AllocKernelArgForGraphNode(); if (status != hipSuccess) { return status; } @@ -417,7 +421,7 @@ hipError_t GraphExec::AllocKernelArgForGraphNode() { // ================================================================================================ hipError_t GraphExec::CaptureAQLPackets() { hipError_t status = hipSuccess; - if (clonedGraph_->max_streams_ == 1) { + if (max_streams_ == 1) { size_t kernArgSizeForGraph = 0; GetKernelArgSizeForGraph(kernArgSizeForGraph); auto device = g_devices[ihipGetDevice()]->devices()[0]; @@ -439,7 +443,7 @@ hipError_t GraphExec::CaptureAQLPackets() { // ================================================================================================ hipError_t GraphExec::UpdateAQLPacket(hip::GraphNode* node) { hipError_t status = hipSuccess; - if (clonedGraph_->max_streams_ == 1) { + if (max_streams_ == 1) { node->CaptureAndFormPacket(capture_stream_, kernArgManager_); } return hipSuccess; @@ -696,7 +700,7 @@ hipError_t GraphExec::Run(hipStream_t graph_launch_stream) { repeatLaunch_ = true; } - if (clonedGraph_->max_streams_ == 1 && instantiateDeviceId_ == launch_stream->DeviceId()) { + if (max_streams_ == 1 && instantiateDeviceId_ == launch_stream->DeviceId()) { if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { // If the graph has kernels that does device side allocation, during packet capture, heap is // allocated because heap pointer has to be added to the AQL packet, and initialized during @@ -708,7 +712,7 @@ hipError_t GraphExec::Run(hipStream_t graph_launch_stream) { } } status = EnqueueGraphWithSingleList(launch_stream); - } else if (clonedGraph_->max_streams_ == 1 && instantiateDeviceId_ != launch_stream->DeviceId()) { + } else if (max_streams_ == 1 && instantiateDeviceId_ != launch_stream->DeviceId()) { for (int i = 0; i < topoOrder_.size(); i++) { topoOrder_[i]->SetStream(launch_stream); status = topoOrder_[i]->CreateCommand(topoOrder_[i]->GetQueue()); @@ -716,9 +720,9 @@ hipError_t GraphExec::Run(hipStream_t graph_launch_stream) { } } else { // Update streams for the graph execution - clonedGraph_->UpdateStreams(launch_stream, parallel_streams_); + UpdateStreams(launch_stream, parallel_streams_); // Execute all nodes in the graph - if (!clonedGraph_->RunNodes()) { + if (!RunNodes()) { LogError("Failed to launch nodes!"); return hipErrorOutOfMemory; } @@ -744,7 +748,6 @@ hipError_t GraphExec::Run(hipStream_t graph_launch_stream) { block_command->enqueue(); block_command->release(); CallbackCommand->release(); - ResetQueueIndex(); return status; } diff --git a/hipamd/src/hip_graph_internal.hpp b/hipamd/src/hip_graph_internal.hpp index 2b1fbc4453..af3f619060 100644 --- a/hipamd/src/hip_graph_internal.hpp +++ b/hipamd/src/hip_graph_internal.hpp @@ -491,6 +491,7 @@ struct Graph { std::unordered_set capturedNodes_; bool graphInstantiated_; std::unordered_set memAllocNodePtrs_; + std::unordered_map clonedNodes_; public: Graph(hip::Device* device, const Graph* original = nullptr) : pOriginalGraph_(original) @@ -636,7 +637,7 @@ struct Graph { bool TopologicalOrder(std::vector& TopoOrder); - Graph* clone(std::unordered_map& clonedNodes) const; + void clone(Graph* newGraph, bool cloneNodes = false) const; Graph* clone() const; void GenerateDOT(std::ostream& fout, hipGraphDebugDotFlags flag) { fout << "subgraph cluster_" << GetID() << " {" << std::endl; @@ -724,14 +725,11 @@ struct Graph { }; struct GraphKernelNode; -struct GraphExec : public amd::ReferenceCountedObject { +struct GraphExec : public amd::ReferenceCountedObject, public Graph { //! Topological order of the graph doesn't include nodes embedded as part of the child graph std::vector topoOrder_; - struct Graph* clonedGraph_; std::vector parallel_streams_; hip::Stream* capture_stream_; - uint currentQueueIndex_; - std::unordered_map clonedNodes_; static std::unordered_set graphExecSet_; static amd::Monitor graphExecSetLock_; uint64_t flags_ = 0; @@ -741,23 +739,14 @@ struct GraphExec : public amd::ReferenceCountedObject { bool repeatLaunch_ = false; public: - GraphExec(std::vector& topoOrder, struct Graph*& clonedGraph, - std::unordered_map& clonedNodes, uint64_t flags = 0) + GraphExec(uint64_t flags = 0) : ReferenceCountedObject(), - topoOrder_(topoOrder), - clonedGraph_(clonedGraph), - clonedNodes_(clonedNodes), - currentQueueIndex_(0), + Graph(hip::getCurrentDevice()), flags_(flags) { amd::ScopedLock lock(graphExecSetLock_); graphExecSet_.insert(this); } - GraphExec() : ReferenceCountedObject() { - amd::ScopedLock lock(graphExecSetLock_); - graphExecSet_.insert(this); - } - ~GraphExec() { for (auto stream : parallel_streams_) { if (stream != nullptr) { @@ -767,7 +756,6 @@ struct GraphExec : public amd::ReferenceCountedObject { } amd::ScopedLock lock(graphExecSetLock_); graphExecSet_.erase(this); - delete clonedGraph_; if (DEBUG_CLR_GRAPH_PACKET_CAPTURE) { if (kernArgManager_ != nullptr) { kernArgManager_->release(); @@ -793,15 +781,6 @@ struct GraphExec : public amd::ReferenceCountedObject { //! Check executable graphs validity static bool isGraphExecValid(GraphExec* pGraphExec); std::vector& GetNodes() { return topoOrder_; } - - hip::Stream* GetAvailableStreams() { - if (currentQueueIndex_ < parallel_streams_.size()) { - return parallel_streams_[currentQueueIndex_++]; - } - return nullptr; - } - - void ResetQueueIndex() { currentQueueIndex_ = 0; } uint64_t GetFlags() const { return flags_; } hipError_t Init(); hipError_t CreateStreams(uint32_t num_streams); @@ -822,19 +801,19 @@ struct GraphExec : public amd::ReferenceCountedObject { hipError_t AllocKernelArgForGraphNode(); void GetKernelArgSizeForGraph(size_t& kernArgSizeForGraph); hipError_t EnqueueGraphWithSingleList(hip::Stream* hip_stream); + bool TopologicalOrder() { return Graph::TopologicalOrder(topoOrder_); } }; -struct ChildGraphNode : public GraphNode { - struct GraphExec graphExec_; +struct ChildGraphNode : public GraphNode, public GraphExec { bool graphCaptureStatus_; public: - ChildGraphNode(Graph* g) : GraphNode(hipGraphNodeTypeGraph, "solid", "rectangle") { - graphExec_.clonedGraph_ = g->clone(); + ChildGraphNode(Graph* g) : GraphNode(hipGraphNodeTypeGraph, "solid", "rectangle"), GraphExec() { + g->clone(this); graphCaptureStatus_ = false; } - ChildGraphNode(const ChildGraphNode& rhs) : GraphNode(rhs) { - graphExec_.clonedGraph_ = rhs.graphExec_.clonedGraph_->clone(); + ChildGraphNode(const ChildGraphNode& rhs) : GraphNode(rhs), GraphExec() { + rhs.Graph::clone(this); graphCaptureStatus_ = rhs.graphCaptureStatus_; } @@ -842,14 +821,14 @@ struct ChildGraphNode : public GraphNode { return new ChildGraphNode(static_cast(*this)); } - Graph* GetChildGraph() override { return graphExec_.clonedGraph_; } + Graph* GetChildGraph() override { return this; } void SetGraphCaptureStatus(bool status) { graphCaptureStatus_ = status; } bool GetGraphCaptureStatus() { return graphCaptureStatus_; } std::vector& GetChildGraphNodeOrder() { - return graphExec_.topoOrder_; + return topoOrder_; } void SetStream(hip::Stream* stream) override { @@ -857,27 +836,25 @@ struct ChildGraphNode : public GraphNode { } bool TopologicalOrder(std::vector& TopoOrder) override { - return graphExec_.clonedGraph_->TopologicalOrder(TopoOrder); + return Graph::TopologicalOrder(TopoOrder); } - bool TopologicalOrder() { return graphExec_.clonedGraph_->TopologicalOrder(graphExec_.topoOrder_); } - void EnqueueCommands(hip::Stream* stream) override { if (graphCaptureStatus_) { - hipError_t status = graphExec_.EnqueueGraphWithSingleList(stream); - } else if (graphExec_.clonedGraph_->max_streams_ == 1) { - for (int i = 0; i < graphExec_.topoOrder_.size(); i++) { - graphExec_.topoOrder_[i]->SetStream(stream_); + hipError_t status = EnqueueGraphWithSingleList(stream); + } else if (max_streams_ == 1) { + for (int i = 0; i < topoOrder_.size(); i++) { + topoOrder_[i]->SetStream(stream_); hipError_t status = - graphExec_.topoOrder_[i]->CreateCommand(graphExec_.topoOrder_[i]->GetQueue()); - graphExec_.topoOrder_[i]->EnqueueCommands(stream_); + topoOrder_[i]->CreateCommand(topoOrder_[i]->GetQueue()); + topoOrder_[i]->EnqueueCommands(stream_); } } } hipError_t SetParams(const Graph* childGraph) { const std::vector& newNodes = childGraph->GetNodes(); - const std::vector& oldNodes = graphExec_.clonedGraph_->GetNodes(); + const std::vector& oldNodes = Graph::GetNodes(); for (std::vector::size_type i = 0; i != newNodes.size(); i++) { hipError_t status = oldNodes[i]->SetParams(newNodes[i]); if (status != hipSuccess) { @@ -889,15 +866,15 @@ struct ChildGraphNode : public GraphNode { hipError_t SetParams(GraphNode* node) override { const ChildGraphNode* childGraphNode = static_cast(node); - return SetParams(childGraphNode->graphExec_.clonedGraph_); + return SetParams((Graph*)this); } virtual std::string GetLabel(hipGraphDebugDotFlags flag) override { - return std::to_string(GetID()) + "\n" + "graph_" + std::to_string(graphExec_.clonedGraph_->GetID()); + return std::to_string(GraphNode::GetID()) + "\n" + "graph_" + std::to_string(Graph::GetID()); } virtual void GenerateDOT(std::ostream& fout, hipGraphDebugDotFlags flag) override { - graphExec_.clonedGraph_->GenerateDOT(fout, flag); + Graph::GenerateDOT(fout, flag); } };