NEQLSTM: Add support for QASYMM8_SIGNED for input_to_forget_weights

* QLSTM only supports QSYMM8 for the argument input_to_forget_weights

* We add support for QASYMM8_SIGNED by dequantizing and requantizing to QSYMM8

* Resolves COMPMID-5184

Change-Id: I1cae18d81dafdb7ae722b520a1354cf4a56b9606
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7321
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
(cherry picked from commit 187a041dedf8e9db0c9e0652f13f8639dca880f3)
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
index 76bb8c0..c6e6a71 100644
--- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
@@ -111,17 +111,81 @@
 NEQLSTMLayer::~NEQLSTMLayer() = default;
 
 NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(), _transpose_input_to_forget_weights(), _transpose_input_to_cell_weights(), _transpose_input_to_output_weights(), _transpose_input_to_input_weights(),
-      _transpose_recurrent_to_forget_weights(), _transpose_recurrent_to_cell_weights(), _transpose_recurrent_to_output_weights(), _transpose_recurrent_to_input_weights(), _transpose_projection_weights(),
-      _input_to_input_reduction(), _recurrent_to_input_reduction(), _input_to_forget_reduction(), _recurrent_to_forget_reduction(), _input_to_cell_reduction(), _recurrent_to_cell_reduction(),
-      _input_to_output_reduction(), _recurrent_to_output_reduction(), _projection_reduction(), _projection_bias_add(), _mm_input_to_forget(), _mm_recurrent_to_forget(), _pixelwise_mul_cell_to_forget(),
-      _input_to_forget_outstage(), _recurrent_to_forget_outstage(), _cell_to_forget_outstage(), _accumulate_input_recurrent_forget(), _accumulate_cell_forget(), _forget_gate_sigmoid(), _mm_input_to_cell(),
-      _input_to_cell_outstage(), _mm_recurrent_to_cell(), _recurrent_to_cell_outstage(), _accumulate_input_recurrent_modulation(), _cell_gate_tanh(), _input_gate_sub(), _mm_input_to_input(),
-      _input_to_input_outstage(), _mm_recurrent_to_input(), _recurrent_to_input_outstage(), _accumulate_input_recurrent_input(), _pixelwise_mul_cell_to_input(), _cell_to_input_outstage(),
-      _accumulate_cell_input(), _input_gate_sigmoid(), _pixelwise_mul_forget_cell(), _pixelwise_mul_input_cell(), _add_forget_cell(), _cell_clip(), _mm_input_to_output(), _input_to_output_outstage(),
-      _mm_recurrent_to_output(), _recurrent_to_output_outstage(), _accumulate_input_recurrent_output(), _pixelwise_mul_cell_to_output(), _cell_to_output_outstage(), _accumulate_cell_to_output(),
-      _output_gate_sigmoid(), _hidden_tanh(), _pixelwise_mul_hidden(), _hidden_outstage(), _mm_projection(), _projection_outstage(), _accumulate_projection(), _projection_clip(), _projection_bias_copy(),
-      _projection_output_to_accumulate_copy(), _projection_accumulate_to_output_copy(), _hidden_to_output_copy(), _layer_norms(), _copy_output(), _layer_norm_weights(), _layer_norm_bias(),
+    : _memory_group(),
+      _dequantize_input_to_forget_weights(),
+      _quantize_input_to_forget_weights(),
+      _transpose_input_to_forget_weights(),
+      _transpose_input_to_cell_weights(),
+      _transpose_input_to_output_weights(),
+      _transpose_input_to_input_weights(),
+      _transpose_recurrent_to_forget_weights(),
+      _transpose_recurrent_to_cell_weights(),
+      _transpose_recurrent_to_output_weights(),
+      _transpose_recurrent_to_input_weights(),
+      _transpose_projection_weights(),
+      _input_to_input_reduction(),
+      _recurrent_to_input_reduction(),
+      _input_to_forget_reduction(),
+      _recurrent_to_forget_reduction(),
+      _input_to_cell_reduction(),
+      _recurrent_to_cell_reduction(),
+      _input_to_output_reduction(),
+      _recurrent_to_output_reduction(),
+      _projection_reduction(),
+      _projection_bias_add(),
+      _mm_input_to_forget(),
+      _mm_recurrent_to_forget(),
+      _pixelwise_mul_cell_to_forget(),
+      _input_to_forget_outstage(),
+      _recurrent_to_forget_outstage(),
+      _cell_to_forget_outstage(),
+      _accumulate_input_recurrent_forget(),
+      _accumulate_cell_forget(),
+      _forget_gate_sigmoid(),
+      _mm_input_to_cell(),
+      _input_to_cell_outstage(),
+      _mm_recurrent_to_cell(),
+      _recurrent_to_cell_outstage(),
+      _accumulate_input_recurrent_modulation(),
+      _cell_gate_tanh(),
+      _input_gate_sub(),
+      _mm_input_to_input(),
+      _input_to_input_outstage(),
+      _mm_recurrent_to_input(),
+      _recurrent_to_input_outstage(),
+      _accumulate_input_recurrent_input(),
+      _pixelwise_mul_cell_to_input(),
+      _cell_to_input_outstage(),
+      _accumulate_cell_input(),
+      _input_gate_sigmoid(),
+      _pixelwise_mul_forget_cell(),
+      _pixelwise_mul_input_cell(),
+      _add_forget_cell(),
+      _cell_clip(),
+      _mm_input_to_output(),
+      _input_to_output_outstage(),
+      _mm_recurrent_to_output(),
+      _recurrent_to_output_outstage(),
+      _accumulate_input_recurrent_output(),
+      _pixelwise_mul_cell_to_output(),
+      _cell_to_output_outstage(),
+      _accumulate_cell_to_output(),
+      _output_gate_sigmoid(),
+      _hidden_tanh(),
+      _pixelwise_mul_hidden(),
+      _hidden_outstage(),
+      _mm_projection(),
+      _projection_outstage(),
+      _accumulate_projection(),
+      _projection_clip(),
+      _projection_bias_copy(),
+      _projection_output_to_accumulate_copy(),
+      _projection_accumulate_to_output_copy(),
+      _hidden_to_output_copy(),
+      _layer_norms(),
+      _copy_output(),
+      _layer_norm_weights(),
+      _layer_norm_bias(),
       _layer_norm_output()
 {
     _memory_group = MemoryGroup(std::move(memory_manager));
@@ -174,12 +238,37 @@
     _recurrent_to_cell_weights_transposed.info()->set_quantization_info(recurrent_to_cell_weights->info()->quantization_info());
     _recurrent_to_output_weights_transposed.info()->set_quantization_info(recurrent_to_output_weights->info()->quantization_info());
 
-    // Validate
-    ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
-                                                      recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
-                                                      forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
-                                                      cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
-                                                      lstm_params_info));
+    if(input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED)
+    {
+        _convert_input_to_forget_weights_to_qsymm8 = true;
+        // Setup dequantize output tensor to go from QASYMM8_SIGNED -> F32
+
+        _input_to_forget_weights_f32.allocator()->init(TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::F32)
+                                                       .set_data_layout(input_to_forget_weights->info()->data_layout()));
+        // Setup the quantize output tensor to go from F32 -> QSYMM8
+        _input_to_forget_weights_symm8.allocator()->init((TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::QSYMM8)
+                                                          .set_data_layout(input_to_forget_weights->info()->data_layout())
+                                                          .set_quantization_info(input_to_forget_weights->info()->quantization_info())));
+
+        _dequantize_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_f32);
+        _quantize_input_to_forget_weights.configure(&_input_to_forget_weights_f32, &_input_to_forget_weights_symm8);
+        _input_to_forget_weights_f32.allocator()->allocate();
+        _input_to_forget_weights_symm8.allocator()->allocate();
+
+        ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), _input_to_forget_weights_symm8.info(), input_to_cell_weights->info(), input_to_output_weights->info(),
+                                                          recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
+                                                          forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
+                                                          cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
+                                                          lstm_params_info));
+    }
+    else
+    {
+        ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
+                                                          recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
+                                                          forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
+                                                          cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
+                                                          lstm_params_info));
+    }
 
     const int batch_size  = input->info()->dimension(1);
     const int num_units   = input_to_output_weights->info()->dimension(1);
@@ -190,7 +279,7 @@
     const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
 
     _projection_bias             = lstm_params.projection_bias();
-    _input_to_forget_weights     = input_to_forget_weights;
+    _input_to_forget_weights     = (input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED) ? &_input_to_forget_weights_symm8 : input_to_forget_weights;
     _input_to_cell_weights       = input_to_cell_weights;
     _input_to_output_weights     = input_to_output_weights;
     _recurrent_to_forget_weights = recurrent_to_forget_weights;
@@ -611,10 +700,9 @@
     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8,DataType::QASYMM8_SIGNED);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
                                                        recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
-
     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
@@ -967,6 +1055,12 @@
     // Acquire all the temporaries
     MemoryGroupResourceScope scope_mg(_memory_group);
 
+    if(_convert_input_to_forget_weights_to_qsymm8)
+    {
+        _dequantize_input_to_forget_weights.run();
+        _quantize_input_to_forget_weights.run();
+    }
+
     // Forget gate.
     _mm_input_to_forget.run();
     _input_to_forget_outstage.run();