blob: 11314590a0bebf17dc41b5a7ae2dc1b251098742 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
David Beck29c75de2018-10-23 13:35:58 +01007#include <armnn/Optional.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include <boost/assert.hpp>
9
10#include <functional>
11#include <map>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <stack>
telsoa014fcda012018-03-09 14:13:49 +000013#include <vector>
14
telsoa01c577f2c2018-08-31 09:22:23 +010015
telsoa014fcda012018-03-09 14:13:49 +000016namespace armnnUtils
17{
18
19namespace
20{
21
22enum class NodeState
23{
24 Visiting,
25 Visited,
26};
27
telsoa01c577f2c2018-08-31 09:22:23 +010028
29template <typename TNodeId>
David Beck29c75de2018-10-23 13:35:58 +010030armnn::Optional<TNodeId> GetNextChild(TNodeId node,
telsoa01c577f2c2018-08-31 09:22:23 +010031 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
32 std::map<TNodeId, NodeState>& nodeStates)
33{
34 for (TNodeId childNode : getIncomingEdges(node))
35 {
36 if (nodeStates.find(childNode) == nodeStates.end())
37 {
38 return childNode;
39 }
40 else
41 {
42 if (nodeStates.find(childNode)->second == NodeState::Visiting)
43 {
44 return childNode;
45 }
46 }
47 }
48
49 return {};
50}
51
telsoa014fcda012018-03-09 14:13:49 +000052template<typename TNodeId>
telsoa01c577f2c2018-08-31 09:22:23 +010053bool TopologicallySort(
54 TNodeId initialNode,
telsoa014fcda012018-03-09 14:13:49 +000055 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
56 std::vector<TNodeId>& outSorted,
57 std::map<TNodeId, NodeState>& nodeStates)
58{
telsoa01c577f2c2018-08-31 09:22:23 +010059 std::stack<TNodeId> nodeStack;
60
61 // If the node is never visited we should search it
62 if (nodeStates.find(initialNode) == nodeStates.end())
telsoa014fcda012018-03-09 14:13:49 +000063 {
telsoa01c577f2c2018-08-31 09:22:23 +010064 nodeStack.push(initialNode);
telsoa014fcda012018-03-09 14:13:49 +000065 }
66
telsoa01c577f2c2018-08-31 09:22:23 +010067 while (!nodeStack.empty())
telsoa014fcda012018-03-09 14:13:49 +000068 {
telsoa01c577f2c2018-08-31 09:22:23 +010069 TNodeId current = nodeStack.top();
70
71 nodeStates[current] = NodeState::Visiting;
72
David Beck29c75de2018-10-23 13:35:58 +010073 auto nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
telsoa01c577f2c2018-08-31 09:22:23 +010074
75 if (nextChildOfCurrent)
76 {
David Beck29c75de2018-10-23 13:35:58 +010077 TNodeId nextChild = nextChildOfCurrent.value();
telsoa01c577f2c2018-08-31 09:22:23 +010078
79 // If the child has not been searched, add to the stack and iterate over this node
80 if (nodeStates.find(nextChild) == nodeStates.end())
81 {
82 nodeStack.push(nextChild);
83 continue;
84 }
85
86 // If we re-encounter a node being visited there is a cycle
87 if (nodeStates[nextChild] == NodeState::Visiting)
88 {
89 return false;
90 }
91 }
92
93 nodeStack.pop();
94
95 nodeStates[current] = NodeState::Visited;
96 outSorted.push_back(current);
telsoa014fcda012018-03-09 14:13:49 +000097 }
98
telsoa014fcda012018-03-09 14:13:49 +000099 return true;
100}
101
102}
103
telsoa01c577f2c2018-08-31 09:22:23 +0100104// Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
telsoa014fcda012018-03-09 14:13:49 +0000105// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
106// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
107// it must return the list of nodes which are required to come before it.
108// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
telsoa01c577f2c2018-08-31 09:22:23 +0100109// This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
telsoa014fcda012018-03-09 14:13:49 +0000110template<typename TNodeId, typename TTargetNodes>
111bool GraphTopologicalSort(
112 const TTargetNodes& targetNodes,
113 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
114 std::vector<TNodeId>& outSorted)
115{
116 outSorted.clear();
117 std::map<TNodeId, NodeState> nodeStates;
118
119 for (TNodeId targetNode : targetNodes)
120 {
telsoa01c577f2c2018-08-31 09:22:23 +0100121 if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
telsoa014fcda012018-03-09 14:13:49 +0000122 {
123 return false;
124 }
125 }
126
127 return true;
128}
129
130}