IVGCVSW-3337 Add Neon backend support for LSTM layer normalisation

* Update neon lstm workload
* Add unit tests
* Add isLstmSupported

Change-Id: I493c159137f6544b0f2532d16d4fafd7a7e587e5
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index 4fee53f..ea875f6 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -26,6 +26,7 @@
 #include "workloads/NeonDequantizeWorkload.hpp"
 #include "workloads/NeonGreaterWorkload.hpp"
 #include "workloads/NeonL2NormalizationFloatWorkload.hpp"
+#include "workloads/NeonLstmFloatWorkload.hpp"
 #include "workloads/NeonMaximumWorkload.hpp"
 #include "workloads/NeonMeanWorkload.hpp"
 #include "workloads/NeonConcatWorkload.hpp"
@@ -334,6 +335,30 @@
     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
 }
 
+bool NeonLayerSupport::IsLstmSupported(const TensorInfo& input,
+                                       const TensorInfo& outputStateIn,
+                                       const TensorInfo& cellStateIn,
+                                       const TensorInfo& scratchBuffer,
+                                       const TensorInfo& outputStateOut,
+                                       const TensorInfo& cellStateOut,
+                                       const TensorInfo& output,
+                                       const LstmDescriptor& descriptor,
+                                       const LstmInputParamsInfo& paramsInfo,
+                                       Optional<std::string&> reasonIfUnsupported) const
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLstmFloatWorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   outputStateIn,
+                                   cellStateIn,
+                                   scratchBuffer,
+                                   outputStateOut,
+                                   cellStateOut,
+                                   output,
+                                   descriptor,
+                                   paramsInfo);
+}
+
 bool NeonLayerSupport::IsMaximumSupported(const TensorInfo& input0,
                                           const TensorInfo& input1,
                                           const TensorInfo& output,
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index 315248c..318cad7 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -96,6 +96,17 @@
                                     const L2NormalizationDescriptor& descriptor,
                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
+    bool IsLstmSupported(const TensorInfo& input,
+                         const TensorInfo& outputStateIn,
+                         const TensorInfo& cellStateIn,
+                         const TensorInfo& scratchBuffer,
+                         const TensorInfo& outputStateOut,
+                         const TensorInfo& cellStateOut,
+                         const TensorInfo& output,
+                         const LstmDescriptor& descriptor,
+                         const LstmInputParamsInfo& paramsInfo,
+                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
     bool IsMaximumSupported(const TensorInfo& input0,
                             const TensorInfo& input1,
                             const TensorInfo& output,
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index 4968d0e..49c5a72 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -710,6 +710,29 @@
     NeonCreateL2NormalizationWorkloadTest<NeonL2NormalizationFloatWorkload, DataType::Float32>(DataLayout::NHWC);
 }
 
+template <typename LstmWorkloadType>
+static void NeonCreateLstmWorkloadTest()
+{
+    Graph graph;
+    NeonWorkloadFactory factory =
+            NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager());
+
+    auto workload = CreateLstmWorkloadTest<LstmWorkloadType>(factory, graph);
+
+    LstmQueueDescriptor queueDescriptor = workload->GetData();
+
+    auto inputHandle  = boost::polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]);
+    auto outputHandle = boost::polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Outputs[1]);
+
+    BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({ 2, 2 }, DataType::Float32)));
+    BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({ 2, 4 }, DataType::Float32)));
+}
+
+BOOST_AUTO_TEST_CASE(CreateLSTMWorkloadFloatWorkload)
+{
+    NeonCreateLstmWorkloadTest<NeonLstmFloatWorkload>();
+}
+
 template <typename ConcatWorkloadType, armnn::DataType DataType>
 static void NeonCreateConcatWorkloadTest(std::initializer_list<unsigned int> outputShape,
                                          unsigned int concatAxis)
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 51fd219..049680a 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -469,6 +469,8 @@
                      LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest)
 ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection,
                      LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest)
+ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm,
+                     LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest)
 
 // Mean
 ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
index c7f5f09..6dd9f4f 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
@@ -97,6 +97,30 @@
         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
     }
 
+    if (m_Data.m_Parameters.m_LayerNormEnabled)
+    {
+        m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+        if (!m_Data.m_Parameters.m_CifgEnabled)
+        {
+            BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
+        }
+
+        m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+        BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
+
+        m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+        BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
+
+        m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+        BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
+
+        lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ?
+                                                  nullptr : m_InputLayerNormWeightsTensor.get(),
+                                                  m_ForgetLayerNormWeightsTensor.get(),
+                                                  m_CellLayerNormWeightsTensor.get(),
+                                                  m_OutputLayerNormWeightsTensor.get());
+    }
+
     const arm_compute::ITensor& input           = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     const arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
     const arm_compute::ITensor& cell_state_in   = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
@@ -113,13 +137,13 @@
     m_ScratchBuffer = std::make_unique<arm_compute::Tensor>();
     if (m_Data.m_Parameters.m_CifgEnabled)
     {
-        // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG
+        // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG
         armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 3 }, DataType::Float32);
         BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1);
     }
     else
     {
-        // scratch_buffer [num_units * 3, batch_size] without CIFG
+        // scratch_buffer [num_units * 4, batch_size] without CIFG
         armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 4 }, DataType::Float32);
         BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2);
     }
@@ -222,6 +246,17 @@
                                        m_Data.m_CellToOutputWeights);
     }
 
+    if (m_Data.m_Parameters.m_LayerNormEnabled)
+    {
+        if (!m_Data.m_Parameters.m_CifgEnabled)
+        {
+            InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
+        }
+        InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
+        InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
+        InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
+    }
+
     // Force Compute Library to perform the necessary copying and reshaping, after which
     // delete all the input tensors that will no longer be needed
     m_LstmLayer.prepare();
@@ -241,27 +276,11 @@
                                                   const TensorInfo& cellStateOut,
                                                   const TensorInfo& output,
                                                   const LstmDescriptor& descriptor,
-                                                  const TensorInfo& inputToForgetWeights,
-                                                  const TensorInfo& inputToCellWeights,
-                                                  const TensorInfo& inputToOutputWeights,
-                                                  const TensorInfo& recurrentToForgetWeights,
-                                                  const TensorInfo& recurrentToCellWeights,
-                                                  const TensorInfo& recurrentToOutputWeights,
-                                                  const TensorInfo& forgetGateBias,
-                                                  const TensorInfo& cellBias,
-                                                  const TensorInfo& outputGateBias,
-                                                  const TensorInfo* inputToInputWeights,
-                                                  const TensorInfo* recurrentToInputWeights,
-                                                  const TensorInfo* cellToInputWeights,
-                                                  const TensorInfo* inputGateBias,
-                                                  const TensorInfo* projectionWeights,
-                                                  const TensorInfo* projectionBias,
-                                                  const TensorInfo* cellToForgetWeights,
-                                                  const TensorInfo* cellToOutputWeights)
+                                                  const LstmInputParamsInfo& paramsInfo)
 {
     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
 
-    // The inputs and the outputs
+    // The inputs and outputs
     const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
     const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
@@ -271,18 +290,24 @@
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
 
     // Basic parameters
-    const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
-    const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
-    const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+    const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+    const arm_compute::TensorInfo aclInputToCellWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+    const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToCellWeights);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToOutputWeights);
-    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
-    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
-    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+    const arm_compute::TensorInfo aclForgetGateBiasInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+    const arm_compute::TensorInfo aclCellBiasInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+    const arm_compute::TensorInfo aclOutputGateBiasInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
 
     arm_compute::TensorInfo aclInputToInputWeightsInfo;
     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -293,48 +318,65 @@
     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
 
+    arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
+    arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
+    arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
+    arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
+
+
     if (!descriptor.m_CifgEnabled)
     {
-        armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
-        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
-        armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
-        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
-
-        if (cellToInputWeights != nullptr)
+        if (descriptor.m_PeepholeEnabled)
         {
-            armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
-            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
         }
-        armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
-        aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+
         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
-                                         cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+                                         descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
                                          &aclInputGateBiasInfo);
     }
 
     if (descriptor.m_ProjectionEnabled)
     {
-        const armnn::TensorInfo& projectionWInfo = *projectionWeights;
-        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
-
-        if (projectionBias != nullptr)
+        if (paramsInfo.m_ProjectionBias != nullptr)
         {
-            const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
-            aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionBias());
         }
+        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
+
         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
-                                               projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+                                               paramsInfo.m_ProjectionBias != nullptr ?
+                                               &aclProjectionBiasInfo : nullptr);
     }
 
     if (descriptor.m_PeepholeEnabled)
     {
-        const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
-        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
-        const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
-        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
+
         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
     }
 
+    if (descriptor.m_LayerNormEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+        }
+        aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+        aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+        aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+
+        lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
+                                                        nullptr : &aclInputLayerNormWeightsInfo,
+                                                        &aclForgetLayerNormWeightsInfo,
+                                                        &aclCellLayerNormWeightsInfo,
+                                                        &aclOutputLayerNormWeightsInfo);
+    }
+
     float cell_threshold = descriptor.m_ClippingThresCell;
     float projection_threshold = descriptor.m_ClippingThresProj;
 
@@ -407,6 +449,10 @@
     FreeTensorIfUnused(m_ProjectionWeightsTensor);
     FreeTensorIfUnused(m_ProjectionBiasTensor);
     FreeTensorIfUnused(m_ScratchBuffer);
+    FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
+    FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
+    FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
+    FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
 }
 
 } //namespace armnn
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
index f87f24d..c116cdd 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
@@ -43,6 +43,11 @@
 
     std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
 
+    std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
+    std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
+    std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
+    std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
+
     void FreeUnusedTensors();
 };
 
@@ -50,21 +55,6 @@
                                                   const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
                                                   const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
                                                   const TensorInfo& output, const LstmDescriptor &descriptor,
-                                                  const TensorInfo& inputToForgetWeights,
-                                                  const TensorInfo& inputToCellWeights,
-                                                  const TensorInfo& inputToOutputWeights,
-                                                  const TensorInfo& recurrentToForgetWeights,
-                                                  const TensorInfo& recurrentToCellWeights,
-                                                  const TensorInfo& recurrentToOutputWeights,
-                                                  const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
-                                                  const TensorInfo& outputGateBias,
-                                                  const TensorInfo* inputToInputWeights,
-                                                  const TensorInfo* recurrentToInputWeights,
-                                                  const TensorInfo* cellToInputWeights,
-                                                  const TensorInfo* inputGateBias,
-                                                  const TensorInfo* projectionWeights,
-                                                  const TensorInfo* projectionBias,
-                                                  const TensorInfo* cellToForgetWeights,
-                                                  const TensorInfo* cellToOutputWeights);
+                                                  const LstmInputParamsInfo& paramsInfo);
 
 } //namespace armnn