COMPMID-1386: Add support for converting weights for CL.

Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 273ef96..ccd7813 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -73,8 +73,9 @@
 }
 
 CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
-      _im2col_output(), _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr)
+    : _memory_group(memory_manager), _im2col_kernel(), _convert_weights(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
+      _accumulate_biases_kernel(), _im2col_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
+      _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
 {
 }
 
@@ -112,7 +113,7 @@
 
     // Initialize output tensor for im2col
     TensorShape shape_im2col = compute_im2col_fc_shape(input->info());
-    _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
+    _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col).set_data_layout(DataLayout::NCHW));
 
     // Configure im2col kernel
     _memory_group.manage(&_im2col_output);
@@ -134,8 +135,8 @@
     configure_mm(input, weights, output);
 }
 
-void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped,
-                                      bool retain_internal_weights)
+void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
+                                      FullyConnectedLayerInfo fc_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
 
@@ -144,15 +145,15 @@
                                                                weights->info(),
                                                                biases != nullptr ? biases->info() : nullptr,
                                                                output->info(),
-                                                               transpose_weights,
-                                                               are_weights_reshaped,
-                                                               retain_internal_weights));
+                                                               fc_info));
 
-    _are_weights_reshaped = transpose_weights ? are_weights_reshaped : true;
-    _is_fc_after_conv     = true;
-    _accumulate_biases    = false;
-    _is_quantized         = is_data_type_quantized_asymmetric(input->info()->data_type());
-    _original_weights     = weights;
+    _are_weights_converted = true;
+    _are_weights_reshaped  = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+    _is_fc_after_conv      = true;
+    _accumulate_biases     = false;
+    _is_quantized          = is_data_type_quantized_asymmetric(input->info()->data_type());
+    _is_prepared           = false;
+    _original_weights      = weights;
 
     // Configure gemmlowp output
     if(_is_quantized)
@@ -172,25 +173,16 @@
         _accumulate_biases_kernel.configure(output, biases);
     }
 
+    const ICLTensor *weights_to_use = weights;
+
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
     //  2) Fully Connected layer -> Fully Connected layer without batches
     //  3) Convolution layer -> Fully Connected layer with batches
     //  4) Fully Connected layer -> Fully Connected layer with batches
 
-    const ICLTensor *weights_to_use = weights;
-
-    if(!_are_weights_reshaped)
-    {
-        weights_to_use = &_reshape_weights_output;
-
-        // Reshape the weights
-        _reshape_weights_kernel.configure(weights, &_reshape_weights_output);
-    }
-
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = output->info()->dimension(1) > 1;
-
     if(is_batched_fc_layer)
     {
         _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
@@ -202,6 +194,28 @@
         _is_fc_after_conv = input->info()->num_dimensions() > 1;
     }
 
+    // Reshape weights if needed
+    if(!_are_weights_reshaped)
+    {
+        // Reshape the weights
+        _reshape_weights_kernel.configure(weights, &_reshape_weights_output);
+        weights_to_use = &_reshape_weights_output;
+    }
+
+    // Convert weights if needed
+    if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
+    {
+        // Convert weights
+        _convert_weights.configure(weights_to_use,
+                                   &_converted_weights_output,
+                                   input->info()->tensor_shape(),
+                                   fc_info.weights_trained_layout);
+
+        weights_to_use         = &_converted_weights_output;
+        _are_weights_converted = false;
+    }
+
+    // Configure fc core
     ICLTensor *tmp_output = (_is_quantized) ? &_gemmlowp_output : output;
     if(_is_fc_after_conv)
     {
@@ -224,26 +238,26 @@
         _gemmlowp_output.allocator()->allocate();
     }
 
-    _are_weights_reshaped = _are_weights_reshaped || retain_internal_weights;
+    _are_weights_reshaped = _are_weights_reshaped || fc_info.retain_internal_weights;
 }
 
-Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped,
-                                       bool retain_internal_weights)
+Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+                                       FullyConnectedLayerInfo fc_info)
 {
-    ARM_COMPUTE_UNUSED(retain_internal_weights);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
 
-    bool            weights_reshaped = transpose_weights ? are_weights_reshaped : true;
+    bool            weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool            is_fc_after_conv = true;
     bool            is_quantized     = is_data_type_quantized_asymmetric(input->data_type());
     const GPUTarget gpu_target       = CLScheduler::get().target();
 
-    const ITensorInfo &im2col_input     = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)));
-    const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
-    const ITensorInfo &gemmlowp_output  = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
+    const ITensorInfo &im2col_input      = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)).set_data_layout(DataLayout::NCHW));
+    const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
+    const ITensorInfo &converted_weights = TensorInfo(reshaped_weights.clone()->set_is_resizable(true).reset_padding());
+    const ITensorInfo &gemmlowp_output   = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32));
 
     // Configure accumulate biases kernel for non quantized asymmetric types
     if(biases != nullptr && !is_quantized)
@@ -262,16 +276,8 @@
     const ITensorInfo *weights_to_use = weights;
     const ITensorInfo *tmp_output     = (is_quantized) ? &gemmlowp_output : output;
 
-    if(!weights_reshaped)
-    {
-        // Validate reshape weights kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
-        weights_to_use = &reshaped_weights;
-    }
-
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = output->dimension(1) > 1;
-
     if(is_batched_fc_layer)
     {
         is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->tensor_shape().cbegin() + 3,
@@ -283,6 +289,23 @@
         is_fc_after_conv = input->num_dimensions() > 1;
     }
 
+    if(!weights_reshaped)
+    {
+        // Validate reshape weights kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
+        weights_to_use = &reshaped_weights;
+    }
+
+    if(is_fc_after_conv && (input->data_layout() != fc_info.weights_trained_layout))
+    {
+        // Validate convert weights kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLConvertFullyConnectedWeights::validate(weights_to_use,
+                                                                             &converted_weights,
+                                                                             input->tensor_shape(),
+                                                                             fc_info.weights_trained_layout));
+        weights_to_use = &converted_weights;
+    }
+
     if(is_fc_after_conv)
     {
         // Fully Connected layer after a Convolution Layer without batches
@@ -349,27 +372,57 @@
 
 void CLFullyConnectedLayer::prepare()
 {
-    // Reshape of the weights (happens only once)
-    if(!_are_weights_reshaped)
+    if(!_is_prepared)
     {
         ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
 
-        // Run reshape weights kernel and mark weights as unused
-        _reshape_weights_output.allocator()->allocate();
-        _reshape_weights_kernel.run();
-        _original_weights->mark_as_unused();
+        auto release_unused = [](CLTensor * w)
+        {
+            if(!w->is_used())
+            {
+                CLScheduler::get().queue().finish();
+                w->allocator()->free();
+            }
+        };
+
+        // Pointer to current weights
+        const ICLTensor *cur_weights = _original_weights;
+
+        // Reshape of the weights if needed (happens only once)
+        if(!_are_weights_reshaped)
+        {
+            // Run reshape weights kernel and mark weights as unused
+            _reshape_weights_output.allocator()->allocate();
+            _reshape_weights_kernel.run();
+
+            cur_weights->mark_as_unused();
+            cur_weights           = &_reshape_weights_output;
+            _are_weights_reshaped = true;
+        }
+
+        // Convert weights if needed (happens only once)
+        if(!_are_weights_converted)
+        {
+            _converted_weights_output.allocator()->allocate();
+            _convert_weights.run();
+
+            cur_weights->mark_as_unused();
+            _are_weights_converted = true;
+        }
+
+        // Release reshaped weights if unused
+        release_unused(&_reshape_weights_output);
 
         // Prepare GEMM prepare and release unused weights
         if(!_is_quantized)
         {
             _mm_gemm.prepare();
-            if(!_reshape_weights_output.is_used())
-            {
-                _reshape_weights_output.allocator()->free();
-            }
         }
 
-        CLScheduler::get().queue().finish();
-        _are_weights_reshaped = true;
+        // Release converted weights if unused
+        release_unused(&_reshape_weights_output);
+        release_unused(&_converted_weights_output);
+
+        _is_prepared = true;
     }
 }