COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel

Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1259
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 41e5c33..2ac2eb7 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -731,29 +731,29 @@
     // 3x4 -> 4x3
     // 3x8 -> 8x3
     // 3x16 -> 16x3
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
+    res0                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
+    res1                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
 #if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
+    res2                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
 #endif // N0 > 2
 #if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
+    res3                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
 #endif // N0 > 3
 #if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
+    res4                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
+    res5                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
+    res6                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
+    res7                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
 #endif // N0 > 4
 #if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
+    res8                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
+    res9                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
+    resA                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
+    resB                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
+    resC                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
+    resD                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
+    resE                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
+    resF                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
 #endif // N0 > 8
 
 #elif K0 == 4 // K0 == 4
@@ -1029,35 +1029,48 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
  *
- * @param[in]  lhs_ptr                           Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in]  lhs_stride_x                      Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in]  lhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  lhs_stride_y                      Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  lhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in]  rhs_ptr                           Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in]  rhs_stride_x                      Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in]  rhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  rhs_stride_y                      Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  rhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in]  lhs_stride_z                      Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  rhs_stride_z                      Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  lhs_cross_plane_pad               (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in]  dst_cross_plane_pad               (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
+ * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
+ * @param[in]  bias_ptr                           (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  bias_stride_x                      (Optional)Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in]  bias_step_x                        (Optional)bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  bias_stride_y                      (Optional)Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in]  bias_step_y                        (Optional)bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  bias_offset_first_element_in_bytes (Optional)The offset of the first element in the bias reshaped matrix
+ * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
+ * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  bias_stride_z                      (Optional) Stride of the bias  matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
                                           IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+                                          IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
                                           IMAGE_DECLARATION(dst),
                                           uint lhs_stride_z,
                                           uint rhs_stride_z,
+#if defined(BETA)
+                                          uint bias_stride_z,
+#endif //defined(BETA)
                                           uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                           ,
@@ -1108,7 +1121,7 @@
 #endif // defined(MATRIX_B_DEPTH)
 
     REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
-    REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
+    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
 
 #if defined(REINTERPRET_INPUT_AS_3D)
     // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
@@ -1144,7 +1157,7 @@
         LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
 
         // Load values from RHS matrix
-        LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+        LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
 
         // Accumulate
         ARM_DOT_K0XN0(K0, a0, b, c0);
@@ -1181,7 +1194,7 @@
         LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
 
         // Load values from RHS matrix
-        LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+        LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
 
         // Accumulate
         ARM_DOT_K0XN0(1, a0, b, c0);
@@ -1236,6 +1249,36 @@
     SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
 #endif // defined(ALPHA)
 
+    // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+    LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+                                    2) * bias_stride_z;
+
+    LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
@@ -1360,35 +1403,48 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
  *
- * @param[in]  lhs_ptr                           Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in]  lhs_stride_x                      Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in]  lhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  lhs_stride_y                      Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  lhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in]  rhs_ptr                           Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in]  rhs_stride_x                      Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in]  rhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  rhs_stride_y                      Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  rhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in]  lhs_stride_z                      Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  rhs_stride_z                      Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  lhs_cross_plane_pad               (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in]  dst_cross_plane_pad               (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
+ * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
+ * @param[in]  bias_ptr                           (Optional) Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  bias_stride_x                      (Optional) Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  bias_stride_y                      (Optional) Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
+ * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
+ * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  bias_stride_z                      (Optional)Stride of the bias reshaped matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
                                            IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+                                           IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
                                            IMAGE_DECLARATION(dst),
                                            uint lhs_stride_z,
                                            uint rhs_stride_z,
+#if defined(BETA)
+                                           uint bias_stride_z,
+#endif //defined(BETA)
                                            uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                            ,
@@ -1438,7 +1494,8 @@
     rhs_offset += z * rhs_stride_z;
 #endif // defined(MATRIX_B_DEPTH)
 
-    REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
+    REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0);   //uint zin0=0,zin1=0,zin2=0,... zin7=0;
+    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
 
 #if defined(REINTERPRET_INPUT_AS_3D)
 
@@ -1568,6 +1625,36 @@
     SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
 #endif // defined(ALPHA)
 
+    // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+    LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+                                    2) * bias_stride_z;
+
+    LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 2c76992..cd2d39b 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -360,69 +360,69 @@
 #define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
 
 #define SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \
-    BASENAME##0 = BASENAME##0 * (DATA_TYPE)SCALE;
+    BASENAME##0 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##1 = BASENAME##1 * (DATA_TYPE)SCALE;
+    BASENAME##1 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##2 = BASENAME##2 * (DATA_TYPE)SCALE;
+    BASENAME##2 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##3 = BASENAME##3 * (DATA_TYPE)SCALE;
+    BASENAME##3 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##4 = BASENAME##4 * (DATA_TYPE)SCALE;
+    BASENAME##4 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##5 = BASENAME##5 * (DATA_TYPE)SCALE;
+    BASENAME##5 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##6 = BASENAME##6 * (DATA_TYPE)SCALE;
+    BASENAME##6 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##7 = BASENAME##7 * (DATA_TYPE)SCALE;
+    BASENAME##7 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##8 = BASENAME##8 * (DATA_TYPE)SCALE;
+    BASENAME##8 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE)      \
-    BASENAME##9 = BASENAME##9 * (DATA_TYPE)SCALE;
+    BASENAME##9 *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##A = BASENAME##A * (DATA_TYPE)SCALE;
+    BASENAME##A *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##B = BASENAME##B * (DATA_TYPE)SCALE;
+    BASENAME##B *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##C = BASENAME##C * (DATA_TYPE)SCALE;
+    BASENAME##C *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##D = BASENAME##D * (DATA_TYPE)SCALE;
+    BASENAME##D *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##E = BASENAME##E * (DATA_TYPE)SCALE;
+    BASENAME##E *= (DATA_TYPE)SCALE;
 
 #define SCALE_ROW_16(DATA_TYPE, BASENAME, SCALE) \
     SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE)     \
-    BASENAME##F = BASENAME##F * (DATA_TYPE)SCALE;
+    BASENAME##F *= (DATA_TYPE)SCALE;
 
-// SCALE_ROW_n scales the variables BASENAME##0 to BASENAME##(n-1) by SCALE
+// SCALE_BLOCK_n scales the variables BASENAME##0 to BASENAME##(n-1) by SCALE
 #define SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) SCALE_ROW_##N(DATA_TYPE, BASENAME, SCALE)
 /** Scale elements stored in variables BASENAME##0 to BASENAME##(N-1) by SCALE
  * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
@@ -479,3 +479,143 @@
 #define TRANSPOSE_K0XN0(K0, N0, BASENAME, B) \
     CONCAT(TRANSPOSE_K0X, N0)                \
     (K0, BASENAME, B);
+
+#define ADD_ROW_1(BASENAME, BIAS) \
+    BASENAME##0 += BIAS##0;
+
+#define ADD_ROW_2(BASENAME, BIAS) \
+    ADD_ROW_1(BASENAME, BIAS)     \
+    BASENAME##1 += BIAS##1;
+
+#define ADD_ROW_3(BASENAME, BIAS) \
+    ADD_ROW_2(BASENAME, BIAS)     \
+    BASENAME##2 += BIAS##2;
+
+#define ADD_ROW_4(BASENAME, BIAS) \
+    ADD_ROW_3(BASENAME, BIAS)     \
+    BASENAME##3 += BIAS##3;
+
+#define ADD_ROW_5(BASENAME, BIAS) \
+    ADD_ROW_4(BASENAME, BIAS)     \
+    BASENAME##4 += BIAS##4;
+
+#define ADD_ROW_6(BASENAME, BIAS) \
+    ADD_ROW_5(BASENAME, BIAS)     \
+    BASENAME##5 += BIAS##5;
+
+#define ADD_ROW_7(BASENAME, BIAS) \
+    ADD_ROW_6(BASENAME, BIAS)     \
+    BASENAME##6 += BIAS##6;
+
+#define ADD_ROW_8(BASENAME, BIAS) \
+    ADD_ROW_7(BASENAME, BIAS)     \
+    BASENAME##7 += BIAS##7;
+
+#define ADD_ROW_9(BASENAME, BIAS) \
+    ADD_ROW_8(BASENAME, BIAS)     \
+    BASENAME##8 += BIAS##8;
+
+#define ADD_ROW_10(BASENAME, BIAS) \
+    ADD_ROW_9(BASENAME, BIAS)      \
+    BASENAME##9 += BIAS##9;
+
+#define ADD_ROW_11(BASENAME, BIAS) \
+    ADD_ROW_10(BASENAME, BIAS)     \
+    BASENAME##A += BIAS##A;
+
+#define ADD_ROW_12(BASENAME, BIAS) \
+    ADD_ROW_11(BASENAME, BIAS)     \
+    BASENAME##B += BIAS##B;
+
+#define ADD_ROW_13(BASENAME, BIAS) \
+    ADD_ROW_12(BASENAME, BIAS)     \
+    BASENAME##C += BIAS##C;
+
+#define ADD_ROW_14(BASENAME, BIAS) \
+    ADD_ROW_13(BASENAME, BIAS)     \
+    BASENAME##D += BIAS##D;
+
+#define ADD_ROW_15(BASENAME, BIAS) \
+    ADD_ROW_14(BASENAME, BIAS)     \
+    BASENAME##E += BIAS##E;
+
+#define ADD_ROW_16(BASENAME, BIAS) \
+    ADD_ROW_15(BASENAME, BIAS)     \
+    BASENAME##F += BIAS##F;
+
+// ADD_ROW_n add the variables BIAS##0... BIAS##(n-1) to BASENAME##0 to BASENAME##(n-1)
+#define ADD_BLOCK_STR(N, BASENAME, BIAS) ADD_ROW_##N(BASENAME, BIAS)
+/** Add BIAS to  BASENAME##0 ... BASENAME##(N-1)
+ * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
+ */
+#define ADD_BLOCK(N, BASENAME, BIAS) ADD_BLOCK_STR(N, BASENAME, BIAS)
+
+#define ADD_ROW_BROADCAST_1(BASENAME, BIAS) \
+    BASENAME##0 += BIAS;
+
+#define ADD_ROW_BROADCAST_2(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_1(BASENAME, BIAS)     \
+    BASENAME##1 += BIAS;
+
+#define ADD_ROW_BROADCAST_3(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_2(BASENAME, BIAS)     \
+    BASENAME##2 += BIAS;
+
+#define ADD_ROW_BROADCAST_4(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_3(BASENAME, BIAS)     \
+    BASENAME##3 += BIAS;
+
+#define ADD_ROW_BROADCAST_5(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_4(BASENAME, BIAS)     \
+    BASENAME##4 += BIAS;
+
+#define ADD_ROW_BROADCAST_6(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_5(BASENAME, BIAS)     \
+    BASENAME##5 += BIAS;
+
+#define ADD_ROW_BROADCAST_7(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_6(BASENAME, BIAS)     \
+    BASENAME##6 += BIAS;
+
+#define ADD_ROW_BROADCAST_8(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_7(BASENAME, BIAS)     \
+    BASENAME##7 += BIAS;
+
+#define ADD_ROW_BROADCAST_9(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_8(BASENAME, BIAS)     \
+    BASENAME##8 += BIAS;
+
+#define ADD_ROW_BROADCAST_10(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_9(BASENAME, BIAS)      \
+    BASENAME##9 += BIAS;
+
+#define ADD_ROW_BROADCAST_11(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_10(BASENAME, BIAS)     \
+    BASENAME##A += BIAS;
+
+#define ADD_ROW_BROADCAST_12(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_11(BASENAME, BIAS)     \
+    BASENAME##B += BIAS;
+
+#define ADD_ROW_BROADCAST_13(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_12(BASENAME, BIAS)     \
+    BASENAME##C += BIAS;
+
+#define ADD_ROW_BROADCAST_14(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_13(BASENAME, BIAS)     \
+    BASENAME##D += BIAS;
+
+#define ADD_ROW_BROADCAST_15(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_14(BASENAME, BIAS)     \
+    BASENAME##E += BIAS;
+
+#define ADD_ROW_BROADCAST_16(BASENAME, BIAS) \
+    ADD_ROW_BROADCAST_15(BASENAME, BIAS)     \
+    BASENAME##F += BIAS;
+
+// ADD_ROW_n add the variables BIAS to BASENAME##0 to BASENAME##(n-1)
+#define ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) ADD_ROW_BROADCAST_##N(BASENAME, BIAS)
+/** Add elements stored in variables BIAS##0 ... BIAS##(N-1) to  BASENAME##0 ... BASENAME##(N-1)
+ * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
+ */
+#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS) ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS)
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 2437265..58c4cdd 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -50,8 +50,9 @@
 {
 using ElementsProcessed = Steps;
 
-Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
-                          const GEMMReshapeInfo &gemm_info)
+Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+                          const GEMMRHSMatrixInfo &rhs_info,
+                          const GEMMReshapeInfo   &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
@@ -72,6 +73,22 @@
     tensor_shape1.set(0, n);
     tensor_shape1.set(1, k);
 
+    if(input2 != nullptr && std::abs(0.0f - beta) > 0.00001f)
+    {
+        const int input2_dim0 = static_cast<int>(input2->dimension(0));
+        const int input2_dim1 = static_cast<int>(input2->dimension(1));
+
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input2, input1);
+        if(gemm_info.broadcast_bias())
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim1 != 1 || input2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted");
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim0 != n || input2_dim1 != m), "Incorrect dimension of bias matrix");
+        }
+    }
+
     const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
 
     const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info));
@@ -97,7 +114,8 @@
     return Status{};
 }
 
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info,
+                                                        const GEMMRHSMatrixInfo &rhs_info,
                                                         const GEMMReshapeInfo &gemm_info, ElementsProcessed &num_elements_processed)
 {
     unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
@@ -152,8 +170,24 @@
                                      ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
                                      output->dimension(1) + bottom_pad);
 
-    window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
-                     update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
+    if(input2 != nullptr)
+    {
+        const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x;
+
+        const int bias_processed_per_iteration_y = gemm_info.broadcast_bias() ? 1 : num_elems_processed_per_iteration_y;
+
+        AccessWindowStatic input2_access(input2, 0, 0,
+                                         ceil_to_multiple(input2->dimension(0), bias_processed_per_iteration_x),
+                                         ceil_to_multiple(input2->dimension(1), bias_processed_per_iteration_y));
+
+        window_changed = update_window_and_padding(win, input0_access, input1_access, input2_access) || // window used by the execute_window_loop
+                         update_window_and_padding(win_out, output_access);                             // window used to update the padding requirements of output tensor
+    }
+    else
+    {
+        window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+                         update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
+    }
 
     output_access.set_valid_region(win_out, ValidRegion(Coordinates(), output->tensor_shape()));
 
@@ -169,23 +203,28 @@
 } // namespace
 
 CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::CLGEMMMatrixMultiplyReshapedOnlyRHSKernel()
-    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _use_dummy_work_items(false)
+    : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _use_dummy_work_items(false),
+      _add_bias(false), _broadcast_bias(false)
 {
 }
 
-void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info,
+void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta,
+                                                          const GEMMLHSMatrixInfo &lhs_info,
                                                           const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
 
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, lhs_info, rhs_info, gemm_info));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), (input2 != nullptr ? input2->info() : nullptr), output->info(), alpha, beta, lhs_info, rhs_info, gemm_info));
 
     _input0                   = input0;
     _input1                   = input1;
+    _input2                   = std::abs(0.0f - beta) > 0.00001f ? input2 : nullptr;
     _output                   = output;
     _reinterpret_input_as_3d  = gemm_info.reinterpret_input_as_3d();
     _reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0);
     _use_dummy_work_items     = preferred_dummy_work_items_support(CLKernelLibrary::get().get_device());
+    _add_bias                 = _input2 != nullptr;
+    _broadcast_bias           = gemm_info.broadcast_bias();
 
     // In case both input and output have to be reinterpreted as 3D tensors,
     // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -202,7 +241,7 @@
     ElementsProcessed num_elements_processed{};
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), lhs_info, rhs_info, gemm_info, num_elements_processed);
+    auto win_config = validate_and_configure_window(input0->info(), input1->info(), input2 != nullptr ? input2->info() : nullptr, output->info(), lhs_info, rhs_info, gemm_info, num_elements_processed);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
 
@@ -210,8 +249,11 @@
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
     build_opts.add_option_if(std::abs(1.0f - alpha) > 0.00001f, "-DALPHA=" + float_to_string_with_full_precision(alpha));
+    build_opts.add_option_if(std::abs(0.0f - beta) > 0.00001f && _input2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta));
+    build_opts.add_option_if(std::abs(1.0f - beta) < 0.00001f, "-DUNIT_BETA");
     build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
     build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
+    build_opts.add_option_if(gemm_info.broadcast_bias(), "-DBROADCAST_BIAS");
     build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
     build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
     build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -257,13 +299,15 @@
     _config_id += support::cpp11::to_string(rhs_info.interleave);
 }
 
-Status CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info,
+Status CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta,
+                                                           const GEMMLHSMatrixInfo &lhs_info,
                                                            const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info)
 {
     ElementsProcessed num_elements_processed{};
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, lhs_info, rhs_info, gemm_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, alpha, beta, lhs_info, rhs_info, gemm_info));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
                                                               input1->clone().get(),
+                                                              input2 != nullptr ? input2->clone().get() : nullptr,
                                                               output->clone().get(),
                                                               lhs_info,
                                                               rhs_info,
@@ -294,7 +338,15 @@
     if(_reinterpret_input_as_3d)
     {
         // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
-        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3;
+        unsigned int idx0;
+        if(_add_bias)
+        {
+            idx0 = 4 * num_arguments_per_2D_tensor() + 4;
+        }
+        else
+        {
+            idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+        }
         const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
@@ -302,7 +354,15 @@
     if(_reinterpret_output_as_3d)
     {
         // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
-        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+        unsigned int idx0;
+        if(_add_bias)
+        {
+            idx0 = 4 * num_arguments_per_2D_tensor() + 4 + (_reinterpret_input_as_3d ? 1 : 0);
+        }
+        else
+        {
+            idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+        }
         const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
@@ -320,12 +380,20 @@
         unsigned int idx = 0;
         add_2D_tensor_argument(idx, _input0, slice);
         add_2D_tensor_argument(idx, _input1, slice_b);
+        if(_add_bias)
+        {
+            add_2D_tensor_argument(idx, _input2, slice);
+        }
         add_2D_tensor_argument(idx, _output, slice);
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2]));
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2]));
+        if(_add_bias)
+        {
+            _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input2->info()->strides_in_bytes()[2]));
+        }
         _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[2]));
         enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items);
     }
     while(window.slide_window_slice_3D(slice));
 }
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 492709f..21a9fce 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -242,10 +242,6 @@
 
 void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
-    ARM_COMPUTE_ERROR_ON(c != nullptr);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_UNUSED(c);
-
     DataType           data_type               = a->info()->data_type();
     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
     const unsigned int m                       = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
@@ -254,11 +250,12 @@
     const unsigned int batch_size              = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
     const GPUTarget    gpu_target              = CLScheduler::get().target();
+    bool               broadcast_bias          = gemm_info.broadcast_bias();
 
     // Set the target for the kernels
     _mm_kernel.set_target(gpu_target);
 
-    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, broadcast_bias);
 
     // Manage intermediate buffers
     if(!_reshape_b_only_on_first_run)
@@ -279,7 +276,7 @@
     _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
 
     // Configure and tune matrix multiply kernel
-    _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, output, alpha, lhs_info, rhs_info, reshape_info);
+    _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, reshape_info);
 
     if(!_reshape_b_only_on_first_run)
     {
@@ -426,7 +423,6 @@
         // Validate matrix addition kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
     }
-
     return Status{};
 }
 
@@ -438,17 +434,16 @@
     TensorInfo tmp_b_info{};
 
     // Get the GPU target
-    const GPUTarget    gpu_target              = CLScheduler::get().target();
-    const DataType     data_type               = a->data_type();
-    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
-    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
-    const unsigned int n                       = b->dimension(0);
-    const unsigned int k                       = a->dimension(0);
-    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
-    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
-    const bool         add_c                   = (beta != 0.f && c != nullptr);
-
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+    const GPUTarget       gpu_target              = CLScheduler::get().target();
+    const DataType        data_type               = a->data_type();
+    bool                  reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int    m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int    n                       = b->dimension(0);
+    const unsigned int    k                       = a->dimension(0);
+    const unsigned int    batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int             depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const bool            broadcast_bias          = gemm_info.broadcast_bias();
+    const GEMMReshapeInfo reshape_info            = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, broadcast_bias);
 
     GEMMLHSMatrixInfo lhs_info;
     GEMMRHSMatrixInfo rhs_info;
@@ -464,13 +459,7 @@
     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
 
     // Validate matrix multiply
-    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, output, alpha, lhs_info, rhs_info, reshape_info));
-
-    if(add_c)
-    {
-        // Validate matrix addition kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
-    }
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, reshape_info));
 
     return Status{};
 }
@@ -497,10 +486,10 @@
     // Select GEMMType
     _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
 
-    const bool is_gemm_v2  = (_gemm_type == GEMMType::RESHAPED_V2) || (_gemm_type == GEMMType::RESHAPED_ONLY_RHS);
-    const bool add_c       = (beta != 0.f && c != nullptr);
-    const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f;
-    const bool fuse_add    = is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1) && !is_gemm_v2;
+    const bool is_gemm_reshaped_only_rhs = _gemm_type == GEMMType::RESHAPED_ONLY_RHS;
+    const bool add_c                     = (beta != 0.f && c != nullptr);
+    const bool is_beta_one               = std::abs(1.0f - beta) < 0.00001f;
+    const bool fuse_add                  = (is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1)) || is_gemm_reshaped_only_rhs;
 
     switch(_gemm_type)
     {