COMPMID-2847: Fuse output stage in GEMMLowpMatrixMultiplyReshapedOnlyRHS

Change-Id: Icd60eb368a34295434e8c141885b4666973a92a1
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2732
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h
index 9dd5496..7845d24 100644
--- a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,6 +25,7 @@
 #define ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H
 
 #include "arm_compute/core/CL/ICLKernel.h"
+#include "arm_compute/core/KernelDescriptors.h"
 
 namespace arm_compute
 {
@@ -33,6 +34,7 @@
 /** OpenCL kernel to multiply matrices with QASYMM8 data type when only the input matrix RHS (input1) has been reshaped
  *
  * @note The input matrix input1 must be reshaped through @ref CLGEMMReshapeRHSMatrixKernel
+ * @note For fused output stage, only GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT type is supported
  */
 class CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel : public ICLKernel
 {
@@ -49,37 +51,59 @@
     CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel &operator=(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel &&) = default;
     /** Initialise the kernel's input and output.
      *
-     * @param[in]  input0    Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
-     * @param[in]  input1    Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0
-     * @param[out] output    Output tensor to store the result of matrix multiplication. Data type supported: S32
-     * @param[in]  lhs_info  LHS matrix information used to retrieve the number of rows to be processed by each thread
-     *                       lhs_info.m0: 2,3,4,5,6,7,8
-     *                       lhs_info.k0: 2,3,4,8,16
-     * @param[in]  rhs_info  RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
-     *                       rhs_info.n0: 2,3,4,8,16
-     *                       rhs_info.k0: 2,3,4,8,16
-     *                       rhs_info.transpose: true
-     * @param[in]  gemm_info GEMM information used to retrieve the original dimensions of the input matrices
+     * @param[in]  input0             Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
+     * @param[in]  input1             Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0
+     * @param[out] output             Output tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/S32.
+     * @param[in]  gemm_info          GEMM information used to retrieve the original dimensions of the input matrices, output stage information and RHS/LHS info.
+     *                                Only the following values are supported for LHS info:
+     *                                lhs_info.m0: 2,3,4,5,6,7,8
+     *                                lhs_info.k0: 2,3,4,8,16
+     *                                Only the following values are supported for RHS info:
+     *                                rhs_info.n0: 2,3,4,8,16
+     *                                rhs_info.k0: same as lhs_info.k0
+     *                                rhs_info.transpose: true
+     * @param[in]  vector_sum_col     (Optional) Input row-vector of sums of all the entries in each column of matrix B.
+     *                                Note: vector_sum_col can be a nullptr in case a_offset = 0. Data type supported: S32
+     * @param[in]  vector_sum_row     (Optional) Input row-vector of sums of all the entries in each row of matrix A.
+     *                                Note: vector_sum_row can be a nullptr in case b_offset = 0. Data type supported: S32
+     * @param[in]  bias               (Optional) Biases tensor. Only shared biases supported and it can be a nullptr if the addition of biases is not required.
+     *                                Biases are 1D tensor with dimensions [OFM]. Data type supported: S32.
+     * @param[in]  output_multipliers (Optional) Output multipliers tensor. In case of per-channel quantization, the number of multipliers must be equal to the number of filters (OFM).
+     *                                Supported data types: S32.
+     * @param[in]  output_shifts      (Optional) Output shifts tensor. In case of per-channel quantization, the number of multipliers must be equal to the number of filters (OFM).
+     *                                Supported data types: S32.
      */
-    void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info);
+    void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMKernelInfo &gemm_info, const ICLTensor *vector_sum_col = nullptr,
+                   const ICLTensor *vector_sum_row = nullptr, const ICLTensor *bias = nullptr, const ICLTensor *output_multipliers = nullptr, const ICLTensor *output_shifts = nullptr);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel
      *
-     * @param[in] input0    Input tensor info for the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
-     * @param[in] input1    Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0
-     * @param[in] output    Output tensor info. Data type supported: S32
-     * @param[in] lhs_info  LHS matrix information used to retrieve the number of rows to be processed by each thread
-     *                      lhs_info.m0: 2,3,4,5,6,7,8
-     *                      lhs_info.k0: 2,3,4,8,16
-     * @param[in] rhs_info  RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
-     *                      rhs_info.n0: 2,3,4,8,16
-     *                      rhs_info.k0: same as lhs_info.k0
-     *                      rhs_info.transpose: true
-     * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices
+     * @param[in] input0             Input tensor info for the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
+     * @param[in] input1             Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0
+     * @param[in] output             Output tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/S32.
+     * @param[in] gemm_info          GEMM information used to retrieve the original dimensions of the input matrices, output stage information and RHS/LHS info.
+     *                               Only the following values are supported for LHS info:
+     *                               lhs_info.m0: 2,3,4,5,6,7,8
+     *                               lhs_info.k0: 2,3,4,8,16
+     *                               Only the following values are supported for RHS info:
+     *                               rhs_info.n0: 2,3,4,8,16
+     *                               rhs_info.k0: same as lhs_info.k0
+     *                               rhs_info.transpose: true
+     * @param[in] vector_sum_col     (Optional) Input row-vector info of sums of all the entries in each column of matrix B.
+     *                               Note: vector_sum_col can be a nullptr in case a_offset = 0. Data type supported: S32
+     * @param[in] vector_sum_row     (Optional) Input row-vector info of sums of all the entries in each row of matrix A.
+     *                               Note: vector_sum_row can be a nullptr in case b_offset = 0. Data type supported: S32
+     * @param[in] bias               (Optional) Biases tensor info. Only shared biases supported and it can be a nullptr if the addition of biases is not required.
+     *                               Biases are 1D tensor with dimensions [OFM]. Data type supported: S32.
+     * @param[in] output_multipliers (Optional) Output multipliers tensor info. In case of per-channel quantization, the number of multipliers must be equal to the number of filters (OFM).
+     *                               Supported data types: S32.
+     * @param[in] output_shifts      (Optional) Output shifts tensor info. In case of per-channel quantization, the number of multipliers must be equal to the number of filters (OFM).
+     *                               Supported data types: S32.
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
-                           const GEMMReshapeInfo &gemm_info);
+    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMKernelInfo &gemm_info, const ITensorInfo *vector_sum_col = nullptr,
+                           const ITensorInfo *vector_sum_row = nullptr, const ITensorInfo *bias = nullptr, const ITensorInfo *output_multipliers = nullptr,
+                           const ITensorInfo *output_shifts = nullptr);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
@@ -88,10 +112,17 @@
     const ICLTensor *_input0;
     const ICLTensor *_input1;
     ICLTensor       *_output;
+    const ICLTensor *_vector_sum_col;
+    const ICLTensor *_vector_sum_row;
+    const ICLTensor *_bias;
+    const ICLTensor *_output_multipliers;
+    const ICLTensor *_output_shifts;
     bool             _slide_matrix_b;
     bool             _reinterpret_input_as_3d;
     bool             _reinterpret_output_as_3d;
     bool             _use_dummy_work_items;
+    bool             _is_quantized_per_channel;
+    bool             _fuse_output_stage;
 };
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H */
\ No newline at end of file
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index 4b04beb..58400b1 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -54,14 +54,21 @@
 /** Descriptor used by the GEMM kernels */
 struct GEMMKernelInfo
 {
-    unsigned int        m{ 0 };                           /**< Number of LHS rows*/
-    unsigned int        n{ 0 };                           /**< Number of RHS columns*/
-    unsigned int        k{ 0 };                           /**< Number of LHS columns or RHS rows */
-    unsigned int        depth_output_gemm3d{ 0 };         /**< Depth of the output tensor in case is reinterpreted as 3D */
-    bool                reinterpret_input_as_3d{ false }; /**< Flag used to reinterpret the input as 3D */
-    bool                broadcast_bias{ false };          /**< Flag used to broadcase the bias addition */
-    bool                fp_mixed_precision{ false };      /**< Flag used to indicate wider accumulators (32 bit instead of 16 for FP16). */
-    ActivationLayerInfo activation_info{};                /**< Activation function to perform after the matrix multiplication */
+    unsigned int            m{ 0 };                           /**< Number of LHS rows*/
+    unsigned int            n{ 0 };                           /**< Number of RHS columns*/
+    unsigned int            k{ 0 };                           /**< Number of LHS columns or RHS rows */
+    unsigned int            depth_output_gemm3d{ 0 };         /**< Depth of the output tensor in case is reinterpreted as 3D */
+    bool                    reinterpret_input_as_3d{ false }; /**< Flag used to reinterpret the input as 3D */
+    bool                    broadcast_bias{ false };          /**< Flag used to broadcast the bias addition */
+    bool                    fp_mixed_precision{ false };      /**< Flag used to indicate wider accumulators (32 bit instead of 16 for FP16). */
+    ActivationLayerInfo     activation_info{};                /**< Activation function to perform after the matrix multiplication */
+    int                     mult_transpose1xW_width{ 1 };     /**< Multiplication factor for the width of the 1xW transposed block */
+    int                     mult_interleave4x4_height{ 1 };   /**< Multiplication factor for the height of the 4x4 interleaved block */
+    GEMMLHSMatrixInfo       lhs_info{};                       /**< LHS matrix information used to retrieve the number of rows processed by each thread */
+    GEMMRHSMatrixInfo       rhs_info{};                       /**< RHS matrix information used for reshaping the RHS matrix */
+    int32_t                 a_offset{ 0 };                    /**< Offset to be added to each element of the matrix A */
+    int32_t                 b_offset{ 0 };                    /**< Offset to be added to each element of the matrix B */
+    GEMMLowpOutputStageInfo output_stage{};                   /**< GEMMLowp output stage information */
 };
 
 /** Descriptor used by the depthwise convolution kernels */