COMPMID-1277 - Optimizing CLIm2ColKernel for NHWC.

This patch includes:

- Im2Col optimizations for NHWC using a new data layout
- Refactoring of CLIm2ColKernel adding validation method and auto-init
- Removed im2col_reduced from CLIm2ColKernel and created a new kernel CLFlattenLayerKernel

Change-Id: I1620640b6796baa268324b33ae92cdd8de53e27c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141241
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
diff --git a/src/runtime/CL/functions/CLFlattenLayer.cpp b/src/runtime/CL/functions/CLFlattenLayer.cpp
index f5809a2..b372c35 100644
--- a/src/runtime/CL/functions/CLFlattenLayer.cpp
+++ b/src/runtime/CL/functions/CLFlattenLayer.cpp
@@ -23,8 +23,7 @@
  */
 #include "arm_compute/runtime/CL/functions/CLFlattenLayer.h"
 
-#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
-#include "arm_compute/core/Size2D.h"
+#include "arm_compute/core/CL/kernels/CLFlattenLayerKernel.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #include "support/ToolchainSupport.h"
 
@@ -32,8 +31,13 @@
 
 void CLFlattenLayer::configure(const ICLTensor *input, ICLTensor *output)
 {
-    auto k = arm_compute::support::cpp14::make_unique<CLIm2ColKernel>();
-    k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
+    auto k = arm_compute::support::cpp14::make_unique<CLFlattenLayerKernel>();
+    k->configure(input, output);
     _kernel = std::move(k);
     CLScheduler::get().tune_kernel_static(*_kernel);
 }
+
+Status CLFlattenLayer::validate(const ITensorInfo *input, const ITensorInfo *output)
+{
+    return CLFlattenLayerKernel::validate(input, output);
+}
\ No newline at end of file
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 6fd78a3..60c28a0 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -73,12 +73,11 @@
 }
 
 CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _im2col_kernel(), _convert_weights(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
-      _accumulate_biases_kernel(), _im2col_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
+    : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
+      _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
       _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
 {
 }
-
 void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
 {
     if(_is_quantized)
@@ -111,20 +110,19 @@
 
     // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
 
-    // Initialize output tensor for im2col
-    TensorShape shape_im2col = compute_im2col_fc_shape(input->info());
-    _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col).set_data_layout(DataLayout::NCHW));
+    // Initialize output tensor for flatten
+    TensorShape shape_flatten = compute_flatten_shape(input->info());
+    _flatten_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_flatten).set_data_layout(DataLayout::NCHW));
 
-    // Configure im2col kernel
-    _memory_group.manage(&_im2col_output);
-    _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
-    CLScheduler::get().tune_kernel_static(_im2col_kernel);
+    // Configure flatten kernel
+    _memory_group.manage(&_flatten_output);
+    _flatten_layer.configure(input, &_flatten_output);
 
     // Configure matrix multiply kernel
-    configure_mm(&_im2col_output, weights, output);
+    configure_mm(&_flatten_output, weights, output);
 
-    // Allocate the output tensor for im2col once all the configure methods have been called
-    _im2col_output.allocator()->allocate();
+    // Allocate the output tensor for flatten once all the configure methods have been called
+    _flatten_output.allocator()->allocate();
 }
 
 void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
@@ -254,7 +252,7 @@
     bool            is_quantized     = is_data_type_quantized_asymmetric(input->data_type());
     const GPUTarget gpu_target       = CLScheduler::get().target();
 
-    const ITensorInfo &im2col_input      = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)).set_data_layout(DataLayout::NCHW));
+    const ITensorInfo &flatten_input     = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)).set_data_layout(DataLayout::NCHW));
     const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
     const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
     const ITensorInfo &gemmlowp_output   = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
@@ -311,9 +309,9 @@
         // Fully Connected layer after a Convolution Layer without batches
         ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2))));
 
-        // Validate im2col kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_input, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false));
-        input_to_use = &im2col_input;
+        // Validate flatten kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayer::validate(input, &flatten_input));
+        input_to_use = &flatten_input;
     }
     else
     {
@@ -341,7 +339,7 @@
     // Linearize input if it comes from a convolutional layer
     if(_is_fc_after_conv)
     {
-        CLScheduler::get().enqueue(_im2col_kernel, false);
+        _flatten_layer.run();
     }
 
     // Run matrix multiply
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1d1b17b..a8d7058 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -171,6 +171,7 @@
 Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(output);
 
     // Check if we need to reshape the matrix B only on the first run
     const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
@@ -180,7 +181,7 @@
 
     TensorInfo tmp_a_info{};
     TensorInfo tmp_b_info{};
-    TensorInfo tmp_output_info = *output->clone();
+    TensorInfo tmp_output_info{};
 
     // Get the GPU target
     const GPUTarget gpu_target = CLScheduler::get().target();
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index fb90415..49549a0 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -171,7 +171,6 @@
     const DataLayout data_layout = input->info()->data_layout();
     const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
     const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-    const int        idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
     const int        idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
 
     const unsigned int kernel_width  = weights->info()->dimension(idx_width);
@@ -193,7 +192,6 @@
     ICLTensor       *gemm_output_to_use        = output;
     ICLTensor       *gemm_output_staged_to_use = output;
 
-    const unsigned   bias_element  = (_append_bias && !_skip_im2col) ? 1 : 0;
     const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
 
     // Get parameters from conv_info
@@ -212,7 +210,6 @@
                                                  dilation);
 
     unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels);
-    unsigned int mat_weights_rows = weights->info()->dimension(idx_width) * weights->info()->dimension(idx_height) * weights->info()->dimension(idx_channel) + bias_element;
 
     // _weights_reshaped will be auto configured in the kernel.
     // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM
@@ -223,25 +220,13 @@
     // Create tensor to store im2col reshaped inputs
     if(!_skip_im2col)
     {
-        // Calculate im2col shape
-        // For OpenCL the batch size is on the third dimension
-        // TODO (giaiod01): Use auto-init COMPMID-1277
-        TensorShape shape_im2col = input->info()->tensor_shape();
-        if(shape_im2col.num_dimensions() >= 3)
-        {
-            shape_im2col.remove_dimension(2);
-        }
-        shape_im2col.set(0, mat_weights_rows);
-        shape_im2col.set(1, conv_w * conv_h);
-
-        // FIXME: input->clone() doesn't work with subtensors for grouped convolutions.
-        TensorInfo im2col_reshaped_info(shape_im2col, 1, data_type);
-        im2col_reshaped_info.set_quantization_info(input->info()->quantization_info());
-        _im2col_output.allocator()->init(im2col_reshaped_info);
         _memory_group.manage(&_im2col_output);
 
-        // Configure and tune im2col
+        // Configure and tune im2col. im2col output shape is auto-initialized
         _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
+
+        // Set quantization info
+        _im2col_output.info()->set_quantization_info(input->info()->quantization_info());
         CLScheduler::get().tune_kernel_static(_im2col_kernel);
 
         // Update GEMM input
@@ -350,11 +335,10 @@
     const ITensorInfo *gemm_output_staged_to_use = output;
     const ITensorInfo *weights_to_use            = weights;
 
-    const bool     is_nhwc      = data_layout == DataLayout::NHWC;
-    const bool     is_quantized = is_data_type_quantized_asymmetric(data_type);
-    const bool     skip_im2col  = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized;
-    const bool     append_bias  = (biases != nullptr) && (!is_quantized);
-    const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
+    const bool is_nhwc      = data_layout == DataLayout::NHWC;
+    const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+    const bool skip_im2col  = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized;
+    const bool append_bias  = (biases != nullptr) && (!is_quantized);
 
     ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -391,7 +375,6 @@
                                                  dilation);
 
     unsigned int mat_weights_cols = weights->dimension(idx_kernels);
-    unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element;
 
     // Output tensor auto inizialitation if not yet initialized
     ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr));
@@ -400,19 +383,14 @@
 
     if(!skip_im2col)
     {
-        // Create tensor info for im2col reshaped inputs
-        // For OpenCL the batch size is on the third dimension
-        // TODO (giaiod01): Use auto-init COMPMID-1277
-        TensorShape shape_im2col = input->tensor_shape();
-        if(input->tensor_shape().num_dimensions() >= 3)
-        {
-            shape_im2col.remove_dimension(2);
-        }
-        shape_im2col.set(0, mat_weights_rows);
-        shape_im2col.set(1, conv_w * conv_h);
-        im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
-        im2col_reshaped_info.set_quantization_info(input->quantization_info());
-        ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
+        const Size2D kernel_dims(kernel_width, kernel_height);
+
+        // Output tensor auto initialization if not yet initialized
+        TensorShape expected_output_shape = compute_im2col_conv_shape(input, kernel_dims, conv_info, append_bias, dilation, true);
+
+        auto_init_if_empty(im2col_reshaped_info, input->clone()->set_tensor_shape(expected_output_shape));
+
+        ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation));
         gemm_input_to_use = &im2col_reshaped_info;
     }
     else if(append_bias)
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 25b8adc..c2f0283 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -113,7 +113,7 @@
     // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
 
     // Initialize output tensor for im2col
-    TensorShape shape_im2col = compute_im2col_fc_shape(input->info());
+    TensorShape shape_im2col = compute_flatten_shape(input->info());
     _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
 
     // Configure im2col kernel
@@ -249,7 +249,7 @@
     bool is_fc_after_conv = true;
     bool is_quantized     = is_data_type_quantized_asymmetric(input->data_type());
 
-    const ITensorInfo &im2col_input      = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)));
+    const ITensorInfo &im2col_input      = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)));
     const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
     const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
     const ITensorInfo &gemmlowp_output   = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
@@ -420,4 +420,4 @@
 
         _is_prepared = true;
     }
-}
+}
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index c0a5d0a..df4a040 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -223,7 +223,7 @@
     {
         // Calculate im2col shape
         // For NEON the batch size is on the fourth dimension
-        // TODO (giaiod01): Use auto-init COMPMID-1277
+        // TODO (giaiod01): Auto-initialize the output shape of im2col COMPMID-1482
         TensorShape shape_im2col = input->info()->tensor_shape();
         shape_im2col.set(0, mat_weights_rows);
         shape_im2col.set(1, conv_w * conv_h);
@@ -232,7 +232,7 @@
         _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
         _memory_group.manage(&_im2col_output);
 
-        // Configure and tune im2col
+        // Configure
         _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation);
 
         // Update GEMM input