COMPMID-1047 Extract Flatten function from Im2Col for NEON

Change-Id: I80f3aaadc8cae8c9ca1a5a239e79bda302b89bd8
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144813
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 98b1488..e5d3128 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -41,11 +41,12 @@
 #include <tuple>
 
 using namespace arm_compute;
+using namespace misc::shape_calculator;
 
 namespace
 {
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                          bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten)
+                          bool has_bias, const Size2D &dilation, unsigned int num_groups)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
@@ -55,18 +56,7 @@
 
     if(output->total_size() > 0)
     {
-        TensorShape expected_output_shape;
-
-        if(is_flatten || is_fully_connected)
-        {
-            expected_output_shape = misc::shape_calculator::compute_flatten_shape(input);
-        }
-        else
-        {
-            expected_output_shape = misc::shape_calculator::compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false);
-        }
-
-        TensorInfo expected_output = output->clone()->set_tensor_shape(expected_output_shape);
+        TensorInfo expected_output = output->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false));
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
     }
@@ -74,6 +64,31 @@
     return Status{};
 }
 
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+                                                        bool has_bias, const Size2D &dilation)
+{
+    const unsigned int width_idx   = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+    const unsigned int height_idx  = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+    const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
+    std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(input->dimension(width_idx), input->dimension(height_idx),
+                                                                             kernel_dims.width, kernel_dims.height,
+                                                                             conv_info, dilation);
+
+    // Output tensor auto initialization if not yet initialized
+    auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false)));
+
+    Window win = calculate_max_window(*input, Steps());
+    win.set(width_idx, Window::Dimension(0, convolved_dims.first, 1));
+    win.set(height_idx, Window::Dimension(0, convolved_dims.second, 1));
+    win.set(channel_idx, Window::Dimension(0, 1, 1));
+
+    // The NEIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
+    output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
+
+    return std::make_pair(Status{}, win);
+}
+
 template <typename T, bool has_pads>
 inline void linearize_volume(const uint8_t *const in_ptr,
                              T                   *out_ptr,
@@ -174,7 +189,7 @@
 } // namespace
 
 template <typename T, bool has_pads>
-void NEIm2ColKernel::run_generic(const Window &window)
+void NEIm2ColKernel::run_im2col(const Window &window)
 {
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
@@ -244,66 +259,21 @@
     in, out);
 }
 
-template <typename T>
-void NEIm2ColKernel::run_reduced(const Window &window)
-{
-    const size_t in_width   = _input->info()->dimension(0);
-    const size_t in_height  = _input->info()->dimension(1);
-    const size_t out_step_x = in_width * _input->info()->element_size();
-    const size_t out_step_y = out_step_x * in_height;
-    const size_t out_width  = _output->info()->dimension(0);
-
-    Window in_window(window);
-    in_window.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-    Window out_window;
-    out_window.use_tensor_dimensions(_output->info()->tensor_shape());
-    out_window.set(Window::DimX, Window::Dimension(out_window.x().start(), out_window.x().end(), in_width));
-
-    Window in_slice  = in_window.first_slice_window_3D();
-    Window out_slice = out_window.first_slice_window_1D();
-
-    do
-    {
-        Iterator in(_input, in_slice);
-        Iterator out(_output, out_slice);
-
-        uint8_t *out_ptr = out.ptr();
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            memcpy(out_ptr + id.y() * out_step_x + id.z() * out_step_y, in.ptr(), out_step_x);
-        },
-        in);
-
-        // Add bias
-        if(_has_bias)
-        {
-            *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = static_cast<T>(1);
-        }
-    }
-    while(in_window.slide_window_slice_3D(in_slice) && out_window.slide_window_slice_1D(out_slice));
-}
-
 NEIm2ColKernel::NEIm2ColKernel()
     : _func(), _input(nullptr), _output(nullptr), _convolved_dims(), _conv_info(), _kernel_width(0), _kernel_height(0), _has_bias(false), _dilation(1U, 1U)
 {
 }
 
 void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                               bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten)
+                               bool has_bias, const Size2D &dilation, unsigned int num_groups)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-
-    // Perform validation step
-    ARM_COMPUTE_UNUSED(is_fully_connected, is_flatten);
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, num_groups));
     ARM_COMPUTE_UNUSED(num_groups);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten));
 
     const DataLayout   data_layout = input->info()->data_layout();
     const unsigned int width_idx   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
     const unsigned int height_idx  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-    const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
 
     _input          = input;
     _output         = output;
@@ -316,73 +286,35 @@
                                         _conv_info, _dilation);
     _has_bias = has_bias;
 
-    unsigned int stride_x = 0;
-    unsigned int stride_y = 0;
-    std::tie(stride_x, stride_y) = conv_info.stride();
-
-    bool run_img2col_reduced = (output->info()->dimension(0) == (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))) && (TensorShape::num_max_dimensions >= 4)
-                               && (std::equal(input->info()->tensor_shape().cbegin() + 3,
-                                              input->info()->tensor_shape().cend(),
-                                              output->info()->tensor_shape().cbegin() + 1))
-                               && ((stride_x == 1) && (stride_y == 1) && !conv_info.has_padding())
-                               && ((dilation.x() == 1) && (dilation.y() == 1));
-
-    Window window = calculate_max_window(*input->info(), Steps());
-
-    if(run_img2col_reduced)
+    switch(_input->info()->data_type())
     {
-        switch(_input->info()->data_type())
-        {
-            case DataType::F32:
-                _func = &NEIm2ColKernel::run_reduced<float>;
-                break;
+        case DataType::F32:
+            _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float, false> : &NEIm2ColKernel::run_im2col<float, true>;
+            break;
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            case DataType::F16:
-                _func = &NEIm2ColKernel::run_reduced<float16_t>;
-                break;
+        case DataType::F16:
+            _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<float16_t, false> : &NEIm2ColKernel::run_im2col<float16_t, true>;
+            break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-            case DataType::QASYMM8:
-                _func = &NEIm2ColKernel::run_reduced<qasymm8_t>;
-                break;
-            default:
-                ARM_COMPUTE_ERROR("Data type not supported");
-                break;
-        }
-    }
-    else
-    {
-        switch(_input->info()->data_type())
-        {
-            case DataType::F32:
-                _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<float, false> : &NEIm2ColKernel::run_generic<float, true>;
-                break;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            case DataType::F16:
-                _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<float16_t, false> : &NEIm2ColKernel::run_generic<float16_t, true>;
-                break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-            case DataType::QASYMM8:
-                _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_generic<qasymm8_t, false> : &NEIm2ColKernel::run_generic<qasymm8_t, true>;
-                break;
-            default:
-                ARM_COMPUTE_ERROR("Data type not supported");
-                break;
-        }
-        window.set(width_idx, Window::Dimension(0, _convolved_dims.first, 1));
-        window.set(height_idx, Window::Dimension(0, _convolved_dims.second, 1));
-        window.set(channel_idx, Window::Dimension(0, 1, 1));
+        case DataType::QASYMM8:
+            _func = (!conv_info.has_padding()) ? &NEIm2ColKernel::run_im2col<qasymm8_t, false> : &NEIm2ColKernel::run_im2col<qasymm8_t, true>;
+            break;
+        default:
+            ARM_COMPUTE_ERROR("Data type not supported");
+            break;
     }
 
-    // The NEIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
-    output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
-
-    IKernel::configure(window);
+    // Configure kernel window
+    auto win_config = validate_and_configure_window(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation);
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    INEKernel::configure(win_config.second);
 }
 
 Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                                bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten)
+                                bool has_bias, const Size2D &dilation, unsigned int num_groups)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), kernel_dims, conv_info, has_bias, dilation).first);
     return Status{};
 }