blob: 85b9f2803c56835c6be9125ab9d8a1945b47354e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "Optimizer.hpp"
6#include "optimizations/All.hpp"
7
8namespace armnn
9{
10
11const Optimizer& Optimizer::Get()
12{
13 // Add optimizations here
14 static optimizations::SquashEqualPermuteSiblings squashEqualPermuteSiblings;
15 static optimizations::SquashEqualReshapeSiblings squashEqualReshapeSiblings;
16 static optimizations::OptimizeInversePermutes optimizeInversePermutes;
17 static optimizations::MovePermuteUp movePermuteUp;
18 static optimizations::PermuteAsReshape permuteAsReshape;
19 static optimizations::OptimizeConsecutiveReshapes optimizeConsecutiveReshapes;
20
21 // Set optimizations in desired order
22 static const Optimizer optimizer({
23 &squashEqualPermuteSiblings,
24 &squashEqualReshapeSiblings,
25 &optimizeInversePermutes,
26 &movePermuteUp,
27 &permuteAsReshape,
28 &optimizeConsecutiveReshapes,
29 });
30
31 return optimizer;
32}
33
34void Optimizer::Optimize(Graph& graph) const
35{
36 auto it = graph.TopologicalSort().end();
37 // Call TopologicalSort() in every iteration to re-order the list in case layers where added/removed.
38 while (it != graph.TopologicalSort().begin())
39 {
40 --it;
41 for (auto&& optimization : m_Optimizations)
42 {
43 optimization->Run(graph, it);
44
45 if ((*it)->IsOutputUnconnected())
46 {
47 it = graph.EraseLayer(it);
48 break;
49 }
50 }
51 }
52}
53
54
55} // namespace armnn