blob: cc821ec9561fcbaf08941701be223b9a8fadf7b8 [file] [log] [blame]
David Beckf98d21a2018-10-26 16:03:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Derek Lambertiff05cc52019-04-26 13:05:17 +01006#include "SubgraphViewSelector.hpp"
David Beckf98d21a2018-10-26 16:03:03 +01007#include "Graph.hpp"
8#include <boost/assert.hpp>
9#include <algorithm>
10#include <unordered_map>
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010011#include <queue>
David Beckf98d21a2018-10-26 16:03:03 +010012
13namespace armnn
14{
15
16namespace
17{
18
19struct LayerSelectionInfo
20{
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010021 using SplitId = uint32_t;
David Beckf98d21a2018-10-26 16:03:03 +010022 using LayerInfoContainer = std::unordered_map<Layer*, LayerSelectionInfo>;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010023 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
David Beckf98d21a2018-10-26 16:03:03 +010024 static constexpr uint32_t InitialSplitId() { return 1; }
25
Derek Lambertiff05cc52019-04-26 13:05:17 +010026 LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
David Beckf98d21a2018-10-26 16:03:03 +010027 : m_Layer{layer}
28 , m_SplitId{0}
29 , m_IsSelected{selector(*layer)}
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010030 , m_IsProcessed(false)
David Beckf98d21a2018-10-26 16:03:03 +010031 {
32 // fill topology information by storing direct children
33 for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
34 {
35 for (InputSlot* childLayerInputSlot : slot->GetConnections())
36 {
37 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
38 m_DirectChildren.push_back(&childLayer);
39 }
40 }
41 }
42
David Beckf98d21a2018-10-26 16:03:03 +010043 bool IsInputLayer() const
44 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010045 return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant;
David Beckf98d21a2018-10-26 16:03:03 +010046 }
47
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010048 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
49 SubgraphView::InputSlots& inputSlots)
David Beckf98d21a2018-10-26 16:03:03 +010050 {
51 for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
52 {
53 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
Matteo Martincighadddddb2019-01-24 14:06:23 +000054 BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
David Beckf98d21a2018-10-26 16:03:03 +010055 if (parentLayerOutputSlot)
56 {
57 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010058 auto parentInfo = layerInfos.find(&parentLayer);
59 if (m_SplitId != parentInfo->second.m_SplitId)
David Beckf98d21a2018-10-26 16:03:03 +010060 {
Matteo Martincighadddddb2019-01-24 14:06:23 +000061 inputSlots.push_back(&(*slot));
David Beckf98d21a2018-10-26 16:03:03 +010062 }
63 }
64 }
65 }
66
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010067 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
68 SubgraphView::OutputSlots& outputSlots)
David Beckf98d21a2018-10-26 16:03:03 +010069 {
70 for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
71 {
72 for (InputSlot* childLayerInputSlot : slot->GetConnections())
73 {
74 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010075 auto childInfo = layerInfos.find(&childLayer);
76 if (m_SplitId != childInfo->second.m_SplitId)
David Beckf98d21a2018-10-26 16:03:03 +010077 {
Matteo Martincighadddddb2019-01-24 14:06:23 +000078 outputSlots.push_back(&(*slot));
David Beckf98d21a2018-10-26 16:03:03 +010079 }
80 }
81 }
82 }
83
84 std::vector<Layer*> m_DirectChildren;
85 Layer* m_Layer;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010086 SplitId m_SplitId;
David Beckf98d21a2018-10-26 16:03:03 +010087 bool m_IsSelected;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010088 bool m_IsProcessed;
David Beckf98d21a2018-10-26 16:03:03 +010089};
90
91} // namespace <anonymous>
92
Derek Lambertiff05cc52019-04-26 13:05:17 +010093SubgraphViewSelector::Subgraphs
94SubgraphViewSelector::SelectSubgraphs(Graph& graph, const LayerSelectorFunction& selector)
Matteo Martincighadddddb2019-01-24 14:06:23 +000095{
Derek Lambertiff05cc52019-04-26 13:05:17 +010096 SubgraphView subgraph(graph);
97 return SubgraphViewSelector::SelectSubgraphs(subgraph, selector);
Matteo Martincighadddddb2019-01-24 14:06:23 +000098}
99
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100100
101template<typename Delegate>
102void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
103 LayerSelectionInfo& layerInfo,
104 Delegate function)
105{
106 Layer& layer = *layerInfo.m_Layer;
107
108 for (auto inputSlot : layer.GetInputSlots())
109 {
110 auto connectedInput = boost::polymorphic_downcast<OutputSlot*>(inputSlot.GetConnection());
111 BOOST_ASSERT_MSG(connectedInput, "Dangling input slot detected.");
112 Layer& inputLayer = connectedInput->GetOwningLayer();
113
114 auto parentInfo = layerInfos.find(&inputLayer);
115 function(parentInfo->second);
116 }
117}
118
119template<typename Delegate>
120void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
121 LayerSelectionInfo& layerInfo,
122 Delegate function)
123{
124 Layer& layer= *layerInfo.m_Layer;
125
126 for (auto& outputSlot : layer.GetOutputSlots())
127 {
128 for (auto& output : outputSlot.GetConnections())
129 {
130 Layer& childLayer = output->GetOwningLayer();
131
132 auto childInfo = layerInfos.find(&childLayer);
133 function(childInfo->second);
134 }
135 }
136}
137
138void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
139{
140 bool newSplit = false;
141 LayerSelectionInfo::SplitId minSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::max();
142 LayerSelectionInfo::SplitId maxSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest();
143 LayerSelectionInfo::SplitId maxSelectableId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest();
144
145 ForEachLayerInput(layerInfos, layerInfo, [&newSplit, &minSplitId, &maxSplitId, &maxSelectableId, &layerInfo](
146 LayerSelectionInfo& parentInfo)
147 {
148 minSplitId = std::min(minSplitId, parentInfo.m_SplitId);
149 maxSplitId = std::max(maxSplitId, parentInfo.m_SplitId);
150 if (parentInfo.m_IsSelected && layerInfo.m_IsSelected)
151 {
152 maxSelectableId = std::max(maxSelectableId, parentInfo.m_SplitId);
153 }
154
155 if (layerInfo.m_IsSelected != parentInfo.m_IsSelected)
156 {
157 newSplit = true;
158 }
159
160 });
161
162 // Assign the split Id for the current layerInfo
163 if (newSplit)
164 {
165 if (maxSelectableId > minSplitId)
166 {
167 // We can be overly aggressive when choosing to create a new split so
168 // here we determine if one of the parent branches are suitable candidates for continuation instead.
169 // Any splitId > minSplitId will come from a shorter branch...and therefore should not be from
170 // the split containing the original fork and thus we avoid the execution dependency.
171 layerInfo.m_SplitId = maxSelectableId;
172 }
173 else
174 {
175 layerInfo.m_SplitId = ++maxSplitId;
176 }
177 } else
178 {
179 // The branch with the highest splitId represents the shortest path of selected nodes.
180 layerInfo.m_SplitId = maxSplitId;
181 }
182}
183
184bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
185{
186 bool ready = true;
187 ForEachLayerInput(layerInfos, layerInfo,
188 [&ready](LayerSelectionInfo& parentInfo)
189 {
190 if (!parentInfo.m_IsProcessed)
191 {
192 ready = false;
193 }
194 });
195 return ready;
196}
197
Derek Lambertiff05cc52019-04-26 13:05:17 +0100198SubgraphViewSelector::Subgraphs
199SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelectorFunction& selector)
David Beckf98d21a2018-10-26 16:03:03 +0100200{
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100201 LayerSelectionInfo::LayerInfoContainer layerInfos;
David Beckf98d21a2018-10-26 16:03:03 +0100202
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100203 LayerSelectionInfo::LayerInfoQueue processQueue;
Derek Lambertiff05cc52019-04-26 13:05:17 +0100204 for (auto& layer : subgraph)
David Beckf98d21a2018-10-26 16:03:03 +0100205 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100206 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
207 LayerSelectionInfo& layerInfo = emplaced.first->second;
208
209 // Start with Input type layers
210 if (layerInfo.IsInputLayer())
211 {
212 processQueue.push(&layerInfo);
213 }
David Beckf98d21a2018-10-26 16:03:03 +0100214 }
215
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100216 while (!processQueue.empty())
David Beckf98d21a2018-10-26 16:03:03 +0100217 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100218 LayerSelectionInfo& layerInfo = *processQueue.front();
219 processQueue.pop(); // remove front from queue
220
221 // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
222 if (!layerInfo.m_IsProcessed)
David Beckf98d21a2018-10-26 16:03:03 +0100223 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100224
225 // Only process this layerInfo if all inputs have been processed
226 if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
227 {
228 // Put back of the process queue if we can't process it just yet
229 processQueue.push(&layerInfo);
230 continue; // Skip to next iteration
231 }
232
233 // Now we do the processing
234 AssignSplitId(layerInfos, layerInfo);
235
236 // Queue any child nodes for processing
237 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
238 {
239 processQueue.push(&childInfo);
240 });
241
242 // We don't need to process this node again
243 layerInfo.m_IsProcessed = true;
David Beckf98d21a2018-10-26 16:03:03 +0100244 }
245 }
246
247 // Collect all selected layers keyed by split id into a map
248 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
249 std::unordered_map<uint32_t, SelectionInfoPtrs> splitMap;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100250 for (auto& info : layerInfos)
David Beckf98d21a2018-10-26 16:03:03 +0100251 {
252 if (info.second.m_IsSelected)
253 {
254 auto it = splitMap.find(info.second.m_SplitId);
255 if (it == splitMap.end())
256 {
257 splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second}));
258 }
259 else
260 {
261 it->second.push_back(&info.second);
262 }
263 }
264 }
265
266 // Now each non-empty split id represents a subgraph
Derek Lambertiff05cc52019-04-26 13:05:17 +0100267 Subgraphs result;
David Beckf98d21a2018-10-26 16:03:03 +0100268 for (auto& splitGraph : splitMap)
269 {
270 if (splitGraph.second.empty() == false)
271 {
Derek Lambertiff05cc52019-04-26 13:05:17 +0100272 SubgraphView::InputSlots inputs;
273 SubgraphView::OutputSlots outputs;
274 SubgraphView::Layers layers;
David Beckf98d21a2018-10-26 16:03:03 +0100275 for (auto&& infoPtr : splitGraph.second)
276 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100277 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
278 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
Matteo Martincigh49124022019-01-11 13:25:59 +0000279 layers.push_back(infoPtr->m_Layer);
David Beckf98d21a2018-10-26 16:03:03 +0100280 }
Matteo Martincigh602af092019-05-01 10:31:27 +0100281 // Create a new sub-graph with the new lists of input/output slots and layer
282 result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
Derek Lambertiff05cc52019-04-26 13:05:17 +0100283 std::move(outputs),
284 std::move(layers)));
David Beckf98d21a2018-10-26 16:03:03 +0100285 }
286 }
287
288 return result;
289}
290
291} // namespace armnn