blob: f455289567989f117e67ee3234257a22e8924a1c [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#pragma once
#include <boost/assert.hpp>
#include <functional>
#include <map>
#include <vector>
namespace armnnUtils
{
namespace
{
enum class NodeState
{
Visiting,
Visited,
};
template<typename TNodeId>
bool Visit(
TNodeId current,
std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
std::vector<TNodeId>& outSorted,
std::map<TNodeId, NodeState>& nodeStates)
{
auto currentStateIt = nodeStates.find(current);
if (currentStateIt != nodeStates.end())
{
if (currentStateIt->second == NodeState::Visited)
{
return true;
}
if (currentStateIt->second == NodeState::Visiting)
{
return false;
}
else
{
BOOST_ASSERT(false);
}
}
nodeStates[current] = NodeState::Visiting;
for (TNodeId inputNode : getIncomingEdges(current))
{
Visit(inputNode, getIncomingEdges, outSorted, nodeStates);
}
nodeStates[current] = NodeState::Visited;
outSorted.push_back(current);
return true;
}
}
// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
// it must return the list of nodes which are required to come before it.
// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
template<typename TNodeId, typename TTargetNodes>
bool GraphTopologicalSort(
const TTargetNodes& targetNodes,
std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
std::vector<TNodeId>& outSorted)
{
outSorted.clear();
std::map<TNodeId, NodeState> nodeStates;
for (TNodeId targetNode : targetNodes)
{
if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates))
{
return false;
}
}
return true;
}
}