Rework DepthwiseConvolution heuristic on OpenCL

Resolves COMPMID-5632

Change-Id: I2bdbe69a610ca2510fbd74d5d412842679299762
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8365
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index 94675d6..b318642 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -441,7 +441,7 @@
     ARM_COMPUTE_ERROR_ON(err != CL_SUCCESS);
 }
 
-bool export_weights_to_cl_image(const ITensorInfo *tensor)
+bool export_to_cl_image(const ITensorInfo *tensor)
 {
     if(tensor->tensor_shape()[0] % 4)
     {
diff --git a/src/core/CL/DefaultLWSHeuristics.cpp b/src/core/CL/DefaultLWSHeuristics.cpp
index c082d7f..c739b9d 100644
--- a/src/core/CL/DefaultLWSHeuristics.cpp
+++ b/src/core/CL/DefaultLWSHeuristics.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -68,6 +68,21 @@
         return cl::NDRange(8, 4, 1);
     }
 }
+
+cl::NDRange get_dwc_lws(size_t gws_x, size_t gws_y, size_t gws_z)
+{
+    ARM_COMPUTE_UNUSED(gws_y);
+    ARM_COMPUTE_UNUSED(gws_z);
+
+    if(gws_x < 32)
+    {
+        return cl::NDRange(gws_x, 4, 4);
+    }
+    else
+    {
+        return cl::NDRange(8, 4, 2);
+    }
+}
 } // namespace
 
 namespace arm_compute
@@ -92,6 +107,10 @@
         {
             return get_winograd_lws(gws_x, gws_y, gws_z);
         }
+        case CLKernelType::DEPTHWISE:
+        {
+            return get_dwc_lws(gws_x, gws_y, gws_z);
+        }
         default:
         {
             return CLKernelLibrary::get().default_ndrange();
diff --git a/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl b/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl
index 8b14b27..8a84587 100644
--- a/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl
+++ b/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl
@@ -145,7 +145,7 @@
         })
 
         // Load tile from the src tensor (TILE A)
-        T_LOAD_NHWC_WITH_DILATION(SRC_DATA_TYPE, 1, _IM0_A, _IN0_A, SRC_TENSOR_TYPE, src, bout, yi + yk * DILATION_Y, xi, (cout / DEPTH_MULTIPLIER), src_w, src_h, DILATION_X, 1, _IBOUNDARY_CHECK, a);
+        T_LOAD_NHWC_WITH_DILATION(SRC_DATA_TYPE, 1, _IM0_A, _IN0_A, SRC_TENSOR_TYPE, src, bout, yi + yk * DILATION_Y, xi, (cout / DEPTH_MULTIPLIER), SRC_WIDTH, SRC_HEIGHT, DILATION_X, 1, _IBOUNDARY_CHECK, a);
 
         TILE(WEI_DATA_TYPE, _IM0_B, _IN0_B, b);
 
@@ -185,7 +185,7 @@
     {
         LOOP_UNROLLING(int, m0, 0, 1, M0,
         {
-            int xi_out = min(xo + M0 - 1 - m0, (int)(dst_w) - 1);
+            int xi_out = min(xo + M0 - 1 - m0, (int)(DST_WIDTH) - 1);
             VSTORE_PARTIAL(N0, PARTIAL_N0)
             (c[M0 - 1 - m0].v, 0, (__global DST_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + cout * sizeof(DST_DATA_TYPE) + (uint)xi_out * dst_stride_y + (uint)yo * dst_stride_z + (uint)bout * dst_stride_w));
         })
@@ -194,7 +194,7 @@
     {
         LOOP_UNROLLING(int, m0, 0, 1, M0,
         {
-            int xi_out = min(xo + M0 - 1 - m0, (int)(dst_w) - 1);
+            int xi_out = min(xo + M0 - 1 - m0, (int)(DST_WIDTH) - 1);
             VSTORE(N0)
             (c[M0 - 1 - m0].v, 0, (__global DST_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + cout * sizeof(DST_DATA_TYPE) + (uint)xi_out * dst_stride_y + (uint)yo * dst_stride_z + (uint)bout * dst_stride_w));
         })
diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp
index 277cba4..cded319 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp
@@ -59,7 +59,8 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON(conv_info.pad_stride_info.stride().first > 1 && dwc_info.m0 != 1);
     ARM_COMPUTE_RETURN_ERROR_ON(conv_info.dilation.x() > 1 && dwc_info.m0 != 1);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((dwc_info.export_weights_to_cl_image == true) && (export_weights_to_cl_image(weights) == false), "Export to cl_image not supported!");
+    ARM_COMPUTE_RETURN_ERROR_ON((dwc_info.export_input_to_cl_image == true));
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((dwc_info.export_weights_to_cl_image == true) && (export_to_cl_image(weights) == false), "Weights cannot be exported to cl_image!");
     ARM_COMPUTE_RETURN_ERROR_ON((dwc_info.export_weights_to_cl_image == true) && ((dwc_info.n0 % 4) != 0));
     ARM_COMPUTE_RETURN_ERROR_ON(conv_info.pad_stride_info.stride().first < 1);
     ARM_COMPUTE_RETURN_ERROR_ON(conv_info.pad_stride_info.stride().second < 1);
@@ -161,7 +162,8 @@
       _depth_multiplier(1),
       _output_multipliers(nullptr),
       _output_shifts(nullptr),
-      _export_to_cl_image(false),
+      _export_input_to_cl_image(false),
+      _export_weights_to_cl_image(false),
       _is_quantized(false)
 {
     _type = CLKernelType::DEPTHWISE;
@@ -192,15 +194,16 @@
     const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_depthwise_convolution_shape(*(input->info()), *(weights->info()), conv_info);
     auto_init_if_empty(*(output->info()), input->info()->clone()->set_tensor_shape(output_shape).set_quantization_info(output->info()->quantization_info()));
 
-    _input              = input;
-    _output             = output;
-    _weights            = weights;
-    _biases             = biases;
-    _depth_multiplier   = conv_info.depth_multiplier;
-    _output_multipliers = output_multipliers;
-    _output_shifts      = output_shifts;
-    _export_to_cl_image = dwc_info.export_weights_to_cl_image;
-    _is_quantized       = is_data_type_quantized(input->info()->data_type());
+    _input                      = input;
+    _output                     = output;
+    _weights                    = weights;
+    _biases                     = biases;
+    _depth_multiplier           = conv_info.depth_multiplier;
+    _output_multipliers         = output_multipliers;
+    _output_shifts              = output_shifts;
+    _export_input_to_cl_image   = dwc_info.export_input_to_cl_image;
+    _export_weights_to_cl_image = dwc_info.export_weights_to_cl_image;
+    _is_quantized               = is_data_type_quantized(input->info()->data_type());
 
     const unsigned int n0          = adjust_vec_size(dwc_info.n0, output->info()->dimension(0));
     const unsigned int m0          = std::min(dwc_info.m0, (unsigned int)output->info()->dimension(1));
@@ -208,8 +211,13 @@
 
     CLBuildOptions build_opts;
 
-    // Update the padding for the weights tensor if we can export to cl_image
-    if(_export_to_cl_image)
+    // Update the padding for the input/weights tensor if we can export to cl_image
+    if(_export_input_to_cl_image)
+    {
+        arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(input->info());
+    }
+
+    if(_export_weights_to_cl_image)
     {
         arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(weights->info());
     }
@@ -234,14 +242,18 @@
 
     build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_function)));
     build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(conv_info.depth_multiplier));
-    build_opts.add_option("-DSRC_TENSOR_TYPE=BUFFER");
+    build_opts.add_option_if_else(_export_input_to_cl_image, "-DSRC_TENSOR_TYPE=IMAGE", "-DSRC_TENSOR_TYPE=BUFFER");
     // Note: SRC_DATA_TYPE must have the same data type of WEI_DATA_TYPE. In quantized, we could
     // have a case where the data types for the activation and weights are different. However, since the implementation
     // only works when both have same data type, we have to change the offset to take into account this aspect
     build_opts.add_option("-DSRC_DATA_TYPE=" + get_cl_type_from_data_type(_input->info()->data_type()));
     build_opts.add_option("-DDST_TENSOR_TYPE=BUFFER");
     build_opts.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(dst_data_type));
-    build_opts.add_option_if_else(_export_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER");
+    build_opts.add_option_if_else(_export_weights_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER");
+    build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(_input->info()->dimension(1)));
+    build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(_input->info()->dimension(2)));
+    build_opts.add_option("-DDST_WIDTH=" + support::cpp11::to_string(_output->info()->dimension(1)));
+    build_opts.add_option("-DDST_HEIGHT=" + support::cpp11::to_string(_output->info()->dimension(2)));
     build_opts.add_option("-DWEI_WIDTH=" + support::cpp11::to_string(_weights->info()->dimension(1)));
     build_opts.add_option("-DWEI_HEIGHT=" + support::cpp11::to_string(_weights->info()->dimension(2)));
     build_opts.add_option("-DWEI_DATA_TYPE=" + get_cl_type_from_data_type(_weights->info()->data_type()));
@@ -353,24 +365,39 @@
 
     Window slice = window_collapsed.first_slice_window_4D();
 
+    cl::Image2D input_cl_image;
     cl::Image2D weights_cl_image;
 
-    if(_export_to_cl_image)
+    if(_export_input_to_cl_image || _export_weights_to_cl_image)
     {
-        const size_t      image_w = _weights->info()->dimension(0) / 4;
-        const size_t      image_h = _weights->info()->dimension(1) * _weights->info()->dimension(2) * _weights->info()->dimension(3);
-        const TensorShape shape2d(image_w, image_h);
-        const size_t      image_row_pitch = _weights->info()->strides_in_bytes()[1];
-
         // Export cl_buffer to cl_image
-        weights_cl_image = create_image2d_from_buffer(CLKernelLibrary::get().context(), _weights->cl_buffer(), shape2d, _weights->info()->data_type(), image_row_pitch);
+        if(_export_input_to_cl_image)
+        {
+            const size_t      image_w = _input->info()->dimension(0) / 4;
+            const size_t      image_h = _input->info()->dimension(1) * _input->info()->dimension(2) * _input->info()->dimension(3);
+            const TensorShape shape2d(image_w, image_h);
+            const size_t      image_row_pitch = _input->info()->strides_in_bytes()[1];
+            input_cl_image                    = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input->cl_buffer(), shape2d, _input->info()->data_type(), image_row_pitch);
+        }
+
+        if(_export_weights_to_cl_image)
+        {
+            const size_t      image_w = _weights->info()->dimension(0) / 4;
+            const size_t      image_h = _weights->info()->dimension(1) * _weights->info()->dimension(2) * _weights->info()->dimension(3);
+            const TensorShape shape2d(image_w, image_h);
+            const size_t      image_row_pitch = _weights->info()->strides_in_bytes()[1];
+            weights_cl_image                  = create_image2d_from_buffer(CLKernelLibrary::get().context(), _weights->cl_buffer(), shape2d, _weights->info()->data_type(), image_row_pitch);
+        }
     }
 
     unsigned int idx = 0;
+    if(_export_input_to_cl_image)
+    {
+        _kernel.setArg(idx++, input_cl_image);
+    }
     add_4d_tensor_nhwc_argument(idx, _input);
     add_4d_tensor_nhwc_argument(idx, _output);
-
-    if(_export_to_cl_image)
+    if(_export_weights_to_cl_image)
     {
         _kernel.setArg(idx++, weights_cl_image);
     }
diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h
index eeed115..5352f68 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h
+++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -103,7 +103,8 @@
     unsigned int     _depth_multiplier{ 0 };
     const ICLTensor *_output_multipliers{};
     const ICLTensor *_output_shifts{};
-    bool             _export_to_cl_image { true };
+    bool             _export_input_to_cl_image{ false };
+    bool             _export_weights_to_cl_image{ true };
     bool             _is_quantized{ false };
 };
 } // namespace arm_compute
diff --git a/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp
index 722c802..fd14f00 100644
--- a/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp
+++ b/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp
@@ -94,7 +94,7 @@
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16,
                                             "K0 can only be: 4, 8, and 16");
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_weights_to_cl_image(weights),
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_to_cl_image(weights),
                                             "Export to CLImage is not supported for this weight configuration");
         }
     }
diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp
index 4ea1981..ba176f8 100644
--- a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp
+++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp
@@ -159,7 +159,7 @@
 
         desc.k0 = 8;
 
-        desc.export_weights_to_cl_image = export_weights_to_cl_image(wei);
+        desc.export_weights_to_cl_image = export_to_cl_image(wei);
     }
 
     return desc;
@@ -183,7 +183,7 @@
 
         desc.k0 = 8;
 
-        desc.export_weights_to_cl_image = export_weights_to_cl_image(wei);
+        desc.export_weights_to_cl_image = export_to_cl_image(wei);
     }
 
     return desc;
diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp
index d87cada..ad94678 100644
--- a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp
+++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp
@@ -77,15 +77,15 @@
     if(src->data_layout() == DataLayout::NHWC)
     {
         // Get the output shape
-        const TensorShape wei_shape          = wei->tensor_shape();
-        const TensorShape dst_shape          = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
-        const bool        export_to_cl_image = export_weights_to_cl_image(wei);
+        const TensorShape wei_shape                  = wei->tensor_shape();
+        const TensorShape dst_shape                  = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
+        const bool        export_weights_to_cl_image = export_to_cl_image(wei);
 
         const int32_t ofm          = dst_shape[0];
         const int32_t m            = dst_shape[1] * dst_shape[2];
         const bool    is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1;
 
-        desc.export_weights_to_cl_image = export_to_cl_image;
+        desc.export_weights_to_cl_image = export_weights_to_cl_image;
 
         if(dst_shape[0] <= 4)
         {
@@ -138,15 +138,15 @@
     if(src->data_layout() == DataLayout::NHWC)
     {
         // Get the output shape
-        const TensorShape wei_shape          = wei->tensor_shape();
-        const TensorShape dst_shape          = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
-        const bool        export_to_cl_image = export_weights_to_cl_image(wei);
+        const TensorShape wei_shape                  = wei->tensor_shape();
+        const TensorShape dst_shape                  = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
+        const bool        export_weights_to_cl_image = export_to_cl_image(wei);
 
         const int32_t ofm          = dst_shape[0];
         const int32_t m            = dst_shape[1] * dst_shape[2];
         const bool    is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1;
 
-        desc.export_weights_to_cl_image = export_to_cl_image;
+        desc.export_weights_to_cl_image = export_weights_to_cl_image;
 
         if(dst_shape[0] <= 4)
         {
@@ -232,14 +232,14 @@
     if(src->data_layout() == DataLayout::NHWC)
     {
         // Get the output shape
-        const TensorShape wei_shape          = wei->tensor_shape();
-        const TensorShape dst_shape          = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
-        const bool        export_to_cl_image = export_weights_to_cl_image(wei);
+        const TensorShape wei_shape                  = wei->tensor_shape();
+        const TensorShape dst_shape                  = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
+        const bool        export_weights_to_cl_image = export_to_cl_image(wei);
 
         const int32_t m            = dst_shape[1] * dst_shape[2];
         const bool    is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1;
 
-        desc.export_weights_to_cl_image = export_to_cl_image;
+        desc.export_weights_to_cl_image = export_weights_to_cl_image;
 
         if(dst_shape[0] <= 4)
         {
@@ -292,15 +292,15 @@
     if(src->data_layout() == DataLayout::NHWC)
     {
         // Get the output shape
-        const TensorShape wei_shape          = wei->tensor_shape();
-        const TensorShape dst_shape          = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
-        const bool        export_to_cl_image = export_weights_to_cl_image(wei);
+        const TensorShape wei_shape                  = wei->tensor_shape();
+        const TensorShape dst_shape                  = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info);
+        const bool        export_weights_to_cl_image = export_to_cl_image(wei);
 
         const int32_t ofm          = dst_shape[0];
         const int32_t m            = dst_shape[1] * dst_shape[2];
         const bool    is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1;
 
-        desc.export_weights_to_cl_image = export_to_cl_image;
+        desc.export_weights_to_cl_image = export_weights_to_cl_image;
 
         if(dst_shape[0] <= 4)
         {
diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
index 8546471..3eadaee 100644
--- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp
@@ -44,7 +44,7 @@
 {
 bool export_weights_to_cl_image_heuristic(const ITensorInfo *weights, unsigned int depth_multiplier, GPUTarget gpu_target)
 {
-    if(!export_weights_to_cl_image(weights))
+    if(!export_to_cl_image(weights))
     {
         return false;
     }
@@ -75,9 +75,12 @@
     return true;
 }
 
-void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation, unsigned int depth_multiplier,
+void initialize_dwc_native_compute_info(DWCComputeKernelInfo &dwc_compute_info, const ITensorInfo *input, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation,
+                                        unsigned int depth_multiplier,
                                         GPUTarget gpu_target)
 {
+    ARM_COMPUTE_UNUSED(input);
+
     if(!is_data_type_float(weights->data_type()))
     {
         dwc_compute_info.export_weights_to_cl_image = false;
@@ -97,6 +100,7 @@
     // Floating point path
 
     // First check if we can export to cl_image.
+    dwc_compute_info.export_input_to_cl_image   = false;
     dwc_compute_info.export_weights_to_cl_image = export_weights_to_cl_image_heuristic(weights, depth_multiplier, gpu_target);
 
     // Set n0
@@ -135,7 +139,28 @@
         const size_t idx_w    = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH);
         const size_t kernel_w = weights->tensor_shape()[idx_w];
 
-        dwc_compute_info.m0 = (kernel_w >= 9) || (kernel_w == 1) ? 1 : 2;
+        if((kernel_w >= 9) || (kernel_w == 1))
+        {
+            dwc_compute_info.m0 = 1;
+        }
+        else
+        {
+            if(weights->data_type() == DataType::F16)
+            {
+                if((input->dimension(1) % 5) == 0)
+                {
+                    dwc_compute_info.m0 = 5;
+                }
+                else
+                {
+                    dwc_compute_info.m0 = 4;
+                }
+            }
+            else
+            {
+                dwc_compute_info.m0 = 2;
+            }
+        }
     }
     else
     {
@@ -237,7 +262,7 @@
     }
 
     DWCComputeKernelInfo dwc_native_compute_info;
-    initialize_dwc_native_compute_info(dwc_native_compute_info, weights_to_use->info(), conv_info, dilation, depth_multiplier, gpu_target);
+    initialize_dwc_native_compute_info(dwc_native_compute_info, input->info(), weights_to_use->info(), conv_info, dilation, depth_multiplier, gpu_target);
 
     const ConvolutionInfo conv_kernel_info{ conv_info, depth_multiplier, act_info, dilation };
 
@@ -322,7 +347,7 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U)));
 
         DWCComputeKernelInfo dwc_native_compute_info;
-        initialize_dwc_native_compute_info(dwc_native_compute_info, &permuted_weights, conv_info, dilation, depth_multiplier, gpu_target);
+        initialize_dwc_native_compute_info(dwc_native_compute_info, input, &permuted_weights, conv_info, dilation, depth_multiplier, gpu_target);
 
         ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output,
                                                                                       dwc_native_compute_info, conv_kernel_info, &output_multipliers_shifts_info, &output_multipliers_shifts_info));
@@ -331,7 +356,7 @@
     else
     {
         DWCComputeKernelInfo dwc_native_compute_info;
-        initialize_dwc_native_compute_info(dwc_native_compute_info, weights, conv_info, dilation, depth_multiplier, gpu_target);
+        initialize_dwc_native_compute_info(dwc_native_compute_info, input, weights, conv_info, dilation, depth_multiplier, gpu_target);
         ARM_COMPUTE_RETURN_ON_ERROR(CLDepthwiseConvolutionLayerNativeKernel::validate(input, weights, biases, output, dwc_native_compute_info, conv_kernel_info, &output_multipliers_shifts_info,
                                                                                       &output_multipliers_shifts_info));
     }