COMPMID-3240: Add support for layer normalization to CLQLSTMLayer

Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I45359a4ddb46c059097a2d77c008f802e8f4c143
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3065
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h b/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h
index 1a2f311..2d47072 100644
--- a/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h
+++ b/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h
@@ -73,7 +73,7 @@
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input, ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias);
+    static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
index 72a61f8..722275e 100644
--- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
@@ -27,6 +27,7 @@
 #include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h"
 #include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h"
 #include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
@@ -216,6 +217,16 @@
     void prepare() override;
 
 private:
+    enum class LayerNormGate : uint8_t
+    {
+        Forget,
+        Cell,
+        Input,
+        Output,
+        Count
+    };
+    static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+
     /** Internal method to configure matrix multiplication plus output stage of each gate.
      *
      * @param[in] compile_context The compile context to be used.
@@ -302,6 +313,7 @@
     CLGEMMLowpOutputStage                _projection_outstage{};
     CLSaturatedArithmeticOperationKernel _accumulate_projection{};
     CLActivationLayer                    _projection_clip{};
+    std::array<CLQLSTMLayerNormalizationKernel, _layer_norm_count> _layer_norms{ {} };
 
     // Tensor pointers
     const ICLTensor *_input_to_input_weights
@@ -317,6 +329,61 @@
     const ICLTensor *_recurrent_to_cell_weights{ nullptr };
     const ICLTensor *_recurrent_to_output_weights{ nullptr };
     const ICLTensor *_projection_weights{ nullptr };
+    std::array<const ICLTensor *, _layer_norm_count> _layer_norm_weights{ {} };
+    std::array<const ICLTensor *, _layer_norm_count> _layer_norm_bias{ {} };
+
+    using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type;
+    inline LayerNormIndexType getGateIndex(LayerNormGate g)
+    {
+        return static_cast<LayerNormIndexType>(g);
+    }
+
+    inline void set_layer_norm_weight(const ICLTensor *t, LayerNormGate g)
+    {
+        _layer_norm_weights[getGateIndex(g)] = t;
+    }
+
+    inline void set_layer_norm_bias(const ICLTensor *t, LayerNormGate g)
+    {
+        _layer_norm_bias[getGateIndex(g)] = t;
+    }
+
+    inline const ICLTensor *get_layer_norm_weight(LayerNormGate g)
+    {
+        return _layer_norm_weights[getGateIndex(g)];
+    }
+
+    inline const ICLTensor *get_layer_norm_bias(LayerNormGate g)
+    {
+        return _layer_norm_bias[getGateIndex(g)];
+    }
+
+    inline CLQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g)
+    {
+        return _layer_norms[getGateIndex(g)];
+    }
+
+    inline void configure_layer_norm(LayerNormGate g, const ICLTensor *in)
+    {
+        ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
+
+        CLTensor *out = &get_layer_norm_output(g);
+        _memory_group.manage(out);
+        out->allocator()->init(*(in->info()));
+
+        get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g));
+    }
+
+    inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
+    {
+        // Output quantization scale will be different, but ignored here
+        // since it will be configured at configure() stage.
+        const TensorInfo out
+        {
+            in
+        };
+        return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
+    }
 
     // Temporary tensors
     CLTensor _input_to_forget_weights_transposed{ nullptr };
@@ -368,6 +435,12 @@
     CLTensor _mm_projection_res{ nullptr };
     CLTensor _projection_outstage_res{ nullptr };
     CLTensor _ones{ nullptr };
+    std::array<CLTensor, _layer_norm_count> _layer_norm_output{ {} };
+
+    inline CLTensor &get_layer_norm_output(LayerNormGate g)
+    {
+        return _layer_norm_output[getGateIndex(g)];
+    }
 
     bool _is_prepared{ false };
     bool _has_cifg{ false };
@@ -375,6 +448,7 @@
     bool _has_projection{ false };
     bool _has_projection_clipping{ false };
     bool _has_peephole{ false };
+    bool _has_layer_norm{ false };
 };
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_CLQLSTMLAYER_H */
diff --git a/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp b/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp
index b9767e8..d9da3cb 100644
--- a/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp
+++ b/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp
@@ -31,11 +31,17 @@
 {
 namespace
 {
+QuantizationInfo compute_output_qinfo()
+{
+    return QuantizationInfo(1.f / 4096);
+}
+
 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input);
     // Output auto inizialitation if not yet initialized
     auto_init_if_empty(*output, *input);
+    output->set_quantization_info(compute_output_qinfo());
 
     const uint32_t temp_num_elems_processed_per_iteration = max_cl_vector_width / input->element_size();
     /* If width is less then step, then make step same as width to avoid global size being step instead of actual width. */
@@ -48,7 +54,7 @@
 
     return std::make_pair(Status{}, win);
 }
-Status validate_arguments(const ITensorInfo *input, ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weight, bias, output);
 
@@ -129,7 +135,7 @@
     configure(CLKernelLibrary::get().get_compile_context(), input, output, weight, bias);
 }
 
-Status CLQLSTMLayerNormalizationKernel::validate(const ITensorInfo *input, ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
+Status CLQLSTMLayerNormalizationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
 {
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, weight, bias));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first);
diff --git a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
index e966c6b..29ffee8 100644
--- a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
+++ b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
@@ -172,10 +172,7 @@
 
 inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo()
 {
-    const UniformQuantizationInfo iq_info      = _input->info()->quantization_info().uniform();
-    const UniformQuantizationInfo wq_info      = _weight->info()->quantization_info().uniform();
-    const float                   output_scale = (wq_info.scale * iq_info.scale) * 1024;
-    return QuantizationInfo(output_scale);
+    return QuantizationInfo(1.f / 4096);
 }
 
 inline std::pair<int64_t, int64_t> NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr)
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index 88c5f77..d9b5c7c 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -92,9 +92,6 @@
                              ICLTensor *cell_state_out, ICLTensor *output_state_out,
                              const LSTMParams<ICLTensor> &lstm_params)
 {
-    ARM_COMPUTE_UNUSED(forget_gate_bias);
-    ARM_COMPUTE_UNUSED(cell_bias);
-    ARM_COMPUTE_UNUSED(output_gate_bias);
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
                                  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
@@ -125,6 +122,21 @@
     _recurrent_to_output_weights = recurrent_to_output_weights;
     _projection_weights          = lstm_params.projection_weights();
 
+    // Layer normalization
+    _has_layer_norm = lstm_params.use_layer_norm();
+    if(_has_layer_norm)
+    {
+        set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
+        set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
+        set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
+        set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
+
+        set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
+        set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
+        set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
+        set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
+    }
+
     _has_cifg       = lstm_params.has_cifg_opt();
     _has_projection = lstm_params.has_projection();
     _has_peephole   = lstm_params.has_peephole_opt();
@@ -218,14 +230,23 @@
         _cell_to_forget_outstage_res.allocator()->allocate();
     }
 
+    CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
+
+    if(_has_layer_norm)
+    {
+        configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
+        _recurrent_to_forget_outstage_res.allocator()->allocate();
+        forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
+    }
+
     // Output quantization info of Sigmoid and Tanh activations
     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
 
     const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     _memory_group.manage(&_forget_gate);
     _forget_gate.allocator()->init(forget_gate_info);
-    _forget_gate_sigmoid.configure(compile_context, &_recurrent_to_forget_outstage_res, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
-    _recurrent_to_forget_outstage_res.allocator()->allocate();
+    _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+    forget_activation_input->allocator()->allocate();
 
     // Modulation gate.
     const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
@@ -245,11 +266,20 @@
                                                      ConvertPolicy::SATURATE);
     _input_to_cell_outstage_res.allocator()->allocate();
 
+    CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
+
+    if(_has_layer_norm)
+    {
+        configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
+        _recurrent_to_cell_outstage_res.allocator()->allocate();
+        cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
+    }
+
     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     _memory_group.manage(&_cell_gate);
     _cell_gate.allocator()->init(cell_gate_info);
-    _cell_gate_tanh.configure(compile_context, &_recurrent_to_cell_outstage_res, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
-    _recurrent_to_cell_outstage_res.allocator()->allocate();
+    _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+    cell_activation_input->allocator()->allocate();
 
     // Input gate.
     const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
@@ -293,8 +323,17 @@
             _cell_to_input_outstage_res.allocator()->allocate();
         }
 
-        _input_gate_tanh.configure(compile_context, &_recurrent_to_input_outstage_res, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
-        _recurrent_to_input_outstage_res.allocator()->allocate();
+        CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
+
+        if(_has_layer_norm)
+        {
+            configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
+            _recurrent_to_input_outstage_res.allocator()->allocate();
+            input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
+        }
+
+        _input_gate_tanh.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+        input_activation_input->allocator()->allocate();
     }
     // Cell.
     // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
@@ -344,11 +383,20 @@
         _mul_cell_to_output_res.allocator()->allocate();
     }
 
+    CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
+
+    if(_has_layer_norm)
+    {
+        configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
+        _recurrent_to_output_outstage_res.allocator()->allocate();
+        output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
+    }
+
     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     _memory_group.manage(&_output_gate);
     _output_gate.allocator()->init(output_gate_info);
-    _output_gate_sigmoid.configure(compile_context, &_recurrent_to_output_outstage_res, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
-    _recurrent_to_output_outstage_res.allocator()->allocate();
+    _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+    output_activation_input->allocator()->allocate();
 
     // Hidden.
     _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
@@ -525,6 +573,8 @@
     gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
     gemmlowp_info.output_data_type   = DataType::QSYMM16;
 
+    const bool has_layer_norm = lstm_params.use_layer_norm();
+
     // Forget gate.
     const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
     const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
@@ -547,6 +597,13 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
     }
 
+    if(has_layer_norm)
+    {
+        const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
+        const ITensorInfo *b_info = forget_gate_bias;
+        ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
+    }
+
     // Output quantization info of Sigmoid and Tanh activations
     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
 
@@ -563,6 +620,13 @@
 
     ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
 
+    if(has_layer_norm)
+    {
+        const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
+        const ITensorInfo *b_info = cell_bias;
+        ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
+    }
+
     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
 
@@ -602,6 +666,13 @@
             ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
         }
 
+        if(has_layer_norm)
+        {
+            const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
+            const ITensorInfo *b_info = lstm_params.input_gate_bias();
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
+        }
+
         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
     }
     // Cell.
@@ -634,6 +705,13 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
     }
 
+    if(has_layer_norm)
+    {
+        const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
+        const ITensorInfo *b_info = output_gate_bias;
+        ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
+    }
+
     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
 
@@ -715,6 +793,11 @@
         CLScheduler::get().enqueue(_accumulate_cell_forget);
     }
 
+    if(_has_layer_norm)
+    {
+        CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
+    }
+
     _forget_gate_sigmoid.run();
 
     // Modulation gate.
@@ -725,6 +808,11 @@
     _recurrent_to_cell_outstage.run();
     CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation);
 
+    if(_has_layer_norm)
+    {
+        CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
+    }
+
     _cell_gate_tanh.run();
 
     // Input gate
@@ -747,6 +835,11 @@
             CLScheduler::get().enqueue(_accumulate_cell_input);
         }
 
+        if(_has_layer_norm)
+        {
+            CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
+        }
+
         _input_gate_tanh.run();
     }
 
@@ -771,6 +864,11 @@
         CLScheduler::get().enqueue(_accumulate_cell_to_output);
     }
 
+    if(_has_layer_norm)
+    {
+        CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
+    }
+
     _output_gate_sigmoid.run();
 
     // Hidden.