telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // See LICENSE file in the project root for full license information. |
| 4 | // |
| 5 | #pragma once |
| 6 | |
| 7 | #include "Optimization.hpp" |
| 8 | |
| 9 | namespace armnn |
| 10 | { |
| 11 | namespace optimizations |
| 12 | { |
| 13 | |
| 14 | class OptimizeConsecutiveReshapesImpl |
| 15 | { |
| 16 | public: |
| 17 | /// Run for every connection between a base RashapeLayer and a child ReshapeLayer. |
| 18 | /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. |
| 19 | void Run(Graph& graph, InputSlot& connection) const |
| 20 | { |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 21 | Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); |
| 22 | Layer& child = connection.GetOwningLayer(); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 23 | |
| 24 | BOOST_ASSERT(base.GetType() == LayerType::Reshape); |
| 25 | BOOST_ASSERT(child.GetType() == LayerType::Reshape); |
| 26 | |
| 27 | OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); |
| 28 | |
| 29 | const TensorInfo& inInfo = parentOut->GetTensorInfo(); |
| 30 | const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo(); |
| 31 | |
| 32 | if (inInfo.GetShape() != outInfo.GetShape()) |
| 33 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame^] | 34 | // Inserts equivalent reshape before base layer. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 35 | const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); |
| 36 | const ReshapeDescriptor descriptor{outInfo.GetShape()}; |
| 37 | auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame^] | 38 | // Sets tensor info for new layer. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 39 | newReshape.GetOutputHandler().SetTensorInfo(outInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame^] | 40 | // Reconnects base with original parent. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 41 | newReshape.GetOutputSlot().MoveAllConnections(*parentOut); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame^] | 42 | // Parent is now the new layer. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 43 | parentOut = &newReshape.GetOutputSlot(); |
| 44 | } |
| 45 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame^] | 46 | // Moves connections in child output to parent layer. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 47 | // Child layer will be removed as it's left unconnected. |
| 48 | // Base layer will be removed if left unconnected. |
| 49 | child.GetOutputSlot().MoveAllConnections(*parentOut); |
| 50 | } |
| 51 | |
| 52 | protected: |
| 53 | OptimizeConsecutiveReshapesImpl() = default; |
| 54 | ~OptimizeConsecutiveReshapesImpl() = default; |
| 55 | }; |
| 56 | |
| 57 | using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>; |
| 58 | |
| 59 | } // namespace optimizations |
| 60 | } // namespace armnn |