Remove OpenCL padding: CLPixelWiseMultiplicationKernel

- Change kernel's vec_size to 16 / sizeof(output)
- Change ICLKernel.cpp to handle broadcast without padding

Resolve COMPMID-3913

Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: I03e884b250ef5784dc109bff8cf2c96b345d119f
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5450
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/CL/ICLKernel.cpp b/src/core/CL/ICLKernel.cpp
index 1c6963f..9ba17d0 100644
--- a/src/core/CL/ICLKernel.cpp
+++ b/src/core/CL/ICLKernel.cpp
@@ -105,8 +105,8 @@
 
     for(unsigned int d = 0; d < dimension_size; ++d)
     {
-        _kernel.setArg<cl_uint>(idx++, strides[d]);
-        _kernel.setArg<cl_uint>(idx++, strides[d] * window[d].step());
+        _kernel.setArg<cl_uint>(idx++, window.is_broadcasted(d) ? 0 : strides[d]);
+        _kernel.setArg<cl_uint>(idx++, window.is_broadcasted(d) ? 0 : (strides[d] * window[d].step()));
     }
 
     _kernel.setArg<cl_uint>(idx++, offset_first_element);
diff --git a/src/core/CL/cl_kernels/pixelwise_mul_float.cl b/src/core/CL/cl_kernels/pixelwise_mul_float.cl
index 845e1c9..0016775 100644
--- a/src/core/CL/cl_kernels/pixelwise_mul_float.cl
+++ b/src/core/CL/cl_kernels/pixelwise_mul_float.cl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,6 +36,10 @@
 #include "activation_float_helpers.h"
 #endif // defined(ACTIVATION_TYPE)
 
+#define VEC_ACC_TYPE VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE_OUT)
+#define VEC_OUT_TYPE VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
+#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE_OUT)
+
 /** Performs a pixelwise multiplication with float scale of either integer or float inputs.
  *
  * @attention The inputs and output data types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
@@ -77,31 +81,30 @@
     const float scale)
 {
     // Get pixels pointer
-    Tensor3D in1 = CONVERT_TO_TENSOR3D_STRUCT(in1);
-    Tensor3D in2 = CONVERT_TO_TENSOR3D_STRUCT(in2);
-    Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+    size_t x = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0);
+    size_t y = get_global_id(1);
+    size_t z = get_global_id(2);
+
+    __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + x * in1_stride_x + y * in1_stride_y + z * in1_stride_z;
+    __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + x * in2_stride_x + y * in2_stride_y + z * in2_stride_z;
+    __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + x * out_stride_x + y * out_stride_y + z * out_stride_z;
 
     // Load data
-    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
-    in1_data = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
-    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
-    in2_data = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
+    VEC_ACC_TYPE in1_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN1, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_IN1 *)in1_addr)), VEC_ACC_TYPE);
+    VEC_ACC_TYPE in2_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN2, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_IN2 *)in2_addr)), VEC_ACC_TYPE);
 
     // Perform multiplication
 #ifdef DATA_TYPE_FLOAT
-    VEC_DATA_TYPE(DATA_TYPE_OUT, 16)
-    res = CONVERT(in1_data * in2_data * (ACC_DATA_TYPE)scale, VEC_DATA_TYPE(DATA_TYPE_OUT, 16));
+    VEC_OUT_TYPE res0 = CONVERT(in1_data * in2_data * (ACC_DATA_TYPE)scale, VEC_OUT_TYPE);
 #else  /* DATA_TYPE_FLOAT */
-    VEC_DATA_TYPE(DATA_TYPE_OUT, 16)
-    res = CONVERT_OP_FLOAT(CONVERT_OP_FLOAT((convert_float16(in1_data * in2_data) * scale), VEC_DATA_TYPE(ACC_DATA_TYPE, 16), ROUND), VEC_DATA_TYPE(DATA_TYPE_OUT, 16), ROUND);
+    VEC_OUT_TYPE res0 = CONVERT_OP_FLOAT(CONVERT_OP_FLOAT((CONVERT(in1_data * in2_data, VEC_FLOAT) * scale), VEC_ACC_TYPE, ROUND), VEC_OUT_TYPE, ROUND);
 #endif /* DATA_TYPE_FLOAT */
 
 #if defined(ACTIVATION_TYPE)
-    vstore16(ACTIVATION(ACTIVATION_TYPE, DATA_TYPE_OUT, VEC_SIZE, res, A_VAL, B_VAL), 0, (__global DATA_TYPE_OUT *)out.ptr);
-#else  // defined(ACTIVATION_TYPE)
-    // Store result
-    vstore16(res, 0, (__global DATA_TYPE_OUT *)out.ptr);
+    res0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE_OUT, VEC_SIZE_OUT, res0, A_VAL, B_VAL);
 #endif // defined(ACTIVATION_TYPE)
+
+    STORE_VECTOR_SELECT(res, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
 }
 #endif /* defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT) */
 
@@ -155,7 +158,7 @@
     res = { vin1.x *vin2.x - vin1.y * vin2.y, vin1.x *vin2.y + vin2.x * vin1.y };
 
 #if defined(ACTIVATION_TYPE)
-    vstore2(ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, res, A_VAL, B_VAL), 0, (__global DATA_TYPE *)out.ptr);
+    vstore2(ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE_OUT, res, A_VAL, B_VAL), 0, (__global DATA_TYPE *)out.ptr);
 #else  // defined(ACTIVATION_TYPE)
     // Store result
     vstore2(res, 0, (__global DATA_TYPE *)out.ptr);
diff --git a/src/core/CL/cl_kernels/pixelwise_mul_int.cl b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
index b0bd338..92a7e6f 100644
--- a/src/core/CL/cl_kernels/pixelwise_mul_int.cl
+++ b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -36,6 +36,10 @@
 #define CONVERT_DOWN(x, type) CONVERT_RTE(x, type)
 
 #if defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT)
+
+#define VEC_ACC_TYPE VEC_DATA_TYPE(ACC_DATA_TYPE, VEC_SIZE_OUT)
+#define VEC_OUT_TYPE VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
+
 /** Performs a pixelwise multiplication with integer scale of integer inputs.
  *
  * @attention The inputs and output data types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
@@ -75,27 +79,29 @@
     TENSOR3D_DECLARATION(out),
     const uint scale)
 {
-    // Get pixels pointer
-    Tensor3D in1 = CONVERT_TO_TENSOR3D_STRUCT(in1);
-    Tensor3D in2 = CONVERT_TO_TENSOR3D_STRUCT(in2);
-    Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+    size_t x = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0);
+    size_t y = get_global_id(1);
+    size_t z = get_global_id(2);
+
+    __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + x * in1_stride_x + y * in1_stride_y + z * in1_stride_z;
+    __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + x * in2_stride_x + y * in2_stride_y + z * in2_stride_z;
+    __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + x * out_stride_x + y * out_stride_y + z * out_stride_z;
 
     // Load data
-    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
-    in1_data = CONVERT(vload16(0, (__global DATA_TYPE_IN1 *)in1.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
-    VEC_DATA_TYPE(ACC_DATA_TYPE, 16)
-    in2_data = CONVERT(vload16(0, (__global DATA_TYPE_IN2 *)in2.ptr), VEC_DATA_TYPE(ACC_DATA_TYPE, 16));
+    VEC_ACC_TYPE in1_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN1, VEC_SIZE_OUT))VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_IN1 *)in1_addr), VEC_ACC_TYPE);
+    VEC_ACC_TYPE in2_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN2, VEC_SIZE_OUT))VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_IN2 *)in2_addr), VEC_ACC_TYPE);
 
     // Perform multiplication and store result
-    vstore16(MUL_OP(in1_data, in2_data, scale, DATA_TYPE_OUT, 16), 0, (__global DATA_TYPE_OUT *)out.ptr);
+    VEC_OUT_TYPE out_data0 = MUL_OP(in1_data, in2_data, scale, DATA_TYPE_OUT, VEC_SIZE_OUT);
+    STORE_VECTOR_SELECT(out_data, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
 }
 #endif /* defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(ACC_DATA_TYPE) && defined(DATA_TYPE_OUT) */
 
-#if defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) && defined(VEC_SIZE)
+#if defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) && defined(VEC_SIZE_OUT)
 
-#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE)
-#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE)
-#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE)
+#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE_OUT)
+#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE_OUT)
+#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
 
 /** Performs a pixelwise multiplication with float scale of quantized inputs.
  *
@@ -141,14 +147,17 @@
     TENSOR3D_DECLARATION(out),
     const float scale)
 {
-    // Get pixels pointer
-    Tensor3D in1 = CONVERT_TO_TENSOR3D_STRUCT(in1);
-    Tensor3D in2 = CONVERT_TO_TENSOR3D_STRUCT(in2);
-    Tensor3D out = CONVERT_TO_TENSOR3D_STRUCT(out);
+    size_t x = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0);
+    size_t y = get_global_id(1);
+    size_t z = get_global_id(2);
+
+    __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + x * in1_stride_x + y * in1_stride_y + z * in1_stride_z;
+    __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + x * in2_stride_x + y * in2_stride_y + z * in2_stride_z;
+    __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + x * out_stride_x + y * out_stride_y + z * out_stride_z;
 
     // Load data
-    VEC_INT in_a = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE_OUT *)in1.ptr), VEC_INT);
-    VEC_INT in_b = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE_OUT *)in2.ptr), VEC_INT);
+    VEC_INT in_a = CONVERT((VEC_TYPE)(VLOAD(VEC_SIZE_OUT)(0, (__global DATA_TYPE_OUT *)in1_addr)), VEC_INT);
+    VEC_INT in_b = CONVERT((VEC_TYPE)(VLOAD(VEC_SIZE_OUT)(0, (__global DATA_TYPE_OUT *)in2_addr)), VEC_INT);
 
     // Dequantize
 #if defined(OFFSET_IN1)
@@ -165,10 +174,9 @@
 #else  // defined(OFFSET_OUT)
     const VEC_FLOAT qresf32 = (in1f32 * in2f32 * scale) / ((VEC_FLOAT)(float)SCALE_OUT);
 #endif // defined(OFFSET_OUT)
-    const VEC_TYPE res = CONVERT_SAT(CONVERT_DOWN(qresf32, VEC_INT), VEC_TYPE);
+    const VEC_TYPE res0 = CONVERT_SAT(CONVERT_DOWN(qresf32, VEC_INT), VEC_TYPE);
 
     // Store result
-    VSTORE(VEC_SIZE)
-    (res, 0, (__global DATA_TYPE_OUT *)out.ptr);
+    STORE_VECTOR_SELECT(res, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
 }
-#endif /* defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) && defined(VEC_SIZE) */
+#endif /* defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) && defined(VEC_SIZE_OUT) */
diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h
index 496f2dd..8b6d530 100644
--- a/src/core/CL/cl_kernels/tile_helpers.h
+++ b/src/core/CL/cl_kernels/tile_helpers.h
@@ -83,18 +83,6 @@
  */
 #define GET_SPATIAL_IDX(IDX, N0, PARTIAL_N0) (max((int)(get_global_id(IDX) * N0 - (N0 - PARTIAL_N0) % N0), 0))
 
-/** Offset (in bytes) calculation for a 1D BUFFER (cl_buffer) tensor */
-#define OFFSET1D(base, data_type, x) (base##_offset_first_element_in_bytes + x * sizeof(data_type))
-
-/** Offset (in bytes) calculation for a 2D BUFFER (cl_buffer) tensor */
-#define OFFSET2D(base, data_type, x, y) (base##_offset_first_element_in_bytes + x * sizeof(data_type) + y * base##_stride_y)
-
-/** Offset (in bytes) calculation for a 3D BUFFER (cl_buffer) tensor */
-#define OFFSET3D(base, data_type, x, y, z) (base##_offset_first_element_in_bytes + x * sizeof(data_type) + y * base##_stride_y + z * base##_stride_z)
-
-/** Offset (in bytes) calculation for a 4D BUFFER (cl_buffer) tensor */
-#define OFFSET4D(base, data_type, x, y, z, w) (base##_offset_first_element_in_bytes + x * sizeof(data_type) + y * base##_stride_y + z * base##_stride_z + w * base##_stride_w)
-
 /** Dot product integet 8bit function
  *
  *  @note Performs: c += dot(a, b)
@@ -184,7 +172,7 @@
         LOOP_UNROLLING(int, _i, 0, HEIGHT, 1) \
         { \
             dst[_i].v = V_LOAD(DATA_TYPE, WIDTH, TENSOR_TYPE, TENSOR, X, ((Y) + _i * (int)(YI_MULTIPLIER)), STRIDE_Y); \
-        }                                                                                                                 \
+        }                                                                                                                  \
     })
 
 /** Load a tile from global memory (tensor) using an indirect Y index tile
diff --git a/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.cpp b/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.cpp
index 56997dc..14e45b2 100644
--- a/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.cpp
@@ -42,8 +42,6 @@
 {
 namespace
 {
-constexpr unsigned int num_elems_processed_per_iteration = 16;
-
 Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale,
                           ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
 {
@@ -92,60 +90,6 @@
 
     return Status{};
 }
-
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst)
-{
-    const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
-
-    // Auto initialize dst if not initialized
-    {
-        set_shape_if_empty(*dst, out_shape);
-
-        if(src1->data_type() == DataType::S16 || src2->data_type() == DataType::S16)
-        {
-            set_format_if_unknown(*dst, Format::S16);
-        }
-        else if(src1->data_type() == DataType::F32 || src2->data_type() == DataType::F32)
-        {
-            set_format_if_unknown(*dst, Format::F32);
-        }
-        else if(src1->data_type() == DataType::QASYMM8)
-        {
-            set_data_type_if_unknown(*dst, DataType::QASYMM8);
-        }
-        else if(src1->data_type() == DataType::QASYMM8_SIGNED)
-        {
-            set_data_type_if_unknown(*dst, DataType::QASYMM8_SIGNED);
-        }
-        else if(src1->data_type() == DataType::QSYMM16)
-        {
-            set_data_type_if_unknown(*dst, DataType::QSYMM16);
-        }
-    }
-
-    Window win        = calculate_max_window(out_shape, Steps(num_elems_processed_per_iteration));
-    Window win_input1 = win.broadcast_if_dimension_le_one(*src1);
-    Window win_input2 = win.broadcast_if_dimension_le_one(*src2);
-
-    AccessWindowHorizontal input1_access(src1, 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal input2_access(src2, 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal output_access(dst, 0, num_elems_processed_per_iteration);
-
-    bool window_changed = update_window_and_padding(win_input1, input1_access)
-                          || update_window_and_padding(win_input2, input2_access)
-                          || update_window_and_padding(win, output_access);
-
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
-
-BorderSize calc_border_size(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst)
-{
-    const unsigned int replicateSize = dst->dimension(0) - std::min(src1->dimension(0), src2->dimension(0));
-    const unsigned int border        = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
-
-    return BorderSize{ 0, border, 0, 0 };
-}
 } // namespace
 
 void ClPixelWiseMultiplicationKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float scale,
@@ -155,12 +99,10 @@
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src1, src2, dst,
                                                   scale, overflow_policy, rounding_policy, act_info));
 
-    // Calculate border size
-    _border_size = calc_border_size(src1, src2, dst);
+    auto padding_info = get_padding_info({ src1, src2, dst });
 
-    // Configure kernel window
-    auto win_config = validate_and_configure_window(src1, src2, dst);
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
+    auto_init_if_empty(*dst, src1->clone()->set_tensor_shape(out_shape));
 
     int scale_int = -1;
     // Extract sign, exponent and mantissa
@@ -197,7 +139,9 @@
         }
     }
 
-    const bool is_quantized = is_data_type_quantized(src1->data_type());
+    const bool         is_quantized      = is_data_type_quantized(src1->data_type());
+    const unsigned int vec_size          = adjust_vec_size(16 / dst->element_size(), dst->dimension(0));
+    const unsigned int vec_size_leftover = dst->dimension(0) % vec_size;
 
     // Set kernel build options
     std::string    kernel_name = "pixelwise_mul";
@@ -205,7 +149,10 @@
     build_opts.add_option("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(src1->data_type()));
     build_opts.add_option("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(src2->data_type()));
     build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(dst->data_type()));
-    build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
+    build_opts.add_option("-DVEC_SIZE_IN1=" + ((dst->dimension(0) != 1 && src1->dimension(0) == 1) ? "1" : support::cpp11::to_string(vec_size)));
+    build_opts.add_option("-DVEC_SIZE_IN2=" + ((dst->dimension(0) != 1 && src2->dimension(0) == 1) ? "1" : support::cpp11::to_string(vec_size)));
+    build_opts.add_option("-DVEC_SIZE_OUT=" + support::cpp11::to_string(vec_size));
+    build_opts.add_option("-DVEC_SIZE_LEFTOVER=" + support::cpp11::to_string(vec_size_leftover));
     if(is_quantized && (dst->data_type() != DataType::S32))
     {
         const UniformQuantizationInfo iq1_info = src1->quantization_info().uniform();
@@ -252,7 +199,10 @@
         _kernel.setArg(idx++, scale);
     }
 
-    ICLKernel::configure_internal(win_config.second);
+    Window win = calculate_max_window(*dst, Steps(vec_size));
+    ICLKernel::configure_internal(win);
+
+    ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info));
 }
 
 Status ClPixelWiseMultiplicationKernel::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale,
@@ -260,7 +210,6 @@
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src1, src2, dst, scale, overflow_policy, rounding_policy, act_info));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src1->clone().get(), src2->clone().get(), dst->clone().get()).first);
 
     return Status{};
 }
@@ -312,14 +261,9 @@
     while(collapsed.slide_window_slice_3D(slice));
 }
 
-BorderSize ClPixelWiseMultiplicationKernel::border_size() const
-{
-    return _border_size;
-}
-
 namespace
 {
-constexpr unsigned int num_elems_processed_per_iteration_complex = 1;
+constexpr unsigned int vec_size_complex = 1;
 
 Status validate_arguments_complex(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const ActivationLayerInfo &act_info)
 {
@@ -342,30 +286,6 @@
 
     return Status{};
 }
-
-std::pair<Status, Window> validate_and_configure_window_complex(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst)
-{
-    const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
-
-    // Auto initialize dst if not initialized
-    const TensorInfo out_info(out_shape, src1->num_channels(), src1->data_type());
-    auto_init_if_empty(*dst, out_info);
-
-    Window win        = calculate_max_window(out_shape, Steps(num_elems_processed_per_iteration_complex));
-    Window win_input1 = win.broadcast_if_dimension_le_one(*src1);
-    Window win_input2 = win.broadcast_if_dimension_le_one(*src2);
-
-    AccessWindowHorizontal input1_access(src1, 0, num_elems_processed_per_iteration_complex);
-    AccessWindowHorizontal input2_access(src2, 0, num_elems_processed_per_iteration_complex);
-    AccessWindowHorizontal output_access(dst, 0, num_elems_processed_per_iteration_complex);
-
-    bool window_changed = update_window_and_padding(win_input1, input1_access)
-                          || update_window_and_padding(win_input2, input2_access)
-                          || update_window_and_padding(win, output_access);
-
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
 } // namespace
 
 void ClComplexPixelWiseMultiplicationKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const ActivationLayerInfo &act_info)
@@ -373,12 +293,10 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(src1, src2, dst, act_info));
 
-    // Calculate border size
-    _border_size = calc_border_size(src1, src2, dst);
+    auto padding_info = get_padding_info({ src1, src2, dst });
 
-    // Configure kernel window
-    auto win_config = validate_and_configure_window_complex(src1, src2, dst);
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
+    auto_init_if_empty(*dst, src1->clone()->set_tensor_shape(out_shape));
 
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
@@ -392,14 +310,16 @@
     // Create kernel
     _kernel = create_kernel(compile_context, "pixelwise_mul_complex", build_opts.options());
 
-    ICLKernel::configure_internal(win_config.second);
+    Window win = calculate_max_window(*dst, Steps(vec_size_complex));
+    ICLKernel::configure_internal(win);
+
+    ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info));
 }
 
 Status ClComplexPixelWiseMultiplicationKernel::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(src1, src2, dst, act_info));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_complex(src1->clone().get(), src2->clone().get(), dst->clone().get()).first);
 
     return Status{};
 }
@@ -450,11 +370,6 @@
     }
     while(collapsed.slide_window_slice_3D(slice));
 }
-
-BorderSize ClComplexPixelWiseMultiplicationKernel::border_size() const
-{
-    return _border_size;
-}
 } // namespace kernels
 } // namespace opencl
 } // namespace arm_compute
diff --git a/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.h b/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.h
index 5889b84..5b82726 100644
--- a/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.h
+++ b/src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.h
@@ -41,7 +41,7 @@
     /** Default constructor */
     ClPixelWiseMultiplicationKernel() = default;
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClPixelWiseMultiplicationKernel);
-    /** Initialise the kernel's src, dst and border mode.
+    /** Initialise the kernel's src and dst.
      *
      * Valid configurations (Input1,Input2) -> Output :
      *
@@ -101,10 +101,6 @@
 
     // Inherited methods overridden:
     void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
-    BorderSize border_size() const override;
-
-public:
-    BorderSize _border_size{};
 };
 
 /** Interface for the complex pixelwise multiplication kernel. */
@@ -114,7 +110,7 @@
     /** Default constructor */
     ClComplexPixelWiseMultiplicationKernel() = default;
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClComplexPixelWiseMultiplicationKernel);
-    /** Initialise the kernel's src, dst and border mode.
+    /** Initialise the kernel's src and dst.
      *
      * @param[in]  compile_context The compile context to be used.
      * @param[in]  src1            An src tensor info. Data types supported: F32. Number of channels supported: 2.
@@ -136,10 +132,6 @@
 
     // Inherited methods overridden:
     void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
-    BorderSize border_size() const override;
-
-public:
-    BorderSize _border_size{};
 };
 } // namespace kernels
 } // namespace opencl
diff --git a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
index 5ebaf5d..efebf2b 100644
--- a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
+++ b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
@@ -25,7 +25,7 @@
 
 #include "arm_compute/core/CL/ICLTensor.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
-#include "src/core/CL/kernels/CLFillBorderKernel.h"
+#include "src/core/CL/ICLKernel.h"
 #include "src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.h"
 
 #include <utility>
diff --git a/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.cpp b/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.cpp
index c4f11a4..137a0de 100644
--- a/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.cpp
+++ b/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.cpp
@@ -24,7 +24,6 @@
 #include "src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.h"
 
 #include "arm_compute/runtime/CL/CLScheduler.h"
-#include "src/core/CL/kernels/CLFillBorderKernel.h"
 #include "src/core/gpu/cl/ClCompileContext.h"
 #include "src/core/gpu/cl/kernels/ClPixelWiseMultiplicationKernel.h"
 
@@ -32,44 +31,12 @@
 {
 namespace opencl
 {
-namespace
-{
-ITensorPack select_border_input(ITensorPack &tensors)
-{
-    ITensorPack pack;
-    if(tensors.get_tensor(TensorType::ACL_DST)->info()->dimension(0) > 1)
-    {
-        if(tensors.get_const_tensor(TensorType::ACL_SRC_1)->info()->dimension(0) == 1)
-        {
-            pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1));
-        }
-        else
-        {
-            pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_0));
-        }
-    }
-    return pack;
-}
-} // namespace
-
 void ClPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float scale,
                                           ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info)
 {
     auto k = std::make_unique<kernels::ClPixelWiseMultiplicationKernel>();
     k->configure(compile_context, src1, src2, dst, scale, overflow_policy, rounding_policy, act_info);
     _kernel = std::move(k);
-
-    if(dst->dimension(0) > 1)
-    {
-        ITensorInfo *broadcasted_info = (src1->dimension(0) == 1) ? src1 : src2;
-
-        if(broadcasted_info->dimension(0) == 1)
-        {
-            auto b = std::make_unique<CLFillBorderKernel>();
-            b->configure(compile_context, broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE);
-            _border_handler = std::move(b);
-        }
-    }
 }
 
 Status ClPixelWiseMultiplication::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale,
@@ -78,48 +45,16 @@
     return kernels::ClPixelWiseMultiplicationKernel::validate(src1, src2, dst, scale, overflow_policy, rounding_policy, act_info);
 }
 
-void ClPixelWiseMultiplication::run(ITensorPack &tensors)
-{
-    if(_border_handler)
-    {
-        auto border_pack = select_border_input(tensors);
-        CLScheduler::get().enqueue_op(*_border_handler, border_pack);
-    }
-    ICLOperator::run(tensors);
-}
-
 void ClComplexPixelWiseMultiplication::configure(const CLCompileContext &compile_context, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const ActivationLayerInfo &act_info)
 {
     auto k = std::make_unique<kernels::ClComplexPixelWiseMultiplicationKernel>();
     k->configure(compile_context, src1, src2, dst, act_info);
     _kernel = std::move(k);
-
-    if(dst->dimension(0) > 1)
-    {
-        ITensorInfo *broadcasted_info = (src1->dimension(0) == 1) ? src1 : src2;
-
-        if(broadcasted_info->dimension(0) == 1)
-        {
-            auto b = std::make_unique<CLFillBorderKernel>();
-            b->configure(compile_context, broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE);
-            _border_handler = std::move(b);
-        }
-    }
 }
 
 Status ClComplexPixelWiseMultiplication::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const ActivationLayerInfo &act_info)
 {
     return kernels::ClComplexPixelWiseMultiplicationKernel::validate(src1, src2, dst, act_info);
 }
-
-void ClComplexPixelWiseMultiplication::run(ITensorPack &tensors)
-{
-    if(_border_handler)
-    {
-        auto border_pack = select_border_input(tensors);
-        CLScheduler::get().enqueue_op(*_border_handler, border_pack);
-    }
-    ICLOperator::run(tensors);
-}
 } // namespace opencl
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.h b/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.h
index e9b3e4a..e1598cb 100644
--- a/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.h
+++ b/src/runtime/gpu/cl/operators/ClPixelWiseMultiplication.h
@@ -99,12 +99,6 @@
      */
     static Status validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale,
                            ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info = ActivationLayerInfo());
-
-    // Inherited methods overridden:
-    void run(ITensorPack &tensors) override;
-
-private:
-    std::unique_ptr<ICLKernel> _border_handler{ nullptr };
 };
 
 /** Basic function to run @ref opencl::ClComplexPixelWiseMultiplication. */
@@ -132,12 +126,6 @@
      * @param[in] act_info (Optional) Activation layer information in case of a fused activation.
      */
     static Status validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo());
-
-    // Inherited methods overridden:
-    void run(ITensorPack &tensors) override;
-
-private:
-    std::unique_ptr<ICLKernel> _border_handler{ nullptr };
 };
 } // namespace opencl
 } // namespace arm_compute