Add dynamic weights for CPU fully connected layer

Resolves: COMPMID-5917
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 1e1598a..af63015 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -166,7 +166,8 @@
       _is_prepared(false),
       _enable_fast_math(false),
       _fixed_format(false),
-      _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
+      _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
+      _dynamic_weights(false)
 {
 }
 
@@ -189,7 +190,7 @@
         const Status            status = get_gemmlowp_output_stage_info(&src_info, &weights_info, dst, act, gemmlowp_output_stage_info);
         ARM_COMPUTE_ERROR_ON(status.error_code() != ErrorCode::OK);
 
-        GEMMInfo gemm_info;
+        GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
         gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
         gemm_info.set_activation_info(act);
         gemm_info.set_fast_math(_enable_fast_math);
@@ -199,7 +200,7 @@
     else
     {
         // Configure matrix multiply kernel
-        GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+        GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
         gemm_info.set_activation_info(act);
         gemm_info.set_fast_math(_enable_fast_math);
         gemm_info.set_fixed_format(_fixed_format);
@@ -256,6 +257,7 @@
     _enable_fast_math         = fc_info.enable_fast_math;
     _fixed_format             = weights_info.weight_format() != WeightFormat::UNSPECIFIED;
     _weight_format            = weights_info.weight_format();
+    _dynamic_weights          = !weights->are_values_constant() && _needs_weights_reshape;
 
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
@@ -329,15 +331,32 @@
     {
         // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
         // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
-        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases
-                                                                                     && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare,
-                                                 _reshaped_weights.total_size());
-        _aux_mem[ConvertedWeights]  = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+        // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time.
+        _aux_mem[TransposedWeights] = MemoryInfo(
+            offset_int_vec(TransposedWeights),
+            _dynamic_weights                                                         ? MemoryLifetime::Temporary  :
+            (_is_quantized_asymmetric && biases && !(biases->are_values_constant())) ? MemoryLifetime::Persistent :
+                                                                                       MemoryLifetime::Prepare,
+            _reshaped_weights.total_size());
+
+        _aux_mem[ConvertedWeights] = MemoryInfo(
+            offset_int_vec(ConvertedWeights),
+            _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
+            _converted_weights.total_size());
     }
     else
     {
-        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), _needs_weights_conversion ? MemoryLifetime::Prepare : MemoryLifetime::Persistent, _reshaped_weights.total_size());
-        _aux_mem[ConvertedWeights]  = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Persistent, _converted_weights.total_size());
+        _aux_mem[TransposedWeights] = MemoryInfo(
+            offset_int_vec(TransposedWeights),
+            _dynamic_weights          ? MemoryLifetime::Temporary :
+            _needs_weights_conversion ? MemoryLifetime::Prepare   :
+                                        MemoryLifetime::Persistent,
+            _reshaped_weights.total_size());
+
+        _aux_mem[ConvertedWeights] = MemoryInfo(
+            offset_int_vec(ConvertedWeights),
+            _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Persistent,
+            _converted_weights.total_size());
     }
     _aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size());
 }
@@ -375,7 +394,6 @@
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
-    ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
 
     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool is_fc_after_conv = true;
@@ -459,6 +477,11 @@
 {
     prepare(tensors);
 
+#ifdef ARM_COMPUTE_ASSERTS_ENABLED
+    ++_asrt_run_count;
+    ARM_COMPUTE_ERROR_ON(_dynamic_weights && _asrt_prepare_count != _asrt_run_count);
+#endif // ARM_COMPUTE_ASSERTS_ENABLED
+
     auto src = tensors.get_const_tensor(ACL_SRC_0);
 
     CpuAuxTensorHandler flattened_src(offset_int_vec(FlattenedSrc), _flattened_src, tensors, false);
@@ -491,8 +514,13 @@
 
 void CpuFullyConnected::prepare(ITensorPack &tensors)
 {
-    if(!_is_prepared)
+    if(!_is_prepared || _dynamic_weights)
     {
+#ifdef ARM_COMPUTE_ASSERTS_ENABLED
+        ++_asrt_prepare_count;
+        ARM_COMPUTE_ERROR_ON(!_dynamic_weights && _asrt_prepare_count > 1);
+#endif // ARM_COMPUTE_ASSERTS_ENABLED
+
         auto weights = tensors.get_const_tensor(ACL_SRC_1);
 
         CpuAuxTensorHandler reshaped_weights(offset_int_vec(TransposedWeights), _reshaped_weights, tensors, false);