COMPMID-2342: Add layer normalization support in CLLSTMLayer

Change-Id: I25d974aa94e69c5f79a0bd99d5869a351d6d954d
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1324
Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
diff --git a/tests/validation/CL/LSTMLayer.cpp b/tests/validation/CL/LSTMLayer.cpp
index 71a9383..69ac61d 100644
--- a/tests/validation/CL/LSTMLayer.cpp
+++ b/tests/validation/CL/LSTMLayer.cpp
@@ -153,10 +153,11 @@
 using CLLSTMLayerFixture = LSTMLayerValidationFixture<CLTensor, CLAccessor, CLLSTMLayer, LSTMParams<ICLTensor>, T>;
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
                                                                                                                  DataType::F32)),
-                                                                                                         framework::dataset::make("ProjectionOpt", { true, false })),
-                                                                                                 framework::dataset::make("PeepholeOpt", { true, false })))
+                                                                                                                 framework::dataset::make("ProjectionOpt", { true, false })),
+                                                                                                         framework::dataset::make("PeepholeOpt", { true, false })),
+                                                                                                 framework::dataset::make("UseLayerNorm", { true, false })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -165,9 +166,11 @@
 TEST_SUITE_END() // FP32
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType", DataType::F16)),
-                                                                                                        framework::dataset::make("ProjectionOpt", { true, false })),
-                                                                                                framework::dataset::make("PeepholeOpt", { true, false })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+                                                                                                                        DataType::F16)),
+                                                                                                                framework::dataset::make("ProjectionOpt", { true, false })),
+                                                                                                        framework::dataset::make("PeepholeOpt", { true, false })),
+                                                                                                framework::dataset::make("UseLayerNorm", { true, false })))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
diff --git a/tests/validation/NEON/LSTMLayer.cpp b/tests/validation/NEON/LSTMLayer.cpp
index b27dfae..c503972 100644
--- a/tests/validation/NEON/LSTMLayer.cpp
+++ b/tests/validation/NEON/LSTMLayer.cpp
@@ -153,10 +153,11 @@
 using NELSTMLayerFixture = LSTMLayerValidationFixture<Tensor, Accessor, NELSTMLayer, LSTMParams<ITensor>, T>;
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
                                                                                                                  DataType::F32)),
-                                                                                                         framework::dataset::make("ProjectionOpt", { true, false })),
-                                                                                                 framework::dataset::make("PeepholeOpt", { true, false })))
+                                                                                                                 framework::dataset::make("ProjectionOpt", { true, false })),
+                                                                                                         framework::dataset::make("PeepholeOpt", { true, false })),
+                                                                                                 framework::dataset::make("UseLayerNorm", { false })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
@@ -166,9 +167,11 @@
 
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType", DataType::F16)),
-                                                                                                        framework::dataset::make("ProjectionOpt", { true, false })),
-                                                                                                framework::dataset::make("PeepholeOpt", { true, false })))
+FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+                                                                                                                        DataType::F16)),
+                                                                                                                framework::dataset::make("ProjectionOpt", { true, false })),
+                                                                                                        framework::dataset::make("PeepholeOpt", { true, false })),
+                                                                                                framework::dataset::make("UseLayerNorm", { false })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f16);
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 2cf83b8..9260686 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -32,6 +32,7 @@
 #include "tests/validation/reference/ConcatenateLayer.h"
 #include "tests/validation/reference/FullyConnectedLayer.h"
 #include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/MeanStdDevNormalizationLayer.h"
 #include "tests/validation/reference/PixelWiseMultiplication.h"
 #include "tests/validation/reference/Transpose.h"
 
@@ -47,12 +48,13 @@
 public:
     template <typename...>
     void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape,
-               TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
+               TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt,
+               bool use_layer_norm)
     {
         _target = compute_target(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
-                                 data_type, projection_opt, peephole_opt);
+                                 data_type, projection_opt, peephole_opt, use_layer_norm);
         _reference = compute_reference(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
-                                       data_type, projection_opt, peephole_opt);
+                                       data_type, projection_opt, peephole_opt, use_layer_norm);
     }
 
 protected:
@@ -70,7 +72,7 @@
     }
     TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
                               const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
-                              float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
+                              float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
     {
         const unsigned int num_cells   = input_weights_shape.y();
         const unsigned int num_outputs = recurrent_weights_shape.x();
@@ -100,6 +102,10 @@
         TensorType cell_to_output_w;
         TensorType projection_w;
         TensorType projection_bias;
+        TensorType input_layer_norm_w;
+        TensorType forget_layer_norm_w;
+        TensorType cell_layer_norm_w;
+        TensorType output_layer_norm_w;
 
         bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
 
@@ -131,6 +137,22 @@
             lstm_params.set_projection_params(&projection_w, &projection_bias);
         }
 
+        if(use_layer_norm)
+        {
+            forget_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+            cell_layer_norm_w   = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+            output_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+            if(!cifg_opt)
+            {
+                input_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+                lstm_params.set_layer_normalization_params(&input_layer_norm_w, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
+            }
+            else
+            {
+                lstm_params.set_layer_normalization_params(nullptr, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
+            }
+        }
+
         // Create and configure function
         FunctionType lstm;
         lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w,
@@ -257,6 +279,35 @@
             fill(AccessorType(projection_bias), 21);
         }
 
+        if(use_layer_norm)
+        {
+            if(!cifg_opt)
+            {
+                ARM_COMPUTE_EXPECT(input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+                input_layer_norm_w.allocator()->allocate();
+
+                ARM_COMPUTE_EXPECT(!input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+                fill(AccessorType(input_layer_norm_w), 22);
+            }
+            ARM_COMPUTE_EXPECT(forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+            ARM_COMPUTE_EXPECT(cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+            ARM_COMPUTE_EXPECT(output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+            forget_layer_norm_w.allocator()->allocate();
+            cell_layer_norm_w.allocator()->allocate();
+            output_layer_norm_w.allocator()->allocate();
+
+            ARM_COMPUTE_EXPECT(!forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+            ARM_COMPUTE_EXPECT(!cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+            ARM_COMPUTE_EXPECT(!output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+            fill(AccessorType(forget_layer_norm_w), 23);
+            fill(AccessorType(cell_layer_norm_w), 24);
+            fill(AccessorType(output_layer_norm_w), 25);
+        }
+
         // Compute function
         lstm.run();
 
@@ -266,7 +317,7 @@
 
     SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
                                       const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
-                                      float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
+                                      float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
     {
         const unsigned int num_cells   = input_weights_shape.y();
         const unsigned int num_outputs = recurrent_weights_shape.x();
@@ -306,6 +357,8 @@
         SimpleTensor<T> cell_state_out{ output_cell_shape, data_type };
         SimpleTensor<T> output{ output_shape, data_type };
 
+        bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
+
         // Fill reference
         fill(input, 0);
         fill(input_to_forget_w, 1);
@@ -314,9 +367,18 @@
         fill(recurrent_to_forget_w, 4);
         fill(recurrent_to_cell_w, 5);
         fill(recurrent_to_output_w, 6);
-        fill(forget_gate_bias, 7);
-        fill(cell_bias, 8);
-        fill(output_gate_bias, 9);
+        if(use_layer_norm)
+        {
+            fill_custom_val(forget_gate_bias, 0.f, 7);
+            fill_custom_val(cell_bias, 0.f, 8);
+            fill_custom_val(output_gate_bias, 0.f, 9);
+        }
+        else
+        {
+            fill(forget_gate_bias, 7);
+            fill(cell_bias, 8);
+            fill(output_gate_bias, 9);
+        }
         fill(output_state_in, 10);
         fill(cell_state_in, 11);
         fill(scratch, 12);
@@ -324,14 +386,19 @@
         fill(recurrent_to_input_w, 14);
         fill(cell_to_input_w, 15);
         fill(recurrent_to_input_w, 16);
-        fill(input_gate_bias, 17);
+        if(!cifg_opt && use_layer_norm)
+        {
+            fill_custom_val(input_gate_bias, 0.f, 17);
+        }
+        else
+        {
+            fill(input_gate_bias, 17);
+        }
         fill(cell_to_forget_w, 18);
         fill(cell_to_output_w, 19);
         fill(projection_w, 20);
         fill(projection_bias, 21);
 
-        bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
-
         // Compute forget_gate
         SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape);
         SimpleTensor<T> transposed_weights     = reference::transpose(recurrent_to_forget_w);
@@ -344,6 +411,15 @@
             forget_gate                               = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
         }
 
+        if(use_layer_norm)
+        {
+            SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type };
+            fill(forget_layer_norm_w, 23);
+            forget_gate = reference::mean_std_normalization_layer(forget_gate);
+            forget_gate = reference::pixel_wise_multiplication(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+            fill(forget_gate_bias, 7);
+            forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE);
+        }
         forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
 
         // Compute input_gate
@@ -365,6 +441,15 @@
                 SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
                 input_gate                               = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
             }
+            if(use_layer_norm)
+            {
+                SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type };
+                fill(input_layer_norm_w, 22);
+                input_gate = reference::mean_std_normalization_layer(input_gate);
+                input_gate = reference::pixel_wise_multiplication(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+                fill(input_gate_bias, 17);
+                input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE);
+            }
             input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
         }
 
@@ -374,9 +459,18 @@
         gemm                                       = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
         SimpleTensor<T> pixelwise_mul              = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
         cell_state_out                             = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
-        cell_state_out                             = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
-        cell_state_out                             = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
-        cell_state_out                             = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+        if(use_layer_norm)
+        {
+            SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type };
+            fill(cell_layer_norm_w, 24);
+            cell_state_out = reference::mean_std_normalization_layer(cell_state_out);
+            cell_state_out = reference::pixel_wise_multiplication(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+            fill(cell_bias, 8);
+            cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
+        }
+        cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+        cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+        cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
         if(cell_threshold != 0.f)
         {
             cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
@@ -392,6 +486,15 @@
             pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
             output        = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
         }
+        if(use_layer_norm)
+        {
+            SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type };
+            fill(output_layer_norm_w, 25);
+            output = reference::mean_std_normalization_layer(output);
+            output = reference::pixel_wise_multiplication(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+            fill(output_gate_bias, 9);
+            output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE);
+        }
         output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
 
         // Compute output state