blob: 4b3f79921ff7f328f618caa6c0f972c2f0c156e5 [file] [log] [blame]
David Beckf98d21a2018-10-26 16:03:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SubGraphSelector.hpp"
7#include "Graph.hpp"
8#include <boost/assert.hpp>
9#include <algorithm>
10#include <unordered_map>
11
12namespace armnn
13{
14
15namespace
16{
17
18struct LayerSelectionInfo
19{
20 using LayerInfoContainer = std::unordered_map<Layer*, LayerSelectionInfo>;
21 static constexpr uint32_t InitialSplitId() { return 1; }
22
23 LayerSelectionInfo(Layer* layer, const SubGraphSelector::LayerSelectorFunction& selector)
24 : m_Layer{layer}
25 , m_SplitId{0}
26 , m_IsSelected{selector(*layer)}
27 {
28 // fill topology information by storing direct children
29 for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
30 {
31 for (InputSlot* childLayerInputSlot : slot->GetConnections())
32 {
33 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
34 m_DirectChildren.push_back(&childLayer);
35 }
36 }
37 }
38
39 void MarkChildrenSplits(LayerInfoContainer& network,
40 uint32_t splitId,
41 bool prevSelected)
42 {
43 if (m_SplitId < splitId)
44 {
45 m_SplitId = splitId;
46 }
47
48 // introduce a new split point at all non-selected points, but only if the
49 // previous point was selected. this prevents creating a new subgraph at
50 // every non-selected layer
51 if (!m_IsSelected && prevSelected)
52 {
53 ++m_SplitId;
54 }
55
56 for (auto& layer : m_DirectChildren)
57 {
58 auto it = network.find(layer);
59 BOOST_ASSERT_MSG(it != network.end(), "All layers must be part of the topology.");
60 if (it != network.end())
61 {
62 it->second.MarkChildrenSplits(network, m_SplitId, m_IsSelected);
63 }
64 }
65 }
66
67 bool IsInputLayer() const
68 {
69 return m_Layer->GetType() == armnn::LayerType::Input;
70 }
71
Matteo Martincighadddddb2019-01-24 14:06:23 +000072 void CollectNonSelectedInputs(SubGraph::InputSlots& inputSlots,
David Beckf98d21a2018-10-26 16:03:03 +010073 const SubGraphSelector::LayerSelectorFunction& selector)
74 {
75 for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
76 {
77 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
Matteo Martincighadddddb2019-01-24 14:06:23 +000078 BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
David Beckf98d21a2018-10-26 16:03:03 +010079 if (parentLayerOutputSlot)
80 {
81 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
82 if (selector(parentLayer) == false)
83 {
Matteo Martincighadddddb2019-01-24 14:06:23 +000084 inputSlots.push_back(&(*slot));
David Beckf98d21a2018-10-26 16:03:03 +010085 }
86 }
87 }
88 }
89
Matteo Martincighadddddb2019-01-24 14:06:23 +000090 void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& outputSlots,
David Beckf98d21a2018-10-26 16:03:03 +010091 const SubGraphSelector::LayerSelectorFunction& selector)
92 {
93 for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
94 {
95 for (InputSlot* childLayerInputSlot : slot->GetConnections())
96 {
97 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
98 if (selector(childLayer) == false)
99 {
Matteo Martincighadddddb2019-01-24 14:06:23 +0000100 outputSlots.push_back(&(*slot));
David Beckf98d21a2018-10-26 16:03:03 +0100101 }
102 }
103 }
104 }
105
106 std::vector<Layer*> m_DirectChildren;
107 Layer* m_Layer;
108 uint32_t m_SplitId;
109 bool m_IsSelected;
110};
111
112} // namespace <anonymous>
113
114SubGraphSelector::SubGraphs
Matteo Martincighadddddb2019-01-24 14:06:23 +0000115SubGraphSelector::SelectSubGraphs(Graph& graph, const LayerSelectorFunction& selector)
116{
117 SubGraph subGraph(graph);
118 return SubGraphSelector::SelectSubGraphs(subGraph, selector);
119}
120
121SubGraphSelector::SubGraphs
122SubGraphSelector::SelectSubGraphs(SubGraph& subGraph, const LayerSelectorFunction& selector)
David Beckf98d21a2018-10-26 16:03:03 +0100123{
124 LayerSelectionInfo::LayerInfoContainer layerInfo;
125
Matteo Martincighadddddb2019-01-24 14:06:23 +0000126 for (auto& layer : subGraph)
David Beckf98d21a2018-10-26 16:03:03 +0100127 {
128 layerInfo.emplace(layer, LayerSelectionInfo{layer, selector});
129 }
130
131 uint32_t splitNo = LayerSelectionInfo::InitialSplitId();
132 for (auto& info : layerInfo)
133 {
134 if (info.second.IsInputLayer())
135 {
Matteo Martincighadddddb2019-01-24 14:06:23 +0000136 // For each input layer we mark the graph where subgraph
David Beckf98d21a2018-10-26 16:03:03 +0100137 // splits need to happen because of the dependency between
138 // the selected and non-selected nodes
139 info.second.MarkChildrenSplits(layerInfo, splitNo, false);
140 }
141 }
142
143 // Collect all selected layers keyed by split id into a map
144 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
145 std::unordered_map<uint32_t, SelectionInfoPtrs> splitMap;
146 for (auto& info : layerInfo)
147 {
148 if (info.second.m_IsSelected)
149 {
150 auto it = splitMap.find(info.second.m_SplitId);
151 if (it == splitMap.end())
152 {
153 splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second}));
154 }
155 else
156 {
157 it->second.push_back(&info.second);
158 }
159 }
160 }
161
162 // Now each non-empty split id represents a subgraph
163 SubGraphs result;
164 for (auto& splitGraph : splitMap)
165 {
166 if (splitGraph.second.empty() == false)
167 {
David Beckf98d21a2018-10-26 16:03:03 +0100168 SubGraph::InputSlots inputs;
Matteo Martincighadddddb2019-01-24 14:06:23 +0000169 SubGraph::OutputSlots outputs;
David Beckf98d21a2018-10-26 16:03:03 +0100170 SubGraph::Layers layers;
171 for (auto&& infoPtr : splitGraph.second)
172 {
David Beckf98d21a2018-10-26 16:03:03 +0100173 infoPtr->CollectNonSelectedInputs(inputs, selector);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000174 infoPtr->CollectNonSelectedOutputSlots(outputs, selector);
Matteo Martincigh49124022019-01-11 13:25:59 +0000175 layers.push_back(infoPtr->m_Layer);
David Beckf98d21a2018-10-26 16:03:03 +0100176 }
Matteo Martincigh0c051f92019-01-31 12:09:49 +0000177 // Create a new sub-graph with the new lists of input/output slots and layer, using
178 // the given sub-graph as a reference of which parent graph to use
Matteo Martincighadddddb2019-01-24 14:06:23 +0000179 result.emplace_back(std::make_unique<SubGraph>(subGraph,
180 std::move(inputs),
181 std::move(outputs),
182 std::move(layers)));
David Beckf98d21a2018-10-26 16:03:03 +0100183 }
184 }
185
186 return result;
187}
188
189} // namespace armnn