COMPMID-2398: Add test for CLFuseBatchNormalizationLayer

Change-Id: I786df628ce15fc33fc42c9437fe82972e02e3b16
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1317
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 904575a..b734fd2 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -298,7 +298,7 @@
     { "finalize", "optical_flow_pyramid_lk.cl" },
     { "flatten", "flatten.cl" },
     { "floor_layer", "floor.cl" },
-    { "fuse_batchnormalization_layer", "batchnormalization_layer.cl" },
+    { "fuse_batchnormalization_conv_layer", "batchnormalization_layer.cl" },
     { "gather", "gather.cl" },
     { "gaussian1x5_sub_x", "gaussian_pyramid.cl" },
     { "gaussian5x1_sub_y", "gaussian_pyramid.cl" },
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index 66d371c..a532131 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -259,161 +259,145 @@
 }
 #endif /* defined(VEC_SIZE) && defined(DATA_TYPE) && defined(DATA_TYPE)*/
 
-#if defined(NUM_CHANNELS) && defined(DATA_TYPE) && defined(EPSILON)
-/** Fuse batchnorm parameters to convolution layer parameters
+#if defined(DIM2) && defined(DATA_TYPE) && defined(EPSILON)
+/** OpenCL kernel to fuse the weights of convolution layer with batch normalization when the data layout is either NCHW or NHWC
  *
- * @attention Data type should be passed using the -DDATA_TYPE compile flag, e.g. -DDATA_TYPE=float
- * @attention Input tensor depth should be given as a preprocessor argument using -DNUM_CHANNELS=size. e.g. -DNUM_CHANNELS=16
- * @attention Batch normalization epsilon parameter should be given as a preprocessor argument with -DEPSILON=value. e.g. -DEPSILON=0.001f
+ * @note The input weights tensor is assumed 4D with the OFMs in the fourth dimension
+ * @note Data type should be passed at compile time using the -DDATA_TYPE, e.g. -DDATA_TYPE=float
+ * @note The third dimension of the input tensor should be passed at compile time using -DNUM_CHANNELS=size. e.g. -DNUM_CHANNELS=16
+ * @note Batch normalization epsilon parameter should be passed at compile time using -DEPSILON=value. e.g. -DEPSILON=0.001f
  *
- * @param[in]  conv_w_ptr                             Pointer to the source tensor. Supported data types: F16/F32
- * @param[in]  conv_w_stride_x                        Stride of the source tensor in X dimension (in bytes)
- * @param[in]  conv_w_step_x                          input_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  conv_w_stride_y                        Stride of the source tensor in Y dimension (in bytes)
- * @param[in]  conv_w_step_y                          input_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  conv_w_stride_z                        Stride of the source tensor in Z dimension (in bytes)
- * @param[in]  conv_w_step_z                          input_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  conv_w_stride_w                        Stride of the source tensor in W dimension (in bytes)
- * @param[in]  conv_w_step_w                          input_stride_w * number of elements along W processed per workitem(in bytes)
- * @param[in]  conv_w_offset_first_element_in_bytes   The offset of the first element in the source tensor
- * @param[in]  bn_mean_ptr                            Pointer to the mean source tensor. Supported data types: same as @p input_ptr
- * @param[in]  bn_mean_stride_x                       Stride of the mean source tensor in X dimension (in bytes)
- * @param[in]  bn_mean_step_x                         bn_mean_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  bn_mean_offset_first_element_in_bytes  The offset of the first element in the mean source tensor
- * @param[in]  bn_var_ptr                             Pointer to the var tensor. Supported data types: same as @p input_ptr
- * @param[in]  bn_var_stride_x                        Stride of the var tensor in X dimension (in bytes)
- * @param[in]  bn_var_step_x                          bn_var_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  bn_var_offset_first_element_in_bytes   The offset of the first element in the var source tensor
- * @param[out] fused_w_ptr                            Pointer to the destination weights tensors. Supported data types: same as @p input_ptr
- * @param[in]  fused_w_stride_x                       Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  fused_w_step_x                         fused_w_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  fused_w_stride_y                       Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  fused_w_step_y                         fused_w_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  fused_w_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  fused_w_step_z                         fused_w_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  fused_w_stride_w                       Stride of the destination tensor in W dimension (in bytes)
- * @param[in]  fused_w_step_w                         fused_w_stride_w * number of elements along W processed per workitem(in bytes)
- * @param[in]  fused_w_offset_first_element_in_bytes  The offset of the first element in the destination tensor
- * @param[in]  fused_b_ptr                            Pointer to the destination bias tensor. Supported data types: same as @p input_ptr
- * @param[in]  fused_b_stride_x                       Stride of the bias source tensor in X dimension (in bytes)
- * @param[in]  fused_b_step_x                         fused_b_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  fused_b_offset_first_element_in_bytes  The offset of the first element in the destination tensor
- * @param[in]  conv_b_ptr                             Pointer to the source bias tensor. Supported data types: same as @p input_ptr
- * @param[in]  conv_b_stride_x                        Stride of the beta source tensor in X dimension (in bytes)
- * @param[in]  conv_b_step_x                          conv_b_beta_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  conv_b_offset_first_element_in_bytes   The offset of the first element in the source bias tensor
- * @param[in]  bn_beta_ptr                            Pointer to the beta source tensor. Supported data types: same as @p input_ptr
- * @param[in]  bn_beta_stride_x                       Stride of the beta source tensor in X dimension (in bytes)
- * @param[in]  bn_beta_step_x                         bn_beta_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  bn_beta_offset_first_element_in_bytes  The offset of the first element in the beta source tensor
- * @param[in]  bn_gamma_ptr                           Pointer to the gamma source tensor. Supported data types: same as @p input_ptr
- * @param[in]  bn_gamma_stride_x                      Stride of the gamma source tensor in X dimension (in bytes)
- * @param[in]  bn_gamma_step_x                        bn_gamma_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  bn_gamma_offset_first_element_in_bytes The offset of the first element in the gamma source tensor
- * @param[in]  epsilon                                Epsilon parameter in the batch normalization equation
+ * @param[in]  w_ptr                                 Pointer to the weights tensor. Supported data types: F16/F32
+ * @param[in]  w_stride_x                            Stride of the weights tensor in X dimension (in bytes)
+ * @param[in]  w_step_x                              w_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  w_stride_y                            Stride of the weights tensor in Y dimension (in bytes)
+ * @param[in]  w_step_y                              w_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  w_stride_z                            Stride of the weights tensor in Z dimension (in bytes)
+ * @param[in]  w_step_z                              w_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  w_offset_first_element_in_bytes       The offset of the first element in the weights tensor
+ * @param[in]  b_ptr                                 (Optional) Pointer to the bias tensor. Supported data types: same as @p w_ptr
+ * @param[in]  b_stride_x                            (Optional) Stride of the bias tensor in X dimension (in bytes)
+ * @param[in]  b_step_x                              (Optional) b_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  b_stride_y                            (Optional) Stride of the bias tensor in Y dimension (in bytes)
+ * @param[in]  b_step_y                              (Optional) b_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  b_stride_z                            (Optional) Stride of the bias tensor in Z dimension (in bytes)
+ * @param[in]  b_step_z                              (Optional) b_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  b_offset_first_element_in_bytes       (Optional) The offset of the first element in the bias tensor
+ * @param[in]  mean_ptr                              Pointer to the mean source tensor. Supported data types: same as @p w_ptr
+ * @param[in]  mean_stride_x                         Stride of the mean source tensor in X dimension (in bytes)
+ * @param[in]  mean_step_x                           mean_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  mean_offset_first_element_in_bytes    The offset of the first element in the mean source tensor
+ * @param[in]  var_ptr                               Pointer to the var tensor. Supported data types: same as @p w_ptr
+ * @param[in]  var_stride_x                          Stride of the var tensor in X dimension (in bytes)
+ * @param[in]  var_step_x                            var_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  var_offset_first_element_in_bytes     The offset of the first element in the var source tensor
+ * @param[out] w_fused_ptr                           (Optional) Pointer to the destination weights tensors. Supported data types: same as @p w_ptr
+ * @param[in]  w_fused_stride_x                      (Optional) Stride of the destination weights tensor in X dimension (in bytes)
+ * @param[in]  w_fused_step_x                        (Optional) w_fused_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  w_fused_stride_y                      (Optional) Stride of the destination weights tensor in Y dimension (in bytes)
+ * @param[in]  w_fused_step_y                        (Optional) w_fused_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  w_fused_stride_z                      (Optional) Stride of the destination weights tensor in Z dimension (in bytes)
+ * @param[in]  w_fused_step_z                        (Optional) w_fused_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  w_fused_offset_first_element_in_bytes (Optional) The offset of the first element in the destination weights tensor
+ * @param[in]  b_fused_ptr                           (Optional) Pointer to the destination bias tensor. Supported data types: same as @p w_ptr
+ * @param[in]  b_fused_stride_x                      (Optional) Stride of the destination bias tensor in X dimension (in bytes)
+ * @param[in]  b_fused_step_x                        (Optional) b_fused_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  b_fused_offset_first_element_in_bytes (Optional) The offset of the first element in the destination bias tensor
+ * @param[in]  beta_ptr                              (Optional) Pointer to the beta source tensor. Supported data types: same as @p w_ptr
+ * @param[in]  beta_stride_x                         (Optional) Stride of the beta source tensor in X dimension (in bytes)
+ * @param[in]  beta_step_x                           (Optional) beta_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  beta_offset_first_element_in_bytes    (Optional) The offset of the first element in the beta source tensor
+ * @param[in]  gamma_ptr                             (Optional) Pointer to the gamma source tensor. Supported data types: same as @p w_ptr
+ * @param[in]  gamma_stride_x                        (Optional) Stride of the gamma source tensor in X dimension (in bytes)
+ * @param[in]  gamma_step_x                          (Optional) gamma_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  gamma_offset_first_element_in_bytes   (Optional) The offset of the first element in the gamma source tensor
  */
-__kernel void fuse_batchnormalization_layer(TENSOR4D_DECLARATION(conv_w),
-                                            VECTOR_DECLARATION(bn_mean),
-                                            VECTOR_DECLARATION(bn_var)
+__kernel void fuse_batchnormalization_conv_layer(TENSOR3D_DECLARATION(w),
+#if defined(BIAS)
+                                                 VECTOR_DECLARATION(b),
+#endif // defined(BIAS)
+                                                 VECTOR_DECLARATION(mean),
+                                                 VECTOR_DECLARATION(var)
 #ifndef IN_PLACE_W
-                                            ,
-                                            TENSOR4D_DECLARATION(fused_w)
-#endif /* not IN_PLACE_W */
+                                                 ,
+                                                 TENSOR3D_DECLARATION(w_fused)
+#endif // ifndef IN_PLACE_W
 #ifndef IN_PLACE_B
-                                            ,
-                                            VECTOR_DECLARATION(fused_b)
-#endif /* not IN_PLACE_B */
-#ifdef HAS_BIAS
-                                            ,
-                                            VECTOR_DECLARATION(conv_b)
-#endif /* HAS_BIAS */
-#ifndef USE_DEFAULT_BETA
-                                            ,
-                                            VECTOR_DECLARATION(bn_beta)
-#endif /* USE_DEFAULT_BETA */
-#ifndef USE_DEFAULT_GAMMA
-                                            ,
-                                            VECTOR_DECLARATION(bn_gamma)
-#endif /* USE_DEFAULT_GAMMA */
-                                           )
+                                                 ,
+                                                 VECTOR_DECLARATION(b_fused)
+#endif // ifndef IN_PLACE_B
+#if defined(BETA)
+                                                 ,
+                                                 VECTOR_DECLARATION(beta)
+#endif // defined(BETA)
+#if defined(GAMMA)
+                                                 ,
+                                                 VECTOR_DECLARATION(gamma)
+#endif // defined(GAMMA)
+                                                )
 {
-    Tensor4D conv_w  = CONVERT_TO_TENSOR4D_STRUCT(conv_w, NUM_CHANNELS);
-    Vector   bn_mean = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_mean);
-    Vector   bn_var  = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_var);
+    int x  = get_global_id(0);
+    int y  = get_global_id(1);
+    int z  = get_global_id(2);
+    int c0 = z % DIM2;
+    int c1 = z / DIM2;
 
-    // Conditional ops
-#ifdef HAS_BIAS
-    Vector conv_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(conv_b);
-#endif /* HAS_BIAS */
-#ifndef USE_DEFAULT_BETA
-    Vector bn_beta = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_beta);
-#endif /* USE_DEFAULT_BETA */
-#ifndef USE_DEFAULT_GAMMA
-    Vector bn_gamma = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_gamma);
-#endif /* USE_DEFAULT_GAMMA */
+    int w_offset = x * sizeof(DATA_TYPE) + y * w_stride_y + z * w_stride_z;
+    int v_offset = c1 * sizeof(DATA_TYPE);
 
-    // In-place ops
-#ifdef IN_PLACE_W
-    Tensor4D fused_w          = conv_w;
-    uint     fused_w_stride_x = conv_w_stride_x;
-#else  /* IN_PLACE_W */
-    Tensor4D  fused_w                      = CONVERT_TO_TENSOR4D_STRUCT(fused_w, NUM_CHANNELS);
-#endif /* IN_PLACE_W */
-#ifdef IN_PLACE_B
-    Vector fused_b = conv_b;
-#else  /* IN_PLACE_B */
-    Vector    fused_b                      = CONVERT_TO_VECTOR_STRUCT_NO_STEP(fused_b);
-#endif /* IN_PLACE_B */
+    DATA_TYPE w_old = 0.0f;
+    DATA_TYPE b_old = 0.0f;
+    DATA_TYPE w_new = 0.0f;
+    DATA_TYPE b_new = 0.0f;
+    DATA_TYPE gamma = 1.0f;
+    DATA_TYPE mean  = 0.0f;
+    DATA_TYPE var   = 1.0f;
+    DATA_TYPE beta  = 0.0f;
 
-    const int current_slice = get_global_id(2) / NUM_CHANNELS;
+    w_old = *((__global DATA_TYPE *)(w_ptr + w_offset + w_offset_first_element_in_bytes));
+    var   = *((__global DATA_TYPE *)(var_ptr + v_offset + var_offset_first_element_in_bytes));
+    mean  = *((__global DATA_TYPE *)(mean_ptr + v_offset + mean_offset_first_element_in_bytes));
 
-#if defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
-    // Check if access on width gets out of bounds
-    // If it does shift access vector to access elements within bounds
-    const int xi = (int)(get_global_id(0) * VEC_SIZE);
-    conv_w.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * conv_w_stride_x;
-    fused_w.ptr -= max(xi - (int)LAST_ACCESSED_X, 0) * fused_w_stride_x;
+#if defined(GAMMA)
+    gamma = *((__global DATA_TYPE *)(gamma_ptr + v_offset + gamma_offset_first_element_in_bytes));
+#endif // defined(GAMMA)
 
-    // Load W
-    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
-    wn = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)conv_w.ptr);
-#else  // !defined(VEC_SIZE) || !defined(LAST_ACCESSED_X)
-    DATA_TYPE wn                           = *((__global DATA_TYPE *)(conv_w.ptr));
-#endif // defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
+    // Compute new weight
+    w_new = (gamma * w_old) / (sqrt(var + EPSILON));
 
-    // rvar = 1 / sqrt(var + epsilon)
-    const DATA_TYPE var  = *((__global DATA_TYPE *)(bn_var.ptr + current_slice * bn_var.stride_x));
-    const DATA_TYPE rvar = INVSQRT_OP(ADD_OP(var, SQCVT_SAT((float)EPSILON)));
-    wn *= rvar;
+#if defined(IN_PLACE_W)
+    *((__global DATA_TYPE *)(w_ptr + w_offset + w_offset_first_element_in_bytes)) = w_new;
+#else  // defined(IN_PLACE_W)
+    *((__global DATA_TYPE *)(w_fused_ptr + w_offset + w_fused_offset_first_element_in_bytes))     = w_new;
+#endif // defined(IN_PLACE_W)
 
-    // Load b
-    const DATA_TYPE mean = *((__global DATA_TYPE *)(bn_mean.ptr + current_slice * bn_mean.stride_x));
-    DATA_TYPE bn         = 0;
-#ifdef HAS_BIAS
-    bn = *((__global DATA_TYPE *)(conv_b.ptr + current_slice * conv_b.stride_x));
-#endif /* HAS_BIAS */
-    bn = (bn - mean) * rvar;
+    // Compute bias
+    if(x == 0 && y == 0 && c0 == 0)
+    {
+#if defined(BIAS)
+        b_old = *((__global DATA_TYPE *)(b_ptr + v_offset + b_offset_first_element_in_bytes));
+#endif // defined(BIAS)
+#if defined(BETA)
+        beta = *((__global DATA_TYPE *)(beta_ptr + v_offset + beta_offset_first_element_in_bytes));
+#endif // defined(BETA)
 
-#ifndef USE_DEFAULT_GAMMA
-    const DATA_TYPE gamma_scalar = *((__global DATA_TYPE *)(bn_gamma.ptr + current_slice * bn_gamma.stride_x));
-    wn *= gamma_scalar;
-    bn *= gamma_scalar;
-#endif /* USE_DEFAULT_GAMMA */
+        b_new = ((gamma * (b_old - mean)) / (sqrt(var + EPSILON))) + beta;
 
-#ifndef USE_DEFAULT_BETA
-    const DATA_TYPE beta_scalar = *((__global DATA_TYPE *)(bn_beta.ptr + current_slice * bn_beta.stride_x));
-    bn += beta_scalar;
-#endif /* USE_DEFAULT_BETA */
+#if defined(BIAS)
 
-#if defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
-    // Store updated weights
-    VSTORE(VEC_SIZE)
-    (wn, 0, (__global DATA_TYPE *)fused_w.ptr);
-#else  // !defined(VEC_SIZE) || !defined(LAST_ACCESSED_X)
-    *((__global DATA_TYPE *)(fused_w.ptr)) = wn;
-#endif // defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
+#if defined(IN_PLACE_B)
+        *((__global DATA_TYPE *)(b_ptr + v_offset + b_offset_first_element_in_bytes)) = b_new;
+#else  // defined(IN_PLACE_B)
+        *((__global DATA_TYPE *)(b_fused_ptr + v_offset + b_fused_offset_first_element_in_bytes)) = b_new;
+#endif // defined(IN_PLACE_B)
 
-    // Store updated bias
-    *((__global DATA_TYPE *)(fused_b.ptr + current_slice * fused_b.stride_x)) = bn;
+#else // defined(BIAS)
+
+#ifndef IN_PLACE_B
+        *((__global DATA_TYPE *)(b_fused_ptr + v_offset + b_fused_offset_first_element_in_bytes)) = b_new;
+#endif // ifndef IN_PLACE_B
+
+#endif // defined(BIAS)
+    }
 }
-#endif /* defined(NUM_CHANNELS) && defined(DATA_TYPE) && defined(EPSILON) */
+#endif // defined(DIM2) && defined(DATA_TYPE) && defined(EPSILON)
\ No newline at end of file
diff --git a/src/core/CL/kernels/CLFuseBatchNormalizationKernel.cpp b/src/core/CL/kernels/CLFuseBatchNormalizationKernel.cpp
index 150d9b6..16ad7d9 100644
--- a/src/core/CL/kernels/CLFuseBatchNormalizationKernel.cpp
+++ b/src/core/CL/kernels/CLFuseBatchNormalizationKernel.cpp
@@ -48,9 +48,9 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(conv_weights, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_var);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(conv_weights, bn_mean, bn_var);
-
-    unsigned int kernels_idx = get_data_layout_dimension_index(conv_weights->data_layout(), DataLayoutDimension::BATCHES);
-    ARM_COMPUTE_RETURN_ERROR_ON(conv_weights->dimension(kernels_idx) != bn_mean->dimension(0));
+    ARM_COMPUTE_RETURN_ERROR_ON(conv_bias == nullptr && fused_bias == nullptr);
+    ARM_COMPUTE_RETURN_ERROR_ON(conv_weights->dimension(3) != bn_mean->dimension(0));
+    ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
 
     // Validate bias
     if(conv_bias != nullptr)
@@ -70,7 +70,6 @@
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mean, bn_gamma);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(conv_weights, bn_gamma);
     }
-
     // Validate output weights
     if(fused_weights != nullptr && fused_weights->total_size() != 0)
     {
@@ -113,20 +112,18 @@
     _epsilon       = epsilon;
 
     _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == conv_weights);
-    _run_in_place_bias    = (fused_bias == nullptr) || (conv_bias != nullptr && fused_bias == conv_bias);
+    _run_in_place_bias    = (conv_bias != nullptr && fused_bias == nullptr) || (conv_bias != nullptr && fused_bias == conv_bias);
 
     // Auto initialize outputs
     if(_fused_weights != nullptr)
     {
         // Output tensor auto initialization if not yet initialized
         auto_init_if_empty(*_fused_weights->info(), *_conv_weights->info()->clone());
-        fused_weights->info()->set_valid_region(conv_weights->info()->valid_region());
     }
     if(_fused_bias != nullptr)
     {
         // Output tensor auto initialization if not yet initialized
         auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
-        _fused_bias->info()->set_valid_region(bn_mean->info()->valid_region());
     }
 
     // Validate arguments
@@ -139,35 +136,22 @@
                                                   epsilon));
 
     // Configure kernel window
-    const unsigned int num_elems_processed_per_iteration_x = 4;
-    const int          output_width_x                      = conv_weights->info()->tensor_shape().x();
-    const bool         multi_access_x                      = (output_width_x / num_elems_processed_per_iteration_x > 0);
-
     Window win = calculate_max_window(*conv_weights->info());
-    if(multi_access_x)
-    {
-        win.set(Window::DimX, Window::Dimension(win.x().start(),
-                                                ceil_to_multiple(win.x().end(), num_elems_processed_per_iteration_x),
-                                                num_elems_processed_per_iteration_x));
-    }
     ICLKernel::configure_internal(win);
 
     // Set build options
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(conv_weights->info()->data_type()));
-    build_opts.add_option("-DSELECT_DATA_TYPE=" + get_cl_select_type_from_data_type(conv_weights->info()->data_type()));
-    build_opts.add_option("-DNUM_CHANNELS=" + support::cpp11::to_string(conv_weights->info()->dimension(2)));
+    build_opts.add_option("-DDIM2=" + support::cpp11::to_string(conv_weights->info()->dimension(2)));
     build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon));
-    build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration_x));
-    build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - num_elems_processed_per_iteration_x, 0)));
     build_opts.add_option_if(_run_in_place_weights, "-DIN_PLACE_W");
     build_opts.add_option_if(_run_in_place_bias, "-DIN_PLACE_B");
-    build_opts.add_option_if(conv_bias != nullptr, "-DHAS_BIAS");
-    build_opts.add_option_if(bn_beta == nullptr, "-DUSE_DEFAULT_BETA");
-    build_opts.add_option_if(bn_gamma == nullptr, "-DUSE_DEFAULT_GAMMA");
+    build_opts.add_option_if(conv_bias != nullptr, "-DBIAS");
+    build_opts.add_option_if(bn_beta != nullptr, "-DBETA");
+    build_opts.add_option_if(bn_gamma != nullptr, "-DGAMMA");
 
     // Create kernel
-    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("fuse_batchnormalization_layer", build_opts.options()));
+    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("fuse_batchnormalization_conv_layer", build_opts.options()));
 }
 
 Status CLFuseBatchNormalizationKernel::validate(const ITensorInfo *conv_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
@@ -185,37 +169,35 @@
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
 
     // Create window slice
-    Window collapsed_window = window.collapse_if_possible(window, Window::DimZ);
-    Window slice            = collapsed_window.first_slice_window_4D();
-
-    Window vector_slice = window.first_slice_window_1D();
-    vector_slice.set(Window::DimX, Window::Dimension(0, 0, 0));
+    Window collapsed_window = window.collapse(window, Window::DimZ);
+    Window slice_1d         = window.first_slice_window_1D();
+    Window slice_3d         = collapsed_window.first_slice_window_3D();
 
     // Add kernel arguments
     unsigned int idx = 0;
-    add_4D_tensor_argument(idx, _conv_weights, slice);
-    add_1D_tensor_argument(idx, _bn_mean, vector_slice);
-    add_1D_tensor_argument(idx, _bn_var, vector_slice);
+    add_3D_tensor_argument(idx, _conv_weights, slice_3d);
+    if(_conv_bias != nullptr)
+    {
+        add_1D_tensor_argument(idx, _conv_bias, slice_1d);
+    }
+    add_1D_tensor_argument(idx, _bn_mean, slice_1d);
+    add_1D_tensor_argument(idx, _bn_var, slice_1d);
     if(!_run_in_place_weights)
     {
-        add_4D_tensor_argument(idx, _fused_weights, slice);
+        add_3D_tensor_argument(idx, _fused_weights, slice_3d);
     }
     if(!_run_in_place_bias)
     {
-        add_1D_tensor_argument(idx, _fused_bias, vector_slice);
-    }
-    if(_conv_bias != nullptr)
-    {
-        add_1D_tensor_argument(idx, _conv_bias, vector_slice);
+        add_1D_tensor_argument(idx, _fused_bias, slice_1d);
     }
     if(_bn_beta != nullptr)
     {
-        add_1D_tensor_argument(idx, _bn_beta, vector_slice);
+        add_1D_tensor_argument(idx, _bn_beta, slice_1d);
     }
     if(_bn_gamma != nullptr)
     {
-        add_1D_tensor_argument(idx, _bn_gamma, vector_slice);
+        add_1D_tensor_argument(idx, _bn_gamma, slice_1d);
     }
-    enqueue(queue, *this, slice);
+    enqueue(queue, *this, slice_3d);
 }
 } // namespace arm_compute