| // |
| // Copyright © 2017 Arm Ltd. All rights reserved. |
| // See LICENSE file in the project root for full license information. |
| // |
| #pragma once |
| |
| #include "Optimization.hpp" |
| #include "Permute.hpp" |
| |
| namespace armnn |
| { |
| namespace optimizations |
| { |
| class MovePermuteUpImpl |
| { |
| public: |
| /// Run for every connection between a base Layer (any) and a child PermuteLayer. If the type |
| /// of the base layer allows it, it moves the permutation to the inputs of the base layer. |
| /// I.e., adds equivalent permutations before the inputs of the base layer and moves the |
| /// connections in the output of the child permute layer to the output of the base layer. |
| void Run(Graph& graph, InputSlot& connection) const |
| { |
| OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); |
| |
| if (baseOutput.GetNumConnections() == 1U) |
| { |
| Layer& base = baseOutput.GetOwningLayer(); |
| |
| if (CanMovePermuteToInputs(base)) |
| { |
| auto permute = boost::polymorphic_downcast<PermuteLayer*>(&connection.GetOwningLayer()); |
| const PermutationVector& perm = permute->GetPermutation(); |
| |
| // Inserts an equivalent permute before every input of the base layer. |
| for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput) |
| { |
| // Inserts a new permute layer. |
| const std::string name = std::string("moved_up-") + permute->GetName(); |
| PermuteLayer& permLayer = *graph.InsertNewLayer<PermuteLayer>(*baseInput, perm, name.c_str()); |
| |
| // Sets output tensor info for the new layer. |
| OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot(); |
| const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm); |
| permLayer.GetOutputHandler().SetTensorInfo(permOutInfo); |
| } |
| |
| // Sets permuted output tensor info |
| const TensorInfo& childOutInfo = permute->GetOutputHandler().GetTensorInfo(); |
| base.GetOutputHandler().SetTensorInfo(childOutInfo); |
| |
| // Bypasses permute. It will be removed as it's left unconnected. |
| permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot()); |
| } |
| } |
| } |
| |
| protected: |
| MovePermuteUpImpl() = default; |
| ~MovePermuteUpImpl() = default; |
| |
| private: |
| static bool CanMovePermuteToInputs(const Layer& base) |
| { |
| switch (base.GetType()) |
| { |
| case LayerType::Activation: |
| case LayerType::Addition: |
| case LayerType::FakeQuantization: |
| case LayerType::Floor: |
| case LayerType::MemCopy: |
| case LayerType::Multiplication: |
| return true; |
| default: |
| return false; |
| } |
| } |
| }; |
| |
| using MovePermuteUp = OptimizeForConnection<Layer, PermuteLayer, MovePermuteUpImpl>; |
| |
| } // namespace optimizations |
| } // namespace armnn |