IVGCVSW-6147 ConstTensorsAsInput: Optimizer - FusePermuteIntoConstLayer

  * No trailing permute layer after a constant layer
  * Unit test for optimization

Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I0d098f5af41d2c55df7cef1ccfb848093320ddc1
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 903f06c..52e60e0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -529,6 +529,7 @@
         src/armnn/test/OptimizerTests.cpp
         src/armnn/test/optimizations/AddBroadcastReshapeLayerTests.cpp
         src/armnn/test/optimizations/ConvertConstDequantisationLayersToConstLayersTest.cpp
+        src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp
         src/armnn/test/optimizations/ConvertConstantsBFloatTests.cpp
         src/armnn/test/optimizations/ConvertConstantsFloatToHalfTests.cpp
         src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 9da28ce..fecc766 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1158,6 +1158,7 @@
         if(selectedBackend == armnn::Compute::GpuAcc || selectedBackend == armnn::Compute::CpuAcc)
         {
             Optimizer::Pass(optGraph, MakeOptimizations(optimizations::PermuteDepthwiseConv2dWeights()));
+            Optimizer::Pass(optGraph, MakeOptimizations(optimizations::FusePermuteIntoConstLayer()));
         }
 
         // Select sub-graphs based on backend
@@ -1719,6 +1720,10 @@
         optGraph.InferTensorInfos();
     }
 
+    // Need to FusePermuteIntoConstantLayer before FoldPadIntoDepthwiseConvolution2d or
+    // FuseBatchNormIntoDepthwiseConvolution2D optimizations are called.
+    Optimizer::Pass(optGraph, MakeOptimizations(FusePermuteIntoConstLayer()));
+
     // Perform optimisation passes
     Optimizer::Pass(optGraph, MakeOptimizations(SquashEqualPermuteSiblings(),
                                                 SquashEqualTransposeSiblings(),
@@ -1739,8 +1744,7 @@
                                                 FuseBatchNormIntoConvolution2DFloat16(),
                                                 FuseBatchNormIntoDepthwiseConvolution2DFloat32(),
                                                 FuseBatchNormIntoDepthwiseConvolution2DFloat16(),
-                                                ConvertConstDequantisationLayersToConstLayers(),
-                                                RedirectMembersToConstantInputs()));
+                                                ConvertConstDequantisationLayersToConstLayers()));
 
     // If Fp32 to Fp16 optimization is set convert Fp32 network to Fp16
     if (options.m_ReduceFp32ToFp16)
diff --git a/src/armnn/optimizations/All.hpp b/src/armnn/optimizations/All.hpp
index e4a1f33..900e763 100644
--- a/src/armnn/optimizations/All.hpp
+++ b/src/armnn/optimizations/All.hpp
@@ -8,6 +8,7 @@
 #include "AddDebug.hpp"
 #include "ConvertConstants.hpp"
 #include "ConvertConstDequantisationLayersToConstLayers.hpp"
+#include "ConvertConstPermuteLayersToConstLayers.hpp"
 #include "ConvertFp32NetworkToBf16.hpp"
 #include "ConvertFp32NetworkToFp16.hpp"
 #include "FoldPadIntoLayer2d.hpp"
diff --git a/src/armnn/optimizations/ConvertConstPermuteLayersToConstLayers.hpp b/src/armnn/optimizations/ConvertConstPermuteLayersToConstLayers.hpp
new file mode 100644
index 0000000..2cc3e8e
--- /dev/null
+++ b/src/armnn/optimizations/ConvertConstPermuteLayersToConstLayers.hpp
@@ -0,0 +1,127 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Optimization.hpp"
+#include <armnnUtils/Permute.hpp>
+#include <ResolveType.hpp>
+
+namespace armnn
+{
+namespace optimizations
+{
+
+class ConvertConstPermuteLayersToConstLayers
+{
+public:
+    void Run(Graph& graph, InputSlot& connection) const
+    {
+        Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
+        Layer& child = connection.GetOwningLayer();
+
+        ARMNN_ASSERT(base.GetType() == LayerType::Constant);
+        ARMNN_ASSERT(child.GetType() == LayerType::Permute);
+
+        if (base.GetDataType() == child.GetDataType())
+        {
+            switch (base.GetDataType())
+            {
+                case DataType::Float16:
+                    ReplaceConstPermuteLayer<DataType::Float16>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::Float32:
+                    ReplaceConstPermuteLayer<DataType::Float32>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::QAsymmU8:
+                    ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::Signed32:
+                    ReplaceConstPermuteLayer<DataType::Signed32>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::QSymmS16:
+                    ReplaceConstPermuteLayer<DataType::QSymmS16>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::QSymmS8:
+                    ReplaceConstPermuteLayer<DataType::QSymmS8>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::QAsymmS8:
+                    ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::BFloat16:
+                    ReplaceConstPermuteLayer<DataType::BFloat16>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::Signed64:
+                    ReplaceConstPermuteLayer<DataType::Signed64>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+                case DataType::Boolean:
+                    ReplaceConstPermuteLayer<DataType::Boolean>(graph,
+                                                                 PolymorphicDowncast<ConstantLayer*>(&base),
+                                                                 PolymorphicDowncast<PermuteLayer*>(&child));
+                    break;
+            }
+        }
+    }
+protected:
+    ConvertConstPermuteLayersToConstLayers()  = default;
+    ~ConvertConstPermuteLayersToConstLayers() = default;
+private:
+    template<armnn::DataType ArmnnType,
+             typename T = armnn::ResolveType<ArmnnType>>
+    static void ReplaceConstPermuteLayer(Graph& graph,
+                                         ConstantLayer* constantLayer,
+                                         PermuteLayer* permuteLayer)
+    {
+        IgnoreUnused(graph);
+        /**
+         * This optimisation is to find situations where a constant set of inputs is being provided to a Permute
+         * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we
+         * want to Permute them once and store them in a Const layer to be used everytime as they will not change.
+         */
+        TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo();
+        std::vector<T> newValues(outputPermuteInfo.GetNumElements());
+        armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(),
+                            constantLayer->m_LayerOutput->Map(true), newValues.data(),
+                            GetDataTypeSize(outputPermuteInfo.GetDataType()));
+
+        TensorInfo newInfo = outputPermuteInfo;
+        newInfo.SetConstant(true);
+        ConstTensor newInput(newInfo, newValues);
+        constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
+
+        // Moves connections in permute output to the constant layer.
+        // Permute layer will be removed if left unconnected.
+        permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());
+
+        // Updating the output tensor
+        constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
+        ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
+    }
+};
+
+using FusePermuteIntoConstLayer = OptimizeForConnection<ConstantLayer,
+                                                        PermuteLayer,
+                                                        ConvertConstPermuteLayersToConstLayers>;
+
+} // namespace optimizations
+} // namespace armnn
\ No newline at end of file
diff --git a/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp b/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp
new file mode 100644
index 0000000..1fcba0e
--- /dev/null
+++ b/src/armnn/test/optimizations/ConvertConstPermuteLayersToConstLayersTest.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "LayersFwd.hpp"
+#include <Network.hpp>
+#include <doctest/doctest.h>
+#include <Optimizer.hpp>
+#include <TestUtils.hpp>
+
+TEST_SUITE("Optimizer")
+{
+using namespace armnn;
+using namespace armnn::optimizations;
+
+TEST_CASE("ConvertConstPermuteToConst")
+{
+    Graph graph;
+    const unsigned int shape[]  = {1, 2, 2, 3};
+
+    const TensorInfo constTensorInfo(4, shape, DataType::Float32, 1.0, 0, true);
+
+    ConstantLayer* constant = graph.AddLayer<ConstantLayer>("constant");
+    std::vector<float> constantValues(constTensorInfo.GetNumElements(), 4.5f);
+    ConstTensor constTensor(constTensorInfo, constantValues.data());
+    constant->m_LayerOutput = std::make_shared<ScopedTensorHandle>(constTensor);
+    constant->GetOutputSlot().SetTensorInfo(constTensorInfo);
+
+    PermuteDescriptor desc({ 0, 2, 3, 1 });
+    PermuteLayer* permuteLayer = graph.AddLayer<PermuteLayer>(desc, "permute");
+    TensorInfo infoPermuted = armnnUtils::Permuted(constTensorInfo, { 0, 2, 3, 1 });
+    permuteLayer->GetOutputSlot().SetTensorInfo(infoPermuted);
+
+    OutputLayer* output = graph.AddLayer<OutputLayer>(0, "output");
+
+    // Connect up constant -> permute -> output
+    constant->GetOutputSlot().Connect(permuteLayer->GetInputSlot(0));
+    permuteLayer->GetOutputSlot().Connect(output->GetInputSlot(0));
+
+    CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+                        &IsLayerOfType<ConstantLayer>,
+                        &IsLayerOfType<PermuteLayer>,
+                        &IsLayerOfType<OutputLayer>));
+
+    armnn::Optimizer::Pass(graph, MakeOptimizations(FusePermuteIntoConstLayer()));
+
+    CHECK(CheckSequence(graph.cbegin(), graph.cend(),
+                        &IsLayerOfType<ConstantLayer>,
+                        &IsLayerOfType<OutputLayer>));
+
+    TensorShape tensorShape = constant->GetOutputSlot(0).GetTensorInfo().GetShape();
+    CHECK(tensorShape[0] == shape[0]);
+    CHECK(tensorShape[1] == shape[3]);
+    CHECK(tensorShape[2] == shape[1]);
+    CHECK(tensorShape[3] == shape[2]);
+
+}
+
+}
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 4eaf636..60bd962 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1043,15 +1043,24 @@
     desc.m_BiasEnabled  = convDesc.m_BiasEnabled;
 
     armnn::IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, node.name().c_str());
-    std::vector<std::string> tensorIndexes= {node.input(0), node.input(1)};
+    std::string permuteStr = "permute_" + node.input(1);
+    std::vector<std::string> tensorIndexes= {node.input(0), permuteStr};
 
-    // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs dephtwise weights layout [1,H,W,O]
-    armnn::PermutationVector perVec {3,0,1,2};
-    auto weightTensor = CreateConstTensor(node.input(1), perVec);
-
+    auto weightTensor = CreateConstTensor(node.input(1));
     IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
+
+    // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs depthwise weights layout [1,H,W,O]
+    armnn::PermutationVector perVec {3, 0, 1, 2};
+    TensorInfo weightsPermuted = armnnUtils::Permuted(weightTensor.first.GetInfo(), perVec);
+
+    // Inserts NewLayer so layers don't need to be re-sorted.
+    IConnectableLayer* permuteLayer = m_Network->AddPermuteLayer(PermuteDescriptor(perVec),
+                                                                 "permute_layer");
+    permuteLayer->GetOutputSlot(0).SetTensorInfo(weightsPermuted);
+    permuteLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+
     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
-    weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+    weightsLayer->GetOutputSlot(0).Connect(permuteLayer->GetInputSlot(0u));
 
     if (node.input_size() == 3)
     {
@@ -1076,7 +1085,7 @@
 
     auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
                                         { m_TensorsInfo[node.input(0)].m_info->GetShape(),
-                                          weightTensor.first.GetInfo().GetShape() });
+                                          weightsPermuted.GetShape() });
 
     layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);