COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel

Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1259
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
index 26a1378..e3b3880 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
@@ -51,8 +51,10 @@
      *
      * @param[in]  input0    Input tensor containing the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
      * @param[in]  input1    Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+     * @param[in]  input2    Input tensor containing the bias matrix. Data type supported: same as @p input0.
      * @param[out] output    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
      * @param[in]  alpha     Weight of the matrix product
+     * @param[in]  beta      Weight of the matrix bias
      * @param[in]  lhs_info  LHS matrix information used to retrieve the number of rows to be processed by each thread. Only the following values are supported:
      *                       lhs_info.m0: 1,2,3,4,5,6,7,8
      * @param[in]  rhs_info  RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
@@ -61,14 +63,17 @@
      *                       rhs_info.transpose: true,false
      * @param[in]  gemm_info GEMM information used to retrieve the original dimensions of the input matrices
      */
-    void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
-                   const GEMMReshapeInfo &gemm_info);
+    void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+                   const GEMMRHSMatrixInfo &rhs_info,
+                   const GEMMReshapeInfo   &gemm_info);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
      *
      * @param[in] input0    Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
      * @param[in] input1    Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+     * @param[in] input2    Input tensor info containing the bias matrix. Data type supported: same as @p input0.
      * @param[in] output    Output tensor info. Data type supported: same as @p input0
      * @param[in] alpha     Weight of the matrix product
+     * @param[in] beta      Weight of the matrix bias
      * @param[in] lhs_info  LHS matrix information used to retrieve the number of rows to be processed by each thread. Only the following values are supported:
      *                      lhs_info.m0: 1,2,3,4,5,6,7,8
      * @param[in] rhs_info  RHS matrix information used for reshaping the input1 tensor.  Only the following values are supported:
@@ -79,8 +84,9 @@
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
-                           const GEMMReshapeInfo &gemm_info);
+    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+                           const GEMMRHSMatrixInfo &rhs_info,
+                           const GEMMReshapeInfo   &gemm_info);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
@@ -88,11 +94,14 @@
 private:
     const ICLTensor *_input0;
     const ICLTensor *_input1;
+    const ICLTensor *_input2;
     ICLTensor       *_output;
     bool             _slide_matrix_b;
     bool             _reinterpret_input_as_3d;
     bool             _reinterpret_output_as_3d;
     bool             _use_dummy_work_items;
+    bool             _add_bias;
+    bool             _broadcast_bias;
 };
 } // namespace arm_compute
-#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/
\ No newline at end of file
+#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 1787e68..d49315d 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1602,7 +1602,7 @@
 public:
     /** Default constructor */
     GEMMReshapeInfo()
-        : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false)
+        : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _broadcast_bias(false)
     {
     }
     /** Constructor
@@ -1615,11 +1615,12 @@
      * @param[in] depth_output_gemm3d       (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.
      *                                      If 0 the output will not be reinterpreted as 3D. Default 0
      * @param[in] reinterpret_input_as_3d   (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
-     *                                                 to perform 1x1 convolutions with the NHWC data layout)
+     *                                      to perform 1x1 convolutions with the NHWC data layout)
+     * @param[in] broadcast_bias            (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
      */
-    GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false)
+    GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool broadcast_bias = false)
         : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
-          _reinterpret_input_as_3d(reinterpret_input_as_3d)
+          _reinterpret_input_as_3d(reinterpret_input_as_3d), _broadcast_bias(broadcast_bias)
     {
     }
     /** Number of matrix A rows
@@ -1681,6 +1682,14 @@
     {
         return _reinterpret_input_as_3d;
     };
+    /** Flag which specifies whether to broadcast the shape of the bias tensor.
+     *
+     * @return True if the shape of the bias tensor is to be broadcasted.
+     */
+    bool broadcast_bias() const
+    {
+        return _broadcast_bias;
+    };
 
 private:
     const int  _m;
@@ -1690,6 +1699,7 @@
     const int  _mult_interleave4x4_height;
     const int  _depth_output_gemm3d;
     const bool _reinterpret_input_as_3d;
+    const bool _broadcast_bias;
 };
 
 struct DepthwiseConvolutionReshapeInfo
@@ -1749,7 +1759,7 @@
     /** Default constructor */
     GEMMInfo()
         : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(),
-          _fp_mixed_precision(false)
+          _fp_mixed_precision(false), _broadcast_bias(false)
     {
     }
     /** Constructor
@@ -1764,12 +1774,13 @@
      * @param[in] retain_internal_weights     (Optional) Retain the weights tensor from previous run
      * @param[in] gemmlowp_output_stage       (Optional) GEMMLowp Output stage info
      * @param[in] fp_mixed_precision          (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
-     *
+     * @param[in] broadcast_bias              (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
      */
     GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
-             GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false)
+             GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false)
         : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
-          _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision)
+          _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision),
+          _broadcast_bias(broadcast_bias)
     {
     }
     /** Flag which specifies if the matrix A has been reshaped
@@ -1838,6 +1849,14 @@
     {
         return _fp_mixed_precision;
     };
+    /** Flag which specifies whether to broadcast the shape of the bias tensor.
+     *
+     * @return True if the shape of the bias tensor is to be broadcasted.
+     */
+    bool broadcast_bias() const
+    {
+        return _broadcast_bias;
+    };
 
 private:
     const bool                    _is_a_reshaped;
@@ -1848,6 +1867,7 @@
     const bool                    _retain_internal_weights;
     const GEMMLowpOutputStageInfo _gemmlowp_output_stage;
     const bool                    _fp_mixed_precision;
+    const bool                    _broadcast_bias;
 };
 
 /** Winograd information */