blob: 7ace2e067094d09c1ec4cee78c6e813ff07ff475 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
surmeh013537c2c2018-05-18 16:31:43 +01007#include "LayersFwd.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01008#include "IGraphObservable.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009
10#include <armnn/Types.hpp>
11#include <armnn/TensorFwd.hpp>
12#include <armnn/NetworkFwd.hpp>
13#include <armnn/Exceptions.hpp>
14
15#include <list>
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <map>
telsoa014fcda012018-03-09 14:13:49 +000017#include <unordered_map>
18#include <unordered_set>
19#include <vector>
20
21#include <boost/assert.hpp>
22#include <boost/iterator/transform_iterator.hpp>
23
24namespace armnn
25{
telsoa01c577f2c2018-08-31 09:22:23 +010026
telsoa014fcda012018-03-09 14:13:49 +000027class Graph
28{
29public:
30 template <typename CVLayerT>
31 static CVLayerT* PtrCast(Layer* const layer)
32 {
33 return boost::polymorphic_downcast<CVLayerT*>(layer);
34 }
35
36 using LayersList = std::list<Layer*>;
telsoa01c577f2c2018-08-31 09:22:23 +010037 using Iterator = LayersList::const_iterator; // Const so pointers in the list can't be modified externally.
telsoa014fcda012018-03-09 14:13:49 +000038 using ConstIterator = boost::transform_iterator<decltype(&PtrCast<const Layer>), Iterator>;
39 using IteratorDifference = Iterator::difference_type;
40
41 using ConstIteratorInputs = boost::transform_iterator<decltype(&PtrCast<const InputLayer>), Iterator>;
42 using ConstIteratorOutputs = boost::transform_iterator<decltype(&PtrCast<const OutputLayer>), Iterator>;
43
44 /// Wrapper class returned by Graph::GetInputLayers()
45 struct InputLayersAccessor
46 {
47 explicit InputLayersAccessor(const Graph& graph) : m_Graph(graph) {}
48
49 ConstIteratorInputs begin() const
50 {
51 return { m_Graph.m_Layers.begin(), &PtrCast<const InputLayer> };
52 }
53
54 ConstIteratorInputs end() const
55 {
56 return { std::next(m_Graph.m_Layers.begin(), static_cast<IteratorDifference>(m_Graph.GetNumInputs())),
57 &PtrCast<const InputLayer> };
58 }
59
60 const Graph& m_Graph;
61 };
62
63 /// Wrapper class returned by Graph::GetOutputLayers()
64 struct OutputLayersAccessor
65 {
66 explicit OutputLayersAccessor(const Graph& graph) : m_Graph(graph) {}
67
68 ConstIteratorOutputs begin() const
69 {
70 return { std::prev(m_Graph.m_Layers.end(), static_cast<IteratorDifference>(m_Graph.GetNumOutputs())),
71 &PtrCast<const OutputLayer> };
72 }
73
74 ConstIteratorOutputs end() const
75 {
76 return { m_Graph.m_Layers.end(), &PtrCast<const OutputLayer> };
77 }
78
79 const Graph& m_Graph;
80 };
81
82 Graph() : m_LayersInOrder(true) {}
83
84 Graph(const Graph& other);
85
86 Graph& operator=(const Graph& other) = delete;
87
88 ~Graph()
89 {
90 for (auto&& layer : m_Layers)
91 {
92 delete layer;
93 }
94 }
95
96 Status Print() const;
97
surmeh01bceff2f2018-03-29 16:29:27 +010098 Status SerializeToDot(std::ostream& stream);
99
telsoa01c577f2c2018-08-31 09:22:23 +0100100 /// Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
telsoa014fcda012018-03-09 14:13:49 +0000101 template <typename LayerT, typename... Args>
102 LayerT* AddLayer(Args&&... args);
103
104 /// Inserts a new layer between the output slot currently connected to insertBefore
105 /// and insertBefore itself.
106 template <typename LayerT, typename... Args>
107 LayerT* InsertNewLayer(InputSlot& insertBefore, Args&&... args);
108
telsoa01c577f2c2018-08-31 09:22:23 +0100109 /// Inserts a new layer between insertAfter and the input slot(s) currently connected to it
110 template <typename LayerT, typename... Args>
111 LayerT* InsertNewLayer(OutputSlot& insertAfter, Args&&... args);
112
telsoa014fcda012018-03-09 14:13:49 +0000113 /// Deletes the layer at the specified position and returns an iterator pointing
114 /// to the next element after the one being deleted.
115 Iterator EraseLayer(Iterator pos);
116
117 /// Deletes the layer and returns an iterator pointing to the next layer in the graph
118 /// (next in the list, after the one being deleted). Sets @a layer to nullptr on return.
119 /// Templated to support pointers to any layer type.
120 template <typename LayerT>
121 Iterator EraseLayer(LayerT*& layer);
122
telsoa01c577f2c2018-08-31 09:22:23 +0100123 /// Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops.
telsoa014fcda012018-03-09 14:13:49 +0000124 Iterator begin() { return m_Layers.begin(); }
telsoa01c577f2c2018-08-31 09:22:23 +0100125 /// Returns iterator pointing to the end of the list. Lowercase for range-based for loops.
telsoa014fcda012018-03-09 14:13:49 +0000126 Iterator end() { return m_Layers.end(); }
127
telsoa01c577f2c2018-08-31 09:22:23 +0100128 /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
telsoa014fcda012018-03-09 14:13:49 +0000129 ConstIterator begin() const { return {m_Layers.begin(), &PtrCast<const Layer>}; }
telsoa01c577f2c2018-08-31 09:22:23 +0100130 /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
telsoa014fcda012018-03-09 14:13:49 +0000131 ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; }
132
telsoa01c577f2c2018-08-31 09:22:23 +0100133 /// Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops.
surmeh01bceff2f2018-03-29 16:29:27 +0100134 ConstIterator cbegin() const { return begin(); }
telsoa01c577f2c2018-08-31 09:22:23 +0100135 /// Returns const iterator pointing to the end of the list. Lowercase for range-based for loops.
surmeh01bceff2f2018-03-29 16:29:27 +0100136 ConstIterator cend() const { return end(); }
137
telsoa01c577f2c2018-08-31 09:22:23 +0100138 /// Sorts layers in topological order and return this.
telsoa014fcda012018-03-09 14:13:49 +0000139 Graph& TopologicalSort() { const_cast<const Graph*>(this)->TopologicalSort(); return *this; }
140 const Graph& TopologicalSort() const;
141
142 size_t GetNumInputs() const { return m_InputIds.size(); }
143 size_t GetNumOutputs() const { return m_OutputIds.size(); }
144
145 /// Returns a wrapper object with begin(), end() methods to iterate over the input layers
telsoa01c577f2c2018-08-31 09:22:23 +0100146 /// in a range-based for loop.
telsoa014fcda012018-03-09 14:13:49 +0000147 InputLayersAccessor GetInputLayers() const { return InputLayersAccessor(*this); }
148
149 /// Returns a wrapper object with begin(), end() methods to iterate over the output layers
telsoa01c577f2c2018-08-31 09:22:23 +0100150 /// in a range-based for loop.
telsoa014fcda012018-03-09 14:13:49 +0000151 OutputLayersAccessor GetOutputLayers() const { return OutputLayersAccessor(*this); }
152
153 size_t GetNumLayers() const { return m_Layers.size(); }
154
telsoa01c577f2c2018-08-31 09:22:23 +0100155 /// Allocates memory for all tensors under output tensor handers of each layer.
telsoa014fcda012018-03-09 14:13:49 +0000156 Status AllocateDynamicBuffers();
157
158 /// Modifies the graph in-place, removing edges connecting layers using different compute devices,
159 /// and relinking them via an intermediary copy layers.
160 void AddCopyLayers();
161
162 void InferTensorInfos();
163
telsoa01c577f2c2018-08-31 09:22:23 +0100164 void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
165 m_Views[notifyOnEvent].emplace_back(observable);
166 }
167
168 void DetachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
169 m_Views[notifyOnEvent].remove(observable);
170 }
171
telsoa014fcda012018-03-09 14:13:49 +0000172private:
173 template <typename LayerT>
174 class LayerInGraphBase;
175
176 template <typename LayerT>
177 class LayerInGraph;
178
surmeh01bceff2f2018-03-29 16:29:27 +0100179 Iterator ForwardToEndOfInputs(Iterator it) const
180 {
181 while ((it != m_Layers.end()) && ((*it)->GetType() == LayerType::Input))
182 {
183 ++it;
184 }
185 return it;
186 }
187
188 Iterator RewindToBeginOfOutputs(Iterator it) const
189 {
190 while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() == LayerType::Output))
191 {
192 --it;
193 }
194 return it;
195 }
196
telsoa01c577f2c2018-08-31 09:22:23 +0100197 /// Gets the position of a layer in the graph.
telsoa014fcda012018-03-09 14:13:49 +0000198 Iterator GetPosInGraph(Layer& layer);
199
telsoa01c577f2c2018-08-31 09:22:23 +0100200 void NotifyObservables(GraphEvent event, Layer* graphState)
201 {
202 // Iterate over all observables observing this event
203 for (auto& observable : m_Views[event])
204 {
205 observable->Update(graphState);
206 }
207 }
208
telsoa014fcda012018-03-09 14:13:49 +0000209 std::unordered_set<LayerBindingId> m_InputIds;
210 std::unordered_set<LayerBindingId> m_OutputIds;
211 std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
212
213 /// Mutable to allow sorting on const object.
214 mutable LayersList m_Layers;
215 mutable bool m_LayersInOrder;
telsoa01c577f2c2018-08-31 09:22:23 +0100216
217 std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
telsoa014fcda012018-03-09 14:13:49 +0000218};
219
telsoa01c577f2c2018-08-31 09:22:23 +0100220/// Common base class for layers in the graph.
telsoa014fcda012018-03-09 14:13:49 +0000221template <typename LayerT>
222class Graph::LayerInGraphBase : public LayerT
223{
224protected:
225 template <typename... Args>
226 LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args)
227 : LayerT(std::forward<Args>(args)...), m_Graph(graph)
228 {
229 m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this));
230 }
231 ~LayerInGraphBase()
232 {
233 const size_t numErased = m_Graph.m_PosInGraphMap.erase(this);
234 boost::ignore_unused(numErased);
235 BOOST_ASSERT(numErased == 1);
236 }
237
238 Graph& m_Graph;
239};
240
telsoa01c577f2c2018-08-31 09:22:23 +0100241/// Input/Output layers specialize this template.
telsoa014fcda012018-03-09 14:13:49 +0000242template <typename LayerT>
243class Graph::LayerInGraph final : public LayerInGraphBase<LayerT>
244{
245public:
246 template <typename... Args>
surmeh01bceff2f2018-03-29 16:29:27 +0100247 LayerInGraph(Graph& graph, Args&&... args)
248 : LayerInGraphBase<LayerT>(graph,
249 // Insert at the back of the intermediate layers (before outputs).
250 std::prev(graph.end(), IteratorDifference(graph.GetNumOutputs())),
251 std::forward<Args>(args)...)
252 {
253 }
254 template <typename... Args>
telsoa014fcda012018-03-09 14:13:49 +0000255 LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args)
surmeh01bceff2f2018-03-29 16:29:27 +0100256 : LayerInGraphBase<LayerT>(graph,
257 // Make sure it's inserted after all inputs and before all outputs.
258 graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)),
259 std::forward<Args>(args)...)
telsoa014fcda012018-03-09 14:13:49 +0000260 {
261 }
262};
263
264/// Inputs add/remove their binding id to m_InputIds in the graph.
265template <>
266class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer>
267{
268public:
269 template <typename... Args>
surmeh01bceff2f2018-03-29 16:29:27 +0100270 LayerInGraph(Graph& graph, Args&&... args)
271 : LayerInGraphBase<InputLayer>(graph,
272 // Always add to the back of the inputs.
273 std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
274 std::forward<Args>(args)...)
telsoa014fcda012018-03-09 14:13:49 +0000275 {
276 const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second;
277 if (!isNewId)
278 {
279 throw InvalidArgumentException("A layer already exists with the specified id");
280 }
281 }
surmeh01bceff2f2018-03-29 16:29:27 +0100282 template <typename... Args>
surmeh013537c2c2018-05-18 16:31:43 +0100283 LayerInGraph(Graph& graph, Iterator, Args&&... args)
284 // Ignore Iterator argument. Always add to the back of the inputs.
surmeh01bceff2f2018-03-29 16:29:27 +0100285 : LayerInGraph(graph, std::forward<Args>(args)...)
286 {
287 }
telsoa014fcda012018-03-09 14:13:49 +0000288 ~LayerInGraph() override
289 {
290 const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId());
291 boost::ignore_unused(numErased);
292 BOOST_ASSERT(numErased == 1);
293 }
294};
295
296/// Outputs add/remove their binding id to m_OutputIds in the graph.
297template <>
298class Graph::LayerInGraph<OutputLayer> final : public LayerInGraphBase<OutputLayer>
299{
300public:
301 template <typename... Args>
surmeh01bceff2f2018-03-29 16:29:27 +0100302 LayerInGraph(Graph& graph, Args&&... args)
303 : LayerInGraphBase<OutputLayer>(graph,
304 // Always add to the back of the outputs.
305 graph.end(),
306 std::forward<Args>(args)...)
telsoa014fcda012018-03-09 14:13:49 +0000307 {
308 const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second;
309 if (!isNewId)
310 {
311 throw InvalidArgumentException("A layer already exists with the specified id");
312 }
313 }
314 ~LayerInGraph() override
315 {
316 const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId());
317 boost::ignore_unused(numErased);
318 BOOST_ASSERT(numErased == 1);
319 }
320};
321
322inline Graph::Iterator Graph::GetPosInGraph(Layer& layer)
323{
324 auto it = m_PosInGraphMap.find(&layer);
325 BOOST_ASSERT(it != m_PosInGraphMap.end());
326 return it->second;
327}
328
329template <typename LayerT, typename... Args>
telsoa014fcda012018-03-09 14:13:49 +0000330inline LayerT* Graph::AddLayer(Args&&... args)
331{
surmeh01bceff2f2018-03-29 16:29:27 +0100332 m_LayersInOrder = m_LayersInOrder &&
333 ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output));
telsoa01c577f2c2018-08-31 09:22:23 +0100334 LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...);
335
336 NotifyObservables(GraphEvent::LayerAdded, layer);
337
338 return layer;
telsoa014fcda012018-03-09 14:13:49 +0000339}
340
341template <typename LayerT, typename... Args>
342inline LayerT* Graph::InsertNewLayer(InputSlot& insertBefore, Args&&... args)
343{
telsoa01c577f2c2018-08-31 09:22:23 +0100344 // Insert after the parent if any, or before the child otherwise, so the topological order is kept.
surmeh01bceff2f2018-03-29 16:29:27 +0100345 OutputSlot* parentOut = insertBefore.GetConnectedOutputSlot();
346 const Iterator pos = (parentOut != nullptr)
347 ? std::next(GetPosInGraph(parentOut->GetOwningLayer()))
348 : GetPosInGraph(insertBefore.GetOwningLayer());
349 LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
telsoa014fcda012018-03-09 14:13:49 +0000350 insertBefore.Insert(*layer);
telsoa01c577f2c2018-08-31 09:22:23 +0100351
352 NotifyObservables(GraphEvent::LayerAdded, layer);
353
354 return layer;
355}
356
357template <typename LayerT, typename... Args>
358inline LayerT* Graph::InsertNewLayer(OutputSlot& insertAfter, Args&&... args)
359{
360 Layer& owningLayer = insertAfter.GetOwningLayer();
361
362 const Iterator pos = std::next(GetPosInGraph(owningLayer));
363 LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...);
364
365 BOOST_ASSERT(layer->GetNumInputSlots() == 1);
366
367 insertAfter.MoveAllConnections(layer->GetOutputSlot());
368 insertAfter.Connect(layer->GetInputSlot(0));
369
370 NotifyObservables(GraphEvent::LayerAdded, layer);
371
telsoa014fcda012018-03-09 14:13:49 +0000372 return layer;
373}
374
375inline Graph::Iterator Graph::EraseLayer(Iterator pos)
376{
telsoa01c577f2c2018-08-31 09:22:23 +0100377 NotifyObservables(GraphEvent::LayerErased, *pos);
378
telsoa014fcda012018-03-09 14:13:49 +0000379 delete *pos;
380 return m_Layers.erase(pos);
381}
382
383template <typename LayerT>
384inline Graph::Iterator Graph::EraseLayer(LayerT*& layer)
385{
386 BOOST_ASSERT(layer != nullptr);
387 Iterator next = EraseLayer(GetPosInGraph(*layer));
388 layer = nullptr;
389 return next;
390}
391
392} // namespace armnn