blob: 5e50c01c09d2d681bd5beccda2ba66f976d7b113 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Optimizer.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "Observable.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include "optimizations/All.hpp"
8
9namespace armnn
10{
11
surmeh01bceff2f2018-03-29 16:29:27 +010012Optimizer::Optimizer()
telsoa014fcda012018-03-09 14:13:49 +000013{
telsoa014fcda012018-03-09 14:13:49 +000014}
15
telsoa01c577f2c2018-08-31 09:22:23 +010016void Optimizer::Pass(Graph& graph, const Optimizations& optimizations)
telsoa014fcda012018-03-09 14:13:49 +000017{
telsoa01c577f2c2018-08-31 09:22:23 +010018 // Create observables to observe changes to the graph
19 AddedLayerObservable addedLayerObservable(graph);
20 ErasedLayerNamesObservable erasedLayerNamesObservable(graph);
21
22 bool graphNeedsSorting = false;
telsoa014fcda012018-03-09 14:13:49 +000023 auto it = graph.TopologicalSort().end();
telsoa01c577f2c2018-08-31 09:22:23 +010024
25 // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed.
telsoa014fcda012018-03-09 14:13:49 +000026 while (it != graph.TopologicalSort().begin())
27 {
28 --it;
telsoa01c577f2c2018-08-31 09:22:23 +010029 for (auto&& optimization : optimizations)
telsoa014fcda012018-03-09 14:13:49 +000030 {
surmeh01bceff2f2018-03-29 16:29:27 +010031 optimization->Run(graph, **it);
telsoa014fcda012018-03-09 14:13:49 +000032
33 if ((*it)->IsOutputUnconnected())
34 {
35 it = graph.EraseLayer(it);
telsoa01c577f2c2018-08-31 09:22:23 +010036 graphNeedsSorting = true;
37 }
38
39 // Add the names of erased layers as related layers to the new added layers
40 for (auto& erasedLayerName : erasedLayerNamesObservable)
41 {
42 for (auto& addedLayer : addedLayerObservable)
43 {
44 addedLayer->AddRelatedLayerName(erasedLayerName);
45 }
46 }
47
48 erasedLayerNamesObservable.Clear();
49 addedLayerObservable.Clear();
50
51 if (graphNeedsSorting)
52 {
53 graphNeedsSorting = false;
telsoa014fcda012018-03-09 14:13:49 +000054 break;
55 }
56 }
57 }
58}
59
telsoa014fcda012018-03-09 14:13:49 +000060} // namespace armnn