blob: 1d6a52efed44759f098e458722d632153e405bab [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Optimizer.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "Observable.hpp"
telsoa014fcda012018-03-09 14:13:49 +00007#include "optimizations/All.hpp"
8
9namespace armnn
10{
11
surmeh01bceff2f2018-03-29 16:29:27 +010012Optimizer::Optimizer()
telsoa014fcda012018-03-09 14:13:49 +000013{
telsoa014fcda012018-03-09 14:13:49 +000014}
15
telsoa01c577f2c2018-08-31 09:22:23 +010016void Optimizer::Pass(Graph& graph, const Optimizations& optimizations)
telsoa014fcda012018-03-09 14:13:49 +000017{
Derek Lambertif1e0ad32021-10-13 18:02:25 +010018 ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Optimizer_Pass");
telsoa01c577f2c2018-08-31 09:22:23 +010019 // Create observables to observe changes to the graph
20 AddedLayerObservable addedLayerObservable(graph);
21 ErasedLayerNamesObservable erasedLayerNamesObservable(graph);
22
23 bool graphNeedsSorting = false;
telsoa014fcda012018-03-09 14:13:49 +000024 auto it = graph.TopologicalSort().end();
telsoa01c577f2c2018-08-31 09:22:23 +010025
26 // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed.
telsoa014fcda012018-03-09 14:13:49 +000027 while (it != graph.TopologicalSort().begin())
28 {
29 --it;
telsoa01c577f2c2018-08-31 09:22:23 +010030 for (auto&& optimization : optimizations)
telsoa014fcda012018-03-09 14:13:49 +000031 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010032 ARMNN_ASSERT(*it);
surmeh01bceff2f2018-03-29 16:29:27 +010033 optimization->Run(graph, **it);
telsoa014fcda012018-03-09 14:13:49 +000034
35 if ((*it)->IsOutputUnconnected())
36 {
Matteo Martincighf3d10212019-05-09 19:06:22 +010037 auto next = std::next(graph.GetPosInGraph(**it));
38 graph.EraseLayer(it);
39 it = next;
telsoa01c577f2c2018-08-31 09:22:23 +010040 graphNeedsSorting = true;
41 }
42
43 // Add the names of erased layers as related layers to the new added layers
44 for (auto& erasedLayerName : erasedLayerNamesObservable)
45 {
46 for (auto& addedLayer : addedLayerObservable)
47 {
48 addedLayer->AddRelatedLayerName(erasedLayerName);
49 }
50 }
51
52 erasedLayerNamesObservable.Clear();
53 addedLayerObservable.Clear();
54
55 if (graphNeedsSorting)
56 {
57 graphNeedsSorting = false;
telsoa014fcda012018-03-09 14:13:49 +000058 break;
59 }
60 }
61 }
62}
63
telsoa014fcda012018-03-09 14:13:49 +000064} // namespace armnn