IVGCVSW-3338 Add CL backend support for LSTM normalization

 * Enable calls to LSTM normalization unit tests on CL backend.
 * Update CL workload to set the layer normalization parameters.

!android-nn-driver:1461

Change-Id: Ia5a29918961c391c1f1d8f331add377a38822ddd
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index ac96bf8..5575a05 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -354,6 +354,9 @@
 ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection,
                      LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest)
 
+ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm,
+                     LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest)
+
 // Convert from Float16 to Float32
 ARMNN_AUTO_TEST_CASE(SimpleConvertFp16ToFp32, SimpleConvertFp16ToFp32Test)
 // Convert from Float32 to Float16
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index 3dbbbc3..f5d081e 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -100,6 +100,28 @@
         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
     }
 
+    if (m_Data.m_Parameters.m_LayerNormEnabled)
+    {
+        m_InputLayerNormWeightsTensor  = std::make_unique<arm_compute::CLTensor>();
+        m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+        m_CellLayerNormWeightsTensor   = std::make_unique<arm_compute::CLTensor>();
+        m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
+
+        if (!m_Data.m_Parameters.m_CifgEnabled)
+        {
+            BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
+        }
+        BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
+        BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
+        BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
+
+        lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ? nullptr :
+                                                  m_InputLayerNormWeightsTensor.get(),
+                                                  m_ForgetLayerNormWeightsTensor.get(),
+                                                  m_CellLayerNormWeightsTensor.get(),
+                                                  m_OutputLayerNormWeightsTensor.get());
+    }
+
     const arm_compute::ICLTensor& input           = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     const arm_compute::ICLTensor& output_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
     const arm_compute::ICLTensor& cell_state_in   = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
@@ -161,7 +183,6 @@
         throw armnn::Exception("Wrong Type of Activation Function!");
     }
 
-
     m_LstmLayer.configure(&input, m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(),
                           m_InputToOutputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(),
                           m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(),
@@ -172,15 +193,15 @@
 
     armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
 
-    InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
-    InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
-    InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+    InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor,     m_Data.m_InputToForgetWeights);
+    InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor,       m_Data.m_InputToCellWeights);
+    InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor,     m_Data.m_InputToOutputWeights);
     InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
-    InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+    InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,   m_Data.m_RecurrentToCellWeights);
     InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
-    InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
-    InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
-    InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
+    InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor,           m_Data.m_ForgetGateBias);
+    InitializeArmComputeClTensorData(*m_CellBiasTensor,                 m_Data.m_CellBias);
+    InitializeArmComputeClTensorData(*m_OutputGateBiasTensor,           m_Data.m_OutputGateBias);
 
     if (!m_Data.m_Parameters.m_CifgEnabled)
     {
@@ -208,6 +229,18 @@
         InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
     }
 
+    if (m_Data.m_Parameters.m_LayerNormEnabled)
+    {
+        if (!m_Data.m_Parameters.m_CifgEnabled)
+        {
+            InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor,  m_Data.m_InputLayerNormWeights);
+        }
+
+        InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
+        InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor,   m_Data.m_CellLayerNormWeights);
+        InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
+    }
+
     // Force Compute Library to perform the necessary copying and reshaping, after which
     // delete all the input tensors that will no longer be needed
     m_LstmLayer.prepare();
@@ -262,6 +295,10 @@
     arm_compute::TensorInfo aclProjectionBiasInfo;
     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
+    arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
+    arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
+    arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
+    arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
 
     if (!descriptor.m_CifgEnabled)
     {
@@ -333,6 +370,26 @@
         throw armnn::Exception("Wrong Type of Activation Function!");
     }
 
+    if (descriptor.m_LayerNormEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            aclInputLayerNormWeightsInfo  = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+        }
+
+        aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+
+        aclCellLayerNormWeightsInfo   = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+
+        aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+
+        lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
+                                                        nullptr : &aclInputLayerNormWeightsInfo,
+                                                        &aclForgetLayerNormWeightsInfo,
+                                                        &aclCellLayerNormWeightsInfo,
+                                                        &aclOutputLayerNormWeightsInfo);
+    }
+
     return arm_compute::CLLSTMLayer::validate(&aclInputInfo, &aclInputToForgetWeightsInfo,
                                               &aclInputToCellWeightsInfo,
                                               &aclInputToOutputWeightsInfo,
@@ -369,6 +426,10 @@
     FreeTensorIfUnused(m_ProjectionWeightsTensor);
     FreeTensorIfUnused(m_ProjectionBiasTensor);
     FreeTensorIfUnused(m_ScratchBuffer);
+    FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
+    FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
+    FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
+    FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
 }
 
 } //namespace armnn
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
index 9a3211a..5bd67c2 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
@@ -39,6 +39,10 @@
     std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
     std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
     std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
+    std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor;
+    std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor;
+    std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor;
+    std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor;
 
     std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;