COMPMID-1386: Add FC convert weights on NEON

Change-Id: I7a3c6db9285e3899494f496b2562d80cec1b6521
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141407
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 9d3cb31..34cabb5 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -74,8 +74,9 @@
 }
 
 NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_function(), _mm_gemm(), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _im2col_output(),
-      _gemmlowp_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_reshaped(false), _is_fc_after_conv(false), _accumulate_biases(false), _is_quantized(false), _is_prepared(false)
+    : _memory_group(std::move(memory_manager)), _im2col_kernel(), _convert_weights(), _reshape_weights_function(), _mm_gemm(), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
+      _im2col_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false),
+      _is_fc_after_conv(false), _accumulate_biases(false), _is_quantized(false), _is_prepared(false)
 {
 }
 
@@ -146,11 +147,12 @@
                                                                output->info(),
                                                                fc_info));
 
-    _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
-    _is_fc_after_conv     = true;
-    _accumulate_biases    = false;
-    _is_quantized         = is_data_type_quantized_asymmetric(input->info()->data_type());
-    _original_weights     = weights;
+    _are_weights_converted = true;
+    _are_weights_reshaped  = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+    _is_fc_after_conv      = true;
+    _accumulate_biases     = false;
+    _is_quantized          = is_data_type_quantized_asymmetric(input->info()->data_type());
+    _original_weights      = weights;
 
     // Configure gemmlowp output
     if(_is_quantized)
@@ -175,17 +177,8 @@
 
     const ITensor *weights_to_use = weights;
 
-    if(!_are_weights_reshaped)
-    {
-        weights_to_use = &_reshape_weights_output;
-
-        // Reshape the weights
-        _reshape_weights_function.configure(weights, &_reshape_weights_output);
-    }
-
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = output->info()->dimension(1) > 1;
-
     if(is_batched_fc_layer)
     {
         _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
@@ -197,6 +190,27 @@
         _is_fc_after_conv = input->info()->num_dimensions() > 1;
     }
 
+    // Reshape weights if needed
+    if(!_are_weights_reshaped)
+    {
+        // Reshape the weights
+        _reshape_weights_function.configure(weights, &_reshape_weights_output);
+        weights_to_use = &_reshape_weights_output;
+    }
+
+    // Convert weights if needed
+    if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
+    {
+        // Convert weights
+        _convert_weights.configure(weights_to_use,
+                                   &_converted_weights_output,
+                                   input->info()->tensor_shape(),
+                                   fc_info.weights_trained_layout);
+
+        weights_to_use         = &_converted_weights_output;
+        _are_weights_converted = false;
+    }
+
     ITensor *tmp_output = (_is_quantized) ? &_gemmlowp_output : output;
     if(_is_fc_after_conv)
     {
@@ -235,9 +249,10 @@
     bool is_fc_after_conv = true;
     bool is_quantized     = is_data_type_quantized_asymmetric(input->data_type());
 
-    const ITensorInfo &im2col_input     = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)));
-    const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
-    const ITensorInfo &gemmlowp_output  = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
+    const ITensorInfo &im2col_input      = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)));
+    const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
+    const ITensorInfo &converted_weights = TensorInfo(reshaped_weights.clone()->set_is_resizable(true).reset_padding());
+    const ITensorInfo &gemmlowp_output   = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
 
     // Configure accumulate biases kernel for non quantized asymmetric types
     if(biases != nullptr && !is_quantized)
@@ -256,13 +271,6 @@
     const ITensorInfo *weights_to_use = weights;
     const ITensorInfo *tmp_output     = (is_quantized) ? &gemmlowp_output : output;
 
-    if(!weights_reshaped)
-    {
-        // Validate reshape weights kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
-        weights_to_use = &reshaped_weights;
-    }
-
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = output->dimension(1) > 1;
 
@@ -277,6 +285,23 @@
         is_fc_after_conv = input->num_dimensions() > 1;
     }
 
+    if(!weights_reshaped)
+    {
+        // Validate reshape weights kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
+        weights_to_use = &reshaped_weights;
+    }
+
+    if(is_fc_after_conv && (input->data_layout() != fc_info.weights_trained_layout))
+    {
+        // Validate convert weights kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(NEConvertFullyConnectedWeights::validate(weights_to_use,
+                                                                             &converted_weights,
+                                                                             input->tensor_shape(),
+                                                                             fc_info.weights_trained_layout));
+        weights_to_use = &converted_weights;
+    }
+
     if(is_fc_after_conv)
     {
         // Fully Connected layer after a Convolution Layer without batches
@@ -345,29 +370,54 @@
 {
     if(!_is_prepared)
     {
+        ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+
+        auto release_unused = [](Tensor * w)
+        {
+            if(!w->is_used())
+            {
+                w->allocator()->free();
+            }
+        };
+
+        // Pointer to current weights
+        const ITensor *cur_weights = _original_weights;
+
         // Reshape of the weights (happens only once)
         if(!_are_weights_reshaped)
         {
-            ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
             // Run reshape weights kernel and mark weights as unused
             _reshape_weights_output.allocator()->allocate();
             _reshape_weights_function.run();
-            _original_weights->mark_as_unused();
 
-            // Prepare GEMM prepare and release unused weights
-            if(!_is_quantized)
-            {
-                _mm_gemm.prepare();
-                if(!_reshape_weights_output.is_used())
-                {
-                    _reshape_weights_output.allocator()->free();
-                }
-            }
-
+            cur_weights->mark_as_unused();
+            cur_weights           = &_reshape_weights_output;
             _are_weights_reshaped = true;
         }
 
+        // Convert weights if needed (happens only once)
+        if(!_are_weights_converted)
+        {
+            _converted_weights_output.allocator()->allocate();
+            _convert_weights.run();
+
+            cur_weights->mark_as_unused();
+            _are_weights_converted = true;
+        }
+
+        // Release reshaped weights if unused
+        release_unused(&_reshape_weights_output);
+
+        // Prepare GEMM prepare and release unused weights
+        if(!_is_quantized)
+        {
+            _mm_gemm.prepare();
+        }
+
+        // Release converted weights if unused
+        release_unused(&_reshape_weights_output);
+        release_unused(&_converted_weights_output);
+
         _is_prepared = true;
     }
 }