COMPMID-417: Cleanup NEON FullyConnectedLayer

Change-Id: Ie02a0a1a28ca2771e29a5e6552242caf0f6db1cf
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83555
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h
index 6cf08de..8d15c50 100644
--- a/arm_compute/core/TensorShape.h
+++ b/arm_compute/core/TensorShape.h
@@ -138,17 +138,28 @@
     }
     /** Collapses given dimension and above.
      *
-     * @note Precondition: dimension < TensorShape::num_max_dimensions
-     *
      * @param[in] dimension Size of the wanted dimension
      *
      * @return The linear size of the collapsed dimensions
      */
     size_t total_size_upper(size_t dimension) const
     {
+        ARM_COMPUTE_ERROR_ON(dimension >= TensorShape::num_max_dimensions);
         return std::accumulate(_id.begin() + dimension, _id.end(), 1, std::multiplies<size_t>());
     }
 
+    /** Compute size of dimensions lower than the given one.
+     *
+     * @param[in] dimension Upper boundary.
+     *
+     * @return The linear size of the collapsed dimensions.
+     */
+    size_t total_size_lower(size_t dimension) const
+    {
+        ARM_COMPUTE_ERROR_ON(dimension > TensorShape::num_max_dimensions);
+        return std::accumulate(_id.begin(), _id.begin() + dimension, 1, std::multiplies<size_t>());
+    }
+
 private:
     /** Remove trailing dimensions of size 1 from the reported number of dimensions. */
     void apply_dimension_correction()
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index af571d1..08099b8 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -97,11 +97,6 @@
     void run() override;
 
 private:
-    void configure_fc_fc_wb(const ITensor *input, const ITensor *weights, ITensor *output);
-    void configure_fc_fc_nb(const ITensor *input, const ITensor *weights, ITensor *output);
-    void configure_conv_fc_wb(const ITensor *input, const ITensor *weights, ITensor *output);
-    void configure_conv_fc_nb(const ITensor *input, const ITensor *weights, ITensor *output);
-
     NEIm2ColKernel                      _im2col_kernel;
     NEFullyConnectedLayerReshapeWeights _reshape_weights_kernel;
     NEGEMMInterleave4x4Kernel           _interleave4x4_kernel;
@@ -111,8 +106,8 @@
     Tensor                              _interleave4x4_output;
     Tensor                              _reshape_weights_output;
     bool                                _are_weights_reshaped;
-    bool                                _is_fc_after_conv;
     bool                                _is_batched_fc_layer;
+    bool                                _linearize_input;
     bool                                _accumulate_biases;
 };
 }
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp
index a4fc494..6ed3791 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.cpp
@@ -48,7 +48,7 @@
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(accum, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(biases, accum);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(biases, accum);
-    ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() != 1);
+    ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1);
 
     _biases = biases;
     _accum  = accum;
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index 8381dd8..8a2a481 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -23,6 +23,7 @@
  */
 #include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
 
+#include "arm_compute/core/AccessWindowStatic.h"
 #include "arm_compute/core/AccessWindowTranspose.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
@@ -1462,7 +1463,7 @@
         AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration_x);
 
         update_window_and_padding(win,
-                                  AccessWindowHorizontal(input0->info(), 0, num_elems_processed_per_iteration_x),
+                                  AccessWindowStatic(input0->info(), 0, 0, input0->info()->tensor_shape().x(), 1),
                                   AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration_x),
                                   output_access);
 
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index e4de60d..6e15f82 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -291,7 +291,10 @@
                                         _conv_info);
     _has_bias = has_bias;
 
-    unsigned int pad_x, pad_y, stride_x, stride_y = 0;
+    unsigned int pad_x    = 0;
+    unsigned int pad_y    = 0;
+    unsigned int stride_x = 0;
+    unsigned int stride_y = 0;
     std::tie(pad_x, pad_y)       = conv_info.pad();
     std::tie(stride_x, stride_y) = conv_info.stride();
 
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 4d9ee85..39983bf 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -30,8 +30,8 @@
 #include <algorithm>
 #include <cmath>
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 NEFullyConnectedLayerReshapeWeights::NEFullyConnectedLayerReshapeWeights()
     : _transpose_kernel(), _transpose1xW_kernel(), _transpose_output(), _transpose_weights(false), _is_batched_fc_layer(false)
 {
@@ -40,11 +40,11 @@
 void NEFullyConnectedLayerReshapeWeights::configure(const ITensor *input, ITensor *output, bool transpose_weights, bool is_batched_fc_layer)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON(input->info()->num_dimensions() > 2);
     ARM_COMPUTE_ERROR_ON(output == nullptr);
-    ARM_COMPUTE_ERROR_ON(input->info()->num_dimensions() != 2);
-    ARM_COMPUTE_ERROR_ON((transpose_weights == false) && (is_batched_fc_layer == false));
+    ARM_COMPUTE_ERROR_ON(!transpose_weights && !is_batched_fc_layer);
 
-    const DataType dt                   = input->info()->data_type();
+    const DataType data_type            = input->info()->data_type();
     const int      fixed_point_position = input->info()->fixed_point_position();
 
     _transpose_weights   = transpose_weights;
@@ -57,7 +57,7 @@
         {
             // Initialize the output tensor for transpose
             TensorShape shape_transposed(input->info()->dimension(1), input->info()->dimension(0));
-            _transpose_output.allocator()->init(TensorInfo(shape_transposed, 1, dt, fixed_point_position));
+            _transpose_output.allocator()->init(TensorInfo(shape_transposed, 1, data_type, fixed_point_position));
             _transpose_kernel.configure(input, &_transpose_output);
 
             // Configure transpose 1xW kernel
@@ -91,6 +91,7 @@
     {
         NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
     }
+
     if(_is_batched_fc_layer)
     {
         NEScheduler::get().schedule(&_transpose1xW_kernel, Window::DimY);
@@ -99,216 +100,142 @@
 
 NEFullyConnectedLayer::NEFullyConnectedLayer()
     : _im2col_kernel(), _reshape_weights_kernel(), _interleave4x4_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _interleave4x4_output(), _reshape_weights_output(),
-      _are_weights_reshaped(false), _is_fc_after_conv(false), _is_batched_fc_layer(false), _accumulate_biases(false)
+      _are_weights_reshaped(false), _is_batched_fc_layer(false), _linearize_input(false), _accumulate_biases(false)
 {
 }
 
-void NEFullyConnectedLayer::configure_conv_fc_wb(const ITensor *input, const ITensor *weights, ITensor *output)
-{
-    ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2) * (16 / weights->info()->element_size())));
-
-    const DataType dt                   = input->info()->data_type();
-    const int      fixed_point_position = input->info()->fixed_point_position();
-
-    // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
-
-    // Initialize output tensor for im2col
-    TensorShape shape_im2col;
-    shape_im2col.set(0, input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2));
-    shape_im2col.set(1, input->info()->dimension(3));
-    shape_im2col.set(2, input->info()->dimension(4));
-    shape_im2col.set(3, input->info()->dimension(5));
-    _im2col_output.allocator()->init(TensorInfo(shape_im2col, 1, dt, fixed_point_position));
-
-    // Initialize output tensor for interleave 4x4
-    TensorShape shape_interleaved = _im2col_output.info()->tensor_shape();
-    shape_interleaved.set(0, shape_interleaved.x() * 4);
-    shape_interleaved.set(1, std::ceil(static_cast<float>(shape_interleaved.y()) / 4));
-    _interleave4x4_output.allocator()->init(TensorInfo(shape_interleaved, 1, dt, fixed_point_position));
-
-    // Configure im2col kernel
-    _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
-
-    // Configure interleave4x4 kernel
-    _interleave4x4_kernel.configure(&_im2col_output, &_interleave4x4_output);
-
-    // Configure matrix multiply kernel
-    _mm_kernel.configure(&_interleave4x4_output, weights, output, 1.0f);
-
-    // Allocate the tensors once all the configure methods have been called
-    _im2col_output.allocator()->allocate();
-    _interleave4x4_output.allocator()->allocate();
-}
-
-void NEFullyConnectedLayer::configure_fc_fc_wb(const ITensor *input, const ITensor *weights, ITensor *output)
-{
-    const DataType dt                   = input->info()->data_type();
-    const int      fixed_point_position = input->info()->fixed_point_position();
-
-    // Initialize output tensor for interleave 4x4
-    TensorShape shape_interleaved = input->info()->tensor_shape();
-    shape_interleaved.set(0, shape_interleaved.x() * 4);
-    shape_interleaved.set(1, std::ceil(static_cast<float>(shape_interleaved.y()) / 4));
-    _interleave4x4_output.allocator()->init(TensorInfo(shape_interleaved, 1, dt, fixed_point_position));
-
-    // Configure interleave4x4 kernel
-    _interleave4x4_kernel.configure(input, &_interleave4x4_output);
-
-    // Configure matrix multiply kernel
-    _mm_kernel.configure(&_interleave4x4_output, weights, output, 1.0f);
-
-    // Allocate the tensors once all the configure methods have been called
-    _interleave4x4_output.allocator()->allocate();
-}
-
-void NEFullyConnectedLayer::configure_conv_fc_nb(const ITensor *input, const ITensor *weights, ITensor *output)
-{
-    ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
-
-    const DataType dt                   = input->info()->data_type();
-    const int      fixed_point_position = input->info()->fixed_point_position();
-
-    // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
-
-    // Initialize output tensor for im2col
-    TensorShape shape_im2col;
-    shape_im2col.set(0, input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2));
-    shape_im2col.set(1, 1);
-    _im2col_output.allocator()->init(TensorInfo(shape_im2col, 1, dt, fixed_point_position));
-
-    // Configure im2col kernel
-    _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
-
-    // Configure matrix multiply kernel
-    _mm_kernel.configure(&_im2col_output, weights, output, 1.0f);
-
-    // Allocate the output tensor for im2col once all the configure methods have been called
-    _im2col_output.allocator()->allocate();
-}
-
-void NEFullyConnectedLayer::configure_fc_fc_nb(const ITensor *input, const ITensor *weights, ITensor *output)
-{
-    ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
-
-    // Configure matrix multiply kernel
-    _mm_kernel.configure(input, weights, output, 1.0f);
-}
-
 void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose_weights, bool are_weights_reshaped)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, weights, output);
-    ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() != 2);
-
-    const DataType dt                   = input->info()->data_type();
-    const int      fixed_point_position = input->info()->fixed_point_position();
-
-    _are_weights_reshaped = are_weights_reshaped;
-    _is_fc_after_conv     = true;
-    _is_batched_fc_layer  = false;
-    _accumulate_biases    = false;
-
-    if(biases != nullptr)
-    {
-        ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
-
-        _accumulate_biases = true;
-
-        // Configure accumulate biases kernel
-        _accumulate_biases_kernel.configure(output, biases);
-    }
-
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
     //  2) Fully Connected layer -> Fully Connected layer without batches
     //  3) Convolution layer -> Fully Connected layer with batches
     //  4) Fully Connected layer -> Fully Connected layer with batches
 
-    // Check if we have a fully connected layer with batches
-    _is_batched_fc_layer = (output->info()->dimension(1) > 1);
+    // Expected shape before transpose and reshaping
+    // Input: In x B (In and B can be multi-dimensional)
+    // Weights: flat(In) x Out
+    // Biases: Out
+    // Output: Out x B (B can be multi-dimensional)
 
-    const ITensor *weights_to_use = weights;
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, weights, output);
 
-    if(!are_weights_reshaped)
+    const DataType data_type            = input->info()->data_type();
+    const int      fixed_point_position = input->info()->fixed_point_position();
+    const int      num_batch_dimensions = std::max(0, static_cast<int>(output->info()->tensor_shape().num_dimensions()) - 1);
+    const int      num_input_dimensions = input->info()->tensor_shape().num_dimensions() - num_batch_dimensions;
+    const size_t   linear_input_size    = input->info()->tensor_shape().total_size_lower(num_input_dimensions);
+
+    _linearize_input      = input->info()->tensor_shape().x() != linear_input_size;
+    _are_weights_reshaped = are_weights_reshaped;
+    _accumulate_biases    = biases != nullptr;
+    _is_batched_fc_layer  = num_batch_dimensions > 0;
+
+    // Check if number of batches match
+    ARM_COMPUTE_ERROR_ON(input->info()->tensor_shape().total_size_upper(num_input_dimensions) != output->info()->tensor_shape().total_size_upper(1));
+    ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 2);
+
+    const size_t   interleave_width = 16 / input->info()->element_size();
+    const ITensor *weights_to_use   = weights;
+
+    if(!are_weights_reshaped && (transpose_weights || _is_batched_fc_layer))
     {
-        if((transpose_weights || _is_batched_fc_layer))
+        weights_to_use = &_reshape_weights_output;
+
+        TensorShape reshaped_weights_shape(weights->info()->tensor_shape());
+
+        // Transpose weights if the user hasn't done it
+        if(transpose_weights)
         {
-            weights_to_use = &_reshape_weights_output;
-
-            if(transpose_weights)
-            {
-                if(_is_batched_fc_layer)
-                {
-                    const float transpose_width = 16.0f / input->info()->element_size();
-                    TensorShape shape_wt(weights->info()->dimension(0) * static_cast<unsigned int>(transpose_width), static_cast<unsigned int>(std::ceil(weights->info()->dimension(1) / transpose_width)));
-                    TensorInfo  info_wt(shape_wt, 1, dt, fixed_point_position);
-                    _reshape_weights_output.allocator()->init(info_wt);
-                }
-                else
-                {
-                    TensorShape shape_wt(weights->info()->dimension(1), weights->info()->dimension(0));
-                    TensorInfo  info_wt(shape_wt, 1, dt, fixed_point_position);
-                    _reshape_weights_output.allocator()->init(info_wt);
-                }
-            }
-            else
-            {
-                ARM_COMPUTE_ERROR_ON(!_is_batched_fc_layer);
-
-                const float transpose_width = 16.0f / input->info()->element_size();
-                TensorShape shape_wt(weights->info()->dimension(1) * static_cast<unsigned int>(transpose_width), static_cast<unsigned int>(std::ceil(weights->info()->dimension(0) / transpose_width)));
-                TensorInfo  info_wt(shape_wt, 1, dt, fixed_point_position);
-                _reshape_weights_output.allocator()->init(info_wt);
-            }
-
-            // Reshape the weights
-            _reshape_weights_kernel.configure(weights, &_reshape_weights_output, transpose_weights, _is_batched_fc_layer);
+            const size_t shape_x = reshaped_weights_shape.x();
+            reshaped_weights_shape.set(0, reshaped_weights_shape.y());
+            reshaped_weights_shape.set(1, shape_x);
         }
+
+        // If the we run multiple batches we need 1xW transpose, too.
+        if(_is_batched_fc_layer)
+        {
+            const float shape_x = reshaped_weights_shape.x();
+            reshaped_weights_shape.set(0, reshaped_weights_shape.y() * interleave_width);
+            reshaped_weights_shape.set(1, static_cast<unsigned int>(std::ceil(shape_x / interleave_width)));
+        }
+
+        _reshape_weights_output.allocator()->init(TensorInfo(reshaped_weights_shape, 1, data_type, fixed_point_position));
+
+        // Reshape the weights
+        _reshape_weights_kernel.configure(weights, &_reshape_weights_output, transpose_weights, _is_batched_fc_layer);
+    }
+
+    // Check correct shape of weights
+    if(_is_batched_fc_layer)
+    {
+        // Transpose + Transpose1xW
+        ARM_COMPUTE_ERROR_ON(weights_to_use->info()->tensor_shape().x() != linear_input_size * interleave_width);
+        ARM_COMPUTE_ERROR_ON(weights_to_use->info()->tensor_shape().y() != static_cast<unsigned int>(std::ceil(static_cast<float>(output->info()->tensor_shape().x()) / interleave_width)));
+    }
+    else
+    {
+        // Transpose
+        ARM_COMPUTE_ERROR_ON(weights_to_use->info()->tensor_shape().x() != output->info()->tensor_shape().x());
+        ARM_COMPUTE_ERROR_ON(weights_to_use->info()->tensor_shape().y() != linear_input_size);
+    }
+
+    const ITensor *multiply_input = input;
+
+    if(_linearize_input)
+    {
+        TensorShape shape_im2col(input->info()->tensor_shape());
+        shape_im2col.collapse(num_input_dimensions);
+        _im2col_output.allocator()->init(TensorInfo(shape_im2col, 1, data_type, fixed_point_position));
+
+        // Configure im2col kernel
+        _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
+
+        multiply_input = &_im2col_output;
     }
 
     if(_is_batched_fc_layer)
     {
-        _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
-                                                                                  input->info()->tensor_shape().cend(),
-                                                                                  output->info()->tensor_shape().cbegin() + 1));
+        TensorShape shape_interleaved(multiply_input->info()->tensor_shape());
+        shape_interleaved.set(0, shape_interleaved.x() * 4);
+        shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f));
+        _interleave4x4_output.allocator()->init(TensorInfo(shape_interleaved, 1, data_type, fixed_point_position));
 
-        if(_is_fc_after_conv)
-        {
-            // Fully Connected layer after a Convolution Layer with batches
-            configure_conv_fc_wb(input, weights_to_use, output);
-        }
-        else
-        {
-            // Fully Connected layer after a Fully Connected Layer with batches
-            configure_fc_fc_wb(input, weights_to_use, output);
-        }
+        // Configure interleave4x4 kernel
+        _interleave4x4_kernel.configure(multiply_input, &_interleave4x4_output);
+
+        multiply_input = &_interleave4x4_output;
     }
-    else
-    {
-        // In case of not batched fully connected layer, the weights will not be reshaped using transposed1xW
-        _is_fc_after_conv = ((weights_to_use->info()->dimension(1)) == (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2)));
 
-        if(_is_fc_after_conv)
-        {
-            // Fully Connected layer after a Convolution Layer without batches
-            configure_conv_fc_nb(input, weights_to_use, output);
-        }
-        else
-        {
-            // Fully Connected layer after a Fully Connected Layer without batches
-            configure_fc_fc_nb(input, weights_to_use, output);
-        }
+    // Configure matrix multiply kernel
+    _mm_kernel.configure(multiply_input, weights_to_use, output, 1.0f);
+
+    if(_accumulate_biases)
+    {
+        ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+        ARM_COMPUTE_ERROR_ON(biases->info()->tensor_shape().x() != output->info()->tensor_shape().x());
+
+        // Configure accumulate biases kernel
+        _accumulate_biases_kernel.configure(output, biases);
     }
 
     // Allocate the transpose tensor if the are_weights_reshaped flag is false and once all the configure methods have been called
-    if(!are_weights_reshaped)
+    if(!are_weights_reshaped && (transpose_weights || _is_batched_fc_layer))
     {
-        if(transpose_weights || _is_batched_fc_layer)
-        {
-            // Allocate the tensor for the weights reshaped
-            _reshape_weights_output.allocator()->allocate();
-        }
+        // Allocate the tensor for the weights reshaped
+        _reshape_weights_output.allocator()->allocate();
+    }
+
+    if(_linearize_input)
+    {
+        _im2col_output.allocator()->allocate();
+    }
+
+    if(_is_batched_fc_layer)
+    {
+        _interleave4x4_output.allocator()->allocate();
     }
 }
 
@@ -321,8 +248,8 @@
         _reshape_weights_kernel.run();
     }
 
-    // Linearize input if comes from a convolutional layer
-    if(_is_fc_after_conv)
+    // Linearize input if it comes from a convolutional layer
+    if(_linearize_input)
     {
         NEScheduler::get().schedule(&_im2col_kernel, Window::DimY);
     }
@@ -342,3 +269,4 @@
         NEScheduler::get().schedule(&_accumulate_biases_kernel, Window::DimY);
     }
 }
+} // namespace arm_compute