COMPMID-617: Add validate support for NEON FullyConnectedLayer

Change-Id: I08987022c8d4cc335c00b8af27bd3edb8fe64d3b
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111596
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Alexander Gilday <alexander.gilday@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 633f78d..4fa329b 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -32,6 +32,8 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
 
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
 #include <arm_neon.h>
 #include <cstddef>
 #include <cstdint>
@@ -42,14 +44,34 @@
 
 namespace
 {
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+                          bool has_bias, bool is_fully_connected, bool is_flatten)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
     ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::QASYMM8 && has_bias);
-    ARM_COMPUTE_UNUSED(kernel_dims);
-    ARM_COMPUTE_UNUSED(conv_info);
+
+    if(is_flatten) /* Called by FlattenLayer */
+    {
+        size_t flatten_shape = input->tensor_shape().x() * input->tensor_shape().y() * input->tensor_shape().z();
+        ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != flatten_shape);
+    }
+    else if(!is_fully_connected) /* Called by ConvolutionLayer */
+    {
+        std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_dims.width, kernel_dims.height, conv_info);
+        ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (input->dimension(2) * kernel_dims.area() + (has_bias ? 1 : 0)));
+        ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != (out_dims.first * out_dims.second));
+        ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(2) != 1);
+    }
+    else /* Called by FullyConnectedLayer */
+    {
+        const int num_batch_dimensions = std::max(0, static_cast<int>(output->tensor_shape().num_dimensions()) - 1);
+        const int num_input_dimensions = input->tensor_shape().num_dimensions() - num_batch_dimensions;
+
+        TensorInfo expected_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_im2col_shape(input, num_input_dimensions));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
+    }
 
     return Status{};
 }
@@ -291,12 +313,15 @@
 {
 }
 
-void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+                               bool has_bias, bool is_fully_connected, bool is_flatten)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
 
     // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias));
+    ARM_COMPUTE_UNUSED(is_fully_connected);
+    ARM_COMPUTE_UNUSED(is_flatten);
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten));
 
     _input          = input;
     _output         = output;
@@ -382,9 +407,10 @@
     IKernel::configure(window);
 }
 
-Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+                                bool has_bias, bool is_fully_connected, bool is_flatten)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten));
     return Status{};
 }