COMPMID-1979: Fuse Activation Function in CLGEMM - part 3

Fused beta*bias in in the old cl gemm kernels
Fused activation function in the old cl gemm kernels

Change-Id: I695fb9189e6d4792010abd256784624982d17d79
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1587
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 213075d..8d638bc 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -46,15 +46,15 @@
 /** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
  *  the output matrix unrolling the values.
  *
- * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
- * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
- * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
- * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
+ * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
+ * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
+ * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
  * @note Only the following values for M0, K0 and V0 are supported:
  *                                      M0: 2,3,4,5,6,7,8
  *                                      K0: 2,3,4,8,16
  *                                      V0: greater than 0
- * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
+ * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
@@ -246,15 +246,15 @@
 /** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
  *  the output matrix unrolling the values.
  *
- * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
- * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
- * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
- * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
+ * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
+ * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
+ * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
  * @note Only the following values for M0, K0 and V0 are supported:
  *                                      M0: 2,3,4,5,6,7,8
  *                                      K0: 2,3,4,8,16
  *                                      V0: greater than 0
- * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
+ * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
@@ -402,10 +402,10 @@
 /** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
  *  the output matrix unrolling the values.
  *
- * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
- * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
- * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
- * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
+ * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
+ * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
+ * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
  * @note Only the following values for K0, N0 and H0 are supported:
  *                                      N0: 2,3,4,8,16
@@ -555,10 +555,10 @@
 /** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
  *  the output matrix unrolling the values.
  *
- * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
- * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
- * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
- * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
+ * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
+ * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
+ * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
  * @note The option -DTRANSPOSE must passed at compile time.
  * @note Only the following values for K0, N0 and H0 are supported:
@@ -1010,11 +1010,11 @@
  *  The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
  *
  * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
- * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
- * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
- * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
- * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
- * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
+ * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
+ * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
+ * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
+ * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
+ * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
  * @note Only the following configurations of M0, N0 and K0 are currently supported:
  *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
@@ -1022,7 +1022,7 @@
  *  - K0 = 2, 3, 4, 8, 16
  *  - H0 >= 1
  *
- * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
  *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
@@ -1043,7 +1043,6 @@
  * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
  * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
- * @param[in]  bias_ptr                           (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
  * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
  * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
  * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1392,10 +1391,10 @@
  *  The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
  *
  * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
- * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90).
- * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
- * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
- * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
+ * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
+ * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
+ * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
+ * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
  * @note Only the following configurations of M0, N0 and K0 are currently supported:
  *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
@@ -1403,7 +1402,7 @@
  *  - K0 = 2, 3, 4, 8, 16
  *  - H0 >= 1
  *
- * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
  *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
@@ -1798,10 +1797,10 @@
  *  The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
  *
  * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
- * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
- * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
- * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
- * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
+ * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
+ * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
+ * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
+ * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
  * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
  * @note Only the following configurations of M0, N0 and K0 are currently supported:
@@ -1811,9 +1810,9 @@
  *  - V0 >= 1
  *  - H0 >= 1
  *
- * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
  *       The activation function is performed after the bias addition
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -2123,17 +2122,17 @@
  *  The RHS matrix is NOT reshaped
  *
  * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
- * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
- * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
- * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
- * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (i.e., -DK0=2)
- * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2)
+ * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
+ * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
+ * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
+ * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
+ * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
  * @note Only the following configurations of M0, N0 and K0 are currently supported:
  *  - M0 = 1, 2, 3, 4, 5, 6, 7, 8
  *  - N0 = 2, 3, 4, 8, 16
  *  - K0 = 2, 3, 4, 8, 16
  *
- * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
  *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
@@ -2154,7 +2153,6 @@
  * @param[in]  rhs_stride_y                       Stride of the RHS matrix in Y dimension (in bytes)
  * @param[in]  rhs_step_y                         rhs_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS matrix
- * @param[in]  bias_ptr                           (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
  * @param[in]  bias_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
  * @param[in]  bias_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
  * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2405,25 +2403,22 @@
 #endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
 
 #if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
-/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
- *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
  *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2436,10 +2431,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2448,17 +2445,21 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                 VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                 IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
+#if defined(BETA)
+                                                 uint src2_stride_z,
+#endif //defined(BETA)
                                                  uint dst_stride_z
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                  ,
@@ -2496,10 +2497,10 @@
     src_addr_b += offset_row_b;
 
     // Reset accumulators
-    float4 c00 = 0.0f;
-    float4 c10 = 0.0f;
-    float4 c20 = 0.0f;
-    float4 c30 = 0.0f;
+    float4 c0 = 0.0f;
+    float4 c1 = 0.0f;
+    float4 c2 = 0.0f;
+    float4 c3 = 0.0f;
 
     for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
     {
@@ -2507,19 +2508,19 @@
         float4 a0 = vload4(0, src_addr_a);
         float4 b0 = vload4(0, src_addr_b);
 
-        c00 += (float4)a0.s0 * b0;
-        c10 += (float4)a0.s1 * b0;
-        c20 += (float4)a0.s2 * b0;
-        c30 += (float4)a0.s3 * b0;
+        c0 += (float4)a0.s0 * b0;
+        c1 += (float4)a0.s1 * b0;
+        c2 += (float4)a0.s2 * b0;
+        c3 += (float4)a0.s3 * b0;
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
         b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
 
-        c00 += (float4)a0.s0 * b0;
-        c10 += (float4)a0.s1 * b0;
-        c20 += (float4)a0.s2 * b0;
-        c30 += (float4)a0.s3 * b0;
+        c0 += (float4)a0.s0 * b0;
+        c1 += (float4)a0.s1 * b0;
+        c2 += (float4)a0.s2 * b0;
+        c3 += (float4)a0.s3 * b0;
     }
 
     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
@@ -2528,36 +2529,20 @@
         float4 a0 = vload4(0, src_addr_a);
         float4 b0 = vload4(0, src_addr_b);
 
-        c00 += (float4)a0.s0 * b0;
-        c10 += (float4)a0.s1 * b0;
-        c20 += (float4)a0.s2 * b0;
-        c30 += (float4)a0.s3 * b0;
+        c0 += (float4)a0.s0 * b0;
+        c1 += (float4)a0.s1 * b0;
+        c2 += (float4)a0.s2 * b0;
+        c3 += (float4)a0.s3 * b0;
     }
 
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
-#if defined(ALPHA)
-    // Multiply by the weight of matrix product
-    c00 = c00 * (float4)ALPHA;
-    c10 = c10 * (float4)ALPHA;
-    c20 = c20 * (float4)ALPHA;
-    c30 = c30 * (float4)ALPHA;
-#endif // defined(ALPHA)
-
-#if defined(ADD_VEC_C)
-    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    float4          c0        = vload4(0, src2_addr);
-
-    c00 += c0;
-    c10 += c0;
-    c20 += c0;
-    c30 += c0;
-#endif /* defined(ADD_VEC_C) */
-
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+    uint4 zout = 0;
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -2575,8 +2560,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (cross_plane_pad * dst_stride_y);
@@ -2584,45 +2569,76 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store 4x4 block
-    vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
-    vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
-    vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
-    vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
-
 #else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(4, float, c, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
+
+    LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(4, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
+                                    2) * src2_stride_z;
+
+    LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(4, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(4, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store 4x4 block
-    vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
-    vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-    vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
-    vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+    vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
 }
 
-/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
- *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
  *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2635,10 +2651,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2647,17 +2665,21 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                         VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                         IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                          IMAGE_DECLARATION(dst),
                                                          uint src0_stride_z,
                                                          uint src1_stride_z,
+#if defined(BETA)
+                                                         uint src2_stride_z,
+#endif //defined(BETA)
                                                          uint dst_stride_z
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                          ,
@@ -2692,22 +2714,10 @@
     src_addr_b += offset_row_b;
 
     // Reset accumulators
-    float c00 = 0.0f;
-    float c01 = 0.0f;
-    float c02 = 0.0f;
-    float c03 = 0.0f;
-    float c10 = 0.0f;
-    float c11 = 0.0f;
-    float c12 = 0.0f;
-    float c13 = 0.0f;
-    float c20 = 0.0f;
-    float c21 = 0.0f;
-    float c22 = 0.0f;
-    float c23 = 0.0f;
-    float c30 = 0.0f;
-    float c31 = 0.0f;
-    float c32 = 0.0f;
-    float c33 = 0.0f;
+    float4 c0 = 0.0f;
+    float4 c1 = 0.0f;
+    float4 c2 = 0.0f;
+    float4 c3 = 0.0f;
 
 #define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
 
@@ -2721,25 +2731,25 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma(a0.s0, b0.s0, c00);
-        c01 = fma(a0.s0, b0.s1, c01);
-        c02 = fma(a0.s0, b0.s2, c02);
-        c03 = fma(a0.s0, b0.s3, c03);
+        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
+        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
+        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
+        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
 
-        c10 = fma(a0.s1, b0.s0, c10);
-        c11 = fma(a0.s1, b0.s1, c11);
-        c12 = fma(a0.s1, b0.s2, c12);
-        c13 = fma(a0.s1, b0.s3, c13);
+        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
+        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
+        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
+        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
 
-        c20 = fma(a0.s2, b0.s0, c20);
-        c21 = fma(a0.s2, b0.s1, c21);
-        c22 = fma(a0.s2, b0.s2, c22);
-        c23 = fma(a0.s2, b0.s3, c23);
+        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
+        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
+        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
+        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
 
-        c30 = fma(a0.s3, b0.s0, c30);
-        c31 = fma(a0.s3, b0.s1, c31);
-        c32 = fma(a0.s3, b0.s2, c32);
-        c33 = fma(a0.s3, b0.s3, c33);
+        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
+        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
+        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
+        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a);
@@ -2748,25 +2758,25 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma(a0.s0, b0.s0, c00);
-        c01 = fma(a0.s0, b0.s1, c01);
-        c02 = fma(a0.s0, b0.s2, c02);
-        c03 = fma(a0.s0, b0.s3, c03);
+        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
+        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
+        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
+        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
 
-        c10 = fma(a0.s1, b0.s0, c10);
-        c11 = fma(a0.s1, b0.s1, c11);
-        c12 = fma(a0.s1, b0.s2, c12);
-        c13 = fma(a0.s1, b0.s3, c13);
+        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
+        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
+        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
+        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
 
-        c20 = fma(a0.s2, b0.s0, c20);
-        c21 = fma(a0.s2, b0.s1, c21);
-        c22 = fma(a0.s2, b0.s2, c22);
-        c23 = fma(a0.s2, b0.s3, c23);
+        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
+        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
+        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
+        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
 
-        c30 = fma(a0.s3, b0.s0, c30);
-        c31 = fma(a0.s3, b0.s1, c31);
-        c32 = fma(a0.s3, b0.s2, c32);
-        c33 = fma(a0.s3, b0.s3, c33);
+        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
+        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
+        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
+        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a);
@@ -2775,25 +2785,25 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma(a0.s0, b0.s0, c00);
-        c01 = fma(a0.s0, b0.s1, c01);
-        c02 = fma(a0.s0, b0.s2, c02);
-        c03 = fma(a0.s0, b0.s3, c03);
+        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
+        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
+        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
+        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
 
-        c10 = fma(a0.s1, b0.s0, c10);
-        c11 = fma(a0.s1, b0.s1, c11);
-        c12 = fma(a0.s1, b0.s2, c12);
-        c13 = fma(a0.s1, b0.s3, c13);
+        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
+        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
+        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
+        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
 
-        c20 = fma(a0.s2, b0.s0, c20);
-        c21 = fma(a0.s2, b0.s1, c21);
-        c22 = fma(a0.s2, b0.s2, c22);
-        c23 = fma(a0.s2, b0.s3, c23);
+        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
+        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
+        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
+        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
 
-        c30 = fma(a0.s3, b0.s0, c30);
-        c31 = fma(a0.s3, b0.s1, c31);
-        c32 = fma(a0.s3, b0.s2, c32);
-        c33 = fma(a0.s3, b0.s3, c33);
+        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
+        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
+        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
+        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a);
@@ -2802,25 +2812,25 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma(a0.s0, b0.s0, c00);
-        c01 = fma(a0.s0, b0.s1, c01);
-        c02 = fma(a0.s0, b0.s2, c02);
-        c03 = fma(a0.s0, b0.s3, c03);
+        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
+        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
+        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
+        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
 
-        c10 = fma(a0.s1, b0.s0, c10);
-        c11 = fma(a0.s1, b0.s1, c11);
-        c12 = fma(a0.s1, b0.s2, c12);
-        c13 = fma(a0.s1, b0.s3, c13);
+        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
+        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
+        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
+        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
 
-        c20 = fma(a0.s2, b0.s0, c20);
-        c21 = fma(a0.s2, b0.s1, c21);
-        c22 = fma(a0.s2, b0.s2, c22);
-        c23 = fma(a0.s2, b0.s3, c23);
+        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
+        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
+        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
+        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
 
-        c30 = fma(a0.s3, b0.s0, c30);
-        c31 = fma(a0.s3, b0.s1, c31);
-        c32 = fma(a0.s3, b0.s2, c32);
-        c33 = fma(a0.s3, b0.s3, c33);
+        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
+        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
+        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
+        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
     }
 
     for(; i < (int)(COLS_MTX_B); ++i)
@@ -2832,74 +2842,34 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma(a0.s0, b0.s0, c00);
-        c01 = fma(a0.s0, b0.s1, c01);
-        c02 = fma(a0.s0, b0.s2, c02);
-        c03 = fma(a0.s0, b0.s3, c03);
+        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
+        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
+        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
+        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
 
-        c10 = fma(a0.s1, b0.s0, c10);
-        c11 = fma(a0.s1, b0.s1, c11);
-        c12 = fma(a0.s1, b0.s2, c12);
-        c13 = fma(a0.s1, b0.s3, c13);
+        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
+        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
+        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
+        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
 
-        c20 = fma(a0.s2, b0.s0, c20);
-        c21 = fma(a0.s2, b0.s1, c21);
-        c22 = fma(a0.s2, b0.s2, c22);
-        c23 = fma(a0.s2, b0.s3, c23);
+        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
+        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
+        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
+        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
 
-        c30 = fma(a0.s3, b0.s0, c30);
-        c31 = fma(a0.s3, b0.s1, c31);
-        c32 = fma(a0.s3, b0.s2, c32);
-        c33 = fma(a0.s3, b0.s3, c33);
+        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
+        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
+        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
+        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
     }
 
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
-#if defined(ALPHA)
-    // Multiply by the weight of matrix product
-    c00 = c00 * ALPHA;
-    c01 = c01 * ALPHA;
-    c02 = c02 * ALPHA;
-    c03 = c03 * ALPHA;
-    c10 = c10 * ALPHA;
-    c11 = c11 * ALPHA;
-    c12 = c12 * ALPHA;
-    c13 = c13 * ALPHA;
-    c20 = c20 * ALPHA;
-    c21 = c21 * ALPHA;
-    c22 = c22 * ALPHA;
-    c23 = c23 * ALPHA;
-    c30 = c30 * ALPHA;
-    c31 = c31 * ALPHA;
-    c32 = c32 * ALPHA;
-    c33 = c33 * ALPHA;
-#endif // defined(ALPHA)
-
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-#if defined(ADD_VEC_C)
-    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    float4          c0        = vload4(0, src2_addr);
-
-    c00 += c0.s0;
-    c01 += c0.s1;
-    c02 += c0.s2;
-    c03 += c0.s3;
-    c10 += c0.s0;
-    c11 += c0.s1;
-    c12 += c0.s2;
-    c13 += c0.s3;
-    c20 += c0.s0;
-    c21 += c0.s1;
-    c22 += c0.s2;
-    c23 += c0.s3;
-    c30 += c0.s0;
-    c31 += c0.s1;
-    c32 += c0.s2;
-    c33 += c0.s3;
-#endif /* defined(ADD_VEC_C) */
+    uint4 zout = 0;
 
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
@@ -2918,8 +2888,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (cross_plane_pad * dst_stride_y);
@@ -2927,48 +2897,79 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store 4x4 block
-    vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
-    vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
-    vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
-    vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
-
 #else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(4, float, c, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
+
+    LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(4, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
+                                    2) * src2_stride_z;
+
+    LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(4, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(4, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store 4x4 block
-    vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
-    vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-    vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
-    vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+    vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
 }
 
 // Undefine local defines
 #undef COLS_MTX_B
 
 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
-/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
- *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
  *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2981,10 +2982,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -2993,17 +2996,21 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                 VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                 IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
+#if defined(BETA)
+                                                 uint src2_stride_z,
+#endif //defined(BETA)
                                                  uint dst_stride_z
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                  ,
@@ -3041,10 +3048,10 @@
     src_addr_b += offset_row_b;
 
     // Reset accumulators
-    half8 c00 = 0.0f;
-    half8 c10 = 0.0f;
-    half8 c20 = 0.0f;
-    half8 c30 = 0.0f;
+    half8 c0 = 0.0f;
+    half8 c1 = 0.0f;
+    half8 c2 = 0.0f;
+    half8 c3 = 0.0f;
 
     for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
     {
@@ -3052,19 +3059,19 @@
         half4 a0 = vload4(0, src_addr_a);
         half8 b0 = vload8(0, src_addr_b);
 
-        c00 += (half8)a0.s0 * b0;
-        c10 += (half8)a0.s1 * b0;
-        c20 += (half8)a0.s2 * b0;
-        c30 += (half8)a0.s3 * b0;
+        c0 += (half8)a0.s0 * b0;
+        c1 += (half8)a0.s1 * b0;
+        c2 += (half8)a0.s2 * b0;
+        c3 += (half8)a0.s3 * b0;
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
         b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
 
-        c00 += (half8)a0.s0 * b0;
-        c10 += (half8)a0.s1 * b0;
-        c20 += (half8)a0.s2 * b0;
-        c30 += (half8)a0.s3 * b0;
+        c0 += (half8)a0.s0 * b0;
+        c1 += (half8)a0.s1 * b0;
+        c2 += (half8)a0.s2 * b0;
+        c3 += (half8)a0.s3 * b0;
     }
 
     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
@@ -3073,40 +3080,20 @@
         half4 a0 = vload4(0, src_addr_a);
         half8 b0 = vload8(0, src_addr_b);
 
-        c00 += (half8)a0.s0 * b0;
-        c10 += (half8)a0.s1 * b0;
-        c20 += (half8)a0.s2 * b0;
-        c30 += (half8)a0.s3 * b0;
+        c0 += (half8)a0.s0 * b0;
+        c1 += (half8)a0.s1 * b0;
+        c2 += (half8)a0.s2 * b0;
+        c3 += (half8)a0.s3 * b0;
     }
 
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
-#if defined(ALPHA)
-    // Multiply by the weight of matrix product
-    c00 = c00 * (half8)ALPHA;
-    c10 = c10 * (half8)ALPHA;
-    c20 = c20 * (half8)ALPHA;
-    c30 = c30 * (half8)ALPHA;
-#endif // defined(ALPHA)
-
-#if defined(ADD_VEC_C)
-    // *INDENT-OFF*
-    // clang-format off
-    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    half8          c0        = vload8(0, src2_addr);
-    // clang-format on
-    // *INDENT-ON*
-
-    c00 += c0;
-    c10 += c0;
-    c20 += c0;
-    c30 += c0;
-#endif /* defined(ADD_VEC_C) */
-
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+    uint4 zout = 0;
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -3124,8 +3111,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (cross_plane_pad * dst_stride_y);
@@ -3133,44 +3120,76 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store 4x8 block
-    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
-    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
-    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
-    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
-
 #else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(4, half, c, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
+
+    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, half, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(4, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
+                                    2) * src2_stride_z;
+
+    LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(4, half, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(4, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store 4x8 block
-    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
-    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
-    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
-    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+    vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
 }
 
-/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
- *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
- *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
  *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3183,10 +3202,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3195,17 +3216,21 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
                                                        IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                       VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                       IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                        IMAGE_DECLARATION(dst),
                                                        uint src0_stride_z,
                                                        uint src1_stride_z,
+#if defined(BETA)
+                                                       uint src2_stride_z,
+#endif //defined(BETA)
                                                        uint dst_stride_z
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                        ,
@@ -3243,10 +3268,10 @@
     src_addr_b += offset_row_b;
 
     // Reset accumulators
-    float8 c00 = 0.0f;
-    float8 c10 = 0.0f;
-    float8 c20 = 0.0f;
-    float8 c30 = 0.0f;
+    float8 c0 = 0.0f;
+    float8 c1 = 0.0f;
+    float8 c2 = 0.0f;
+    float8 c3 = 0.0f;
 
     for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
     {
@@ -3254,19 +3279,19 @@
         float4 a0 = convert_float4(vload4(0, src_addr_a));
         float8 b0 = convert_float8(vload8(0, src_addr_b));
 
-        c00 += (float8)a0.s0 * b0;
-        c10 += (float8)a0.s1 * b0;
-        c20 += (float8)a0.s2 * b0;
-        c30 += (float8)a0.s3 * b0;
+        c0 += (float8)a0.s0 * b0;
+        c1 += (float8)a0.s1 * b0;
+        c2 += (float8)a0.s2 * b0;
+        c3 += (float8)a0.s3 * b0;
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
         b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
 
-        c00 += (float8)a0.s0 * b0;
-        c10 += (float8)a0.s1 * b0;
-        c20 += (float8)a0.s2 * b0;
-        c30 += (float8)a0.s3 * b0;
+        c0 += (float8)a0.s0 * b0;
+        c1 += (float8)a0.s1 * b0;
+        c2 += (float8)a0.s2 * b0;
+        c3 += (float8)a0.s3 * b0;
     }
 
     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
@@ -3275,40 +3300,20 @@
         float4 a0 = convert_float4(vload4(0, src_addr_a));
         float8 b0 = convert_float8(vload8(0, src_addr_b));
 
-        c00 += (float8)a0.s0 * b0;
-        c10 += (float8)a0.s1 * b0;
-        c20 += (float8)a0.s2 * b0;
-        c30 += (float8)a0.s3 * b0;
+        c0 += (float8)a0.s0 * b0;
+        c1 += (float8)a0.s1 * b0;
+        c2 += (float8)a0.s2 * b0;
+        c3 += (float8)a0.s3 * b0;
     }
 
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
-#if defined(ALPHA)
-    // Multiply by the weight of matrix product
-    c00 = c00 * (float8)ALPHA;
-    c10 = c10 * (float8)ALPHA;
-    c20 = c20 * (float8)ALPHA;
-    c30 = c30 * (float8)ALPHA;
-#endif // defined(ALPHA)
-
-#if defined(ADD_VEC_C)
-    // *INDENT-OFF*
-    // clang-format off
-    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    float8         c0        = convert_float8(vload8(0, src2_addr));
-    // clang-format on
-    // *INDENT-ON*
-
-    c00 += c0;
-    c10 += c0;
-    c20 += c0;
-    c30 += c0;
-#endif /* defined(ADD_VEC_C) */
-
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+    uint4 zout = 0;
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -3326,8 +3331,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (cross_plane_pad * dst_stride_y);
@@ -3335,44 +3340,86 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store 4x8 block
-    vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
-    vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
-    vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
-    vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
-
 #else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(4, float, c, ALPHA);
+#endif // defined(ALPHA)
+
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
+
+    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+    float8 bias_f0 = convert_float8(bias0);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, float, bias_f, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(4, c, bias_f0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
+                                    2) * src2_stride_z;
+
+    LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+    float8 bias_f0 = convert_float8(bias0);
+    float8 bias_f1 = convert_float8(bias1);
+    float8 bias_f2 = convert_float8(bias2);
+    float8 bias_f3 = convert_float8(bias3);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(4, float, bias_f, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(4, c, bias_f);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+    half8 c_h0 = convert_half8(c0);
+    half8 c_h1 = convert_half8(c1);
+    half8 c_h2 = convert_half8(c2);
+    half8 c_h3 = convert_half8(c3);
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store 4x8 block
-    vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
-    vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
-    vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
-    vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+    vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
 }
 
-/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
- *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
+/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
  *
  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
- * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3385,26 +3432,34 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                         VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                         IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                          IMAGE_DECLARATION(dst),
                                                          uint src0_stride_z,
                                                          uint src1_stride_z,
+#if defined(BETA)
+                                                         uint src2_stride_z,
+#endif //defined(BETA)
                                                          uint dst_stride_z
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                          ,
@@ -3442,10 +3497,10 @@
     src_addr_b += offset_row_b;
 
     // Reset accumulators
-    half8 c00 = 0.0f;
-    half8 c10 = 0.0f;
-    half8 c20 = 0.0f;
-    half8 c30 = 0.0f;
+    half8 c0 = 0.0f;
+    half8 c1 = 0.0f;
+    half8 c2 = 0.0f;
+    half8 c3 = 0.0f;
 
 #define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
 
@@ -3460,20 +3515,20 @@
         src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
 
         // Load values from matrix B (transposed)
         b0 = vload8(0, src_addr_b);
 
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s4, b0, c00);
-        c10 = fma((half8)a0.s5, b0, c10);
-        c20 = fma((half8)a0.s6, b0, c20);
-        c30 = fma((half8)a0.s7, b0, c30);
+        c0 = fma((half8)a0.s4, b0, c0);
+        c1 = fma((half8)a0.s5, b0, c1);
+        c2 = fma((half8)a0.s6, b0, c2);
+        c3 = fma((half8)a0.s7, b0, c3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload8(0, src_addr_a);
@@ -3482,20 +3537,20 @@
         src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
 
         // Load values from matrix B (transposed)
         b0 = vload8(0, src_addr_b);
 
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s4, b0, c00);
-        c10 = fma((half8)a0.s5, b0, c10);
-        c20 = fma((half8)a0.s6, b0, c20);
-        c30 = fma((half8)a0.s7, b0, c30);
+        c0 = fma((half8)a0.s4, b0, c0);
+        c1 = fma((half8)a0.s5, b0, c1);
+        c2 = fma((half8)a0.s6, b0, c2);
+        c3 = fma((half8)a0.s7, b0, c3);
 #else  // MULT_INTERLEAVE4X4_HEIGHT == 1
         // Load values from matrix A (interleaved) and matrix B (transposed)
         half4 a0 = vload4(0, src_addr_a);
@@ -3504,10 +3559,10 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a);
@@ -3516,10 +3571,10 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a);
@@ -3528,10 +3583,10 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
         a0 = vload4(0, src_addr_a);
@@ -3540,10 +3595,10 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
 #endif // MULT_INTERLEAVE4X4_HEIGHT == 1
     }
 
@@ -3556,40 +3611,20 @@
         src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
         src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
 
-        c00 = fma((half8)a0.s0, b0, c00);
-        c10 = fma((half8)a0.s1, b0, c10);
-        c20 = fma((half8)a0.s2, b0, c20);
-        c30 = fma((half8)a0.s3, b0, c30);
+        c0 = fma((half8)a0.s0, b0, c0);
+        c1 = fma((half8)a0.s1, b0, c1);
+        c2 = fma((half8)a0.s2, b0, c2);
+        c3 = fma((half8)a0.s3, b0, c3);
     }
 
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
-#if defined(ALPHA)
-    // Multiply by the weight of matrix product
-    c00 = c00 * (half8)ALPHA;
-    c10 = c10 * (half8)ALPHA;
-    c20 = c20 * (half8)ALPHA;
-    c30 = c30 * (half8)ALPHA;
-#endif // defined(ALPHA)
-
-#if defined(ADD_VEC_C)
-    // *INDENT-OFF*
-    // clang-format off
-    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    half8          c0        = vload8(0, src2_addr);
-    // clang-format on
-    // *INDENT-ON*
-
-    c00 += c0;
-    c10 += c0;
-    c20 += c0;
-    c30 += c0;
-#endif /* defined(ADD_VEC_C) */
-
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+    uint4 zout = 0;
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
@@ -3607,8 +3642,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (cross_plane_pad * dst_stride_y);
@@ -3616,23 +3651,57 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store 4x8 block
-    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
-    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
-    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
-    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
-
 #else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(4, half, c, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
+
+    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, half, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(4, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
+                                    2) * src2_stride_z;
+
+    LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(4, half, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(4, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store 4x8 block
-    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
-    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
-    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
-    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+    vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
 }
 
 // Undefine local defines
@@ -3647,15 +3716,15 @@
 #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
  *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
- *
  * @note This OpenCL kernel works with floating point data types (F16/F32)
  * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
  * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -3663,8 +3732,6 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3677,10 +3744,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -3689,18 +3758,22 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
  * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
                                      IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                     VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                     IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                      IMAGE_DECLARATION(dst),
                                      uint src0_stride_z,
                                      uint src1_stride_z,
+#if defined(BETA)
+                                     uint src2_stride_z,
+#endif //defined(BETA)
                                      uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                      ,
@@ -3865,49 +3938,18 @@
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
     }
 
+    int z = get_global_id(2);
+
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    acc0 = acc0 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-    acc1 = acc1 * (VECTOR_TYPE)ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-    acc2 = acc2 * (VECTOR_TYPE)ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-    acc3 = acc3 * (VECTOR_TYPE)ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-
-#if defined(ADD_VEC_C)
-    // *INDENT-OFF*
-    // clang-format off
-    __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    VECTOR_TYPE         c0        = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
-    // clang-format on
-    // *INDENT-ON*
-
-    acc0 += c0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    acc1 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    acc2 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    acc3 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif /* defined(ADD_VEC_C) */
-
-    int z = get_global_id(2);
+    uint4 zout = 0;
 
 #if defined(REINTERPRET_OUTPUT_AS_3D)
+
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
     //
@@ -3924,8 +3966,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (dst_cross_plane_pad * dst_stride_y);
@@ -3933,44 +3975,69 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
+
+    LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
+                                (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
+
+    LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store output block
     STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
-    // Add offset for batched GEMM
-    dst_addr += z * dst_stride_z;
-
-    // Store output block
-    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
-    (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
-    (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
-    (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
-    (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 #endif // defined(DATA_TYPE)
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
  *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
- *
  * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
  * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
  * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -3978,9 +4045,7 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
- * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
+ * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
@@ -3992,10 +4057,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4004,18 +4071,22 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
  * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                 VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                 IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
+#if defined(BETA)
+                                                 uint src2_stride_z,
+#endif //defined(BETA)
                                                  uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                                  ,
@@ -4080,30 +4151,18 @@
 #endif // defined(MATRIX_B_DEPTH)
 
     // Initialize accumulators
-    float acc00 = 0.0f;
-    float acc01 = 0.0f;
-    float acc02 = 0.0f;
-    float acc03 = 0.0f;
+    float4 acc0 = 0.0f;
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    float acc10 = 0.0f;
-    float acc11 = 0.0f;
-    float acc12 = 0.0f;
-    float acc13 = 0.0f;
+    float4 acc1 = 0.0f;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    float acc20 = 0.0f;
-    float acc21 = 0.0f;
-    float acc22 = 0.0f;
-    float acc23 = 0.0f;
+    float4 acc2 = 0.0f;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    float acc30 = 0.0f;
-    float acc31 = 0.0f;
-    float acc32 = 0.0f;
-    float acc33 = 0.0f;
+    float4 acc3 = 0.0f;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
     // A and B src indices get incremented at the same time.
@@ -4131,33 +4190,33 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0.s0, b0.s0, acc00);
-        acc01 = fma(a0.s0, b0.s1, acc01);
-        acc02 = fma(a0.s0, b0.s2, acc02);
-        acc03 = fma(a0.s0, b0.s3, acc03);
+        acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
+        acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
+        acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
+        acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 
-        acc10 = fma(a1.s0, b0.s0, acc10);
-        acc11 = fma(a1.s0, b0.s1, acc11);
-        acc12 = fma(a1.s0, b0.s2, acc12);
-        acc13 = fma(a1.s0, b0.s3, acc13);
+        acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
+        acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
+        acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
+        acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 
-        acc20 = fma(a2.s0, b0.s0, acc20);
-        acc21 = fma(a2.s0, b0.s1, acc21);
-        acc22 = fma(a2.s0, b0.s2, acc22);
-        acc23 = fma(a2.s0, b0.s3, acc23);
+        acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
+        acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
+        acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
+        acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
-        acc30 = fma(a3.s0, b0.s0, acc30);
-        acc31 = fma(a3.s0, b0.s1, acc31);
-        acc32 = fma(a3.s0, b0.s2, acc32);
-        acc33 = fma(a3.s0, b0.s3, acc33);
+        acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
+        acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
+        acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
+        acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         // Load values from matrix A and matrix B
@@ -4165,33 +4224,33 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0.s1, b0.s0, acc00);
-        acc01 = fma(a0.s1, b0.s1, acc01);
-        acc02 = fma(a0.s1, b0.s2, acc02);
-        acc03 = fma(a0.s1, b0.s3, acc03);
+        acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
+        acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
+        acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
+        acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 
-        acc10 = fma(a1.s1, b0.s0, acc10);
-        acc11 = fma(a1.s1, b0.s1, acc11);
-        acc12 = fma(a1.s1, b0.s2, acc12);
-        acc13 = fma(a1.s1, b0.s3, acc13);
+        acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
+        acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
+        acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
+        acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 
-        acc20 = fma(a2.s1, b0.s0, acc20);
-        acc21 = fma(a2.s1, b0.s1, acc21);
-        acc22 = fma(a2.s1, b0.s2, acc22);
-        acc23 = fma(a2.s1, b0.s3, acc23);
+        acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
+        acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
+        acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
+        acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
-        acc30 = fma(a3.s1, b0.s0, acc30);
-        acc31 = fma(a3.s1, b0.s1, acc31);
-        acc32 = fma(a3.s1, b0.s2, acc32);
-        acc33 = fma(a3.s1, b0.s3, acc33);
+        acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
+        acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
+        acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
+        acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         // Load values from matrix A and matrix B
@@ -4199,33 +4258,33 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0.s2, b0.s0, acc00);
-        acc01 = fma(a0.s2, b0.s1, acc01);
-        acc02 = fma(a0.s2, b0.s2, acc02);
-        acc03 = fma(a0.s2, b0.s3, acc03);
+        acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
+        acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
+        acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
+        acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 
-        acc10 = fma(a1.s2, b0.s0, acc10);
-        acc11 = fma(a1.s2, b0.s1, acc11);
-        acc12 = fma(a1.s2, b0.s2, acc12);
-        acc13 = fma(a1.s2, b0.s3, acc13);
+        acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
+        acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
+        acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
+        acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 
-        acc20 = fma(a2.s2, b0.s0, acc20);
-        acc21 = fma(a2.s2, b0.s1, acc21);
-        acc22 = fma(a2.s2, b0.s2, acc22);
-        acc23 = fma(a2.s2, b0.s3, acc23);
+        acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
+        acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
+        acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
+        acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
-        acc30 = fma(a3.s2, b0.s0, acc30);
-        acc31 = fma(a3.s2, b0.s1, acc31);
-        acc32 = fma(a3.s2, b0.s2, acc32);
-        acc33 = fma(a3.s2, b0.s3, acc33);
+        acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
+        acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
+        acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
+        acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         // Load values from matrix A and matrix B
@@ -4233,33 +4292,33 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0.s3, b0.s0, acc00);
-        acc01 = fma(a0.s3, b0.s1, acc01);
-        acc02 = fma(a0.s3, b0.s2, acc02);
-        acc03 = fma(a0.s3, b0.s3, acc03);
+        acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
+        acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
+        acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
+        acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 
-        acc10 = fma(a1.s3, b0.s0, acc10);
-        acc11 = fma(a1.s3, b0.s1, acc11);
-        acc12 = fma(a1.s3, b0.s2, acc12);
-        acc13 = fma(a1.s3, b0.s3, acc13);
+        acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
+        acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
+        acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
+        acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 
-        acc20 = fma(a2.s3, b0.s0, acc20);
-        acc21 = fma(a2.s3, b0.s1, acc21);
-        acc22 = fma(a2.s3, b0.s2, acc22);
-        acc23 = fma(a2.s3, b0.s3, acc23);
+        acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
+        acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
+        acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
+        acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
 
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
-        acc30 = fma(a3.s3, b0.s0, acc30);
-        acc31 = fma(a3.s3, b0.s1, acc31);
-        acc32 = fma(a3.s3, b0.s2, acc32);
-        acc33 = fma(a3.s3, b0.s3, acc33);
+        acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
+        acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
+        acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
+        acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         src_addr.s0 += 4 * sizeof(float);
@@ -4298,27 +4357,27 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0, b0.s0, acc00);
-        acc01 = fma(a0, b0.s1, acc01);
-        acc02 = fma(a0, b0.s2, acc02);
-        acc03 = fma(a0, b0.s3, acc03);
+        acc0.s0 = fma(a0, b0.s0, acc0.s0);
+        acc0.s1 = fma(a0, b0.s1, acc0.s1);
+        acc0.s2 = fma(a0, b0.s2, acc0.s2);
+        acc0.s3 = fma(a0, b0.s3, acc0.s3);
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-        acc10 = fma(a1, b0.s0, acc10);
-        acc11 = fma(a1, b0.s1, acc11);
-        acc12 = fma(a1, b0.s2, acc12);
-        acc13 = fma(a1, b0.s3, acc13);
+        acc1.s0 = fma(a1, b0.s0, acc1.s0);
+        acc1.s1 = fma(a1, b0.s1, acc1.s1);
+        acc1.s2 = fma(a1, b0.s2, acc1.s2);
+        acc1.s3 = fma(a1, b0.s3, acc1.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-        acc20 = fma(a2, b0.s0, acc20);
-        acc21 = fma(a2, b0.s1, acc21);
-        acc22 = fma(a2, b0.s2, acc22);
-        acc23 = fma(a2, b0.s3, acc23);
+        acc2.s0 = fma(a2, b0.s0, acc2.s0);
+        acc2.s1 = fma(a2, b0.s1, acc2.s1);
+        acc2.s2 = fma(a2, b0.s2, acc2.s2);
+        acc2.s3 = fma(a2, b0.s3, acc2.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-        acc30 = fma(a3, b0.s0, acc30);
-        acc31 = fma(a3, b0.s1, acc31);
-        acc32 = fma(a3, b0.s2, acc32);
-        acc33 = fma(a3, b0.s3, acc33);
+        acc3.s0 = fma(a3, b0.s0, acc3.s0);
+        acc3.s1 = fma(a3, b0.s1, acc3.s1);
+        acc3.s2 = fma(a3, b0.s2, acc3.s2);
+        acc3.s3 = fma(a3, b0.s3, acc3.s3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         src_addr.s0 += sizeof(float);
@@ -4329,62 +4388,10 @@
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    acc00 = acc00 * ALPHA;
-    acc01 = acc01 * ALPHA;
-    acc02 = acc02 * ALPHA;
-    acc03 = acc03 * ALPHA;
-#endif // defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-    acc10 = acc10 * ALPHA;
-    acc11 = acc11 * ALPHA;
-    acc12 = acc12 * ALPHA;
-    acc13 = acc13 * ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-    acc20 = acc20 * ALPHA;
-    acc21 = acc21 * ALPHA;
-    acc22 = acc22 * ALPHA;
-    acc23 = acc23 * ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-    acc30 = acc30 * ALPHA;
-    acc31 = acc31 * ALPHA;
-    acc32 = acc32 * ALPHA;
-    acc33 = acc33 * ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-#if defined(ADD_VEC_C)
-    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    float4          c0        = vload4(0, src2_addr);
-
-    acc00 += c0.s0;
-    acc01 += c0.s1;
-    acc02 += c0.s2;
-    acc03 += c0.s3;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    acc10 += c0.s0;
-    acc11 += c0.s1;
-    acc12 += c0.s2;
-    acc13 += c0.s3;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    acc20 += c0.s0;
-    acc21 += c0.s1;
-    acc22 += c0.s2;
-    acc23 += c0.s3;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    acc30 += c0.s0;
-    acc31 += c0.s1;
-    acc32 += c0.s2;
-    acc33 += c0.s3;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif /* defined(ADD_VEC_C) */
+    uint4 zout = 0;
 
 #if defined(REINTERPRET_OUTPUT_AS_3D)
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
@@ -4403,8 +4410,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (dst_cross_plane_pad * dst_stride_y);
@@ -4412,50 +4419,78 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store the output block
-    vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
+
+    LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
+                                (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
+
+    LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias
+    ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store the output block
-    vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+    vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+    vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+    vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+    vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
  *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
- *
  * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
  * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
  * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
  * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -4463,9 +4498,7 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
- * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
+ * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
@@ -4477,10 +4510,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4489,18 +4524,22 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
  * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
                                                       IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                      VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                      IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                       IMAGE_DECLARATION(dst),
                                                       uint src0_stride_z,
                                                       uint src1_stride_z,
+#if defined(BETA)
+                                                      uint src2_stride_z,
+#endif //defined(BETA)
                                                       uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                                       ,
@@ -4566,20 +4605,15 @@
 #endif // defined(MATRIX_B_DEPTH)
 
     // Initialize accumulators
-    float acc00 = 0.0f;
-    float acc01 = 0.0f;
-
+    float2 acc0 = 0.0f;
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    float acc10 = 0.0f;
-    float acc11 = 0.0f;
+    float2 acc1 = 0.0f;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    float acc20 = 0.0f;
-    float acc21 = 0.0f;
+    float2 acc2 = 0.0f;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    float acc30 = 0.0f;
-    float acc31 = 0.0f;
+    float2 acc3 = 0.0f;
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
     // A and B src indices get incremented at the same time.
@@ -4613,95 +4647,95 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0.s0, b0.s0, acc00);
-        acc00 = fma(a0.s1, b1.s0, acc00);
-        acc00 = fma(a0.s2, b2.s0, acc00);
-        acc00 = fma(a0.s3, b3.s0, acc00);
-        acc00 = fma(a0.s4, b4.s0, acc00);
-        acc00 = fma(a0.s5, b5.s0, acc00);
-        acc00 = fma(a0.s6, b6.s0, acc00);
-        acc00 = fma(a0.s7, b7.s0, acc00);
+        acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
+        acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
+        acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
+        acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
+        acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
+        acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
+        acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
+        acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
 
-        acc01 = fma(a0.s0, b0.s1, acc01);
-        acc01 = fma(a0.s1, b1.s1, acc01);
-        acc01 = fma(a0.s2, b2.s1, acc01);
-        acc01 = fma(a0.s3, b3.s1, acc01);
-        acc01 = fma(a0.s4, b4.s1, acc01);
-        acc01 = fma(a0.s5, b5.s1, acc01);
-        acc01 = fma(a0.s6, b6.s1, acc01);
-        acc01 = fma(a0.s7, b7.s1, acc01);
+        acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
+        acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
+        acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
+        acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
+        acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
+        acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
+        acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
+        acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if defined(REINTERPRET_INPUT_AS_3D)
         a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
 #else  // defined(REINTERPRET_INPUT_AS_3D)
-        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+        a0                    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
 #endif // defined(REINTERPRET_INPUT_AS_3D)
-        acc10 = fma(a0.s0, b0.s0, acc10);
-        acc10 = fma(a0.s1, b1.s0, acc10);
-        acc10 = fma(a0.s2, b2.s0, acc10);
-        acc10 = fma(a0.s3, b3.s0, acc10);
-        acc10 = fma(a0.s4, b4.s0, acc10);
-        acc10 = fma(a0.s5, b5.s0, acc10);
-        acc10 = fma(a0.s6, b6.s0, acc10);
-        acc10 = fma(a0.s7, b7.s0, acc10);
+        acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
+        acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
+        acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
+        acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
+        acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
+        acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
+        acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
+        acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
 
-        acc11 = fma(a0.s0, b0.s1, acc11);
-        acc11 = fma(a0.s1, b1.s1, acc11);
-        acc11 = fma(a0.s2, b2.s1, acc11);
-        acc11 = fma(a0.s3, b3.s1, acc11);
-        acc11 = fma(a0.s4, b4.s1, acc11);
-        acc11 = fma(a0.s5, b5.s1, acc11);
-        acc11 = fma(a0.s6, b6.s1, acc11);
-        acc11 = fma(a0.s7, b7.s1, acc11);
+        acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
+        acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
+        acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
+        acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
+        acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
+        acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
+        acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
+        acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if defined(REINTERPRET_INPUT_AS_3D)
         a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
 #else  // defined(REINTERPRET_INPUT_AS_3D)
-        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+        a0                    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
 #endif // defined(REINTERPRET_INPUT_AS_3D)
-        acc20 = fma(a0.s0, b0.s0, acc20);
-        acc20 = fma(a0.s1, b1.s0, acc20);
-        acc20 = fma(a0.s2, b2.s0, acc20);
-        acc20 = fma(a0.s3, b3.s0, acc20);
-        acc20 = fma(a0.s4, b4.s0, acc20);
-        acc20 = fma(a0.s5, b5.s0, acc20);
-        acc20 = fma(a0.s6, b6.s0, acc20);
-        acc20 = fma(a0.s7, b7.s0, acc20);
+        acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
+        acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
+        acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
+        acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
+        acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
+        acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
+        acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
+        acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
 
-        acc21 = fma(a0.s0, b0.s1, acc21);
-        acc21 = fma(a0.s1, b1.s1, acc21);
-        acc21 = fma(a0.s2, b2.s1, acc21);
-        acc21 = fma(a0.s3, b3.s1, acc21);
-        acc21 = fma(a0.s4, b4.s1, acc21);
-        acc21 = fma(a0.s5, b5.s1, acc21);
-        acc21 = fma(a0.s6, b6.s1, acc21);
-        acc21 = fma(a0.s7, b7.s1, acc21);
+        acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
+        acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
+        acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
+        acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
+        acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
+        acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
+        acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
+        acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 #if defined(REINTERPRET_INPUT_AS_3D)
         a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
 #else  // defined(REINTERPRET_INPUT_AS_3D)
-        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+        a0                    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // defined(REINTERPRET_INPUT_AS_3D)
-        acc30 = fma(a0.s0, b0.s0, acc30);
-        acc30 = fma(a0.s1, b1.s0, acc30);
-        acc30 = fma(a0.s2, b2.s0, acc30);
-        acc30 = fma(a0.s3, b3.s0, acc30);
-        acc30 = fma(a0.s4, b4.s0, acc30);
-        acc30 = fma(a0.s5, b5.s0, acc30);
-        acc30 = fma(a0.s6, b6.s0, acc30);
-        acc30 = fma(a0.s7, b7.s0, acc30);
+        acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
+        acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
+        acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
+        acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
+        acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
+        acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
+        acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
+        acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
 
-        acc31 = fma(a0.s0, b0.s1, acc31);
-        acc31 = fma(a0.s1, b1.s1, acc31);
-        acc31 = fma(a0.s2, b2.s1, acc31);
-        acc31 = fma(a0.s3, b3.s1, acc31);
-        acc31 = fma(a0.s4, b4.s1, acc31);
-        acc31 = fma(a0.s5, b5.s1, acc31);
-        acc31 = fma(a0.s6, b6.s1, acc31);
-        acc31 = fma(a0.s7, b7.s1, acc31);
+        acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
+        acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
+        acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
+        acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
+        acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
+        acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
+        acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
+        acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         src_addr.s0 += sizeof(float) * 8;
@@ -4740,42 +4774,24 @@
         src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
-        acc00 = fma(a0, b0.s0, acc00);
-        acc01 = fma(a0, b0.s1, acc01);
+        acc0.s0 = fma(a0, b0.s0, acc0.s0);
+        acc0.s1 = fma(a0, b0.s1, acc0.s1);
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-        acc10 = fma(a1, b0.s0, acc10);
-        acc11 = fma(a1, b0.s1, acc11);
+        acc1.s0 = fma(a1, b0.s0, acc1.s0);
+        acc1.s1 = fma(a1, b0.s1, acc1.s1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-        acc20 = fma(a2, b0.s0, acc20);
-        acc21 = fma(a2, b0.s1, acc21);
+        acc2.s0 = fma(a2, b0.s0, acc2.s0);
+        acc2.s1 = fma(a2, b0.s1, acc2.s1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-        acc30 = fma(a3, b0.s0, acc30);
-        acc31 = fma(a3, b0.s1, acc31);
+        acc3.s0 = fma(a3, b0.s0, acc3.s0);
+        acc3.s1 = fma(a3, b0.s1, acc3.s1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
         src_addr.s0 += sizeof(float);
     }
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    acc00 = acc00 * ALPHA;
-    acc01 = acc01 * ALPHA;
-#endif // defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-    acc10 = acc10 * ALPHA;
-    acc11 = acc11 * ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-    acc20 = acc20 * ALPHA;
-    acc21 = acc21 * ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-    acc30 = acc30 * ALPHA;
-    acc31 = acc31 * ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-
     int z = get_global_id(2);
 
     // Compute destination address
@@ -4784,27 +4800,10 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-#if defined(ADD_VEC_C)
-    __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    float2          c0        = vload2(0, src2_addr);
-
-    acc00 += c0.s0;
-    acc01 += c0.s1;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    acc10 += c0.s0;
-    acc11 += c0.s1;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    acc20 += c0.s0;
-    acc21 += c0.s1;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    acc30 += c0.s0;
-    acc31 += c0.s1;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif /* defined(ADD_VEC_C) */
+    uint4 zout = 0;
 
 #if defined(REINTERPRET_OUTPUT_AS_3D)
+
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
     //
@@ -4821,8 +4820,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (dst_cross_plane_pad * dst_stride_y);
@@ -4830,50 +4829,78 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
-    // Store the output block
-    vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
+
+    LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
+                                (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
+
+    LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias
+    ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store the output block
-    vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+    vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+    vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+    vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+    vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
  *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
- *
  * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
  * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
  * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -4881,8 +4908,6 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4895,10 +4920,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -4907,18 +4934,22 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
  * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
                                                        IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                       VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                       IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                        IMAGE_DECLARATION(dst),
                                                        uint src0_stride_z,
                                                        uint src1_stride_z,
+#if defined(BETA)
+                                                       uint src2_stride_z,
+#endif //defined(BETA)
                                                        uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                                        ,
@@ -5117,56 +5148,6 @@
 #endif                                    // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
     }
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
-#else  //defined(ALPHA)
-    half8 hacc0 = convert_half8(acc0);
-#endif // defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
-    half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
-#else  //defined(ALPHA)
-    half8 hacc1 = convert_half8(acc1);
-#endif //defined(ALPHA)
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
-    half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
-#else  //defined(ALPHA)
-    half8 hacc2 = convert_half8(acc2);
-#endif //defined(ALPHA)
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
-    half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
-#else  //defined(ALPHA)
-    half8 hacc3 = convert_half8(acc3);
-#endif // defined(ALPHA)
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
-#if defined(ADD_VEC_C)
-    // *INDENT-OFF*
-    // clang-format off
-    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    half8          c0        = vload8(0, src2_addr);
-    // clang-format on
-    // *INDENT-ON*
-
-    hacc0 += c0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    hacc1 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    hacc2 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    hacc3 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif /* defined(ADD_VEC_C) */
-
     int z = get_global_id(2);
 
     // Compute destination address
@@ -5175,7 +5156,10 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+    uint4 zout = 0;
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
+
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
     //
@@ -5192,8 +5176,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (dst_cross_plane_pad * dst_stride_y);
@@ -5201,38 +5185,91 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-    // Store the output block
-    STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s);
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 
-    // Store the output block
-    vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
+#endif // defined(ALPHA)
+
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
+
+    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+    float8 bias_f0 = convert_float8(bias0);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, float, bias_f, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
+                                (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
+
+    LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+    float8 bias_f0 = convert_float8(bias0);
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
+    float8 bias_f1 = convert_float8(bias1);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
+    float8 bias_f2 = convert_float8(bias2);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+    float8 bias_f3 = convert_float8(bias3);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // REINTERPRET_OUTPUT_AS_3D
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias
+    ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+    half8 acc_h0 = convert_half8(acc0);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    half8 acc_h1 = convert_half8(acc1);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    half8 acc_h2 = convert_half8(acc2);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    half8 acc_h3 = convert_half8(acc3);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
+
+    // Store the output block
+    STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
 }
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
  *
- * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
- *
  * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
  * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
  * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
- * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
- *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -5240,8 +5277,6 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
- * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
- *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -5254,10 +5289,12 @@
  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in]  src2_ptr                           (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
- * @param[in]  src2_stride_x                      (Optional) Stride of the source vector in X dimension (in bytes)
- * @param[in]  src2_step_x                        (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
+ * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
@@ -5266,18 +5303,22 @@
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
  * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
-#if defined(ADD_VEC_C)
-                                                 VECTOR_DECLARATION(src2),
-#endif /* defined(ADD_VEC_C) */
+#if defined(BETA)
+                                                 IMAGE_DECLARATION(src2),
+#endif // defined(BETA)
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
+#if defined(BETA)
+                                                 uint src2_stride_z,
+#endif //defined(BETA)
                                                  uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                                  ,
@@ -5476,40 +5517,6 @@
 #endif                                   // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
     }
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    acc0 = acc0 * (half8)ALPHA;
-#endif // defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-    acc1 = acc1 * (half8)ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-    acc2 = acc2 * (half8)ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-    acc3 = acc3 * (half8)ALPHA;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
-
-#if defined(ADD_VEC_C)
-    // *INDENT-OFF*
-    // clang-format off
-    __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
-    half8          c0        = vload8(0, src2_addr);
-    // clang-format on
-    // *INDENT-ON*
-
-    acc0 += c0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    acc1 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    acc2 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    acc3 += c0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif /* defined(ADD_VEC_C) */
-
     int z = get_global_id(2);
 
     // Compute destination address
@@ -5518,7 +5525,10 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+    uint4 zout = 0;
+
 #if defined(REINTERPRET_OUTPUT_AS_3D)
+
     // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
     // in order to take into account the presence of possible cross plane paddings
     //
@@ -5535,8 +5545,8 @@
     //  |__________________|
 
     // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
-    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
-    zout       = min(DEPTH_GEMM3D - 1, zout);
+    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
     zout *= (dst_cross_plane_pad * dst_stride_y);
@@ -5544,25 +5554,54 @@
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
     dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+    REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
+
+#if defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
+
+    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, half, bias, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
+                                (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
+
+    LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
+#endif // UNIT_BIAS
+
+    // acc = acc + bias
+    ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
 
     // Store the output block
     STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
-    // Add offset for batched GEMM
-    dst_addr += z * dst_stride_z;
-
-    // Store the output block
-    vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-    vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-    vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-    vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // REINTERPRET_OUTPUT_AS_3D
 }
 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
 
@@ -5746,7 +5785,7 @@
     Image  accum  = CONVERT_TO_IMAGE_STRUCT(accum);
     Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
 
-    // Vector size, i.e. number of vector elements.
+    // Vector size, e.g. number of vector elements.
     VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
     accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
     VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)