blob: fcc1fec34e7435c71bb8c1cb8805639960724972 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
14template <typename Comparable>
15class SquashEqualSiblingsImpl
16{
17public:
18 /// Run for every connection between a base Layer (any) and a child ComparableLayer.
19 /// For all siblings of the child layer that compare equal to it, bypasses and removes
20 /// them. I.e., moves the connections in the outputs of the siblings to the outputs of
21 /// the child layer, so the siblings are left unconnected (and later removed).
22 void Run(Graph& graph, InputSlot& connection) const
23 {
24 auto& child = connection.GetOwningLayer();
25
26 if (!child.IsOutputUnconnected())
27 {
28 OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
telsoa014fcda012018-03-09 14:13:49 +000029
surmeh01bceff2f2018-03-29 16:29:27 +010030 if (baseOutput.GetNumConnections() > 1)
telsoa014fcda012018-03-09 14:13:49 +000031 {
surmeh01bceff2f2018-03-29 16:29:27 +010032 auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child);
33
34 Layer* lowestPriorityChild = &child;
35 for (auto&& it : baseOutput.GetConnections())
telsoa014fcda012018-03-09 14:13:49 +000036 {
surmeh01bceff2f2018-03-29 16:29:27 +010037 Layer* sibling = &it->GetOwningLayer();
38 if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
telsoa014fcda012018-03-09 14:13:49 +000039 {
surmeh01bceff2f2018-03-29 16:29:27 +010040 if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
41 {
42 std::swap(sibling, lowestPriorityChild);
43 }
telsoa01c577f2c2018-08-31 09:22:23 +010044 // Bypasses sibling. It will be removed as it's left unconnected.
surmeh01bceff2f2018-03-29 16:29:27 +010045 auto siblingOut = sibling->BeginOutputSlots();
46 for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
47 lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
48 {
49 siblingOut->MoveAllConnections(*lowestPriorityChildOut);
50 ++siblingOut;
51 }
telsoa014fcda012018-03-09 14:13:49 +000052 }
53 }
54 }
55 }
56 }
57
58protected:
59 SquashEqualSiblingsImpl() = default;
60 ~SquashEqualSiblingsImpl() = default;
61};
62
63using SquashEqualPermuteSiblings = OptimizeForConnection<Layer, PermuteLayer, SquashEqualSiblingsImpl<PermuteLayer>>;
64using SquashEqualReshapeSiblings = OptimizeForConnection<Layer, ReshapeLayer, SquashEqualSiblingsImpl<ReshapeLayer>>;
65
66} // namespace optimizations
67} // namespace armnn