Add support for non-constant weights and biases in CpuFullyConnected

Changing the approach for specifying that weights and biases tensors are
non-constant by making it a member of TensorInfo rather than an option
of the functions.

Resolves: COMPMID-4222, COMPMID-4811

Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: I9b0081ccbcf8271ce029ba6755563d64c59e1d32
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6313
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 378f104..ece9ca5 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -212,6 +212,9 @@
 
     /*** "Pretransposed" interface ***/
 
+    /* Compute col sums over all columns */
+    virtual void requantize_bias(void *, const To *, const int, const int) {};
+
     /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
     /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
     virtual void pretranspose_B_array(void *, const To *, const int, const int) {};
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 4133d9e..03c53b0 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -314,9 +314,14 @@
 
     if(_aux_mem[Pretranspose].size > 0)
     {
-        // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
-        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), MemoryLifetime::Prepare, _reshaped_weights.total_size());
-        _aux_mem[ConvertedWeights]  = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+        // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
+        // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
+        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric
+                                                                                     && biases && !(biases->are_values_constant())) ?
+                                                 MemoryLifetime::Persistent :
+                                                 MemoryLifetime::Prepare,
+                                                 _reshaped_weights.total_size());
+        _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
     }
     else
     {
@@ -334,10 +339,9 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(biases != nullptr && biases->num_dimensions() > 1);
     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!fc_info.constant_weights, "Non-constant weights are currently not supported");
+    ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
 
     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool is_fc_after_conv = true;
@@ -358,6 +362,19 @@
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = dst->dimension(1) > 1;
 
+    if(biases != nullptr)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+        if(is_data_type_quantized(src->data_type()))
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
+        }
+    }
+
     if(is_batched_fc_layer)
     {
         is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 97893b0..23095d8 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -206,6 +206,9 @@
     std::vector<TypeInput>           _indirect_pad{};
     arm_gemm::ConvolutionParameters  _cp{};
     experimental::MemoryRequirements _aux_mem{ Count };
+    bool                             _B_pretranspose_required{ false };
+    bool                             _is_b_constant{ true };
+    bool                             _is_c_constant{ true };
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -348,6 +351,10 @@
                                                              const OutputStage &os)
 {
     ARM_COMPUTE_UNUSED(c);
+
+    _is_b_constant = b->are_values_constant();
+    _is_c_constant = c ? c->are_values_constant() : true;
+
     arm_gemm::GemmConfig gemm_cfg;
     _kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
     if(_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
@@ -391,6 +398,7 @@
         const size_t       B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
         _pretranspose_info                     = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
         _aux_mem[Pretranspose]                 = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
+        _B_pretranspose_required               = true;
     }
 
     // Handle indirect GEMM convolution
@@ -485,6 +493,35 @@
         in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
     }
 
+    // If necessary, run pretranspose every time if either weights or biases are non-constant
+    if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
+    {
+        if(c && c->info()->data_type() == DataType::S32)
+        {
+            _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
+        }
+
+        // Pretranspose B if required
+        if(_B_pretranspose_required)
+        {
+            const int  ldb            = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+            const auto b_ptr          = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+            const int  multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+
+            CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
+            ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
+
+            if(_is_b_constant)
+            {
+                _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+            }
+            else
+            {
+                _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+            }
+        }
+    }
+
     const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
 
     // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads