COMPMID-1051 - Fix validate method in NEGEMMConvolutionLayer

Change-Id: I10e8e1267a09246cac77e677f1c087bb1d80a61b
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127517
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index c339947..7f25c2e 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -466,55 +466,6 @@
         optimised_kernel = true;
     }
 
-    // Reshape weights if needed
-    if(optimised_kernel)
-    {
-        if(are_weights_reshaped)
-        {
-            mat_weights_cols = weights_info.num_kernels();
-            mat_weights_rows = weights->dimension(1);
-        }
-        else
-        {
-            TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows };
-
-            // Create tensor to store the reshaped weights
-            reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution));
-            ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */));
-            weights = reshaped_weights.get();
-        }
-    }
-    else
-    {
-        if(are_weights_reshaped)
-        {
-            const unsigned int transpose_width = 16 / input->element_size();
-            mat_weights_cols                   = weights_info.num_kernels();
-            mat_weights_rows                   = weights->dimension(0) / transpose_width + (append_bias ? 1 : 0);
-        }
-        else
-        {
-            TensorShape reshaped_weights_shape;
-
-            if(is_fully_connected_convolution || is_quantized)
-            {
-                reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows };
-            }
-            else
-            {
-                // Create tensor to store transposed weights
-                const float transpose_width = 16.0f / input->element_size();
-                reshaped_weights_shape      = TensorShape{ mat_weights_rows *static_cast<unsigned int>(transpose_width),
-                                                           static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)) };
-            }
-
-            // Create tensor to store the reshaped weights
-            reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution));
-            ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */));
-            weights = reshaped_weights.get();
-        }
-    }
-
     // Validate im2col
     const unsigned int mat_input_cols = mat_weights_rows;
     const unsigned int mat_input_rows = conv_w * conv_h;
@@ -531,19 +482,52 @@
     shape_gemm.set(1, mat_input_rows);
     TensorInfo gemm_output_info = input->clone()->set_tensor_shape(shape_gemm);
 
-    // Validate GEMM interleave and multiply
-    if(is_interleaved)
+    // Reshape weights if needed
+    if(optimised_kernel)
     {
-        TensorShape shape_interleaved = shape_im2col;
-        shape_interleaved.set(0, shape_interleaved.x() * 4);
-        shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f));
-        TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved);
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info));
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo()));
+        ARM_COMPUTE_RETURN_ERROR_ON(are_weights_reshaped);
+
+        // Create tensor to store the reshaped weights
+        reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution));
+        ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */));
     }
     else
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo()));
+        TensorShape reshaped_weights_shape;
+
+        if(is_fully_connected_convolution || is_quantized)
+        {
+            reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows };
+        }
+        else
+        {
+            // Create tensor to store transposed weights
+            const float transpose_width = 16.0f / input->element_size();
+            reshaped_weights_shape      = TensorShape{ mat_weights_rows *static_cast<unsigned int>(transpose_width),
+                                                       static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)) };
+        }
+
+        // Create tensor to store the reshaped weights
+        reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution));
+        ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */));
+        weights = reshaped_weights.get();
+
+        // Validate GEMM interleave and multiply
+        if(is_interleaved)
+        {
+            TensorShape shape_interleaved = shape_im2col;
+            shape_interleaved.set(0, shape_interleaved.x() * 4);
+            shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f));
+            TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved);
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info));
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo(shape_im2col[1],            // m
+                                                                             weights->tensor_shape()[0], // n
+                                                                             shape_im2col[0]) /* k */));
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo()));
+        }
     }
 
     ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h)));