blob: f455289567989f117e67ee3234257a22e8924a1c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include <boost/assert.hpp>
8
9#include <functional>
10#include <map>
11#include <vector>
12
13namespace armnnUtils
14{
15
16namespace
17{
18
19enum class NodeState
20{
21 Visiting,
22 Visited,
23};
24
25template<typename TNodeId>
26bool Visit(
27 TNodeId current,
28 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
29 std::vector<TNodeId>& outSorted,
30 std::map<TNodeId, NodeState>& nodeStates)
31{
32 auto currentStateIt = nodeStates.find(current);
33 if (currentStateIt != nodeStates.end())
34 {
35 if (currentStateIt->second == NodeState::Visited)
36 {
37 return true;
38 }
39 if (currentStateIt->second == NodeState::Visiting)
40 {
41 return false;
42 }
43 else
44 {
45 BOOST_ASSERT(false);
46 }
47 }
48
49 nodeStates[current] = NodeState::Visiting;
50
51 for (TNodeId inputNode : getIncomingEdges(current))
52 {
53 Visit(inputNode, getIncomingEdges, outSorted, nodeStates);
54 }
55
56 nodeStates[current] = NodeState::Visited;
57
58 outSorted.push_back(current);
59 return true;
60}
61
62}
63
64// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
65// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
66// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
67// it must return the list of nodes which are required to come before it.
68// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
69// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
70template<typename TNodeId, typename TTargetNodes>
71bool GraphTopologicalSort(
72 const TTargetNodes& targetNodes,
73 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
74 std::vector<TNodeId>& outSorted)
75{
76 outSorted.clear();
77 std::map<TNodeId, NodeState> nodeStates;
78
79 for (TNodeId targetNode : targetNodes)
80 {
81 if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates))
82 {
83 return false;
84 }
85 }
86
87 return true;
88}
89
90}