Add test for ClGemmLowpMatrixMultiplyCore to test a batched matrix multiplication with variable input tensors

Resolves: COMPMID-5506
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Change-Id: I8345a3b7a83ef46f9ec7a77197cc65c933ec9ac6
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8239
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 6d073cd..f1ec81a 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -26,8 +26,8 @@
 
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "tests/framework/Fixture.h"
-#include "tests/validation/reference/GEMMLowp.h"
 #include "tests/validation/Validation.h"
+#include "tests/validation/reference/GEMMLowp.h"
 
 namespace arm_compute
 {
@@ -85,7 +85,7 @@
     }
 }
 
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
 TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
                                    GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
                                    QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false)
@@ -146,12 +146,25 @@
         ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
         fill(AccessorType(bias), 2);
     }
+
+    // Run with variable inputs.
+    if(run_twice)
+    {
+        gemmlowp.run();
+        fill(AccessorType(a), 3); // Fill tensors with new seed after run
+        fill(AccessorType(b), 4);
+        if(is_fused)
+        {
+            fill(AccessorType(bias), 5);
+        }
+    }
+
     // Compute GEMM function
     gemmlowp.run();
     return output;
 }
 
-template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false>
+template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false, bool run_twice = false>
 SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
                                                  DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
 {
@@ -196,11 +209,19 @@
         transpose_matrix<TW>(b, b_transposed);
     }
 
+    // Run with variable inputs.
+    if(run_twice)
+    {
+        reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
+        fill((pretranspose_A) ? a_transposed : a, 3);
+        fill((pretranspose_B) ? b_transposed : b, 4);
+    }
+
     return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
 }
 }
 
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
 class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
 {
 public:
@@ -214,12 +235,12 @@
 protected:
     TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
     {
-        return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t>(shape_a, shape_b, shape_output, a_offset, b_offset);
+        return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
     }
 
     SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
     {
-        return compute_gemmlowp_reference<reinterpret_input_as_3d>(shape_a, shape_b, shape_output, a_offset, b_offset);
+        return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
     }
 
     TensorType            _target{};
@@ -1395,7 +1416,7 @@
                                      broadcast_bias ? 1 : m,
                                      broadcast_bias ? 1 : batch_size);
 
-        _target    = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
+        _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
         if(gemm_validated == true)
         {
             _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset);
@@ -1584,7 +1605,7 @@
         const TensorShape lhs_shape(k, m, batch_size);
         const TensorShape rhs_shape(n, k, batch_size);
 
-        _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
+        _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
         if(gemm_validated == true)
         {
             _reference = compute_reference(lhs_shape, rhs_shape, data_type);