Laurent Carlier | 749294b | 2020-06-01 09:03:17 +0100 | [diff] [blame] | 1 | // |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
| 5 | #include "Optimizer.hpp" |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 6 | #include "Observable.hpp" |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 7 | #include "optimizations/All.hpp" |
| 8 | |
| 9 | namespace armnn |
| 10 | { |
| 11 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 12 | Optimizer::Optimizer() |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 13 | { |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 14 | } |
| 15 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 16 | void Optimizer::Pass(Graph& graph, const Optimizations& optimizations) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 17 | { |
Derek Lamberti | f1e0ad3 | 2021-10-13 18:02:25 +0100 | [diff] [blame] | 18 | ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Optimizer_Pass"); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 19 | // Create observables to observe changes to the graph |
| 20 | AddedLayerObservable addedLayerObservable(graph); |
| 21 | ErasedLayerNamesObservable erasedLayerNamesObservable(graph); |
| 22 | |
| 23 | bool graphNeedsSorting = false; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 24 | auto it = graph.TopologicalSort().end(); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 25 | |
| 26 | // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 27 | while (it != graph.TopologicalSort().begin()) |
| 28 | { |
| 29 | --it; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 30 | for (auto&& optimization : optimizations) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 31 | { |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 32 | ARMNN_ASSERT(*it); |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 33 | optimization->Run(graph, **it); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 34 | |
| 35 | if ((*it)->IsOutputUnconnected()) |
| 36 | { |
Matteo Martincigh | f3d1021 | 2019-05-09 19:06:22 +0100 | [diff] [blame] | 37 | auto next = std::next(graph.GetPosInGraph(**it)); |
| 38 | graph.EraseLayer(it); |
| 39 | it = next; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 40 | graphNeedsSorting = true; |
| 41 | } |
| 42 | |
| 43 | // Add the names of erased layers as related layers to the new added layers |
| 44 | for (auto& erasedLayerName : erasedLayerNamesObservable) |
| 45 | { |
| 46 | for (auto& addedLayer : addedLayerObservable) |
| 47 | { |
| 48 | addedLayer->AddRelatedLayerName(erasedLayerName); |
| 49 | } |
| 50 | } |
| 51 | |
| 52 | erasedLayerNamesObservable.Clear(); |
| 53 | addedLayerObservable.Clear(); |
| 54 | |
| 55 | if (graphNeedsSorting) |
| 56 | { |
| 57 | graphNeedsSorting = false; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 58 | break; |
| 59 | } |
| 60 | } |
| 61 | } |
| 62 | } |
| 63 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 64 | } // namespace armnn |