blob: 5f972a976791a58041860c2f268086542ba549c3 [file] [log] [blame]
David Beckf98d21a2018-10-26 16:03:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Francis Murtagha49ff082022-01-17 17:08:01 +00006#include <armnn/backends/SubgraphView.hpp>
7
8#include <Graph.hpp>
David Beckf98d21a2018-10-26 16:03:03 +01009
Jan Eilers8eb25602020-03-09 12:13:48 +000010#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +010011#include <armnn/utility/NumericCast.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
David Beckf98d21a2018-10-26 16:03:03 +010013
Matteo Martincighadddddb2019-01-24 14:06:23 +000014#include <utility>
15
David Beckf98d21a2018-10-26 16:03:03 +010016namespace armnn
17{
18
Matteo Martincighadddddb2019-01-24 14:06:23 +000019namespace
David Beckf98d21a2018-10-26 16:03:03 +010020{
Matteo Martincighadddddb2019-01-24 14:06:23 +000021
22template <class C>
23void AssertIfNullsOrDuplicates(const C& container, const std::string& errorMessage)
24{
25 using T = typename C::value_type;
26 std::unordered_set<T> duplicateSet;
27 std::for_each(container.begin(), container.end(), [&duplicateSet, &errorMessage](const T& i)
28 {
29 // Ignore unused for release builds
Jan Eilers8eb25602020-03-09 12:13:48 +000030 IgnoreUnused(errorMessage);
Matteo Martincighadddddb2019-01-24 14:06:23 +000031
32 // Check if the item is valid
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010033 ARMNN_ASSERT_MSG(i, errorMessage.c_str());
Matteo Martincighadddddb2019-01-24 14:06:23 +000034
35 // Check if a duplicate has been found
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010036 ARMNN_ASSERT_MSG(duplicateSet.find(i) == duplicateSet.end(), errorMessage.c_str());
Matteo Martincighadddddb2019-01-24 14:06:23 +000037
38 duplicateSet.insert(i);
39 });
David Beckf98d21a2018-10-26 16:03:03 +010040}
41
Matteo Martincighadddddb2019-01-24 14:06:23 +000042} // anonymous namespace
43
Derek Lambertiff05cc52019-04-26 13:05:17 +010044SubgraphView::SubgraphView(Graph& graph)
Matteo Martincighadddddb2019-01-24 14:06:23 +000045 : m_InputSlots{}
46 , m_OutputSlots{}
47 , m_Layers(graph.begin(), graph.end())
Francis Murtagh56ccf682021-12-13 18:48:12 +000048 , m_IConnectableLayers(graph.begin(), graph.end())
David Beckf98d21a2018-10-26 16:03:03 +010049{
Derek Lamberti161d29c2020-12-07 13:54:12 +000050 ArrangeBySortOrder();
Derek Lambertiff05cc52019-04-26 13:05:17 +010051 CheckSubgraph();
David Beckf98d21a2018-10-26 16:03:03 +010052}
53
Francis Murtagh56ccf682021-12-13 18:48:12 +000054/// IConnectable Duplication to maintain backwards compatibility
Matteo Martincigh602af092019-05-01 10:31:27 +010055SubgraphView::SubgraphView(InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers)
Francis Murtagh56ccf682021-12-13 18:48:12 +000056 : m_InputSlots{InputSlots{inputs.begin(), inputs.end()}}
57 , m_IInputSlots{IInputSlots{inputs.begin(), inputs.end()}}
58 , m_OutputSlots{OutputSlots{outputs.begin(), outputs.end()}}
59 , m_IOutputSlots{IOutputSlots{outputs.begin(), outputs.end()}}
60 , m_Layers(layers)
61 , m_IConnectableLayers(IConnectableLayers{layers.begin(), layers.end()})
Matteo Martincighadddddb2019-01-24 14:06:23 +000062{
Derek Lamberti161d29c2020-12-07 13:54:12 +000063 ArrangeBySortOrder();
Derek Lambertiff05cc52019-04-26 13:05:17 +010064 CheckSubgraph();
Matteo Martincighadddddb2019-01-24 14:06:23 +000065}
66
Francis Murtagh56ccf682021-12-13 18:48:12 +000067/// IConnectable Duplication to maintain backwards compatibility
Francis Murtagh9d74ba62022-01-19 16:31:58 +000068SubgraphView::SubgraphView(SubgraphView::IConnectableLayers&& layers,
69 SubgraphView::IInputSlots&& inputs,
70 SubgraphView::IOutputSlots&& outputs)
Francis Murtagh56ccf682021-12-13 18:48:12 +000071 : m_IInputSlots{inputs}
72 , m_IOutputSlots{outputs}
73 , m_IConnectableLayers(IConnectableLayers{layers.begin(), layers.end()})
74{
75 // Cast from IConnectableLayer to Layer for backward compatibility
76 auto f = [](IConnectableLayer* value)
77 {
78 return PolymorphicDowncast<Layer*>(value);
79 };
80 std::transform(layers.begin(), layers.end(), std::back_inserter(m_Layers), f);
81
Francis Murtagh9d74ba62022-01-19 16:31:58 +000082 m_InputSlots.resize(inputs.size());
83 m_IInputSlots.resize(inputs.size());
84 for (unsigned int i = 0; i < inputs.size(); i++)
85 {
86 m_InputSlots.at(i) = PolymorphicDowncast<InputSlot*>(inputs[i]);
87 m_IInputSlots.at(i) = inputs[i];
88 }
89
90 m_OutputSlots.resize(outputs.size());
91 m_IOutputSlots.resize(outputs.size());
92 for (unsigned int i = 0; i < outputs.size(); i++)
93 {
94 m_OutputSlots.at(i) = PolymorphicDowncast<OutputSlot*>(outputs[i]);
95 m_IOutputSlots.at(i) = outputs[i];
96 }
97
98 ArrangeBySortOrder();
99 CheckSubgraph();
100}
101
102/// IConnectable Duplication to maintain backwards compatibility
103SubgraphView::SubgraphView(SubgraphView::IConnectableLayers&& layers,
104 SubgraphView::IInputSlots&& inputs,
105 SubgraphView::IOutputSlots&& outputs,
106 std::shared_ptr<SubgraphViewWorkingCopy> ptr)
107 : m_IInputSlots{inputs}
108 , m_IOutputSlots{outputs}
109 , m_IConnectableLayers(IConnectableLayers{layers.begin(), layers.end()})
110 , p_WorkingCopyImpl(std::move(ptr))
111{
112 // Cast from IConnectableLayer to Layer for backward compatibility
113 auto f = [](IConnectableLayer* value)
114 {
115 return PolymorphicDowncast<Layer*>(value);
116 };
117 std::transform(layers.begin(), layers.end(), std::back_inserter(m_Layers), f);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000118
119 m_InputSlots.resize(inputs.size());
120 m_IInputSlots.resize(inputs.size());
121 for (unsigned int i = 0; i < inputs.size(); i++)
122 {
123 m_InputSlots.at(i) = PolymorphicDowncast<InputSlot*>(inputs[i]);
124 m_IInputSlots.at(i) = inputs[i];
125 }
126
127 m_OutputSlots.resize(outputs.size());
128 m_IOutputSlots.resize(outputs.size());
129 for (unsigned int i = 0; i < outputs.size(); i++)
130 {
131 m_OutputSlots.at(i) = PolymorphicDowncast<OutputSlot*>(outputs[i]);
132 m_IOutputSlots.at(i) = outputs[i];
133 }
134
135 ArrangeBySortOrder();
136 CheckSubgraph();
137}
138
Derek Lambertiff05cc52019-04-26 13:05:17 +0100139SubgraphView::SubgraphView(const SubgraphView& subgraph)
140 : m_InputSlots(subgraph.m_InputSlots.begin(), subgraph.m_InputSlots.end())
Francis Murtagh56ccf682021-12-13 18:48:12 +0000141 , m_IInputSlots(subgraph.m_IInputSlots.begin(), subgraph.m_IInputSlots.end())
Derek Lambertiff05cc52019-04-26 13:05:17 +0100142 , m_OutputSlots(subgraph.m_OutputSlots.begin(), subgraph.m_OutputSlots.end())
Francis Murtagh56ccf682021-12-13 18:48:12 +0000143 , m_IOutputSlots(subgraph.m_IOutputSlots.begin(), subgraph.m_IOutputSlots.end())
Derek Lambertiff05cc52019-04-26 13:05:17 +0100144 , m_Layers(subgraph.m_Layers.begin(), subgraph.m_Layers.end())
Francis Murtagh56ccf682021-12-13 18:48:12 +0000145 , m_IConnectableLayers(IConnectableLayers{subgraph.m_IConnectableLayers.begin(),
146 subgraph.m_IConnectableLayers.end()})
Matteo Martincighadddddb2019-01-24 14:06:23 +0000147{
Derek Lamberti161d29c2020-12-07 13:54:12 +0000148 ArrangeBySortOrder();
Derek Lambertiff05cc52019-04-26 13:05:17 +0100149 CheckSubgraph();
Matteo Martincighadddddb2019-01-24 14:06:23 +0000150}
151
Derek Lambertiff05cc52019-04-26 13:05:17 +0100152SubgraphView::SubgraphView(SubgraphView&& subgraph)
153 : m_InputSlots(std::move(subgraph.m_InputSlots))
Francis Murtagh56ccf682021-12-13 18:48:12 +0000154 , m_IInputSlots(std::move(subgraph.m_IInputSlots))
Derek Lambertiff05cc52019-04-26 13:05:17 +0100155 , m_OutputSlots(std::move(subgraph.m_OutputSlots))
Francis Murtagh56ccf682021-12-13 18:48:12 +0000156 , m_IOutputSlots(std::move(subgraph.m_IOutputSlots))
Derek Lambertiff05cc52019-04-26 13:05:17 +0100157 , m_Layers(std::move(subgraph.m_Layers))
Francis Murtagh56ccf682021-12-13 18:48:12 +0000158 , m_IConnectableLayers(std::move(subgraph.m_IConnectableLayers))
Matteo Martincighadddddb2019-01-24 14:06:23 +0000159{
Derek Lamberti161d29c2020-12-07 13:54:12 +0000160 ArrangeBySortOrder();
Derek Lambertiff05cc52019-04-26 13:05:17 +0100161 CheckSubgraph();
Matteo Martincighadddddb2019-01-24 14:06:23 +0000162}
163
Matteo Martincigh602af092019-05-01 10:31:27 +0100164SubgraphView::SubgraphView(IConnectableLayer* layer)
Francis Murtagh56ccf682021-12-13 18:48:12 +0000165 : m_Layers{PolymorphicDowncast<Layer*>(layer)}
166 , m_IConnectableLayers{layer}
Matteo Martincighadddddb2019-01-24 14:06:23 +0000167{
168 unsigned int numInputSlots = layer->GetNumInputSlots();
169 m_InputSlots.resize(numInputSlots);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000170 m_IInputSlots.resize(numInputSlots);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000171 for (unsigned int i = 0; i < numInputSlots; i++)
172 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100173 m_InputSlots.at(i) = PolymorphicDowncast<InputSlot*>(&(layer->GetInputSlot(i)));
Francis Murtagh56ccf682021-12-13 18:48:12 +0000174 m_IInputSlots.at(i) = &(layer->GetInputSlot(i));
Matteo Martincighadddddb2019-01-24 14:06:23 +0000175 }
176
177 unsigned int numOutputSlots = layer->GetNumOutputSlots();
178 m_OutputSlots.resize(numOutputSlots);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000179 m_IOutputSlots.resize(numOutputSlots);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000180 for (unsigned int i = 0; i < numOutputSlots; i++)
181 {
Jan Eilersbb446e52020-04-02 13:56:54 +0100182 m_OutputSlots.at(i) = PolymorphicDowncast<OutputSlot*>(&(layer->GetOutputSlot(i)));
Francis Murtagh56ccf682021-12-13 18:48:12 +0000183 m_IOutputSlots.at(i) = &(layer->GetOutputSlot(i));
Matteo Martincighadddddb2019-01-24 14:06:23 +0000184 }
185
Derek Lambertiff05cc52019-04-26 13:05:17 +0100186 CheckSubgraph();
Matteo Martincighadddddb2019-01-24 14:06:23 +0000187}
188
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100189SubgraphView& SubgraphView::operator=(SubgraphView&& other)
190{
191 m_InputSlots = std::move(other.m_InputSlots);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000192 m_IInputSlots = std::move(other.m_IInputSlots);
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100193 m_OutputSlots = std::move(other.m_OutputSlots);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000194 m_IOutputSlots = std::move(other.m_IOutputSlots);
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100195 m_Layers = std::move(other.m_Layers);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000196 m_IConnectableLayers = std::move(other.m_IConnectableLayers);
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100197
198 CheckSubgraph();
199
200 return *this;
201}
202
Derek Lambertiff05cc52019-04-26 13:05:17 +0100203void SubgraphView::CheckSubgraph()
Matteo Martincighadddddb2019-01-24 14:06:23 +0000204{
Matteo Martincighadddddb2019-01-24 14:06:23 +0000205 // Check for invalid or duplicate input slots
206 AssertIfNullsOrDuplicates(m_InputSlots, "Sub-graphs cannot contain null or duplicate input slots");
207
208 // Check for invalid or duplicate output slots
209 AssertIfNullsOrDuplicates(m_OutputSlots, "Sub-graphs cannot contain null or duplicate output slots");
210
211 // Check for invalid or duplicate layers
212 AssertIfNullsOrDuplicates(m_Layers, "Sub-graphs cannot contain null or duplicate layers");
Francis Murtagh56ccf682021-12-13 18:48:12 +0000213
214 // Check for invalid or duplicate input slots
215 AssertIfNullsOrDuplicates(m_IInputSlots, "Sub-graphs cannot contain null or duplicate IInputSlots");
216
217 // Check for invalid or duplicate output slots
218 AssertIfNullsOrDuplicates(m_IOutputSlots, "Sub-graphs cannot contain null or duplicate IOutputSlots");
219
220 // Check for invalid or duplicate layers
221 AssertIfNullsOrDuplicates(m_IConnectableLayers,
222 "Sub-graphs cannot contain null or duplicate IConnectableLayers");
Matteo Martincighadddddb2019-01-24 14:06:23 +0000223}
224
Derek Lambertiff05cc52019-04-26 13:05:17 +0100225const SubgraphView::InputSlots& SubgraphView::GetInputSlots() const
David Beckf98d21a2018-10-26 16:03:03 +0100226{
227 return m_InputSlots;
228}
229
Francis Murtagh56ccf682021-12-13 18:48:12 +0000230const SubgraphView::IInputSlots& SubgraphView::GetIInputSlots() const
231{
232 return m_IInputSlots;
233}
234
Derek Lambertiff05cc52019-04-26 13:05:17 +0100235const SubgraphView::OutputSlots& SubgraphView::GetOutputSlots() const
David Beckf98d21a2018-10-26 16:03:03 +0100236{
237 return m_OutputSlots;
238}
239
Francis Murtagh56ccf682021-12-13 18:48:12 +0000240const SubgraphView::IOutputSlots& SubgraphView::GetIOutputSlots() const
241{
242 return m_IOutputSlots;
243}
244
Derek Lambertiff05cc52019-04-26 13:05:17 +0100245const InputSlot* SubgraphView::GetInputSlot(unsigned int index) const
David Beckf98d21a2018-10-26 16:03:03 +0100246{
247 return m_InputSlots.at(index);
248}
249
Francis Murtagh56ccf682021-12-13 18:48:12 +0000250const IInputSlot* SubgraphView::GetIInputSlot(unsigned int index) const
251{
252 return m_IInputSlots.at(index);
253}
254
Derek Lambertiff05cc52019-04-26 13:05:17 +0100255InputSlot* SubgraphView::GetInputSlot(unsigned int index)
David Beckf98d21a2018-10-26 16:03:03 +0100256{
Matteo Martincigh602af092019-05-01 10:31:27 +0100257 return m_InputSlots.at(index);
David Beckf98d21a2018-10-26 16:03:03 +0100258}
259
Francis Murtagh56ccf682021-12-13 18:48:12 +0000260IInputSlot* SubgraphView::GetIInputSlot(unsigned int index)
261{
262 return m_IInputSlots.at(index);
263}
264
Derek Lambertiff05cc52019-04-26 13:05:17 +0100265const OutputSlot* SubgraphView::GetOutputSlot(unsigned int index) const
David Beckf98d21a2018-10-26 16:03:03 +0100266{
267 return m_OutputSlots.at(index);
268}
269
Francis Murtagh56ccf682021-12-13 18:48:12 +0000270const IOutputSlot* SubgraphView::GetIOutputSlot(unsigned int index) const
271{
272 return m_IOutputSlots.at(index);
273}
274
Derek Lambertiff05cc52019-04-26 13:05:17 +0100275OutputSlot* SubgraphView::GetOutputSlot(unsigned int index)
David Beckf98d21a2018-10-26 16:03:03 +0100276{
277 return m_OutputSlots.at(index);
278}
279
Francis Murtagh56ccf682021-12-13 18:48:12 +0000280IOutputSlot* SubgraphView::GetIOutputSlot(unsigned int index)
281{
282 return m_IOutputSlots.at(index);
283}
284
Derek Lambertiff05cc52019-04-26 13:05:17 +0100285unsigned int SubgraphView::GetNumInputSlots() const
David Beckf98d21a2018-10-26 16:03:03 +0100286{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000287 return armnn::numeric_cast<unsigned int>(m_IInputSlots.size());
David Beckf98d21a2018-10-26 16:03:03 +0100288}
289
Derek Lambertiff05cc52019-04-26 13:05:17 +0100290unsigned int SubgraphView::GetNumOutputSlots() const
David Beckf98d21a2018-10-26 16:03:03 +0100291{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000292 return armnn::numeric_cast<unsigned int>(m_IOutputSlots.size());
David Beckf98d21a2018-10-26 16:03:03 +0100293}
294
Matteo Martincigh602af092019-05-01 10:31:27 +0100295const SubgraphView::Layers& SubgraphView::GetLayers() const
David Beckf98d21a2018-10-26 16:03:03 +0100296{
297 return m_Layers;
298}
299
Francis Murtagh56ccf682021-12-13 18:48:12 +0000300const SubgraphView::IConnectableLayers& SubgraphView::GetIConnectableLayers() const
301{
302 return m_IConnectableLayers;
303}
304
Matteo Martincigh602af092019-05-01 10:31:27 +0100305SubgraphView::Iterator SubgraphView::begin()
Matteo Martincigh49124022019-01-11 13:25:59 +0000306{
307 return m_Layers.begin();
308}
309
Derek Lambertiff05cc52019-04-26 13:05:17 +0100310SubgraphView::Iterator SubgraphView::end()
Matteo Martincigh49124022019-01-11 13:25:59 +0000311{
312 return m_Layers.end();
313}
314
Francis Murtagh56ccf682021-12-13 18:48:12 +0000315// IConnectable Duplication to maintain backwards compatibility
316SubgraphView::IConnectableLayerIterator SubgraphView::beginIConnectable()
317{
318 return m_IConnectableLayers.begin();
319}
320
321SubgraphView::IConnectableLayerIterator SubgraphView::endIConnectable()
322{
323 return m_IConnectableLayers.end();
324}
325
Derek Lambertiff05cc52019-04-26 13:05:17 +0100326SubgraphView::ConstIterator SubgraphView::begin() const
Matteo Martincigh49124022019-01-11 13:25:59 +0000327{
328 return m_Layers.begin();
329}
330
Derek Lambertiff05cc52019-04-26 13:05:17 +0100331SubgraphView::ConstIterator SubgraphView::end() const
Matteo Martincigh49124022019-01-11 13:25:59 +0000332{
333 return m_Layers.end();
334}
335
Francis Murtagh56ccf682021-12-13 18:48:12 +0000336// IConnectable Duplication to maintain backwards compatibility
337SubgraphView::ConstIConnectableIterator SubgraphView::beginIConnectable() const
338{
339 return m_IConnectableLayers.begin();
340}
341
342SubgraphView::ConstIConnectableIterator SubgraphView::endIConnectable() const
343{
344 return m_IConnectableLayers.end();
345}
346
Derek Lambertiff05cc52019-04-26 13:05:17 +0100347SubgraphView::ConstIterator SubgraphView::cbegin() const
Matteo Martincigh49124022019-01-11 13:25:59 +0000348{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000349 // Ignore deprecated call as this is internal to SubgraphView
350 ARMNN_NO_DEPRECATE_WARN_BEGIN
Matteo Martincigh49124022019-01-11 13:25:59 +0000351 return begin();
Francis Murtagh56ccf682021-12-13 18:48:12 +0000352 ARMNN_NO_DEPRECATE_WARN_END
Matteo Martincigh49124022019-01-11 13:25:59 +0000353}
354
Derek Lambertiff05cc52019-04-26 13:05:17 +0100355SubgraphView::ConstIterator SubgraphView::cend() const
Matteo Martincigh49124022019-01-11 13:25:59 +0000356{
Francis Murtagh56ccf682021-12-13 18:48:12 +0000357 // Ignore deprecated call as this is internal to SubgraphView
358 ARMNN_NO_DEPRECATE_WARN_BEGIN
Matteo Martincigh49124022019-01-11 13:25:59 +0000359 return end();
Francis Murtagh56ccf682021-12-13 18:48:12 +0000360 ARMNN_NO_DEPRECATE_WARN_END
361}
362
363// IConnectable Duplication to maintain backwards compatibility
364SubgraphView::ConstIConnectableIterator SubgraphView::cbeginIConnectable() const
365{
366 return beginIConnectable();
367}
368
369SubgraphView::ConstIConnectableIterator SubgraphView::cendIConnectable() const
370{
371 return endIConnectable();
Matteo Martincigh49124022019-01-11 13:25:59 +0000372}
373
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100374void SubgraphView::Clear()
375{
376 m_InputSlots.clear();
377 m_OutputSlots.clear();
378 m_Layers.clear();
Francis Murtagh56ccf682021-12-13 18:48:12 +0000379
380 m_IInputSlots.clear();
381 m_IOutputSlots.clear();
382 m_IConnectableLayers.clear();
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100383}
384
Derek Lamberti161d29c2020-12-07 13:54:12 +0000385void SubgraphView::ArrangeBySortOrder()
386{
387 using LayerList = std::list<Layer*>;
388 auto compareLayerPriority = [](const LayerList::value_type& layerA, const LayerList::value_type& layerB)
389 {
390 return layerA->GetPriority() < layerB->GetPriority();
391 };
392
393 m_Layers.sort(compareLayerPriority);
Francis Murtagh56ccf682021-12-13 18:48:12 +0000394
395 using IConnectableLayersList = std::list<IConnectableLayer*>;
396 auto compareIConnectableLayerPriority = [](const IConnectableLayersList::value_type& layerA,
397 const IConnectableLayersList::value_type& layerB)
398 {
399 return PolymorphicDowncast<Layer*>(layerA)->GetPriority() <
400 PolymorphicDowncast<Layer*>(layerB)->GetPriority();
401 };
402
403 m_IConnectableLayers.sort(compareIConnectableLayerPriority);
Derek Lamberti161d29c2020-12-07 13:54:12 +0000404}
405
Francis Murtagh9d74ba62022-01-19 16:31:58 +0000406struct SubgraphView::SubgraphViewWorkingCopy
407{
408public:
409
410 SubgraphViewWorkingCopy() = default;
411 SubgraphViewWorkingCopy(Graph graph)
412 : m_Graph(graph)
413 {};
414
415 Graph m_Graph;
416
417};
418
419SubgraphView SubgraphView::GetWorkingCopy()
420{
421 if (p_WorkingCopyImpl)
422 {
423 throw Exception("The SubgraphView calling GetWorkingCopy() is already a working copy. This function "
424 "should be called on original SubgraphView obtained from OptimizeSubgraphView()");
425 }
426
427 // Create a cut down SubgraphView with underlying graph containing only the relevant layers.
428 // It needs its own underlying layers so that they can be replaced safely.
429 Graph newGraph = Graph();
430 std::unordered_map<const IConnectableLayer*, IConnectableLayer*> originalToClonedLayerMap;
431 std::list<armnn::IConnectableLayer*> originalSubgraphLayers = GetIConnectableLayers();
432
433 auto ptr = std::make_shared<SubgraphViewWorkingCopy>(std::move(newGraph));
434 SubgraphView::IInputSlots workingCopyInputs;
435
436 for (auto&& originalLayer : originalSubgraphLayers)
437 {
438 Layer* const layer = PolymorphicDowncast<const Layer*>(originalLayer)->Clone(ptr->m_Graph);
439 originalToClonedLayerMap.emplace(originalLayer, layer);
440 }
441
442 // Add IInputSlots to workingCopy
443 std::vector<const IConnectableLayer*> processed;
444 for (auto originalSubgraphInputSlot : GetIInputSlots())
445 {
446 const IConnectableLayer& originalSubgraphLayer =
447 PolymorphicDowncast<InputSlot*>(originalSubgraphInputSlot)->GetOwningLayer();
448
449 // Only need process Slots of layer once
450 if (std::find(processed.begin(), processed.end(), &originalSubgraphLayer) == processed.end())
451 {
452 IConnectableLayer* clonedLayer = originalToClonedLayerMap[&originalSubgraphLayer];
453
454 // Add the InputSlot to WorkingCopy InputSlots
455 for (unsigned int i = 0; i < clonedLayer->GetNumInputSlots(); i++)
456 {
457 workingCopyInputs.push_back(&clonedLayer->GetInputSlot(i));
458 }
459 processed.push_back(&originalSubgraphLayer);
460 }
461 }
462 // Empty processed
463 processed.clear();
464
465 for (auto originalSubgraphLayer : originalSubgraphLayers)
466 {
467 IConnectableLayer* const clonedLayer = originalToClonedLayerMap[originalSubgraphLayer];
468
469 // connect all cloned layers as per original subgraph
470 for (unsigned int i = 0; i < clonedLayer->GetNumOutputSlots(); i++)
471 {
472 // OutputLayers have no OutputSlots to be connected
473 if (clonedLayer->GetType() != LayerType::Output)
474 {
475 auto& outputSlot = clonedLayer->GetOutputSlot(i);
476 for (unsigned int k = 0; k < originalSubgraphLayer->GetNumOutputSlots(); k++)
477 {
478 auto& originalOutputSlot = originalSubgraphLayer->GetOutputSlot(k);
479 for (unsigned int j = 0; j < originalOutputSlot.GetNumConnections(); j++)
480 {
481 // nextLayer is the layer with IInputSlot connected to IOutputSlot we are working on
482 const IConnectableLayer& nextLayer =
483 originalOutputSlot.GetConnection(j)->GetOwningIConnectableLayer();
484
485 // Check the layer is in our map and so has a clonedLayer
486 if (originalToClonedLayerMap.find(&nextLayer) != originalToClonedLayerMap.end())
487 {
488 IConnectableLayer* newGraphTargetLayer = originalToClonedLayerMap[&nextLayer];
489
490 IInputSlot& inputSlot =
491 newGraphTargetLayer->GetInputSlot(
492 PolymorphicDowncast<OutputSlot*>(
493 &originalOutputSlot)->GetConnection(j)->GetSlotIndex());
494
495 // Then make the connection
496 outputSlot.Connect(inputSlot);
497 }
498 }
499 // Copy the tensorInfo to the clonedOutputSlot
500 outputSlot.SetTensorInfo(originalOutputSlot.GetTensorInfo());
501 }
502 }
503 }
504 }
505
506 SubgraphView::IOutputSlots workingCopyOutputs;
507
508 // Add IOutputSlots to workingCopy
509 for (auto outputSlot : GetIOutputSlots())
510 {
511
512 const IConnectableLayer& originalSubgraphLayer = outputSlot->GetOwningIConnectableLayer();
513
514 // OutputLayers have no OutputSlots to be connected
515 // Only need process Slots of layer once
516 if (originalSubgraphLayer.GetType() != LayerType::Output &&
517 std::find(processed.begin(), processed.end(), &originalSubgraphLayer) == processed.end())
518 {
519 IConnectableLayer* clonedLayer = originalToClonedLayerMap[&originalSubgraphLayer];
520
521 // Add the OutputSlot to WorkingCopy InputSlots
522 for (unsigned int i = 0; i < clonedLayer->GetNumOutputSlots(); i++)
523 {
524 workingCopyOutputs.push_back(&clonedLayer->GetOutputSlot(i));
525 }
526 processed.push_back(&originalSubgraphLayer);
527 }
528 }
529 processed.clear();
530
531 SubgraphView::IConnectableLayers workingCopyLayers;
532 for (auto& pair : originalToClonedLayerMap)
533 {
534 workingCopyLayers.push_back(pair.second);
535 }
536
537 return {std::move(workingCopyLayers),
538 std::move(workingCopyInputs),
539 std::move(workingCopyOutputs),
540 ptr};
541}
542
543void SubgraphView::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer)
544{
545 ARMNN_ASSERT(substituteLayer != nullptr);
546 SubgraphView substituteSubgraph(substituteLayer);
547
548 SubstituteSubgraph(subgraph, substituteSubgraph);
549}
550
551void SubgraphView::SubstituteSubgraph(SubgraphView& patternSubgraph, const SubgraphView& substituteSubgraph)
552{
553 if (!p_WorkingCopyImpl)
554 {
555 throw NullPointerException("The SubgraphView calling SubstituteSubgraphView is not a working copy. "
556 "Call this function on SubgraphView returned from SubgraphView::GetWorkingCopy()");
557 }
558
559 // Add substitute layer to the Main graph i.e. graph in p_WorkingCopyImpl
560 auto workingCopyGraph = &p_WorkingCopyImpl->m_Graph;
561 substituteSubgraph.ForEachIConnectableLayer([workingCopyGraph](IConnectableLayer* iConnectableLayer)
562 {
563 // Search WorkingCopy Graph for substituteLayer and add if missing
564 if (std::find(std::begin(workingCopyGraph->m_Layers),
565 std::end(workingCopyGraph->m_Layers),
566 iConnectableLayer) ==
567 std::end(workingCopyGraph->m_Layers))
568 {
569 auto layer = PolymorphicDowncast<Layer*>(iConnectableLayer);
570
571 layer->Reparent(*workingCopyGraph,
572 (workingCopyGraph->m_Layers).end());
573
574 workingCopyGraph->m_LayersInOrder = false;
575 }
576 });
577
578 // Replace the old connections with connections to new layer
579 workingCopyGraph->ReplaceSubgraphConnections(patternSubgraph, substituteSubgraph);
580
581 // Update input/outputSlot pointers
582 m_IInputSlots = std::move(substituteSubgraph.m_IInputSlots);
583 m_IOutputSlots = std::move(substituteSubgraph.m_IOutputSlots);
584
585 // Delete the old layers.
586 workingCopyGraph->EraseSubgraphLayers(patternSubgraph);
587
588 // Sort
589 workingCopyGraph->TopologicalSort();
590
591 // Update SubgraphView layer pointers to match those of the internal WorkingCopy layer pointers
592 m_IConnectableLayers = IConnectableLayers{ workingCopyGraph->m_Layers.begin(),
593 workingCopyGraph->m_Layers.end() };
594}
595
596
David Beckf98d21a2018-10-26 16:03:03 +0100597} // namespace armnn