COMPMID-1706: Fuse the bias addition within CLGEMM

Change-Id: I378f2023f4fa010f195f76716ac07aa86279bfae
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/280
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
index 797bda8..724a7d6 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
@@ -30,12 +30,14 @@
 {
 class ICLTensor;
 
-/** OpenCL kernel to multiply two input matrices "A" and "B" . All elements of the output matrix will be multiplied by alpha
+/** OpenCL kernel to multiply two input matrices "A" and "B" and add a vector "C" if provided. All elements of the output matrix will be multiplied by alpha. In case vector C is passed, it will be added to the previous result (a broadcast addition will be performed).
  *
  * @note If the input tensors @p input0 and @p input1 have been reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel" and @ref CLGEMMReshapeRHSMatrixKernel,
  *       the flag @p is_interleaved_transposed must be set to true
  *
- * @attention The second input tensor must have at least 2 dimensions (matrix)
+ * @attention Vector C (@p input2) must be 1D. A broadcast addition is performed.
+ *
+ * @attention @p input1 tensor must have at least 2 dimensions (matrix)
  *
  */
 class CLGEMMMatrixMultiplyKernel : public ICLKernel
@@ -55,21 +57,25 @@
      *
      * @param[in]  input0                    Input tensor containing the Matrix A. Data types supported: F16/F32
      * @param[in]  input1                    Input tensor containing the Matrix B. Data type supported: same as @p input0
+     * @param[in]  input2                    Input tensor containing the Vector C. Can be nullptr. Data type supported: same as @p input0
      * @param[out] output                    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
      * @param[in]  alpha                     Weight of the matrix product
+     * @param[in]  beta                      (Optional) Weight of vector C. Default value is 0. Only beta = 1 is currently supported.
      * @param[in]  is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel
      * @param[in]  reshape_info              (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
      * @param[in]  fp_mixed_precision        (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy
      *
      */
-    void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo(),
-                   bool fp_mixed_precision = false);
+    void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta = 0.f,
+                   bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo(), bool fp_mixed_precision = false);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyKernel
      *
-     * @param[in] input0                    Input tensor containing the Matrix A. Data types supported: F16/F32
-     * @param[in] input1                    Input tensor containing the Matrix B. Data type supported: same as @p input0
+     * @param[in] input0                    Input tensor containing the Matrix A info. Data types supported: F16/F32
+     * @param[in] input1                    Input tensor containing the Matrix B info. Data type supported: same as @p input0
+     * @param[in] input2                    Input tensor containing the Vector C info. Can be nullptr. Data type supported: same as @p input0
      * @param[in] output                    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
      * @param[in] alpha                     Weight of the matrix product
+     * @param[in] beta                      Weight of vector C. Default value is 0. Only beta = 1 is currently supported.
      * @param[in] is_interleaved_transposed True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel
      * @param[in] reshape_info              GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
      * @param[in] gpu_target                GPU Target
@@ -77,8 +83,8 @@
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
-                           GPUTarget gpu_target, bool fp_mixed_precision = false);
+    static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta,
+                           bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision = false);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
@@ -86,10 +92,12 @@
 public:
     const ICLTensor *_input0;
     const ICLTensor *_input1;
+    const ICLTensor *_input2;
     ICLTensor       *_output;
     bool             _slide_matrix_b;
     bool             _reinterpret_input_as_3d;
     bool             _reinterpret_output_as_3d;
+    bool             _has_vec_c;
 };
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CLGEMMMATRIXMULTIPLYKERNEL_H__ */
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index d6d88ce..e800dd7 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -136,7 +136,7 @@
     CLGEMM                                              _mm_gemm;
     CLGEMMLowpMatrixMultiplyCore                        _mm_gemmlowp;
     CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
-    CLGEMMMatrixAccumulateBiasesKernel                  _accumulate_biases_kernel;
+    CLGEMMMatrixAccumulateBiasesKernel                  _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer
     CLTensor                                            _flatten_output;
     CLTensor                                            _gemmlowp_output;
     CLTensor                                            _converted_weights_output;
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index d7694a8..b304576 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -163,7 +163,7 @@
      * @param[in, out] output                Output tensor. Data types supported: Same as @p input,
      *                                       except for input of QASYMM8 type where output should be of S32 type.
      * @param[in]      gemmlowp_output_stage GEMMLowp output stage info
-     * @param[in]      gemm_3d_depth         (Optional) Depth of GEMM 3D (Defaults to 1)
+     * @param[in]      gemm_3d_depth         Depth of GEMM 3D
      */
     void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth = 1);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines
@@ -175,13 +175,14 @@
      * @param[in] biases                Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
      *                                  Data type supported: Should match @p input data type, except for input of QASYMM8 type where biases should be of S32 type.
      * @param[in] gemmlowp_output_stage GEMMLowp output stage info
-     * @param[in] gemm_3d_depth         (Optional) Depth of GEMM 3D (Defaults to 1)
-     * @param[in] skip_im2col           (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false)
+     * @param[in] gemm_3d_depth         Depth of GEMM 3D
+     * @param[in] skip_im2col           Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout.
+     * @param[in] run_addition          Flag which specifies if @ref CLGEMMMatrixMatrixMultiplyAddition to be run.
      *
      * @return a status
      */
     static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
-                              int gemm_3d_depth = 1, bool skip_im2col = false);
+                              int gemm_3d_depth, bool skip_im2col, bool run_addition);
 
 private:
     CLMemoryGroup                        _memory_group;
@@ -207,6 +208,7 @@
     bool _is_quantized;
     bool _is_activationlayer_enabled;
     bool _is_prepared;
+    bool _run_addition;
 };
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 3359a42..4736f80 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1784,6 +1784,8 @@
 /** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
@@ -1796,6 +1798,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1808,6 +1812,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1821,6 +1829,9 @@
  */
 __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                 VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
@@ -1910,6 +1921,16 @@
     c30 = c30 * (float4)ALPHA;
 #endif // defined(ALPHA)
 
+#if defined(ADD_VEC_C)
+    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    float4          c0        = vload4(0, src2_addr);
+
+    c00 += c0;
+    c10 += c0;
+    c20 += c0;
+    c30 += c0;
+#endif /* defined(ADD_VEC_C) */
+
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
@@ -1959,7 +1980,9 @@
 }
 
 /** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
+ *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
+ *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
  *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
@@ -1974,6 +1997,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1986,6 +2011,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1999,6 +2028,9 @@
  */
 __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                         VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                          IMAGE_DECLARATION(dst),
                                                          uint src0_stride_z,
                                                          uint src1_stride_z,
@@ -2223,6 +2255,28 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+#if defined(ADD_VEC_C)
+    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    float4          c0        = vload4(0, src2_addr);
+
+    c00 += c0.s0;
+    c01 += c0.s1;
+    c02 += c0.s2;
+    c03 += c0.s3;
+    c10 += c0.s0;
+    c11 += c0.s1;
+    c12 += c0.s2;
+    c13 += c0.s3;
+    c20 += c0.s0;
+    c21 += c0.s1;
+    c22 += c0.s2;
+    c23 += c0.s3;
+    c30 += c0.s0;
+    c31 += c0.s1;
+    c32 += c0.s2;
+    c33 += c0.s3;
+#endif /* defined(ADD_VEC_C) */
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -2275,6 +2329,8 @@
 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
@@ -2287,6 +2343,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2299,6 +2357,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2312,6 +2374,9 @@
  */
 __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                 VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
@@ -2401,6 +2466,20 @@
     c30 = c30 * (half8)ALPHA;
 #endif // defined(ALPHA)
 
+#if defined(ADD_VEC_C)
+    // *INDENT-OFF*
+    // clang-format off
+    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    half8          c0        = vload8(0, src2_addr);
+    // clang-format on
+    // *INDENT-ON*
+
+    c00 += c0;
+    c10 += c0;
+    c20 += c0;
+    c30 += c0;
+#endif /* defined(ADD_VEC_C) */
+
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
@@ -2452,6 +2531,8 @@
 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
@@ -2464,6 +2545,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2476,6 +2559,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2489,6 +2576,9 @@
  */
 __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
                                                        IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                       VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                        IMAGE_DECLARATION(dst),
                                                        uint src0_stride_z,
                                                        uint src1_stride_z,
@@ -2578,6 +2668,20 @@
     c30 = c30 * (float8)ALPHA;
 #endif // defined(ALPHA)
 
+#if defined(ADD_VEC_C)
+    // *INDENT-OFF*
+    // clang-format off
+    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    float8         c0        = convert_float8(vload8(0, src2_addr));
+    // clang-format on
+    // *INDENT-ON*
+
+    c00 += c0;
+    c10 += c0;
+    c20 += c0;
+    c30 += c0;
+#endif /* defined(ADD_VEC_C) */
+
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
@@ -2629,6 +2733,8 @@
 /** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
@@ -2641,6 +2747,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2653,6 +2761,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2663,6 +2775,9 @@
  */
 __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                         VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                          IMAGE_DECLARATION(dst),
                                                          uint src0_stride_z,
                                                          uint src1_stride_z,
@@ -2834,6 +2949,20 @@
     c30 = c30 * (half8)ALPHA;
 #endif // defined(ALPHA)
 
+#if defined(ADD_VEC_C)
+    // *INDENT-OFF*
+    // clang-format off
+    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    half8          c0        = vload8(0, src2_addr);
+    // clang-format on
+    // *INDENT-ON*
+
+    c00 += c0;
+    c10 += c0;
+    c20 += c0;
+    c30 += c0;
+#endif /* defined(ADD_VEC_C) */
+
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
@@ -2892,7 +3021,9 @@
 #if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
 #if defined(DATA_TYPE)
 #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
-/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
+/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
+ *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
  *
  * @note This OpenCL kernel works with floating point data types (F16/F32)
  * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
@@ -2908,6 +3039,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2920,6 +3053,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2934,6 +3071,9 @@
  */
 __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
                                      IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                     VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                      IMAGE_DECLARATION(dst),
                                      uint src0_stride_z,
                                      uint src1_stride_z,
@@ -3134,6 +3274,26 @@
     acc3 = acc3 * (VECTOR_TYPE)ALPHA;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
 
+#if defined(ADD_VEC_C)
+    // *INDENT-OFF*
+    // clang-format off
+    __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    VECTOR_TYPE         c0        = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
+    // clang-format on
+    // *INDENT-ON*
+
+    acc0 += c0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    acc1 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    acc2 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    acc3 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif /* defined(ADD_VEC_C) */
+
     int z = get_global_id(2);
 
 #if defined(REINTERPRET_OUTPUT_AS_3D)
@@ -3204,6 +3364,8 @@
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
@@ -3219,6 +3381,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3231,6 +3395,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3245,6 +3413,9 @@
  */
 __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                 VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
@@ -3599,6 +3770,34 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+#if defined(ADD_VEC_C)
+    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    float4          c0        = vload4(0, src2_addr);
+
+    acc00 += c0.s0;
+    acc01 += c0.s1;
+    acc02 += c0.s2;
+    acc03 += c0.s3;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    acc10 += c0.s0;
+    acc11 += c0.s1;
+    acc12 += c0.s2;
+    acc13 += c0.s3;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    acc20 += c0.s0;
+    acc21 += c0.s1;
+    acc22 += c0.s2;
+    acc23 += c0.s3;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    acc30 += c0.s0;
+    acc31 += c0.s1;
+    acc32 += c0.s2;
+    acc33 += c0.s3;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif /* defined(ADD_VEC_C) */
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -3658,6 +3857,8 @@
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
  * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
@@ -3674,6 +3875,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3686,6 +3889,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3700,6 +3907,9 @@
  */
 __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
                                                       IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                      VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                       IMAGE_DECLARATION(dst),
                                                       uint src0_stride_z,
                                                       uint src1_stride_z,
@@ -3986,6 +4196,26 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+#if defined(ADD_VEC_C)
+    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    float2          c0        = vload2(0, src2_addr);
+
+    acc00 += c0.s0;
+    acc01 += c0.s1;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    acc10 += c0.s0;
+    acc11 += c0.s1;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    acc20 += c0.s0;
+    acc21 += c0.s1;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    acc30 += c0.s0;
+    acc31 += c0.s1;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif /* defined(ADD_VEC_C) */
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -4046,6 +4276,8 @@
 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
@@ -4061,6 +4293,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4073,6 +4307,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4087,6 +4325,9 @@
  */
 __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
                                                        IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                       VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                        IMAGE_DECLARATION(dst),
                                                        uint src0_stride_z,
                                                        uint src1_stride_z,
@@ -4327,6 +4568,26 @@
 #endif // defined(ALPHA)
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
+#if defined(ADD_VEC_C)
+    // *INDENT-OFF*
+    // clang-format off
+    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    half8          c0        = vload8(0, src2_addr);
+    // clang-format on
+    // *INDENT-ON*
+
+    hacc0 += c0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    hacc1 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    hacc2 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    hacc3 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif /* defined(ADD_VEC_C) */
+
     int z = get_global_id(2);
 
     // Compute destination address
@@ -4394,6 +4655,8 @@
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
  *
+ * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+ *
  * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
@@ -4409,6 +4672,8 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
+ * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4421,6 +4686,10 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4435,6 +4704,9 @@
  */
 __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
+#if defined(ADD_VEC_C)
+                                                 VECTOR_DECLARATION(src2),
+#endif /* defined(ADD_VEC_C) */
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
@@ -4659,6 +4931,26 @@
     acc3 = acc3 * (half8)ALPHA;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
 
+#if defined(ADD_VEC_C)
+    // *INDENT-OFF*
+    // clang-format off
+    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
+    half8          c0        = vload8(0, src2_addr);
+    // clang-format on
+    // *INDENT-ON*
+
+    acc0 += c0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    acc1 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    acc2 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    acc3 += c0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif /* defined(ADD_VEC_C) */
+
     int z = get_global_id(2);
 
     // Compute destination address
diff --git a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
index 825d7fb..803ed30 100644
--- a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -86,14 +86,13 @@
     _input  = input;
     _output = output;
 
-    std::ostringstream ma_arguments;
-    ma_arguments << "-DBETA=" << beta;
-    std::set<std::string> build_opts;
-    build_opts.emplace(ma_arguments.str());
+    // Create build options
+    CLBuildOptions build_opts;
+    build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(beta));
 
     // Create kernel
     std::string data_type_name = lower_string(string_from_data_type(input->info()->data_type()));
-    _kernel                    = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_ma_" + data_type_name), build_opts));
+    _kernel                    = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_ma_" + data_type_name), build_opts.options()));
 
     // Configure kernel window
     auto win_config = validate_and_configure_window(input->info(), output->info());
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index b667621..2b004c2 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -48,8 +48,8 @@
 {
 using ElementsProcessed = Steps;
 
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
-                                 bool fp_mixed_precision)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float beta,
+                                 bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
     ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
@@ -61,9 +61,20 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
 
+    const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f;
+    const bool has_vec_c   = input2 != nullptr && beta != 0.f;
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(has_vec_c && !is_beta_one, "Adding input2 is only supported for beta equal to 1");
+
     if(!is_interleaved_transposed)
     {
         ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+
+        if(has_vec_c)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2);
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor");
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->dimension(0) != input1->dimension(0), "Length of Vector C must match the number of columns of matrix B");
+        }
     }
     else
     {
@@ -101,6 +112,12 @@
 
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+
+        if(has_vec_c)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2);
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor");
+        }
     }
 
     if(output->total_size() != 0)
@@ -113,10 +130,11 @@
     return Status{};
 }
 
-inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output,
-                                                               bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
+inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output,
+                                                               float beta, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
                                                                ElementsProcessed &num_elements_processed)
 {
+    ARM_COMPUTE_UNUSED(beta);
     bool   window_changed = false;
     Window win{};
     Window win_out{};
@@ -126,6 +144,7 @@
     unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
     bool           reinterpret_input_as_3d             = reshape_info.reinterpret_input_as_3d();
     bool           reinterpret_output_as_3d            = (reshape_info.depth_output_gemm3d() != 0);
+    const bool     has_vec_c                           = input2 != nullptr && beta != 0.f;
 
     // In case both input and output have to be reinterpreted as 3D tensors,
     // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -176,6 +195,11 @@
 
         window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
                          update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
+        if(has_vec_c)
+        {
+            AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x);
+            window_changed = window_changed || update_window_and_padding(win, input2_access);
+        }
 
         output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
     }
@@ -209,6 +233,11 @@
 
         window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
                          update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
+        if(has_vec_c)
+        {
+            AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x);
+            window_changed = window_changed || update_window_and_padding(win, input2_access);
+        }
 
         Coordinates coord;
         coord.set_num_dimensions(output->num_dimensions());
@@ -227,20 +256,22 @@
 } // namespace
 
 CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
-    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
+    : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _has_vec_c(false)
 {
 }
 
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
-                                           bool fp_mixed_precision)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta,
+                                           bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
 
     // Perform validate step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta,
+                                                  is_interleaved_transposed, reshape_info, fp_mixed_precision));
 
     _input0                   = input0;
     _input1                   = input1;
+    _input2                   = input2;
     _output                   = output;
     _reinterpret_input_as_3d  = reshape_info.reinterpret_input_as_3d();
     _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0);
@@ -266,7 +297,8 @@
     ElementsProcessed num_elements_processed{};
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, gpu_target, num_elements_processed);
+    auto win_config = validate_and_configure_window(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta, is_interleaved_transposed, reshape_info,
+                                                    gpu_target, num_elements_processed);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
 
@@ -288,6 +320,8 @@
 
     const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST;
 
+    _has_vec_c = input2 != nullptr && beta != 0.f;
+
     std::string kernel_name;
     if(is_interleaved_transposed)
     {
@@ -351,6 +385,9 @@
         build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
     }
 
+    // Configure matrix C addition if necessary
+    build_opts.add_option_if(_has_vec_c, "-DADD_VEC_C");
+
     // Create kernel
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
 
@@ -373,16 +410,18 @@
     _config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
 }
 
-Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
-                                            const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
+Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta,
+                                            bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
 {
     // Note: num_elements_processed will be set in validate_and_configure_window()
     ElementsProcessed num_elements_processed{};
     ARM_COMPUTE_UNUSED(alpha);
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, beta, is_interleaved_transposed, reshape_info, fp_mixed_precision));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
                                                               input1->clone().get(),
+                                                              (input2 != nullptr) ? input2->clone().get() : nullptr,
                                                               output->clone().get(),
+                                                              beta,
                                                               is_interleaved_transposed,
                                                               reshape_info,
                                                               gpu_target,
@@ -409,10 +448,12 @@
     slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
     slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
 
+    const unsigned int num_arguments_vec_c = (_has_vec_c) ? num_arguments_per_1D_tensor() : 0;
+
     if(_reinterpret_input_as_3d)
     {
         // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
-        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3;
+        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + num_arguments_vec_c;
         const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
@@ -420,7 +461,7 @@
     if(_reinterpret_output_as_3d)
     {
         // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
-        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + num_arguments_vec_c;
         const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
@@ -438,6 +479,10 @@
         unsigned int idx = 0;
         add_2D_tensor_argument(idx, _input0, slice);
         add_2D_tensor_argument(idx, _input1, slice_b);
+        if(_has_vec_c)
+        {
+            add_1D_tensor_argument(idx, _input2, slice);
+        }
         add_2D_tensor_argument(idx, _output, slice);
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2]));
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2]));
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index cd40fc6..e91038f 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -160,6 +160,10 @@
     const auto workload   = static_cast<float>((m * n) / 20.0f);
     _is_new_gemm_reshaped = (workload > 1600.0f) && (get_arch_from_target(gpu_target) == GPUTarget::BIFROST) && _is_interleaved_transposed && (data_type == DataType::F32);
 
+    const bool add_matrix_c  = (beta != 0.f && c != nullptr);
+    const bool is_beta_one   = std::abs(1.0f - beta) < 0.00001f;
+    const bool use_fused_add = is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1) && !_is_new_gemm_reshaped;
+
     // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(_is_interleaved_transposed)
     {
@@ -202,9 +206,8 @@
     if(!_is_new_gemm_reshaped)
     {
         // Configure and tune matrix multiply kernel
-        _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
-                                                                                                            mult_transpose1xW_width, mult_interleave4x4_height,
-                                                                                                            depth_output_gemm3d, reinterpret_input_as_3d),
+        _mm_kernel.configure(matrix_a, matrix_b, (add_matrix_c && !use_fused_add) ? nullptr : c, output, alpha, beta, _is_interleaved_transposed,
+                             GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d),
                              gemm_info.fp_mixed_precision());
         CLScheduler::get().tune_kernel_static(_mm_kernel);
     }
@@ -220,7 +223,7 @@
     }
 
     // Configure matrix addition kernel
-    if(beta != 0 && c != nullptr)
+    if(add_matrix_c && !use_fused_add)
     {
         _ma_kernel.configure(c, output, beta);
         _run_addition = true;
@@ -284,6 +287,10 @@
     const auto workload             = static_cast<float>((m * n) / 20.0f);
     const bool is_new_gemm_reshaped = (workload > 1600.f) && (get_arch_from_target(gpu_target) == GPUTarget::BIFROST) && run_interleave_transpose && (data_type == DataType::F32);
 
+    const bool add_matrix_c  = (beta != 0.f && c != nullptr);
+    const bool is_beta_one   = std::abs(1.0f - beta) < 0.00001f;
+    const bool use_fused_add = is_beta_one && (c != nullptr && c->num_dimensions() == 1) && !is_new_gemm_reshaped;
+
     // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(run_interleave_transpose)
     {
@@ -328,10 +335,11 @@
     if(!is_new_gemm_reshaped)
     {
         // Validate matrix multiply
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, (add_matrix_c && !use_fused_add) ? nullptr : c, output, alpha, beta,
+                                                                         run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
     }
 
-    if(beta != 0 && c != nullptr)
+    if(add_matrix_c && !use_fused_add)
     {
         // Validate matrix addition kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 3a8b1a5..7105e85 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -93,7 +93,7 @@
 CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(),
       _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false), _skip_col2im(false), _is_quantized(false),
-      _is_activationlayer_enabled(false), _is_prepared(false)
+      _is_activationlayer_enabled(false), _is_prepared(false), _run_addition(true)
 {
 }
 
@@ -101,7 +101,8 @@
                                           int gemm_3d_depth)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), gemmlowp_output_stage, gemm_3d_depth, _skip_im2col));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), gemmlowp_output_stage, gemm_3d_depth, _skip_im2col,
+                                           _run_addition));
 
     const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
                                          gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
@@ -125,13 +126,15 @@
     }
     else
     {
+        // Bias does not need to be added in GEMM if im2col is being used or the Matrix Addition kernel needs to be run
+        const bool skip_bias_in_gemm = _run_addition || !_skip_im2col;
         // Configure matrix multiply function
-        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
+        _mm_gemm.configure(input, weights, (skip_bias_in_gemm) ? nullptr : biases, output, 1.0f, 1.0f, gemm_info);
     }
 }
 
 Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
-                                           const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col)
+                                           const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, bool run_addition)
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
@@ -156,8 +159,10 @@
     }
     else
     {
+        // Bias does not need to be added in GEMM if im2col is being used or the Matrix Addition kernel needs to be run
+        const bool skip_bias_in_gemm = run_addition || !skip_im2col;
         // Perform validation step on Matrix multiply function
-        return CLGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
+        return CLGEMM::validate(input, weights, (skip_bias_in_gemm) ? nullptr : biases, output, 1.0f, 1.0f, gemm_info);
     }
 }
 
@@ -193,6 +198,8 @@
     _skip_col2im                = data_layout == DataLayout::NHWC;
     _append_bias                = (biases != nullptr) && (!_is_quantized);
     _is_activationlayer_enabled = act_info.enabled();
+    // In case of F16, fused bias will be used in GEMM
+    _run_addition = (_skip_im2col) && (_append_bias) && (data_type != DataType::F16);
 
     // Set the GPU target for im2col and col2im
     _im2col_kernel.set_target(CLScheduler::get().target());
@@ -375,6 +382,8 @@
     const bool skip_im2col                = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
     const bool skip_col2im                = data_layout == DataLayout::NHWC;
     bool       is_activationlayer_enabled = act_info.enabled();
+    // In case of F16, fused bias will be used in GEMM
+    const bool run_addition = (skip_im2col) && (append_bias) && (data_type != DataType::F16);
 
     ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(idx_channel) * num_groups) != input->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -429,7 +438,7 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation, num_groups));
         gemm_input_to_use = &im2col_reshaped_info;
     }
-    else if(append_bias)
+    else if(run_addition)
     {
         // Validate add bias kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output, biases, output, ConvertPolicy::SATURATE));
@@ -496,7 +505,7 @@
     // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
 
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, run_addition));
 
     // Validate Col2Im
     if(!skip_col2im)
@@ -537,7 +546,7 @@
         _mm_gemm.run();
     }
 
-    if(_skip_im2col && _append_bias)
+    if(_run_addition)
     {
         CLScheduler::get().enqueue(_add_bias_kernel);
     }
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index bbf362c..0876ae1 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -44,6 +44,7 @@
     {
         add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f);
         add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+        add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 1.0f);
         add_config(TensorShape(681U, 1023U), TensorShape(213U, 681U), TensorShape(213U, 1023U), TensorShape(213U, 1023U), 0.2f, 1.2f);
         add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 0.4f, 0.7f);
     }
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index 15a3504..ae3c3ed 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -44,6 +44,7 @@
     {
         add_config(TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f);
         add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f);
+        add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 1.0f);
         add_config(TensorShape(8U, 2U), TensorShape(16U, 8U), TensorShape(16U, 2U), TensorShape(16U, 2U), 1.0f, 0.0f);
         add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 12U), TensorShape(21U, 12U), 0.2f, 1.2f);
         add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U), TensorShape(17U, 1U), 0.4f, 0.7f);