blob: e2c5f911a079e89384d97a65af8a8996a4b20b86 [file] [log] [blame]
David Beckf98d21a2018-10-26 16:03:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Derek Lambertiff05cc52019-04-26 13:05:17 +01006#include "SubgraphViewSelector.hpp"
David Beckf98d21a2018-10-26 16:03:03 +01007#include "Graph.hpp"
Jan Eilers8eb25602020-03-09 12:13:48 +00008
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009#include <armnn/utility/Assert.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010011#include <armnn/utility/PolymorphicDowncast.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000012
David Beckf98d21a2018-10-26 16:03:03 +010013#include <algorithm>
Matteo Martincighf02e6cd2019-05-17 12:15:30 +010014#include <map>
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +010015#include <queue>
Rob Hughes30db8ad2019-11-08 15:50:10 +000016#include <unordered_set>
David Beckf98d21a2018-10-26 16:03:03 +010017
18namespace armnn
19{
20
21namespace
22{
23
Rob Hughes30db8ad2019-11-08 15:50:10 +000024/// Intermediate data-structure to store the subgraph that a layer has been assigned to.
25/// This is a "disjoint set" data structure that allows efficient merging of subgraphs,
26/// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees
27/// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered
28/// to have been merged. Merging subgraphs is performed by attaching one tree to another,
29/// which is a simple pointer update.
30///
31/// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers
32/// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you
33/// should use IsMergedWith().
34///
35/// This structure also stores information about the dependencies of each subgraph, which is needed
36/// to determine whether certain subgraphs can be merged. Checking whether a subgraph
37/// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized
38/// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores
39/// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this
40/// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the
41/// complement of our dependencies).
42class PartialSubgraph
43{
44public:
45 /// If this subgraph has been merged with another then there is an agreed "representative" for the combined
46 /// subgraph, which uniquely identifies the subgraph.
47 PartialSubgraph* GetRepresentative()
48 {
49 // Recurse up the tree to find the root node.
50 if (m_Parent == nullptr)
51 {
52 return this;
53 }
54 else
55 {
56 PartialSubgraph* result = m_Parent->GetRepresentative();
57 // Update our parent pointer to point directly to the root in order to speed up future calls to this method.
58 // This essentially "flattens" the tree.
59 m_Parent = result;
60 return result;
61 }
62 }
63
64 /// Merges this subgraph with another.
65 void MergeWith(PartialSubgraph* other)
66 {
67 if (m_Parent == nullptr)
68 {
69 other = other->GetRepresentative();
70 if (this == other)
71 {
72 // Already merged - no-op
73 return;
74 }
75 m_Parent = other;
76
77 // Update others' dependency sets to point to the new representative rather than us.
78 // Keeping these up-to-date means we can rely on these sets containing representatives when
79 // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element
80 // of the set. See description at the top of this class for more rationale.
81 for (PartialSubgraph* a : m_Antecedents)
82 {
83 size_t numErased = a->m_Dependants.erase(this);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010084 ARMNN_ASSERT(numErased == 1);
Jan Eilers8eb25602020-03-09 12:13:48 +000085 IgnoreUnused(numErased);
Rob Hughes30db8ad2019-11-08 15:50:10 +000086 a->m_Dependants.insert(m_Parent);
87 }
88 for (PartialSubgraph* a : m_Dependants)
89 {
90 size_t numErased = a->m_Antecedents.erase(this);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010091 ARMNN_ASSERT(numErased == 1);
Jan Eilers8eb25602020-03-09 12:13:48 +000092 IgnoreUnused(numErased);
Rob Hughes30db8ad2019-11-08 15:50:10 +000093 a->m_Antecedents.insert(m_Parent);
94 }
95
96 // Merge our dependency sets into our new representative.
97 // We no longer need to maintain our own sets, as requests will always be forwarded to the representative.
98 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
99 m_Antecedents.clear();
100 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
101 m_Dependants.clear();
102 }
103 else
104 {
105 // Defer request to the representative
106 GetRepresentative()->MergeWith(other);
107 }
108 }
109
110 /// Checks if this subgraph has been merged with the given subgraph.
111 bool IsMergedWith(PartialSubgraph* other)
112 {
113 return GetRepresentative() == other->GetRepresentative();
114 }
115
116 /// Marks the given subgraph as a direct antecedent (dependency) of this one.
117 void AddDirectAntecedent(PartialSubgraph* antecedent)
118 {
119 if (m_Parent == nullptr)
120 {
121 antecedent = antecedent->GetRepresentative();
122
123 m_Antecedents.insert(antecedent);
124 // Also record all of its antecedents, so that we end up with direct and indirect antecedents.
125 // This makes the lookup in HasAntecedent() faster.
126 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
127 // All of our dependents also need to include the new antecedents
128 for (PartialSubgraph* d : m_Dependants)
129 {
130 d->m_Antecedents.insert(antecedent);
131 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
132 }
133
134 // Store reverse dependencies as well, required so that we can efficiently navigate the graph
135 // when making updates.
136 antecedent->m_Dependants.insert(this);
137 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
138 for (PartialSubgraph* a : antecedent->m_Antecedents)
139 {
140 a->m_Dependants.insert(this);
141 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
142 }
143 }
144 else
145 {
146 // Defer request to the representative
147 GetRepresentative()->AddDirectAntecedent(antecedent);
148 }
149 }
150
151 /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly.
152 bool HasAntecedent(PartialSubgraph* antecedent)
153 {
154 if (m_Parent == nullptr)
155 {
156 antecedent = antecedent->GetRepresentative();
157 // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup.
158 return m_Antecedents.count(antecedent) > 0;
159 }
160 else
161 {
162 // Defer request to the representative
163 return GetRepresentative()->HasAntecedent(antecedent);
164 }
165 }
166
167private:
168 /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph.
169 PartialSubgraph* m_Parent;
170 /// The representatives of all the subgraphs which we depend on, either directly or indirectly.
171 std::unordered_set<PartialSubgraph*> m_Antecedents;
172 /// The representatives of all the subgraphs which depend on us, either directly or indirectly.
173 std::unordered_set<PartialSubgraph*> m_Dependants;
174};
175
176/// Intermediate data structure to store information associated with a particular layer.
David Beckf98d21a2018-10-26 16:03:03 +0100177struct LayerSelectionInfo
178{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000179 using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100180 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
David Beckf98d21a2018-10-26 16:03:03 +0100181
Derek Lambertiff05cc52019-04-26 13:05:17 +0100182 LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
David Beckf98d21a2018-10-26 16:03:03 +0100183 : m_Layer{layer}
Rob Hughes30db8ad2019-11-08 15:50:10 +0000184 , m_Subgraph{nullptr}
David Beckf98d21a2018-10-26 16:03:03 +0100185 , m_IsSelected{selector(*layer)}
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100186 , m_IsProcessed(false)
David Beckf98d21a2018-10-26 16:03:03 +0100187 {
David Beckf98d21a2018-10-26 16:03:03 +0100188 }
189
David Beckf98d21a2018-10-26 16:03:03 +0100190 bool IsInputLayer() const
191 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100192 return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant;
David Beckf98d21a2018-10-26 16:03:03 +0100193 }
194
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100195 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
Francis Murtagh56ccf682021-12-13 18:48:12 +0000196 SubgraphView::IInputSlots& inputSlots)
David Beckf98d21a2018-10-26 16:03:03 +0100197 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000198 for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginInputSlots();
199 slot != PolymorphicDowncast<Layer*>(m_Layer)->EndInputSlots();
200 ++slot)
David Beckf98d21a2018-10-26 16:03:03 +0100201 {
202 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100203 ARMNN_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
David Beckf98d21a2018-10-26 16:03:03 +0100204 if (parentLayerOutputSlot)
205 {
206 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100207 auto parentInfo = layerInfos.find(&parentLayer);
Matteo Martincighf02e6cd2019-05-17 12:15:30 +0100208 if (parentInfo == layerInfos.end() ||
Rob Hughes30db8ad2019-11-08 15:50:10 +0000209 !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
David Beckf98d21a2018-10-26 16:03:03 +0100210 {
Matteo Martincigh05349c52019-05-21 13:29:00 +0100211 // Avoid collecting duplicate input slots
212 InputSlot* inputSlot = &(*slot);
213 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
214 {
215 inputSlots.push_back(inputSlot);
216 }
David Beckf98d21a2018-10-26 16:03:03 +0100217 }
218 }
219 }
220 }
221
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100222 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
Francis Murtagh56ccf682021-12-13 18:48:12 +0000223 SubgraphView::IOutputSlots& outputSlots)
David Beckf98d21a2018-10-26 16:03:03 +0100224 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000225 for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginOutputSlots();
226 slot != PolymorphicDowncast<Layer*>(m_Layer)->EndOutputSlots();
227 ++slot)
David Beckf98d21a2018-10-26 16:03:03 +0100228 {
229 for (InputSlot* childLayerInputSlot : slot->GetConnections())
230 {
231 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100232 auto childInfo = layerInfos.find(&childLayer);
Matteo Martincighf02e6cd2019-05-17 12:15:30 +0100233 if (childInfo == layerInfos.end() ||
Rob Hughes30db8ad2019-11-08 15:50:10 +0000234 !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
David Beckf98d21a2018-10-26 16:03:03 +0100235 {
Matteo Martincigh05349c52019-05-21 13:29:00 +0100236 // Avoid collecting duplicate output slots
237 OutputSlot* outputSlot = &(*slot);
238 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
239 {
240 outputSlots.push_back(outputSlot);
241 }
David Beckf98d21a2018-10-26 16:03:03 +0100242 }
243 }
244 }
245 }
246
Francis Murtagh56ccf682021-12-13 18:48:12 +0000247 IConnectableLayer* m_Layer;
Rob Hughes30db8ad2019-11-08 15:50:10 +0000248 /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true.
249 /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph -
250 /// see the description of the PartialSubgraph class.
251 std::shared_ptr<PartialSubgraph> m_Subgraph;
David Beckf98d21a2018-10-26 16:03:03 +0100252 bool m_IsSelected;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100253 bool m_IsProcessed;
David Beckf98d21a2018-10-26 16:03:03 +0100254};
255
256} // namespace <anonymous>
257
Derek Lambertiff05cc52019-04-26 13:05:17 +0100258SubgraphViewSelector::Subgraphs
259SubgraphViewSelector::SelectSubgraphs(Graph& graph, const LayerSelectorFunction& selector)
Matteo Martincighadddddb2019-01-24 14:06:23 +0000260{
Derek Lambertiff05cc52019-04-26 13:05:17 +0100261 SubgraphView subgraph(graph);
262 return SubgraphViewSelector::SelectSubgraphs(subgraph, selector);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000263}
264
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100265
266template<typename Delegate>
267void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
268 LayerSelectionInfo& layerInfo,
269 Delegate function)
270{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000271 Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100272
273 for (auto inputSlot : layer.GetInputSlots())
274 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100275 auto connectedInput = PolymorphicDowncast<OutputSlot*>(inputSlot.GetConnection());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100276 ARMNN_ASSERT_MSG(connectedInput, "Dangling input slot detected.");
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100277 Layer& inputLayer = connectedInput->GetOwningLayer();
278
279 auto parentInfo = layerInfos.find(&inputLayer);
Matteo Martincighf02e6cd2019-05-17 12:15:30 +0100280 if (parentInfo != layerInfos.end())
281 {
282 function(parentInfo->second);
283 }
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100284 }
285}
286
287template<typename Delegate>
288void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
289 LayerSelectionInfo& layerInfo,
290 Delegate function)
291{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000292 Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100293
294 for (auto& outputSlot : layer.GetOutputSlots())
295 {
296 for (auto& output : outputSlot.GetConnections())
297 {
298 Layer& childLayer = output->GetOwningLayer();
299
300 auto childInfo = layerInfos.find(&childLayer);
Matteo Martincighf02e6cd2019-05-17 12:15:30 +0100301 if (childInfo != layerInfos.end())
302 {
303 function(childInfo->second);
304 }
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100305 }
306 }
307}
308
309void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
310{
Rob Hughes30db8ad2019-11-08 15:50:10 +0000311 // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned.
312 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100313 {
Rob Hughes30db8ad2019-11-08 15:50:10 +0000314 // We can only attach ourselves to the subgraph from this input if there isn't a cut here.
315 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100316 {
Rob Hughes30db8ad2019-11-08 15:50:10 +0000317 // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs.
318 // This will be the case if the subgraph that we will become part of is already a dependency
319 // of one of the subgraphs that are input to this layer, e.g:
320 //
321 // 0 | The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X.
322 // / \ |
323 // 1 0 | We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0.
324 // \ / | We can however merge X into subgraph 1.
325 // X |
326 //
327 bool dependenciesOk = true;
328 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
329 {
330 // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer.
331 // Hence it is important that this is efficient - see PartialSubgraph class description.
332 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
333 {
334 dependenciesOk = false;
335 }
336 });
337
338 if (dependenciesOk)
339 {
340 // Merge into the subgraph of this input. If we have already been merged into another subgraph
341 // (from another input of this layer), then merge both of them together.
342 if (layerInfo.m_Subgraph == nullptr)
343 {
344 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
345 }
346 else
347 {
348 // We call MergeWith() ~ n times, where n is the number of inputs to this layer.
349 // Therefore it does not need to be as performant as HasAntecedent().
350 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
351 }
352 }
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100353 }
Rob Hughes30db8ad2019-11-08 15:50:10 +0000354 });
355
356 // If we weren't able to merge into an existing subgraph then we need to make a new one
357 if (layerInfo.m_Subgraph == nullptr)
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100358 {
Rob Hughes30db8ad2019-11-08 15:50:10 +0000359 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100360 }
Rob Hughes30db8ad2019-11-08 15:50:10 +0000361
362 // Record dependencies of the chosen subgraph based on the inputs of this layer.
363 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
364 {
365 // These functions are called ~n times, where n is the number of inputs to this layer.
366 // Therefore it does not need to be as performant as HasAntecedent().
367 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
368 {
369 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
370 }
371 });
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100372}
373
374bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
375{
376 bool ready = true;
377 ForEachLayerInput(layerInfos, layerInfo,
378 [&ready](LayerSelectionInfo& parentInfo)
379 {
380 if (!parentInfo.m_IsProcessed)
381 {
382 ready = false;
383 }
384 });
385 return ready;
386}
387
Derek Lambertiff05cc52019-04-26 13:05:17 +0100388SubgraphViewSelector::Subgraphs
389SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelectorFunction& selector)
David Beckf98d21a2018-10-26 16:03:03 +0100390{
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100391 LayerSelectionInfo::LayerInfoContainer layerInfos;
David Beckf98d21a2018-10-26 16:03:03 +0100392
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100393 LayerSelectionInfo::LayerInfoQueue processQueue;
Francis Murtagh56ccf682021-12-13 18:48:12 +0000394 const SubgraphView::IConnectableLayers& subgraphLayers = subgraph.GetIConnectableLayers();
395 for (auto& layer : subgraphLayers)
David Beckf98d21a2018-10-26 16:03:03 +0100396 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000397
398 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector});
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100399 LayerSelectionInfo& layerInfo = emplaced.first->second;
400
401 // Start with Input type layers
402 if (layerInfo.IsInputLayer())
403 {
404 processQueue.push(&layerInfo);
405 }
David Beckf98d21a2018-10-26 16:03:03 +0100406 }
407
Francis Murtagh56ccf682021-12-13 18:48:12 +0000408 const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots();
Matteo Martincighf02e6cd2019-05-17 12:15:30 +0100409 for (auto& inputSlot : subgraphInputSlots)
410 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000411 Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer();
Matteo Martincighf02e6cd2019-05-17 12:15:30 +0100412 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
413 LayerSelectionInfo& layerInfo = emplaced.first->second;
414
415 processQueue.push(&layerInfo);
416 }
417
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100418 while (!processQueue.empty())
David Beckf98d21a2018-10-26 16:03:03 +0100419 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100420 LayerSelectionInfo& layerInfo = *processQueue.front();
421 processQueue.pop(); // remove front from queue
422
423 // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
424 if (!layerInfo.m_IsProcessed)
David Beckf98d21a2018-10-26 16:03:03 +0100425 {
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100426 // Only process this layerInfo if all inputs have been processed
427 if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
428 {
429 // Put back of the process queue if we can't process it just yet
430 processQueue.push(&layerInfo);
431 continue; // Skip to next iteration
432 }
433
434 // Now we do the processing
435 AssignSplitId(layerInfos, layerInfo);
436
437 // Queue any child nodes for processing
438 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
439 {
440 processQueue.push(&childInfo);
441 });
442
443 // We don't need to process this node again
444 layerInfo.m_IsProcessed = true;
David Beckf98d21a2018-10-26 16:03:03 +0100445 }
446 }
447
Rob Hughes30db8ad2019-11-08 15:50:10 +0000448 // Collect all selected layers keyed by subgraph representative into a map
David Beckf98d21a2018-10-26 16:03:03 +0100449 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
Rob Hughes30db8ad2019-11-08 15:50:10 +0000450 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
Derek Lamberti5cf4d1c2019-05-03 18:57:12 +0100451 for (auto& info : layerInfos)
David Beckf98d21a2018-10-26 16:03:03 +0100452 {
453 if (info.second.m_IsSelected)
454 {
Rob Hughes30db8ad2019-11-08 15:50:10 +0000455 auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative());
David Beckf98d21a2018-10-26 16:03:03 +0100456 if (it == splitMap.end())
457 {
Rob Hughes30db8ad2019-11-08 15:50:10 +0000458 splitMap.insert(
459 std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
David Beckf98d21a2018-10-26 16:03:03 +0100460 }
461 else
462 {
463 it->second.push_back(&info.second);
464 }
465 }
466 }
467
Rob Hughes30db8ad2019-11-08 15:50:10 +0000468 // Now each entry in splitMap represents a subgraph
Derek Lambertiff05cc52019-04-26 13:05:17 +0100469 Subgraphs result;
David Beckf98d21a2018-10-26 16:03:03 +0100470 for (auto& splitGraph : splitMap)
471 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000472 SubgraphView::IInputSlots inputs;
473 SubgraphView::IOutputSlots outputs;
474 SubgraphView::IConnectableLayers layers;
Rob Hughes30db8ad2019-11-08 15:50:10 +0000475 for (auto&& infoPtr : splitGraph.second)
David Beckf98d21a2018-10-26 16:03:03 +0100476 {
Rob Hughes30db8ad2019-11-08 15:50:10 +0000477 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
478 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
479 layers.push_back(infoPtr->m_Layer);
David Beckf98d21a2018-10-26 16:03:03 +0100480 }
Rob Hughes1addbf32021-02-19 09:24:44 +0000481
482 // Sort lists into deterministic order, not relying on pointer values which may be different on each execution.
483 // This makes debugging the optimised graph much easier as subsequent stages can also be deterministic.
Francis Murtagh56ccf682021-12-13 18:48:12 +0000484 std::sort(inputs.begin(), inputs.end(), [](const IInputSlot* a, const IInputSlot* b)
Rob Hughes1addbf32021-02-19 09:24:44 +0000485 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000486 auto* castA = PolymorphicDowncast<const InputSlot*>(a);
487 auto* castB = PolymorphicDowncast<const InputSlot*>(b);
488 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
489 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
Rob Hughes1addbf32021-02-19 09:24:44 +0000490 if (guidA < guidB)
491 {
492 return true;
493 }
494 else if (guidA == guidB)
495 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000496 return (castA->GetSlotIndex() < castB->GetSlotIndex());
Rob Hughes1addbf32021-02-19 09:24:44 +0000497 }
498 return false;
499 });
Francis Murtagh56ccf682021-12-13 18:48:12 +0000500 std::sort(outputs.begin(), outputs.end(), [](const IOutputSlot* a, const IOutputSlot* b)
Rob Hughes1addbf32021-02-19 09:24:44 +0000501 {
Francis Murtagh56ccf682021-12-13 18:48:12 +0000502 auto* castA = PolymorphicDowncast<const OutputSlot*>(a);
503 auto* castB = PolymorphicDowncast<const OutputSlot*>(b);
504 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
505 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
Rob Hughes1addbf32021-02-19 09:24:44 +0000506 if (guidA < guidB)
507 {
508 return true;
509 }
510 else if (guidA == guidB)
511 {
512 return (a->CalculateIndexOnOwner() < b->CalculateIndexOnOwner());
513 }
514 return false;
515 });
Francis Murtagh56ccf682021-12-13 18:48:12 +0000516 layers.sort([](const IConnectableLayer* a, const IConnectableLayer* b) { return a->GetGuid() < b->GetGuid(); });
Rob Hughes1addbf32021-02-19 09:24:44 +0000517
Rob Hughes30db8ad2019-11-08 15:50:10 +0000518 // Create a new sub-graph with the new lists of input/output slots and layer
Francis Murtagh56ccf682021-12-13 18:48:12 +0000519 result.emplace_back(std::make_unique<SubgraphView>(std::move(layers),
520 std::move(inputs),
521 std::move(outputs)));
David Beckf98d21a2018-10-26 16:03:03 +0100522 }
523
524 return result;
525}
526
527} // namespace armnn