COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel

Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1259
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 41e5c33..2ac2eb7 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -731,29 +731,29 @@
     // 3x4 -> 4x3
     // 3x8 -> 8x3
     // 3x16 -> 16x3
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
+    res0                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
+    res1                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
 #if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
+    res2                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
 #endif // N0 > 2
 #if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
+    res3                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
 #endif // N0 > 3
 #if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
+    res4                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
+    res5                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
+    res6                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
+    res7                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
 #endif // N0 > 4
 #if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
+    res8                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
+    res9                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
+    resA                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
+    resB                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
+    resC                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
+    resD                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
+    resE                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
+    resF                      = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
 #endif // N0 > 8
 
 #elif K0 == 4 // K0 == 4
@@ -1029,35 +1029,48 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
  *
- * @param[in]  lhs_ptr                           Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in]  lhs_stride_x                      Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in]  lhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  lhs_stride_y                      Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  lhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in]  rhs_ptr                           Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in]  rhs_stride_x                      Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in]  rhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  rhs_stride_y                      Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  rhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in]  lhs_stride_z                      Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  rhs_stride_z                      Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  lhs_cross_plane_pad               (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in]  dst_cross_plane_pad               (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
+ * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
+ * @param[in]  bias_ptr                           (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  bias_stride_x                      (Optional)Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in]  bias_step_x                        (Optional)bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  bias_stride_y                      (Optional)Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in]  bias_step_y                        (Optional)bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  bias_offset_first_element_in_bytes (Optional)The offset of the first element in the bias reshaped matrix
+ * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
+ * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  bias_stride_z                      (Optional) Stride of the bias  matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
                                           IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+                                          IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
                                           IMAGE_DECLARATION(dst),
                                           uint lhs_stride_z,
                                           uint rhs_stride_z,
+#if defined(BETA)
+                                          uint bias_stride_z,
+#endif //defined(BETA)
                                           uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                           ,
@@ -1108,7 +1121,7 @@
 #endif // defined(MATRIX_B_DEPTH)
 
     REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
-    REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
+    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
 
 #if defined(REINTERPRET_INPUT_AS_3D)
     // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
@@ -1144,7 +1157,7 @@
         LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
 
         // Load values from RHS matrix
-        LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+        LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
 
         // Accumulate
         ARM_DOT_K0XN0(K0, a0, b, c0);
@@ -1181,7 +1194,7 @@
         LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
 
         // Load values from RHS matrix
-        LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+        LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
 
         // Accumulate
         ARM_DOT_K0XN0(1, a0, b, c0);
@@ -1236,6 +1249,36 @@
     SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
 #endif // defined(ALPHA)
 
+    // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+    LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+                                    2) * bias_stride_z;
+
+    LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
@@ -1360,35 +1403,48 @@
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
  *
- * @param[in]  lhs_ptr                           Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in]  lhs_stride_x                      Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in]  lhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  lhs_stride_y                      Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  lhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in]  rhs_ptr                           Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in]  rhs_stride_x                      Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in]  rhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  rhs_stride_y                      Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in]  rhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in]  lhs_stride_z                      Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  rhs_stride_z                      Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  lhs_cross_plane_pad               (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in]  dst_cross_plane_pad               (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  lhs_ptr                            Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in]  lhs_stride_x                       Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  lhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  lhs_stride_y                       Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  lhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  lhs_offset_first_element_in_bytes  The offset of the first element in the LHS reshaped matrix
+ * @param[in]  rhs_ptr                            Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  rhs_stride_x                       Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in]  rhs_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  rhs_stride_y                       Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in]  rhs_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the RHS reshaped matrix
+ * @param[in]  bias_ptr                           (Optional) Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in]  bias_stride_x                      (Optional) Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in]  bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  bias_stride_y                      (Optional) Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in]  bias_step_y                        (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
+ * @param[out] dst_ptr                            Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
+ * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  lhs_stride_z                       Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  rhs_stride_z                       Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in]  bias_stride_z                      (Optional)Stride of the bias reshaped matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  lhs_cross_plane_pad                (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
                                            IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+                                           IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
                                            IMAGE_DECLARATION(dst),
                                            uint lhs_stride_z,
                                            uint rhs_stride_z,
+#if defined(BETA)
+                                           uint bias_stride_z,
+#endif //defined(BETA)
                                            uint dst_stride_z
 #if defined(REINTERPRET_INPUT_AS_3D)
                                            ,
@@ -1438,7 +1494,8 @@
     rhs_offset += z * rhs_stride_z;
 #endif // defined(MATRIX_B_DEPTH)
 
-    REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
+    REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0);   //uint zin0=0,zin1=0,zin2=0,... zin7=0;
+    REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
 
 #if defined(REINTERPRET_INPUT_AS_3D)
 
@@ -1568,6 +1625,36 @@
     SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
 #endif // defined(ALPHA)
 
+    // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+    LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+    __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+                                    2) * bias_stride_z;
+
+    LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);