blob: 935186d32efc7d390b5497b4d47d3bacbbf6abcd [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
14class OptimizeConsecutiveReshapesImpl
15{
16public:
17 /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
18 /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19 void Run(Graph& graph, InputSlot& connection) const
20 {
surmeh01bceff2f2018-03-29 16:29:27 +010021 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22 Layer& child = connection.GetOwningLayer();
telsoa014fcda012018-03-09 14:13:49 +000023
24 BOOST_ASSERT(base.GetType() == LayerType::Reshape);
25 BOOST_ASSERT(child.GetType() == LayerType::Reshape);
26
27 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28
29 const TensorInfo& inInfo = parentOut->GetTensorInfo();
30 const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31
32 if (inInfo.GetShape() != outInfo.GetShape())
33 {
telsoa01c577f2c2018-08-31 09:22:23 +010034 // Inserts equivalent reshape before base layer.
telsoa014fcda012018-03-09 14:13:49 +000035 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
36 const ReshapeDescriptor descriptor{outInfo.GetShape()};
37 auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
telsoa01c577f2c2018-08-31 09:22:23 +010038 // Sets tensor info for new layer.
telsoa014fcda012018-03-09 14:13:49 +000039 newReshape.GetOutputHandler().SetTensorInfo(outInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010040 // Reconnects base with original parent.
telsoa014fcda012018-03-09 14:13:49 +000041 newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
telsoa01c577f2c2018-08-31 09:22:23 +010042 // Parent is now the new layer.
telsoa014fcda012018-03-09 14:13:49 +000043 parentOut = &newReshape.GetOutputSlot();
44 }
45
telsoa01c577f2c2018-08-31 09:22:23 +010046 // Moves connections in child output to parent layer.
telsoa014fcda012018-03-09 14:13:49 +000047 // Child layer will be removed as it's left unconnected.
48 // Base layer will be removed if left unconnected.
49 child.GetOutputSlot().MoveAllConnections(*parentOut);
50 }
51
52protected:
53 OptimizeConsecutiveReshapesImpl() = default;
54 ~OptimizeConsecutiveReshapesImpl() = default;
55};
56
57using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
58
59} // namespace optimizations
60} // namespace armnn