telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
| 5 | #pragma once |
| 6 | |
| 7 | #include "Optimization.hpp" |
| 8 | |
| 9 | namespace armnn |
| 10 | { |
| 11 | namespace optimizations |
| 12 | { |
| 13 | |
| 14 | template <typename Comparable> |
| 15 | class SquashEqualSiblingsImpl |
| 16 | { |
| 17 | public: |
| 18 | /// Run for every connection between a base Layer (any) and a child ComparableLayer. |
| 19 | /// For all siblings of the child layer that compare equal to it, bypasses and removes |
| 20 | /// them. I.e., moves the connections in the outputs of the siblings to the outputs of |
| 21 | /// the child layer, so the siblings are left unconnected (and later removed). |
| 22 | void Run(Graph& graph, InputSlot& connection) const |
| 23 | { |
| 24 | auto& child = connection.GetOwningLayer(); |
| 25 | |
| 26 | if (!child.IsOutputUnconnected()) |
| 27 | { |
| 28 | OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 29 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 30 | if (baseOutput.GetNumConnections() > 1) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 31 | { |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 32 | auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child); |
| 33 | |
| 34 | Layer* lowestPriorityChild = &child; |
| 35 | for (auto&& it : baseOutput.GetConnections()) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 36 | { |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 37 | Layer* sibling = &it->GetOwningLayer(); |
| 38 | if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling)) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 39 | { |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 40 | if (sibling->GetPriority() < lowestPriorityChild->GetPriority()) |
| 41 | { |
| 42 | std::swap(sibling, lowestPriorityChild); |
| 43 | } |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 44 | // Bypasses sibling. It will be removed as it's left unconnected. |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 45 | auto siblingOut = sibling->BeginOutputSlots(); |
| 46 | for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots(); |
| 47 | lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut) |
| 48 | { |
| 49 | siblingOut->MoveAllConnections(*lowestPriorityChildOut); |
| 50 | ++siblingOut; |
| 51 | } |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 52 | } |
| 53 | } |
| 54 | } |
| 55 | } |
| 56 | } |
| 57 | |
| 58 | protected: |
| 59 | SquashEqualSiblingsImpl() = default; |
| 60 | ~SquashEqualSiblingsImpl() = default; |
| 61 | }; |
| 62 | |
| 63 | using SquashEqualPermuteSiblings = OptimizeForConnection<Layer, PermuteLayer, SquashEqualSiblingsImpl<PermuteLayer>>; |
| 64 | using SquashEqualReshapeSiblings = OptimizeForConnection<Layer, ReshapeLayer, SquashEqualSiblingsImpl<ReshapeLayer>>; |
| 65 | |
| 66 | } // namespace optimizations |
| 67 | } // namespace armnn |