COMPMID-2501: Support multiplier > 1 during QASYMM8 requantization for Quantized LSTM

Change-Id: I7eddbdf77881f313b707b9e59428245f1330a2cf
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2119
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
diff --git a/arm_compute/core/NEON/NESymm.h b/arm_compute/core/NEON/NESymm.h
index a60d5d0..8345e0b 100644
--- a/arm_compute/core/NEON/NESymm.h
+++ b/arm_compute/core/NEON/NESymm.h
@@ -54,13 +54,23 @@
                                       int16x8_t    min_s16,
                                       int16x8_t    max_s16)
 {
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
-    in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+    if(result_shift < 0)
+    {
+        in_s32.val[0] = vmulq_n_s32(in_s32.val[0], (1 << -result_shift));
+        in_s32.val[1] = vmulq_n_s32(in_s32.val[1], (1 << -result_shift));
 
-    // Round to the nearest division by a power-of-two using result_shift_s32
-    in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
-    in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+        in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+        in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+    }
+    else
+    {
+        // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+        in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+        in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+        // Round to the nearest division by a power-of-two using result_shift_s32
+        in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
+        in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+    }
 
     // Convert S32 to S16
     int16x8_t out_s16 = vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1]));
@@ -90,13 +100,18 @@
 inline int16_t finalize_quantization_int16(int32_t in_value, int result_fixedpoint_multiplier,
                                            int32_t result_shift, int16_t min_s16, int16_t max_s16)
 {
-    int32x4_t in_s32 = vdupq_n_s32(in_value);
-
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
-
-    // Shift value by result_shift_s32
-    in_value = rounding_divide_by_pow2(in_value, result_shift);
+    if(result_shift < 0)
+    {
+        const int64_t in_64 = static_cast<int64_t>(in_value) * (1 << (-result_shift)) * static_cast<int64_t>(result_fixedpoint_multiplier);
+        in_value = static_cast<int32_t>((in_64 + (1 << 30)) >> 31);
+    }
+    else
+    {
+        // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+        const int64_t in_64 = static_cast<int64_t>(in_value) * static_cast<int64_t>(result_fixedpoint_multiplier);
+        // Shift value by result_shift_s32
+        in_value = rounding_divide_by_pow2(static_cast<int32_t>((in_64 + (1 << 30)) >> 31), result_shift);
+    }
 
     // Bound the result
     int16_t out_s16 = static_cast<int16_t>(std::max<int32_t>(-32768, std::min<int32_t>(32767, in_value)));
diff --git a/arm_compute/core/utils/quantization/AsymmHelpers.h b/arm_compute/core/utils/quantization/AsymmHelpers.h
index a0efe12..bc5b9db 100644
--- a/arm_compute/core/utils/quantization/AsymmHelpers.h
+++ b/arm_compute/core/utils/quantization/AsymmHelpers.h
@@ -31,6 +31,15 @@
 {
 namespace quantization
 {
+/** Calculate quantized representation of multiplier.
+ *
+ * @param[in]  multiplier       Real multiplier.
+ * @param[out] quant_multiplier Integer multiplier.
+ * @param[out] shift            bit shift. A negative value indicates a left shift, while a positive value indicates a right shift
+ *
+ * @return a status
+ */
+Status calculate_quantized_multiplier(float multiplier, int *quant_multiplier, int *shift);
 /** Calculate quantized representation of multiplier with value less than one.
  *
  * @param[in]  multiplier       Real multiplier.
@@ -39,7 +48,7 @@
  *
  * @return a status
  */
-arm_compute::Status calculate_quantized_multiplier_less_than_one(float multiplier, int *quant_multiplier, int *right_shift);
+Status calculate_quantized_multiplier_less_than_one(float multiplier, int *quant_multiplier, int *right_shift);
 /** Calculate quantized representation of multiplier having value greater than one.
  *
  * @param[in]  multiplier           Real multiplier.
@@ -48,7 +57,7 @@
  *
  * @return a status
  */
-arm_compute::Status calculate_quantized_multiplier_greater_than_one(float multiplier, int *quantized_multiplier, int *left_shift);
+Status calculate_quantized_multiplier_greater_than_one(float multiplier, int *quantized_multiplier, int *left_shift);
 /** Get minimum and maximum values for the input quantized data type
  *
  * @ return min and max values for the quantized data type
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index fc90dbd..214c7a4 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -1888,7 +1888,11 @@
 #endif // defined(ADD_BIAS)
 
     // Multiply by result_mult_int and shift
+#if RESULT_SHIFT < 0
+    input_values = ASYMM_MULT(input_values * (1 << (-RESULT_SHIFT)), RESULT_FIXEDPOINT_MULTIPLIER, 4);
+#else // RESULT_SHIFT >= 0
     input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 4);
+#endif // RESULT_SHIFT < 0
 
     short4 res = convert_short4_sat(input_values);
 
diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp
index 5905244..42bd84d 100644
--- a/src/core/utils/quantization/AsymmHelpers.cpp
+++ b/src/core/utils/quantization/AsymmHelpers.cpp
@@ -34,6 +34,20 @@
 constexpr int64_t fixed_point_one_Q0 = (1LL << 31);
 constexpr float   epsilon            = 0.00001f;
 
+Status calculate_quantized_multiplier(float multiplier, int *quant_multiplier, int *shift)
+{
+    if(multiplier > 1.f)
+    {
+        Status status = calculate_quantized_multiplier_greater_than_one(multiplier, quant_multiplier, shift);
+        *shift *= -1;
+        return status;
+    }
+    else
+    {
+        return calculate_quantized_multiplier_less_than_one(multiplier, quant_multiplier, shift);
+    }
+}
+
 Status calculate_quantized_multiplier_less_than_one(float multiplier,
                                                     int *quant_multiplier,
                                                     int *right_shift)
diff --git a/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp b/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp
index 4e6df1d..e5f1278 100644
--- a/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp
@@ -159,8 +159,7 @@
     const float multiplier        = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
     int         output_multiplier = 0;
     int         output_shift      = 0;
-
-    quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+    quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
 
     _memory_group.manage(&_output_lowp);
     _output_stage.configure(&_output_highp, &_bias, &_output_lowp, output_multiplier, output_shift);
@@ -361,12 +360,13 @@
     input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
     weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
 
-    // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
     const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
 
-    const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
-    ARM_COMPUTE_UNUSED(multiplier);
-    ARM_COMPUTE_RETURN_ERROR_ON(multiplier > 1.0f);
+    const float multiplier        = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
+    int         output_multiplier = 0;
+    int         output_shift      = 0;
+    ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+
     // _output_stage
     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(&output_highp, &bias_concatenated, &output_lowp));
 
diff --git a/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp b/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp
index e325619..cfd996b 100644
--- a/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp
+++ b/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp
@@ -138,8 +138,7 @@
     const float multiplier        = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
     int         output_multiplier = 0;
     int         output_shift      = 0;
-
-    quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+    quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
 
     _memory_group.manage(&_output_lowp);
     _output_stage.configure(&_output_highp, &_bias, &_output_lowp, output_multiplier, output_shift);
@@ -340,12 +339,13 @@
     input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
     weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
 
-    // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
     const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
 
-    const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
-    ARM_COMPUTE_UNUSED(multiplier);
-    ARM_COMPUTE_RETURN_ERROR_ON(multiplier > 1.0f);
+    const float multiplier        = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
+    int         output_multiplier = 0;
+    int         output_shift      = 0;
+    ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+
     // _output_stage
     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(&output_highp, &bias_concatenated, &output_lowp));
 
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp
index b8dfc03..f5bd871 100644
--- a/tests/validation/CL/GEMMLowp.cpp
+++ b/tests/validation/CL/GEMMLowp.cpp
@@ -305,6 +305,14 @@
                                                                          2)
                                                                          * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
 
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases = framework::dataset::make("result_fixedpoint_multiplier", 1073741823, 1073741825) * framework::dataset::make("result_shift", -3,
+                                                                    -2)
+                                                                    * framework::dataset::make("min", 0) * framework::dataset::make("max", 0) * framework::dataset::make("addBias", { false, true });
+
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", -3,
+                                                                         -1)
+                                                                         * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
+
 using CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture =
     GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture<CLTensor, CLAccessor, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint>;
 
@@ -344,19 +352,41 @@
 }
 // clang-format on
 // *INDENT-ON*
+TEST_SUITE(NoRelu)
+TEST_SUITE(MultSmallerEq1)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
                        quantize_down_int32_to_int16_scale_by_fixedpoint_cases))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+                       quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // NoRelu
 TEST_SUITE(BoundedReLu)
+TEST_SUITE(MultSmallerEq1)
 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
                        quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+                       quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
 TEST_SUITE_END() // BoundedReLu
 TEST_SUITE_END() // QuantizeDownInt32ToInt16ScaleByFixedPoint
 TEST_SUITE_END() // OutputStage
diff --git a/tests/validation/CL/LSTMLayerQuantized.cpp b/tests/validation/CL/LSTMLayerQuantized.cpp
index 1fc0af1..686d6bc 100644
--- a/tests/validation/CL/LSTMLayerQuantized.cpp
+++ b/tests/validation/CL/LSTMLayerQuantized.cpp
@@ -72,13 +72,14 @@
 
 // *INDENT-OFF*
 // clang-format off
-TEST_CASE(IntegrationTestCaseSmall, framework::DatasetMode::PRECOMMIT)
+TEST_SUITE(IntegrationTestCase)
+TEST_SUITE(MultSmallerEq1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
 {
     const int batch_size  = 2;
     const int input_size  = 2;
     const int output_size = 4;
 
-
     QuantizationInfo qasymm(1.f / 128.f, 128);
     QuantizationInfo qweights(1.f / 128.f, 128);
     QuantizationInfo qsymm_3(8.f / 32768.f, 0);
@@ -211,7 +212,7 @@
     validate(CLAccessor(output_state), expected_output);
 }
 
-TEST_CASE(IntegrationTestCaseLarge, framework::DatasetMode::PRECOMMIT)
+TEST_CASE(RunLarge, framework::DatasetMode::PRECOMMIT)
 {
     const int batch_size  = 16;
     const int input_size  = 8;
@@ -448,11 +449,154 @@
     lstmq.run();
     validate(CLAccessor(output_state), expected_output);
 }
+TEST_SUITE_END() // MultSmallerEq1
+
+TEST_SUITE(MultGreater1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
+{
+    //Input sequence length is 1
+    const int batch_size  = 2;
+    const int input_size  = 2;
+    const int output_size = 4;
+
+    QuantizationInfo qasymm(1.f / 128.f, 128);
+    QuantizationInfo qweights(1.f / 16.f, 16);
+    QuantizationInfo qsymm_3(8.f / 32768.f, 0);
+    QuantizationInfo qsymm_4(16.f / 32768.f, 0);
+
+    TensorShape input_shape{ input_size, batch_size };
+    TensorShape input_weights_shape{ input_size, output_size };
+    TensorShape recurrent_weights_shape{ output_size, output_size };
+    TensorShape output_shape{ output_size, batch_size};
+    TensorShape bias_shape{ output_size };
+
+    auto input_to_input_weights      = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_to_forget_weights     = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_to_cell_weights       = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_to_output_weights     = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_input_weights  = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_forget_weights = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_cell_weights   = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_output_weights = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_gate_bias             = create_tensor<CLTensor>(bias_shape, DataType::S32);
+    auto forget_gate_bias            = create_tensor<CLTensor>(bias_shape, DataType::S32);
+    auto cell_gate_bias              = create_tensor<CLTensor>(bias_shape, DataType::S32);
+    auto output_gate_bias            = create_tensor<CLTensor>(bias_shape, DataType::S32);
+
+    // LSTM input
+    auto input = create_tensor<CLTensor>(input_shape, DataType::QASYMM8, 1, qasymm);
+
+    // LSTM output state
+    auto output_state = create_tensor<CLTensor>(output_shape, DataType::QASYMM8, 1, qasymm);
+
+    // LSTM cell state
+    auto cell_state = create_tensor<CLTensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
+
+    CLLSTMLayerQuantized lstmq;
+
+    lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
+                    &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
+                    &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
+
+    input.allocator()->allocate();
+    input_to_input_weights.allocator()->allocate();
+    input_to_forget_weights.allocator()->allocate();
+    input_to_cell_weights.allocator()->allocate();
+    input_to_output_weights.allocator()->allocate();
+    recurrent_to_input_weights.allocator()->allocate();
+    recurrent_to_forget_weights.allocator()->allocate();
+    recurrent_to_cell_weights.allocator()->allocate();
+    recurrent_to_output_weights.allocator()->allocate();
+    input_gate_bias.allocator()->allocate();
+    forget_gate_bias.allocator()->allocate();
+    cell_gate_bias.allocator()->allocate();
+    output_gate_bias.allocator()->allocate();
+    cell_state.allocator()->allocate();
+    output_state.allocator()->allocate();
+
+    // Fill weights and biases
+    fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 122,  130,
+                                                              124,  134,
+                                                               120,   122,
+                                                             134,  134 });
+
+    fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 204,  193,
+                                                                148,  59,
+                                                                113,  17,
+                                                                 66, 197 });
+
+    fill_tensor(input_to_cell_weights, std::vector<uint8_t> { 172,  101,
+                                                              184, 209,
+                                                              165,  82,
+                                                              108, 209 });
+
+    fill_tensor(input_to_output_weights, std::vector<uint8_t> { 203, 244,
+                                                                219, 114,
+                                                                130,  16,
+                                                                163, 222 });
+
+    fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 162, 168,  7,  95,
+                                                                    91, 155, 108, 216,
+                                                                   255, 100,  48, 188,
+                                                                    58,  37, 186, 147 });
+
+    fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> {  46,  58,  47, 170,
+                                                                    246,  96,  12,  99,
+                                                                     68,  23, 186, 161,
+                                                                    237, 164,  89,   6 });
+
+    fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> { 234,  99,   71, 206,
+                                                                  205, 159,   64, 253,
+                                                                  191, 148,  116,   8,
+                                                                  209, 136,   59, 138 });
+
+    fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> {  23, 241, 137, 36,
+                                                                    206,   5, 227, 56,
+                                                                    254, 176, 231, 47,
+                                                                     18, 201, 161, 11 });
+
+    fill_tensor(input_gate_bias, std::vector<int>  {-103038,   30525,  115255, -38154 });
+    fill_tensor(forget_gate_bias, std::vector<int> { -23428,  126970,  116806,  46307 });
+    fill_tensor(cell_gate_bias, std::vector<int>   { 128006,   69949,  -42808,  42568 });
+    fill_tensor(output_gate_bias, std::vector<int> { -67066,  -53607,   47233,  7300  });
+
+    SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
+
+    // Initialize state
+    fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128,
+                                                     128, 128, 128, 128 });
+    fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0,
+                                                   0, 0, 0, 0 });
+
+    // First input
+    fill_tensor(input, std::vector<uint8_t> { 106,  193,
+                                              155,  150 });
+
+    fill_tensor(expected_output, std::vector<uint8_t> { 128, 128,  31, 128,
+                                                        128, 128,  31, 128 });
+
+    lstmq.run();
+    validate(CLAccessor(output_state), expected_output);
+
+    // Second input
+    fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 5, 128,
+                                                        128, 128, 5, 128 });
+    lstmq.run();
+    validate(CLAccessor(output_state), expected_output);
+
+    // Third input
+    fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 1, 128,
+                                                        128, 128, 1, 128, });
+    lstmq.run();
+    validate(CLAccessor(output_state), expected_output);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // IntegrationTestCase
 // clang-format on
 // *INDENT-ON*
 
 TEST_SUITE_END() // LSTMLayerQuantized
-TEST_SUITE_END() // NEON
+TEST_SUITE_END() // CL
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index 2f604c9..d79374e 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -417,6 +417,13 @@
 const auto quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1,
                                                                          2)
                                                                          * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases = framework::dataset::make("result_fixedpoint_multiplier", 1073741823, 1073741825) * framework::dataset::make("result_shift", -3,
+                                                                    -2)
+                                                                    * framework::dataset::make("min", 0) * framework::dataset::make("max", 0) * framework::dataset::make("addBias", { false, true });
+
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", -3,
+                                                                         -1)
+                                                                         * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
 
 using NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture =
     GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture<Tensor, Accessor, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint>;
@@ -499,27 +506,44 @@
         validate(bias.info()->padding(), padding);
     }
 }
-
+TEST_SUITE(NoRelu)
+TEST_SUITE(MultSmallerEq1)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
                        quantize_down_int32_to_int16_scale_by_fixedpoint_cases))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
-
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+                       quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // NoRelu
 TEST_SUITE(BoundedReLu)
+TEST_SUITE(MultSmallerEq1)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
                        quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+                       quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
 TEST_SUITE_END() // BoundedReLu
-
 TEST_SUITE_END() // QuantizeDownInt32ToInt16ScaleByFixedPoint
-
 TEST_SUITE_END() // OutputStage
-
 TEST_SUITE_END() // GEMMLowp
 TEST_SUITE_END() // NEON
 } // namespace validation
diff --git a/tests/validation/NEON/LSTMLayerQuantized.cpp b/tests/validation/NEON/LSTMLayerQuantized.cpp
index 0935165..b57a8f7 100644
--- a/tests/validation/NEON/LSTMLayerQuantized.cpp
+++ b/tests/validation/NEON/LSTMLayerQuantized.cpp
@@ -77,7 +77,9 @@
 
 // *INDENT-OFF*
 // clang-format off
-TEST_CASE(IntegrationTestCaseSmall, framework::DatasetMode::PRECOMMIT)
+TEST_SUITE(IntegrationTestCase)
+TEST_SUITE(MultSmallerEq1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
 {
     const int batch_size  = 2;
     const int input_size  = 2;
@@ -216,7 +218,7 @@
     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
 }
 
-TEST_CASE(IntegrationTestCaseLarge, framework::DatasetMode::PRECOMMIT)
+TEST_CASE(RunLarge, framework::DatasetMode::PRECOMMIT)
 {
     const int batch_size  = 16;
     const int input_size  = 8;
@@ -453,11 +455,154 @@
     lstmq.run();
     validate(Accessor(output_state), expected_output, tolerance_qsymm16);
 }
+TEST_SUITE_END() // MultSmallerEq1
+
+TEST_SUITE(MultGreater1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
+{
+    //Input sequence length is 1
+    const int batch_size  = 2;
+    const int input_size  = 2;
+    const int output_size = 4;
+
+    QuantizationInfo qasymm(1.f / 128.f, 128);
+    QuantizationInfo qweights(1.f / 16.f, 16);
+    QuantizationInfo qsymm_3(8.f / 32768.f, 0);
+    QuantizationInfo qsymm_4(16.f / 32768.f, 0);
+
+    TensorShape input_shape{ input_size, batch_size };
+    TensorShape input_weights_shape{ input_size, output_size };
+    TensorShape recurrent_weights_shape{ output_size, output_size };
+    TensorShape output_shape{ output_size, batch_size};
+    TensorShape bias_shape{ output_size };
+
+    auto input_to_input_weights      = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_to_forget_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_to_cell_weights       = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_to_output_weights     = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_input_weights  = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_forget_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_cell_weights   = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto recurrent_to_output_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+    auto input_gate_bias             = create_tensor<Tensor>(bias_shape, DataType::S32);
+    auto forget_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
+    auto cell_gate_bias              = create_tensor<Tensor>(bias_shape, DataType::S32);
+    auto output_gate_bias            = create_tensor<Tensor>(bias_shape, DataType::S32);
+
+    // LSTM input
+    auto input = create_tensor<Tensor>(input_shape, DataType::QASYMM8, 1, qasymm);
+
+    // LSTM output state
+    auto output_state = create_tensor<Tensor>(output_shape, DataType::QASYMM8, 1, qasymm);
+
+    // LSTM cell state
+    auto cell_state = create_tensor<Tensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
+
+    NELSTMLayerQuantized lstmq;
+
+    lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
+                    &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
+                    &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
+
+    input.allocator()->allocate();
+    input_to_input_weights.allocator()->allocate();
+    input_to_forget_weights.allocator()->allocate();
+    input_to_cell_weights.allocator()->allocate();
+    input_to_output_weights.allocator()->allocate();
+    recurrent_to_input_weights.allocator()->allocate();
+    recurrent_to_forget_weights.allocator()->allocate();
+    recurrent_to_cell_weights.allocator()->allocate();
+    recurrent_to_output_weights.allocator()->allocate();
+    input_gate_bias.allocator()->allocate();
+    forget_gate_bias.allocator()->allocate();
+    cell_gate_bias.allocator()->allocate();
+    output_gate_bias.allocator()->allocate();
+    cell_state.allocator()->allocate();
+    output_state.allocator()->allocate();
+
+    // Fill weights and biases
+    fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 122,  130,
+                                                              124,  134,
+                                                               120,   122,
+                                                             134,  134 });
+
+    fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 204,  193,
+                                                                148,  59,
+                                                                113,  17,
+                                                                 66, 197 });
+
+    fill_tensor(input_to_cell_weights, std::vector<uint8_t> { 172,  101,
+                                                              184, 209,
+                                                              165,  82,
+                                                              108, 209 });
+
+    fill_tensor(input_to_output_weights, std::vector<uint8_t> { 203, 244,
+                                                                219, 114,
+                                                                130,  16,
+                                                                163, 222 });
+
+    fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 162, 168,  7,  95,
+                                                                    91, 155, 108, 216,
+                                                                   255, 100,  48, 188,
+                                                                    58,  37, 186, 147 });
+
+    fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> {  46,  58,  47, 170,
+                                                                    246,  96,  12,  99,
+                                                                     68,  23, 186, 161,
+                                                                    237, 164,  89,   6 });
+
+    fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> { 234,  99,   71, 206,
+                                                                  205, 159,   64, 253,
+                                                                  191, 148,  116,   8,
+                                                                  209, 136,   59, 138 });
+
+    fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> {  23, 241, 137, 36,
+                                                                    206,   5, 227, 56,
+                                                                    254, 176, 231, 47,
+                                                                     18, 201, 161, 11 });
+
+    fill_tensor(input_gate_bias, std::vector<int>  {-103038,   30525,  115255, -38154 });
+    fill_tensor(forget_gate_bias, std::vector<int> { -23428,  126970,  116806,  46307 });
+    fill_tensor(cell_gate_bias, std::vector<int>   { 128006,   69949,  -42808,  42568 });
+    fill_tensor(output_gate_bias, std::vector<int> { -67066,  -53607,   47233,  7300  });
+
+    SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
+
+    // Initialize state
+    fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128,
+                                                     128, 128, 128, 128 });
+    fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0,
+                                                   0, 0, 0, 0 });
+
+    // First input
+    fill_tensor(input, std::vector<uint8_t> { 106,  193,
+                                              155,  150 });
+
+    fill_tensor(expected_output, std::vector<uint8_t> { 128, 128,  31, 128,
+                                                        128, 128,  31, 128 });
+
+    lstmq.run();
+    validate(Accessor(output_state), expected_output);
+
+    // Second input
+    fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 5, 128,
+                                                        128, 128, 5, 128 });
+    lstmq.run();
+    validate(Accessor(output_state), expected_output);
+
+    // Third input
+    fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 1, 128,
+                                                        128, 128, 1, 128, });
+    lstmq.run();
+    validate(Accessor(output_state), expected_output);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // IntegrationTestCase
 // clang-format on
 // *INDENT-ON*
 
 TEST_SUITE_END() // LSTMLayerQuantized
-TEST_SUITE_END() // CL
+TEST_SUITE_END() // NEON
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 97d0532..4283cb5 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -112,7 +112,14 @@
         }
 
         // Fixed point multiplication
-        result = asymm_rounding_divide_by_pow2(asymm_int_mult(result, result_fixedpoint_multiplier), result_shift);
+        if(result_shift < 0)
+        {
+            result = asymm_int_mult(result * (1 << (-result_shift)), result_fixedpoint_multiplier);
+        }
+        else
+        {
+            result = asymm_rounding_divide_by_pow2(asymm_int_mult(result, result_fixedpoint_multiplier), result_shift);
+        }
 
         // Bounded ReLu
         if(min != max)