COMPMID-2847: Fuse output stage in GEMMLowpMatrixMultiplyReshapedOnlyRHS

Change-Id: Icd60eb368a34295434e8c141885b4666973a92a1
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2732
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index cdb78c2..54b63df 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -75,8 +75,9 @@
       _is_midgard(false),
       _reshape_b_only_on_first_run(false),
       _is_prepared(false),
-      _fuse_output_stage(false),
-      _convert_to_qasymm8(false)
+      _run_output_stage(false),
+      _convert_to_qasymm8(false),
+      _run_offset_contribution(false)
 {
 }
 
@@ -172,34 +173,19 @@
         _mtx_a_reduction_kernel.configure(a, &_vector_sum_row);
     }
 
+    GEMMKernelInfo gemm_kernel_info;
+    gemm_kernel_info.m                       = m;
+    gemm_kernel_info.n                       = n;
+    gemm_kernel_info.k                       = k;
+    gemm_kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
+    gemm_kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+    gemm_kernel_info.lhs_info                = lhs_info;
+    gemm_kernel_info.rhs_info                = rhs_info;
+    gemm_kernel_info.a_offset                = _a_offset;
+    gemm_kernel_info.b_offset                = _b_offset;
     // If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage
     if(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE)
     {
-        _fuse_output_stage = true;
-
-        _memory_group.manage(&_mm_result_s32);
-
-        if(_is_gemm_reshaped)
-        {
-            // Configure and tune matrix multiply kernel
-            _mm_reshaped_only_rhs_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
-        }
-        else
-        {
-            if(_is_midgard)
-            {
-                // Configure matrix multiply kernel
-                _mm_midgard_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
-            }
-            else
-            {
-                // Pick up the GEMM configuration
-                std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
-
-                // Configure matrix multiply kernel
-                _mm_native_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
-            }
-        }
         // Configure offset contribution kernel
         const size_t num_filters = (gemm_info.gemmlowp_output_stage().is_quantized_per_channel) ? gemm_info.gemmlowp_output_stage().gemmlowp_multipliers.size() : 1;
 
@@ -208,8 +194,46 @@
 
         GEMMLowpOutputStageInfo gemmlowp_output_stage = gemm_info.gemmlowp_output_stage();
         gemmlowp_output_stage.output_data_type        = _matrix_a->info()->data_type();
-        _offset_contribution_output_stage_kernel.configure(&_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0),
-                                                           _a_offset, _b_offset, gemmlowp_output_stage, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
+
+        gemm_kernel_info.output_stage = gemmlowp_output_stage;
+
+        if(_is_gemm_reshaped && gemmlowp_output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+        {
+            // Configure and tune matrix multiply kernel with fused output stage
+            _mm_reshaped_only_rhs_kernel.configure(_matrix_a, matrix_b, output, gemm_kernel_info, _a_offset == 0 ? nullptr : &_vector_sum_col,
+                                                   _b_offset == 0 ? nullptr : &_vector_sum_row, c, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
+        }
+        else
+        {
+            _run_output_stage = true;
+
+            _memory_group.manage(&_mm_result_s32);
+
+            if(_is_gemm_reshaped)
+            {
+                _mm_reshaped_only_rhs_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, gemm_kernel_info);
+            }
+            else
+            {
+                if(_is_midgard)
+                {
+                    // Configure matrix multiply kernel
+                    _mm_midgard_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+                }
+                else
+                {
+                    // Pick up the GEMM configuration
+                    std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+
+                    // Configure matrix multiply kernel
+                    _mm_native_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+                }
+                _offset_contribution_output_stage_kernel.configure(&_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0),
+                                                                   _a_offset, _b_offset, gemmlowp_output_stage, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
+
+                _mm_result_s32.allocator()->allocate();
+            }
+        }
 
         _gemm_output_stage_multipliers.allocator()->allocate();
         _gemm_output_stage_shifts.allocator()->allocate();
@@ -220,15 +244,14 @@
         std::memcpy(_gemm_output_stage_shifts.ptr_to_element(Coordinates(0)), gemm_info.gemmlowp_output_stage().gemmlowp_shifts.data(), num_filters * sizeof(int32_t));
         _gemm_output_stage_multipliers.unmap();
         _gemm_output_stage_shifts.unmap();
-
-        _mm_result_s32.allocator()->allocate();
     }
     else
     {
+        _run_offset_contribution = true;
         if(_is_gemm_reshaped)
         {
             // Configure and tune matrix multiply kernel
-            _mm_reshaped_only_rhs_kernel.configure(_matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+            _mm_reshaped_only_rhs_kernel.configure(_matrix_a, matrix_b, output, gemm_kernel_info);
         }
         else
         {
@@ -350,61 +373,85 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row));
     }
 
+    GEMMKernelInfo gemm_kernel_info;
+    gemm_kernel_info.m                       = m;
+    gemm_kernel_info.n                       = n;
+    gemm_kernel_info.k                       = k;
+    gemm_kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
+    gemm_kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+    gemm_kernel_info.lhs_info                = lhs_info;
+    gemm_kernel_info.rhs_info                = rhs_info;
+    gemm_kernel_info.a_offset                = a_offset;
+    gemm_kernel_info.b_offset                = b_offset;
     if(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE)
     {
-        TensorInfo mm_result_s32_info{};
-
-        if(reshape_matrix_b)
-        {
-            // Output tensor auto inizialitation if not yet initialized
-            auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, reshape_info)).set_data_type(DataType::S32));
-
-            // Validate matrix multiply
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
-        }
-        else
-        {
-            // Output tensor auto inizialitation if not yet initialized
-            auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32));
-
-            if(is_midgard)
-            {
-                // Validate matrix multiply
-                ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, reshape_info));
-            }
-            else
-            {
-                // Pick up the GEMM configuration
-                std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
-
-                // Validate matrix multiply
-                ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
-            }
-        }
-
-        // Validate offset contribution kernel
         const size_t num_filters = (gemm_info.gemmlowp_output_stage().is_quantized_per_channel) ? gemm_info.gemmlowp_output_stage().gemmlowp_multipliers.size() : 1;
 
         const TensorInfo gemm_output_stage_multipliers_shifts_info(TensorInfo(TensorShape(num_filters), 1, DataType::S32));
 
         GEMMLowpOutputStageInfo gemmlowp_output_stage = gemm_info.gemmlowp_output_stage();
         gemmlowp_output_stage.output_data_type        = a->data_type();
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionOutputStageKernel::validate(&mm_result_s32_info,
-                                                                                            a_offset == 0 ? nullptr : &info_vector_sum_col,
-                                                                                            b_offset == 0 ? nullptr : &info_vector_sum_row,
-                                                                                            c,
-                                                                                            output,
-                                                                                            a_offset, b_offset,
-                                                                                            gemmlowp_output_stage,
-                                                                                            &gemm_output_stage_multipliers_shifts_info,
-                                                                                            &gemm_output_stage_multipliers_shifts_info));
+
+        gemm_kernel_info.output_stage = gemmlowp_output_stage;
+        if(reshape_matrix_b && gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::validate(matrix_a_info, matrix_b_info, output, gemm_kernel_info,
+                                                                                                a_offset == 0 ? nullptr : &info_vector_sum_col,
+                                                                                                b_offset == 0 ? nullptr : &info_vector_sum_row,
+                                                                                                c,
+                                                                                                &gemm_output_stage_multipliers_shifts_info,
+                                                                                                &gemm_output_stage_multipliers_shifts_info));
+        }
+        else
+        {
+            TensorInfo mm_result_s32_info{};
+
+            if(reshape_matrix_b)
+            {
+                // Output tensor auto inizialitation if not yet initialized
+                auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, reshape_info)).set_data_type(DataType::S32));
+
+                // Validate matrix multiply
+                ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, gemm_kernel_info));
+            }
+            else
+            {
+                // Output tensor auto inizialitation if not yet initialized
+                auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32));
+
+                if(is_midgard)
+                {
+                    // Validate matrix multiply
+                    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, reshape_info));
+                }
+                else
+                {
+                    // Pick up the GEMM configuration
+                    std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+
+                    // Validate matrix multiply
+                    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
+                }
+            }
+
+            // Validate offset contribution kernel
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionOutputStageKernel::validate(&mm_result_s32_info,
+                                                                                                a_offset == 0 ? nullptr : &info_vector_sum_col,
+                                                                                                b_offset == 0 ? nullptr : &info_vector_sum_row,
+                                                                                                c,
+                                                                                                output,
+                                                                                                a_offset, b_offset,
+                                                                                                gemmlowp_output_stage,
+                                                                                                &gemm_output_stage_multipliers_shifts_info,
+                                                                                                &gemm_output_stage_multipliers_shifts_info));
+        }
     }
     else
     {
         if(reshape_matrix_b)
         {
             // Validate matrix multiply
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::validate(matrix_a_info, matrix_b_info, output, gemm_kernel_info));
         }
         else
         {
@@ -458,6 +505,12 @@
         CLScheduler::get().enqueue(_mtx_b_reduction_kernel, false);
     }
 
+    // Run matrix A reduction kernel only if _b_offset is not equal to 0
+    if(_b_offset != 0)
+    {
+        CLScheduler::get().enqueue(_mtx_a_reduction_kernel, false);
+    }
+
     // Run matrix multiply
     if(_is_gemm_reshaped)
     {
@@ -474,19 +527,12 @@
             CLScheduler::get().enqueue(_mm_native_kernel, false);
         }
     }
-
-    // Run matrix A reduction kernel only if _b_offset is not equal to 0
-    if(_b_offset != 0)
-    {
-        CLScheduler::get().enqueue(_mtx_a_reduction_kernel, false);
-    }
-
-    if(_fuse_output_stage)
+    if(_run_output_stage)
     {
         // Run offset contribution/output stage kernel
         CLScheduler::get().enqueue(_offset_contribution_output_stage_kernel, true);
     }
-    else
+    if(_run_offset_contribution)
     {
         // Run offset contribution kernel
         CLScheduler::get().enqueue(_offset_contribution_kernel, true);