blob: c9930a65243ef7a824bdde852d2cb2ce22e80305 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kelly490b7be2020-03-03 12:39:09 +00002// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "Optimization.hpp"
8
9namespace armnn
10{
11namespace optimizations
12{
13
14class TransposeAsReshapeImpl
15{
16public:
17 /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
18 void Run(Graph& graph, TransposeLayer& transpose) const
19 {
20 if (IsReshape(transpose))
21 {
22 const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo();
23
24 const std::string name = std::string("as_reshape-") + transpose.GetName();
25 const ReshapeDescriptor descriptor{outInfo.GetShape()};
26 // Inserts NewLayer so layers don't need to be re-sorted.
27 auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str());
Mike Kelly490b7be2020-03-03 12:39:09 +000028
29 // Bypass transpose. It will be deleted since it's left unconnected.
30 transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31 }
32 }
33
34protected:
35 TransposeAsReshapeImpl() = default;
36 ~TransposeAsReshapeImpl() = default;
37
38private:
39 static bool IsReshape(const TransposeLayer& layer)
40 {
41 const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42 const PermutationVector& permutation = layer.GetPermutation();
43
44 const unsigned int numDimensions = permutation.GetSize();
45 std::map<unsigned int, unsigned int> permuteMappings;
46 for (unsigned int i = 0; i < permutation.GetSize(); ++i)
47 {
48 permuteMappings[permutation[i]] = i;
49 }
50
51 std::vector<unsigned int> permuteVector;
52 for (unsigned int i = 0; i < permutation.GetSize(); ++i)
53 {
54 permuteVector.push_back(permuteMappings.at(i));
55 }
56
57 unsigned int lastGtOne = 0;
58 while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
59 {
60 ++lastGtOne;
61 }
62
63 bool isReshape = true;
64 for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
65 {
66 if (outShape[permuteVector[i]] > 1U)
67 {
68 isReshape = permuteVector[lastGtOne] < permuteVector[i];
69 lastGtOne = i;
70 }
71 }
72
73 return isReshape;
74 }
75};
76
77using TransposeAsReshape = OptimizeForType<TransposeLayer, TransposeAsReshapeImpl>;
78
79} // namespace optimizations
80} // namespace armnn