blob: 88b305c9c774efd01d50988351a18ead84d3e0b7 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
David Beck29c75de2018-10-23 13:35:58 +01007#include <armnn/Optional.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008
9#include <functional>
10#include <map>
telsoa01c577f2c2018-08-31 09:22:23 +010011#include <stack>
telsoa014fcda012018-03-09 14:13:49 +000012#include <vector>
13
telsoa01c577f2c2018-08-31 09:22:23 +010014
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnnUtils
16{
17
18namespace
19{
20
21enum class NodeState
22{
23 Visiting,
24 Visited,
25};
26
telsoa01c577f2c2018-08-31 09:22:23 +010027
28template <typename TNodeId>
David Beck29c75de2018-10-23 13:35:58 +010029armnn::Optional<TNodeId> GetNextChild(TNodeId node,
telsoa01c577f2c2018-08-31 09:22:23 +010030 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
31 std::map<TNodeId, NodeState>& nodeStates)
32{
33 for (TNodeId childNode : getIncomingEdges(node))
34 {
35 if (nodeStates.find(childNode) == nodeStates.end())
36 {
37 return childNode;
38 }
39 else
40 {
41 if (nodeStates.find(childNode)->second == NodeState::Visiting)
42 {
43 return childNode;
44 }
45 }
46 }
47
48 return {};
49}
50
telsoa014fcda012018-03-09 14:13:49 +000051template<typename TNodeId>
telsoa01c577f2c2018-08-31 09:22:23 +010052bool TopologicallySort(
53 TNodeId initialNode,
telsoa014fcda012018-03-09 14:13:49 +000054 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
55 std::vector<TNodeId>& outSorted,
56 std::map<TNodeId, NodeState>& nodeStates)
57{
telsoa01c577f2c2018-08-31 09:22:23 +010058 std::stack<TNodeId> nodeStack;
59
60 // If the node is never visited we should search it
61 if (nodeStates.find(initialNode) == nodeStates.end())
telsoa014fcda012018-03-09 14:13:49 +000062 {
telsoa01c577f2c2018-08-31 09:22:23 +010063 nodeStack.push(initialNode);
telsoa014fcda012018-03-09 14:13:49 +000064 }
65
telsoa01c577f2c2018-08-31 09:22:23 +010066 while (!nodeStack.empty())
telsoa014fcda012018-03-09 14:13:49 +000067 {
telsoa01c577f2c2018-08-31 09:22:23 +010068 TNodeId current = nodeStack.top();
69
70 nodeStates[current] = NodeState::Visiting;
71
David Beck29c75de2018-10-23 13:35:58 +010072 auto nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
telsoa01c577f2c2018-08-31 09:22:23 +010073
74 if (nextChildOfCurrent)
75 {
David Beck29c75de2018-10-23 13:35:58 +010076 TNodeId nextChild = nextChildOfCurrent.value();
telsoa01c577f2c2018-08-31 09:22:23 +010077
78 // If the child has not been searched, add to the stack and iterate over this node
79 if (nodeStates.find(nextChild) == nodeStates.end())
80 {
81 nodeStack.push(nextChild);
82 continue;
83 }
84
85 // If we re-encounter a node being visited there is a cycle
86 if (nodeStates[nextChild] == NodeState::Visiting)
87 {
88 return false;
89 }
90 }
91
92 nodeStack.pop();
93
94 nodeStates[current] = NodeState::Visited;
95 outSorted.push_back(current);
telsoa014fcda012018-03-09 14:13:49 +000096 }
97
telsoa014fcda012018-03-09 14:13:49 +000098 return true;
99}
100
101}
102
telsoa01c577f2c2018-08-31 09:22:23 +0100103// Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
telsoa014fcda012018-03-09 14:13:49 +0000104// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
105// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
106// it must return the list of nodes which are required to come before it.
107// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
telsoa01c577f2c2018-08-31 09:22:23 +0100108// This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
telsoa014fcda012018-03-09 14:13:49 +0000109template<typename TNodeId, typename TTargetNodes>
110bool GraphTopologicalSort(
111 const TTargetNodes& targetNodes,
112 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
113 std::vector<TNodeId>& outSorted)
114{
115 outSorted.clear();
116 std::map<TNodeId, NodeState> nodeStates;
117
118 for (TNodeId targetNode : targetNodes)
119 {
telsoa01c577f2c2018-08-31 09:22:23 +0100120 if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
telsoa014fcda012018-03-09 14:13:49 +0000121 {
122 return false;
123 }
124 }
125
126 return true;
127}
128
129}