blob: 9a926a57a4efbc5b8fe91e5aa9d09e08f43b7209 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
14class OptimizeConsecutiveReshapesImpl
15{
16public:
17 /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
18 /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
19 void Run(Graph& graph, InputSlot& connection) const
20 {
surmeh01bceff2f2018-03-29 16:29:27 +010021 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22 Layer& child = connection.GetOwningLayer();
telsoa014fcda012018-03-09 14:13:49 +000023
24 BOOST_ASSERT(base.GetType() == LayerType::Reshape);
25 BOOST_ASSERT(child.GetType() == LayerType::Reshape);
26
27 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28
29 const TensorInfo& inInfo = parentOut->GetTensorInfo();
30 const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31
32 if (inInfo.GetShape() != outInfo.GetShape())
33 {
34 // Insert equivalent reshape before base layer
35 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
36 const ReshapeDescriptor descriptor{outInfo.GetShape()};
37 auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
38 // Set tensor info for new layer
39 newReshape.GetOutputHandler().SetTensorInfo(outInfo);
40 // Reconnect base with original parent
41 newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
42 // Parent is now the new layer
43 parentOut = &newReshape.GetOutputSlot();
44 }
45
46 // Move connections in child output to parent layer.
47 // Child layer will be removed as it's left unconnected.
48 // Base layer will be removed if left unconnected.
49 child.GetOutputSlot().MoveAllConnections(*parentOut);
50 }
51
52protected:
53 OptimizeConsecutiveReshapesImpl() = default;
54 ~OptimizeConsecutiveReshapesImpl() = default;
55};
56
57using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
58
59} // namespace optimizations
60} // namespace armnn