blob: 78b08ecace1c61e1115405f6c499455a91eb1fd4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighadddddb2019-01-24 14:06:23 +00005
telsoa014fcda012018-03-09 14:13:49 +00006#include "Graph.hpp"
Derek Lambertiff05cc52019-04-26 13:05:17 +01007#include "SubgraphView.hpp"
surmeh013537c2c2018-05-18 16:31:43 +01008#include "LayersFwd.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000010#include <armnn/backends/IBackendInternal.hpp>
Derek Lamberti84da38b2019-06-13 11:40:08 +010011
12#include <armnn/BackendId.hpp>
Matthew Benthamf48afc62020-01-15 17:55:08 +000013#include <armnn/Logging.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014#include <armnn/TypesUtils.hpp>
Matthew Benthamf48afc62020-01-15 17:55:08 +000015#include <armnn/Utils.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010016#include <armnn/utility/Assert.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
18#include <boost/polymorphic_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019#include <boost/format.hpp>
20
21#include <unordered_map>
surmeh01bceff2f2018-03-29 16:29:27 +010022#include <DotSerializer.hpp>
23#include <sstream>
24
telsoa014fcda012018-03-09 14:13:49 +000025namespace armnn
26{
27
28Graph::Graph(const Graph& other)
29: m_LayersInOrder(other.m_LayersInOrder)
30{
31 std::unordered_map<const Layer*, Layer*> otherToClonedMap;
32
33 for (auto&& otherLayer : other.m_Layers)
34 {
35 Layer* const layer = otherLayer->Clone(*this);
36 otherToClonedMap.emplace(otherLayer, layer);
37 }
38
telsoa01c577f2c2018-08-31 09:22:23 +010039 // Copies slot connections.
telsoa014fcda012018-03-09 14:13:49 +000040 for (auto&& otherLayer : other.m_Layers)
41 {
42 Layer* const thisLayer = otherToClonedMap[otherLayer];
43
44 auto outputSlot = thisLayer->BeginOutputSlots();
45 for (auto&& otherOutputSlot : otherLayer->GetOutputSlots())
46 {
47 for (auto&& otherInputSlot : otherOutputSlot.GetConnections())
48 {
49 const Layer& otherTgtLayer = otherInputSlot->GetOwningLayer();
50 Layer* const thisTgtLayer = otherToClonedMap[&otherTgtLayer];
51
52 InputSlot& inputSlot = thisTgtLayer->GetInputSlot(otherInputSlot->GetSlotIndex());
53 outputSlot->Connect(inputSlot);
54 }
55 outputSlot->SetTensorInfo(otherOutputSlot.GetTensorInfo());
56 ++outputSlot;
57 }
58 }
59}
60
61Status Graph::Print() const
62{
63 if (m_Layers.empty())
64 {
Derek Lamberti08446972019-11-26 16:38:31 +000065 ARMNN_LOG(info) << "\n Graph is empty.\n";
telsoa014fcda012018-03-09 14:13:49 +000066 return Status::Success;
67 }
Derek Lamberti08446972019-11-26 16:38:31 +000068 ARMNN_LOG(info) << "\n";
69 ARMNN_LOG(info) << "Walking Pattern: \n";
telsoa014fcda012018-03-09 14:13:49 +000070
71 for (auto&& it : TopologicalSort())
72 {
Derek Lamberti08446972019-11-26 16:38:31 +000073 ARMNN_LOG(info) << it->GetName() << ":" << GetLayerTypeAsCString(it->GetType())
David Beck33f0ae02018-10-18 15:13:56 +010074 << ":" << it->GetBackendId().Get();
telsoa014fcda012018-03-09 14:13:49 +000075 }
Derek Lamberti08446972019-11-26 16:38:31 +000076 ARMNN_LOG(info) << "\n\n";
telsoa014fcda012018-03-09 14:13:49 +000077
78 return Status::Success;
79}
80
surmeh01bceff2f2018-03-29 16:29:27 +010081Status Graph::SerializeToDot(std::ostream& stream)
82{
83 {
84 DotGraph graph(stream, "Optimized");
85
86 {
87 // Default node attributes:
88 DotDefaults nodes(stream, "node");
89 nodes.GetAttributeSet()
90 .AddAttribute("shape", "record");
91 }
92
93 {
94 // Default edge attributes:
95 DotDefaults edges(stream, "edge");
96 edges.GetAttributeSet()
97 .AddAttribute("fontsize", 8)
98 .AddAttribute("fontcolor", "blue")
99 .AddAttribute("fontname", "arial-bold");
100 }
101
telsoa01c577f2c2018-08-31 09:22:23 +0100102 // First declares the nodes.
surmeh01bceff2f2018-03-29 16:29:27 +0100103 for (auto&& layer : m_Layers)
104 {
105 DotNode node(stream, layer->GetGuid(), GetLayerTypeAsCString(layer->GetType()));
telsoa01c577f2c2018-08-31 09:22:23 +0100106 // Extracts the layer parameters.
surmeh01bceff2f2018-03-29 16:29:27 +0100107 ParameterStringifyFunction extractParams = [&node](const std::string & name, const std::string & value){
108 node.GetContents().AddContent(name + " : " + value);
109 };
110 layer->SerializeLayerParameters(extractParams);
111 }
112
telsoa01c577f2c2018-08-31 09:22:23 +0100113 // Second declares the edges.
surmeh01bceff2f2018-03-29 16:29:27 +0100114 for (auto&& layer : m_Layers)
115 {
116 LayerGuid toId = layer->GetGuid();
117
118 for (unsigned int i=0;i<layer->GetNumInputSlots(); i++)
119 {
120 OutputSlot* outputSlot = static_cast<OutputSlot*>(layer->GetInputSlot(i).GetConnection());
121 LayerGuid fromId = outputSlot->GetOwningLayer().GetGuid();
122 DotEdge edge(stream, fromId, toId);
123
telsoa01c577f2c2018-08-31 09:22:23 +0100124 // Now print the tensor shape on the edge.
surmeh01bceff2f2018-03-29 16:29:27 +0100125 {
telsoa01c577f2c2018-08-31 09:22:23 +0100126 // Constructs the label attribute with HTML markup.
surmeh01bceff2f2018-03-29 16:29:27 +0100127 std::stringstream ss;
surmeh013537c2c2018-05-18 16:31:43 +0100128 ss << "< " << outputSlot->GetTensorInfo().GetShape() << " >";
surmeh01bceff2f2018-03-29 16:29:27 +0100129 edge.GetAttributeSet().AddAttribute("label", ss);
130 }
131 }
132 }
133 }
134
135 if (stream.bad())
136 {
137 return Status::Failure;
138 }
139 return Status::Success;
140}
141
telsoa014fcda012018-03-09 14:13:49 +0000142Status Graph::AllocateDynamicBuffers()
143{
telsoa01c577f2c2018-08-31 09:22:23 +0100144 // Layers must be sorted in topological order
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100145 ARMNN_ASSERT(m_LayersInOrder);
telsoa01c577f2c2018-08-31 09:22:23 +0100146
147 std::unordered_set<const ITensorHandle*> preallocatedTensors;
148 std::unordered_map<const ITensorHandle*, unsigned int> handleReferenceCounts;
149
150 // Finds the first TensorHandle ancestor of a SubTensorHandle. If the ITensorHandle provided
151 // is a TensorHandle, the function just returns it
152 auto TraceSubTensorHandleAncestry = [](ITensorHandle* const subTensorHandle)
153 {
154 ITensorHandle* ancestor = subTensorHandle;
155 while (ancestor && ancestor->GetParent())
156 {
157 ancestor = ancestor->GetParent();
158 }
159 return ancestor;
160 };
161
162 // Checks whether a TensorHandle has been pre-allocated
163 auto IsPreallocated = [&](ITensorHandle* const tensorHandle)
164 {
165 return tensorHandle && preallocatedTensors.find(tensorHandle) != preallocatedTensors.end();
166 };
167
168 // Constant tensor handles need to last from the beginning of execution till the end,
169 // therefore we pre-allocate them upfront
telsoa014fcda012018-03-09 14:13:49 +0000170 for (auto&& layer : m_Layers)
171 {
telsoa01c577f2c2018-08-31 09:22:23 +0100172 if (layer->GetType() == LayerType::Constant)
telsoa014fcda012018-03-09 14:13:49 +0000173 {
telsoa01c577f2c2018-08-31 09:22:23 +0100174 for (auto&& slot = layer->BeginOutputSlots(); slot != layer->EndOutputSlots(); ++slot)
175 {
176 ITensorHandle *tensorHandle = TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData());
177
178 if (tensorHandle && !IsPreallocated(tensorHandle))
179 {
180 tensorHandle->Allocate();
181 preallocatedTensors.insert(tensorHandle);
182 }
183 }
telsoa014fcda012018-03-09 14:13:49 +0000184 }
185 }
telsoa01c577f2c2018-08-31 09:22:23 +0100186
187 // Iterate over the network in topological order
188 for (auto&& layer : m_Layers)
189 {
190 // Count the amount of times each output slot references a certain buffer (ITensorHandle).
191 // The first time we encounter a new tensor handle, we start managing its lifetime.
192 for (auto&& slot = layer->BeginOutputSlots(); slot != layer->EndOutputSlots(); ++slot)
193 {
194 ITensorHandle *tensorHandle = TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData());
195
196 if (tensorHandle && !IsPreallocated(tensorHandle))
197 {
198 unsigned int numConnections = slot->GetNumConnections();
199 if (handleReferenceCounts.find(tensorHandle) == handleReferenceCounts.end())
200 {
201 handleReferenceCounts[tensorHandle] = numConnections;
202 tensorHandle->Manage();
Pablo Tello9af1dcd2019-12-03 15:46:50 +0000203 if (handleReferenceCounts[tensorHandle] == 0u)
204 {
205 // if nobody consumes this tensor we call Allocate()
206 tensorHandle->Allocate();
207 }
telsoa01c577f2c2018-08-31 09:22:23 +0100208 }
209 else
210 {
211 handleReferenceCounts[tensorHandle] += numConnections;
212 }
213 }
214 }
215
216 // Loop through the input slots in the same layer and decrement the reference counter associated
217 // to each tensor handle we encounter. Once it reaches zero, we end the lifetime of the tensor handle
218 for (auto&& slot = layer->BeginInputSlots(); slot != layer->EndInputSlots(); ++slot)
219 {
220 ITensorHandle *tensorHandle = TraceSubTensorHandleAncestry(
221 slot->GetConnectedOutputSlot()->GetOutputHandler().GetData());
222
223 if (tensorHandle && !IsPreallocated(tensorHandle))
224 {
225 --handleReferenceCounts[tensorHandle];
226
227 if (handleReferenceCounts[tensorHandle] == 0u)
228 {
229 // Stop managing lifetime of tensor handle
230 tensorHandle->Allocate();
231 handleReferenceCounts.erase(tensorHandle);
232 }
233 }
234 }
235 }
236
telsoa014fcda012018-03-09 14:13:49 +0000237 return Status::Success;
238}
239
240const Graph& Graph::TopologicalSort() const
241{
242 if (!m_LayersInOrder)
243 {
telsoa01c577f2c2018-08-31 09:22:23 +0100244 // Resets layer order.
telsoa014fcda012018-03-09 14:13:49 +0000245 for (auto&& it : m_Layers)
246 {
247 it->ResetPriority();
248 }
249
Matteo Martincighadddddb2019-01-24 14:06:23 +0000250 auto compareLayerPriority = [](const LayerList::value_type& layerA, const LayerList::value_type& layerB)
telsoa014fcda012018-03-09 14:13:49 +0000251 {
252 return layerA->GetPriority() < layerB->GetPriority();
253 };
254
255 m_Layers.sort(compareLayerPriority);
256
257 m_LayersInOrder = true;
258 }
259
260 return *this;
261}
262
Derek Lambertif674aa02019-08-01 15:56:25 +0100263void Graph::AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<IBackendInternal>>& backends,
264 TensorHandleFactoryRegistry& registry)
telsoa014fcda012018-03-09 14:13:49 +0000265{
Derek Lambertif674aa02019-08-01 15:56:25 +0100266 // Returns true if the given layer could potentially need an intermediate copy/import layer (depending on its
267 // connections to other layers).
268 auto MayNeedCompatibilityLayer = [](const Layer& layer)
telsoa014fcda012018-03-09 14:13:49 +0000269 {
Derek Lamberti84da38b2019-06-13 11:40:08 +0100270 // All layers should have been associated with a valid compute device at this point.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100271 ARMNN_ASSERT(layer.GetBackendId() != Compute::Undefined);
Derek Lambertif674aa02019-08-01 15:56:25 +0100272 // Does not need another compatibility layer if a copy or import layer is already present.
273 return layer.GetType() != LayerType::MemCopy &&
274 layer.GetType() != LayerType::MemImport;
Derek Lamberti84da38b2019-06-13 11:40:08 +0100275 };
telsoa014fcda012018-03-09 14:13:49 +0000276
Derek Lambertif674aa02019-08-01 15:56:25 +0100277 auto IsCompatibilityStrategy = [](EdgeStrategy strategy)
278 {
279 return strategy == EdgeStrategy::CopyToTarget ||
280 strategy == EdgeStrategy::ExportToTarget;
281 };
282
283 ForEachLayer([this, &backends, &registry, MayNeedCompatibilityLayer, IsCompatibilityStrategy](Layer* srcLayer)
Derek Lamberti84da38b2019-06-13 11:40:08 +0100284 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100285 ARMNN_ASSERT(srcLayer);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100286
Derek Lambertif674aa02019-08-01 15:56:25 +0100287 if (!MayNeedCompatibilityLayer(*srcLayer))
Derek Lamberti84da38b2019-06-13 11:40:08 +0100288 {
289 // The current layer does not need copy layers, move to the next one
290 return;
291 }
292
293 const std::vector<OutputSlot>& srcOutputSlots = srcLayer->GetOutputSlots();
294 for (unsigned int srcOutputIndex = 0; srcOutputIndex < srcOutputSlots.size(); srcOutputIndex++)
295 {
296 OutputSlot& srcOutputSlot = srcLayer->GetOutputSlot(srcOutputIndex);
297 const std::vector<InputSlot*> srcConnections = srcOutputSlot.GetConnections();
Derek Lambertif674aa02019-08-01 15:56:25 +0100298 const std::vector<EdgeStrategy> srcEdgeStrategies = srcOutputSlot.GetEdgeStrategies();
Derek Lamberti84da38b2019-06-13 11:40:08 +0100299 for (unsigned int srcConnectionIndex = 0; srcConnectionIndex < srcConnections.size(); srcConnectionIndex++)
300 {
301 InputSlot* dstInputSlot = srcConnections[srcConnectionIndex];
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100302 ARMNN_ASSERT(dstInputSlot);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100303
Derek Lambertif674aa02019-08-01 15:56:25 +0100304 EdgeStrategy strategy = srcEdgeStrategies[srcConnectionIndex];
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100305 ARMNN_ASSERT_MSG(strategy != EdgeStrategy::Undefined,
Derek Lamberti84da38b2019-06-13 11:40:08 +0100306 "Undefined memory strategy found while adding copy layers for compatibility");
307
308 const Layer& dstLayer = dstInputSlot->GetOwningLayer();
Derek Lambertif674aa02019-08-01 15:56:25 +0100309 if (MayNeedCompatibilityLayer(dstLayer) &&
310 IsCompatibilityStrategy(strategy))
Derek Lamberti84da38b2019-06-13 11:40:08 +0100311 {
312 // A copy layer is needed in between the source and destination layers.
313 // Record the operation rather than attempting to modify the graph as we go.
314 // (invalidating iterators)
Derek Lambertif674aa02019-08-01 15:56:25 +0100315 const std::string compLayerName = boost::str(boost::format("[ %1% (%2%) -> %3% (%4%) ]")
Derek Lamberti84da38b2019-06-13 11:40:08 +0100316 % srcLayer->GetName()
317 % srcOutputIndex
318 % dstLayer.GetName()
319 % dstInputSlot->GetSlotIndex());
320
Derek Lambertif674aa02019-08-01 15:56:25 +0100321 Layer* compLayer = nullptr;
322 if (strategy == EdgeStrategy::CopyToTarget)
323 {
324 compLayer = InsertNewLayer<MemCopyLayer>(*dstInputSlot, compLayerName.c_str());
325 }
326 else
327 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100328 ARMNN_ASSERT_MSG(strategy == EdgeStrategy::ExportToTarget, "Invalid edge strategy found.");
Derek Lambertif674aa02019-08-01 15:56:25 +0100329 compLayer = InsertNewLayer<MemImportLayer>(*dstInputSlot, compLayerName.c_str());
330 }
Derek Lamberti84da38b2019-06-13 11:40:08 +0100331
Derek Lambertif674aa02019-08-01 15:56:25 +0100332 compLayer->SetBackendId(dstLayer.GetBackendId());
333
334 OutputSlot& compOutputSlot = compLayer->GetOutputSlot(0);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100335 auto backendIt = backends.find(dstLayer.GetBackendId());
336 if (backendIt != backends.end() &&
337 backendIt->second &&
338 backendIt->second->SupportsTensorAllocatorAPI())
339 {
340 auto backend = backendIt->second.get();
341 auto tensorHandleFactoryIds = backend->GetHandleFactoryPreferences();
342 bool found = false;
Derek Lamberti84da38b2019-06-13 11:40:08 +0100343
344 for (auto preference : tensorHandleFactoryIds)
345 {
346 auto factory = registry.GetFactory(preference);
Derek Lambertif674aa02019-08-01 15:56:25 +0100347 if (factory)
Derek Lamberti84da38b2019-06-13 11:40:08 +0100348 {
Derek Lambertif674aa02019-08-01 15:56:25 +0100349 auto srcPref = srcOutputSlot.GetTensorHandleFactoryId();
350 auto srcFactory = registry.GetFactory(srcPref);
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100351
Ferran Balaguer97520102019-08-14 12:11:27 +0100352 if (srcFactory)
Derek Lambertif674aa02019-08-01 15:56:25 +0100353 {
Ferran Balaguer97520102019-08-14 12:11:27 +0100354 bool canExportImport =
Ferran Balaguerbfeb2712019-08-07 15:14:56 +0100355 (factory->GetImportFlags() & srcFactory->GetExportFlags()) != 0;
356
Ferran Balaguer97520102019-08-14 12:11:27 +0100357 if (factory->SupportsMapUnmap() || canExportImport)
358 {
359 compOutputSlot.SetTensorHandleFactory(preference);
360 found = true;
361 break;
362 }
Derek Lambertif674aa02019-08-01 15:56:25 +0100363 }
Derek Lamberti84da38b2019-06-13 11:40:08 +0100364 }
365 }
366
Ferran Balaguer97520102019-08-14 12:11:27 +0100367 if (!found)
368 {
369 compOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId);
370 }
telsoa014fcda012018-03-09 14:13:49 +0000371 }
Derek Lamberti84da38b2019-06-13 11:40:08 +0100372 else
373 {
Derek Lambertif674aa02019-08-01 15:56:25 +0100374 compOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId);
Derek Lamberti84da38b2019-06-13 11:40:08 +0100375 }
376
Derek Lambertif674aa02019-08-01 15:56:25 +0100377 // The output strategy of a compatibility layer is always DirectCompatibility.
378 compOutputSlot.SetEdgeStrategy(0, EdgeStrategy::DirectCompatibility);
Matthew Bentham0cf01dc2019-07-30 08:24:12 +0000379
380 // Recalculate the connection index on the previous layer as we have just inserted into it.
381 const std::vector<InputSlot*>& newSourceConnections = srcOutputSlot.GetConnections();
382 long newSrcConnectionIndex = std::distance(newSourceConnections.begin(),
383 std::find(newSourceConnections.begin(),
384 newSourceConnections.end(),
Derek Lambertif674aa02019-08-01 15:56:25 +0100385 &compLayer->GetInputSlot(0)));
Matthew Bentham0cf01dc2019-07-30 08:24:12 +0000386
Derek Lambertif674aa02019-08-01 15:56:25 +0100387 // The input strategy of a compatibility layer is always DirectCompatibilty.
388 srcOutputSlot.SetEdgeStrategy(boost::numeric_cast<unsigned int>(newSrcConnectionIndex),
389 EdgeStrategy::DirectCompatibility);
telsoa014fcda012018-03-09 14:13:49 +0000390 }
telsoa014fcda012018-03-09 14:13:49 +0000391 }
392 }
Derek Lamberti84da38b2019-06-13 11:40:08 +0100393 });
telsoa014fcda012018-03-09 14:13:49 +0000394}
395
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100396void Graph::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer)
Matteo Martincigh49124022019-01-11 13:25:59 +0000397{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100398 ARMNN_ASSERT(substituteLayer != nullptr);
Matteo Martincigh49124022019-01-11 13:25:59 +0000399
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100400 ReplaceSubgraphConnections(subgraph, substituteLayer);
401 EraseSubgraphLayers(subgraph);
Matteo Martincigh49124022019-01-11 13:25:59 +0000402}
403
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100404void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph)
Matteo Martincighadddddb2019-01-24 14:06:23 +0000405{
David Monahan5200afa2019-05-10 11:52:14 +0100406 // Look through each layer in the new subgraph and add any that are not already a member of this graph
407 substituteSubgraph.ForEachLayer([this](Layer* layer)
408 {
409 if (std::find(std::begin(m_Layers), std::end(m_Layers), layer) == std::end(m_Layers))
410 {
411 layer->Reparent(*this, m_Layers.end());
412 m_LayersInOrder = false;
413 }
414 });
415
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100416 ReplaceSubgraphConnections(subgraph, substituteSubgraph);
417 EraseSubgraphLayers(subgraph);
Matteo Martincighe3a42452019-05-23 13:56:01 +0100418 TopologicalSort();
Matteo Martincighadddddb2019-01-24 14:06:23 +0000419}
420
Derek Lambertiff05cc52019-04-26 13:05:17 +0100421void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, IConnectableLayer* substituteLayer)
Matteo Martincigh49124022019-01-11 13:25:59 +0000422{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100423 ARMNN_ASSERT(substituteLayer != nullptr);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000424
Matteo Martincigh0c051f92019-01-31 12:09:49 +0000425 // Create a new sub-graph with only the given layer, using
426 // the given sub-graph as a reference of which parent graph to use
Matteo Martincigh602af092019-05-01 10:31:27 +0100427 SubgraphView substituteSubgraph(substituteLayer);
Derek Lambertiff05cc52019-04-26 13:05:17 +0100428 ReplaceSubgraphConnections(subgraph, substituteSubgraph);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000429}
430
Derek Lambertiff05cc52019-04-26 13:05:17 +0100431void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const SubgraphView& substituteSubgraph)
Matteo Martincighadddddb2019-01-24 14:06:23 +0000432{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100433 ARMNN_ASSERT_MSG(!substituteSubgraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty");
Matteo Martincighadddddb2019-01-24 14:06:23 +0000434
Derek Lambertiff05cc52019-04-26 13:05:17 +0100435 const SubgraphView::Layers& substituteSubgraphLayers = substituteSubgraph.GetLayers();
436 std::for_each(substituteSubgraphLayers.begin(), substituteSubgraphLayers.end(), [&](Layer* layer)
Matteo Martincighadddddb2019-01-24 14:06:23 +0000437 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000438 IgnoreUnused(layer);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100439 ARMNN_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(),
Matteo Martincighadddddb2019-01-24 14:06:23 +0000440 "Substitute layer is not a member of graph");
441 });
Matteo Martincigh49124022019-01-11 13:25:59 +0000442
Derek Lambertiff05cc52019-04-26 13:05:17 +0100443 const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots();
444 const SubgraphView::OutputSlots& subgraphOutputSlots = subgraph.GetOutputSlots();
Matteo Martincigh49124022019-01-11 13:25:59 +0000445
Derek Lambertiff05cc52019-04-26 13:05:17 +0100446 unsigned int subgraphNumInputSlots = boost::numeric_cast<unsigned int>(subgraphInputSlots.size());
447 unsigned int subgraphNumOutputSlots = boost::numeric_cast<unsigned int>(subgraphOutputSlots.size());
Matteo Martincigh49124022019-01-11 13:25:59 +0000448
Derek Lambertiff05cc52019-04-26 13:05:17 +0100449 const SubgraphView::InputSlots& substituteSubgraphInputSlots = substituteSubgraph.GetInputSlots();
450 const SubgraphView::OutputSlots& substituteSubgraphOutputSlots = substituteSubgraph.GetOutputSlots();
Matteo Martincigh49124022019-01-11 13:25:59 +0000451
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100452 ARMNN_ASSERT(subgraphNumInputSlots == substituteSubgraphInputSlots.size());
453 ARMNN_ASSERT(subgraphNumOutputSlots == substituteSubgraphOutputSlots.size());
Matteo Martincighadddddb2019-01-24 14:06:23 +0000454
455 // Disconnect the sub-graph and replace it with the substitute sub-graph
456
Matteo Martincigh49124022019-01-11 13:25:59 +0000457 // Step 1: process input slots
Derek Lambertiff05cc52019-04-26 13:05:17 +0100458 for (unsigned int inputSlotIdx = 0; inputSlotIdx < subgraphNumInputSlots; ++inputSlotIdx)
Matteo Martincigh49124022019-01-11 13:25:59 +0000459 {
Derek Lambertiff05cc52019-04-26 13:05:17 +0100460 InputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100461 ARMNN_ASSERT(subgraphInputSlot);
Matteo Martincigh49124022019-01-11 13:25:59 +0000462
Derek Lambertiff05cc52019-04-26 13:05:17 +0100463 IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection();
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100464 ARMNN_ASSERT(connectedOutputSlot);
Derek Lambertiff05cc52019-04-26 13:05:17 +0100465 connectedOutputSlot->Disconnect(*subgraphInputSlot);
Matteo Martincigh49124022019-01-11 13:25:59 +0000466
Derek Lambertiff05cc52019-04-26 13:05:17 +0100467 IInputSlot* substituteInputSlot = substituteSubgraphInputSlots.at(inputSlotIdx);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100468 ARMNN_ASSERT(substituteInputSlot);
Matteo Martincighadddddb2019-01-24 14:06:23 +0000469 connectedOutputSlot->Connect(*substituteInputSlot);
Matteo Martincigh49124022019-01-11 13:25:59 +0000470 }
471
472 // Step 2: process output slots
Derek Lambertiff05cc52019-04-26 13:05:17 +0100473 for(unsigned int outputSlotIdx = 0; outputSlotIdx < subgraphNumOutputSlots; ++outputSlotIdx)
Matteo Martincigh49124022019-01-11 13:25:59 +0000474 {
Derek Lambertiff05cc52019-04-26 13:05:17 +0100475 OutputSlot* subgraphOutputSlot = subgraphOutputSlots.at(outputSlotIdx);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100476 ARMNN_ASSERT(subgraphOutputSlot);
Matteo Martincigh49124022019-01-11 13:25:59 +0000477
Derek Lambertiff05cc52019-04-26 13:05:17 +0100478 OutputSlot* substituteOutputSlot = substituteSubgraphOutputSlots.at(outputSlotIdx);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100479 ARMNN_ASSERT(substituteOutputSlot);
Derek Lambertiff05cc52019-04-26 13:05:17 +0100480 subgraphOutputSlot->MoveAllConnections(*substituteOutputSlot);
Matteo Martincigh49124022019-01-11 13:25:59 +0000481 }
482}
483
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100484void Graph::EraseSubgraphLayers(SubgraphView &subgraph)
Matteo Martincigh49124022019-01-11 13:25:59 +0000485{
Derek Lambertiff05cc52019-04-26 13:05:17 +0100486 for (auto layer : subgraph.GetLayers())
Matteo Martincigh49124022019-01-11 13:25:59 +0000487 {
488 EraseLayer(layer);
489 }
Derek Lambertic2fe5fb2019-05-08 10:23:08 +0100490 subgraph.Clear();
Matteo Martincigh49124022019-01-11 13:25:59 +0000491}
492
telsoa014fcda012018-03-09 14:13:49 +0000493void Graph::InferTensorInfos()
494{
495 for (auto&& layer : TopologicalSort())
496 {
497 for (auto&& input : layer->GetInputSlots())
498 {
Matthew Bentham584a2b82019-11-01 13:29:48 +0000499 const IOutputSlot* source = input.GetConnectedOutputSlot();
500 if (source == NULL)
501 {
502 std::ostringstream message;
503 message << "Input not connected on "
504 << GetLayerTypeAsCString(layer->GetType())
505 << " layer \""
506 << layer->GetName()
507 << "\"";
508 throw LayerValidationException(message.str());
509 }
510
511 if (!source->IsTensorInfoSet())
512 {
513 throw LayerValidationException("All inputs must have the TensorInfo set at this point.");
514 }
telsoa014fcda012018-03-09 14:13:49 +0000515 }
516 layer->ValidateTensorShapesFromInputs();
517 }
518}
519
520} // namespace armnn