Add support for non-constant weights and biases in CpuFullyConnected

Changing the approach for specifying that weights and biases tensors are
non-constant by making it a member of TensorInfo rather than an option
of the functions.

Resolves: COMPMID-4222

Change-Id: I96e6f3868f51785c9700a3ef6a1fe7b05747862c
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6162
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h
index 0171e31..bc3a6be 100644
--- a/arm_compute/core/ITensorInfo.h
+++ b/arm_compute/core/ITensorInfo.h
@@ -240,6 +240,11 @@
      * @return True if its dynamic else false
      */
     virtual bool is_dynamic() const = 0;
+    /** Flag indicating whether the values of the tensor are constant, meaning that they can change on kernel/function execution.
+     *
+     * @return True if values are constant else false
+     */
+    virtual bool are_values_constant() const = 0;
     /** Set the flag whether the tensor size can be changed.
      *
      * @param[in] is_resizable Flag that marks the tensor if it can be changed or not.
@@ -247,6 +252,13 @@
      * @return Reference to this ITensorInfo object
      */
     virtual ITensorInfo &set_is_resizable(bool is_resizable) = 0;
+    /** Set the flag whether the tensor values can change during kernel/function execution.
+     *
+     * @param[in] are_values_constant Flag that marks the tensor values if they can be changed or not.
+     *
+     * @return Reference to this ITensorInfo object
+     */
+    virtual ITensorInfo &set_are_values_constant(bool are_values_constant) = 0;
     /** Valid region of the tensor. All elements in the valid region have defined values, i.e. are not undefined.
      *
      * @return The valid region.
diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h
index 1b2278d..54836d0 100644
--- a/arm_compute/core/SubTensorInfo.h
+++ b/arm_compute/core/SubTensorInfo.h
@@ -196,12 +196,23 @@
         ARM_COMPUTE_ERROR_ON(_parent == nullptr);
         return _parent->is_dynamic();
     }
+    bool are_values_constant() const override
+    {
+        ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+        return _parent->are_values_constant();
+    }
     ITensorInfo &set_is_resizable(bool is_resizable) override
     {
         ARM_COMPUTE_ERROR_ON(_parent == nullptr);
         _parent->set_is_resizable(is_resizable);
         return *this;
     }
+    ITensorInfo &set_are_values_constant(bool are_values_constant) override
+    {
+        ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+        _parent->set_are_values_constant(are_values_constant);
+        return *this;
+    }
     ValidRegion valid_region() const override
     {
         return _valid_region;
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h
index a433084..9bc8680 100644
--- a/arm_compute/core/TensorInfo.h
+++ b/arm_compute/core/TensorInfo.h
@@ -267,6 +267,10 @@
     {
         return std::find(std::cbegin(_dims_state), std::cend(_dims_state), get_dynamic_state_value()) != std::cend(_dims_state);
     }
+    bool are_values_constant() const override
+    {
+        return _are_values_constant;
+    }
     ITensorInfo &set_is_resizable(bool is_resizable) override
     {
         _is_resizable = is_resizable;
@@ -288,6 +292,11 @@
     {
         return _data_layout;
     }
+    ITensorInfo &set_are_values_constant(bool are_values_constant) override
+    {
+        _are_values_constant = are_values_constant;
+        return *this;
+    }
 
 private:
     /** Calculates strides, offset and total size resulting from the specified padding around the XY plane.
@@ -309,6 +318,7 @@
     PaddingSize      _padding;
     QuantizationInfo _quantization_info;
     DataLayout       _data_layout;
+    bool             _are_values_constant;
 };
 } // namespace arm_compute
 #endif /*ARM_COMPUTE_TENSORINFO_H */
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 9c00cbc..36b77b8 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1544,7 +1544,6 @@
     bool       transpose_weights{ true };                  /**<  Transpose weights if true. */
     bool       are_weights_reshaped{ false };              /**<  Reshape the weights tensor if false. */
     bool       retain_internal_weights{ false };           /**<  Retain internal reshaped weights. */
-    bool       constant_weights{ true };                   /**<  If false, weights can vary between runs. */
     /* Other parameters */
     bool fp_mixed_precision{ false }; /**<  Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */
 
@@ -1951,9 +1950,8 @@
           _fast_math(false),
           _fp_mixed_precision(false),
           _broadcast_bias(false),
-          _pretranpose_B(true),
-          _activation_info(),
-          _constant_weights(true)
+          _pretranspose_B(true),
+          _activation_info()
     {
     }
     /** Constructor
@@ -1971,11 +1969,10 @@
      * @param[in] fast_math                   (Optional) Use a data type of shorter width to improve performance
      * @param[in] broadcast_bias              (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
      * @param[in] activation_info             (Optional) Activation to apply after the matrix multiplication
-     * @param[in] constant_weights            (Optional) Weights have constant values throughout multiple executions
      */
     GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
              GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
-             const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool constant_weights = true) noexcept
+             const ActivationLayerInfo &activation_info = ActivationLayerInfo()) noexcept
         : _is_a_reshaped(is_a_reshaped),
           _is_b_reshaped(is_b_reshaped),
           _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -1986,9 +1983,8 @@
           _fast_math(fast_math),
           _fp_mixed_precision(fp_mixed_precision),
           _broadcast_bias(broadcast_bias),
-          _pretranpose_B(reshape_b_only_on_first_run),
-          _activation_info(activation_info),
-          _constant_weights(constant_weights)
+          _pretranspose_B(reshape_b_only_on_first_run),
+          _activation_info(activation_info)
     {
     }
     /** Flag which specifies if the matrix A has been reshaped
@@ -2085,17 +2081,17 @@
      *
      * @return True if b should be pre-transposed else false.
      */
-    bool pretranpose_B() const
+    bool pretranspose_B() const
     {
-        return _pretranpose_B;
+        return _pretranspose_B;
     };
     /** Set pre-transpose b flag
      *
      * @param[in] flag Flag to set
      */
-    void set_pretranpose_B(bool flag)
+    void set_pretranspose_B(bool flag)
     {
-        _pretranpose_B = flag;
+        _pretranspose_B = flag;
     }
     /** Activation layer to apply after the matrix multiplication
      *
@@ -2113,14 +2109,6 @@
     {
         _activation_info = activation_info;
     }
-    /** Flag which specifies if the values of the weights tensor are constant throughout multiple executions or not
-     *
-     * @return True if the weights tensor is constant
-     */
-    bool constant_weights() const
-    {
-        return _constant_weights;
-    };
 
 private:
     bool                    _is_a_reshaped;
@@ -2133,9 +2121,8 @@
     bool                    _fast_math;
     bool                    _fp_mixed_precision;
     bool                    _broadcast_bias;
-    bool                    _pretranpose_B;
+    bool                    _pretranspose_B;
     ActivationLayerInfo     _activation_info;
-    bool                    _constant_weights;
 };
 
 /** Winograd information */
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 5cbdf20..20c8230 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -523,7 +523,7 @@
         return size;
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
         if (std::is_same<OutputStage, Requantize32>::value) {
             _col_bias = reinterpret_cast<int32_t *>(in_buffer);
 
@@ -534,6 +534,10 @@
                 compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize * _args._Ksections, B + (i * B_multi_stride), ldb, _col_bias + (i * _args._Nsize), _args._Ksize * _args._Ksections, i, 0);
             }
         }
+    }
+
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+        requantize_bias(in_buffer, B, ldb, B_multi_stride);
 
         // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index c72dca2..efb5bd1 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -269,12 +269,16 @@
         return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi));
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
         col_bias = reinterpret_cast<int32_t *>(in_buffer);
 
         for (unsigned int i=0; i<_nmulti; i++) {
             compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize),  _Ksize, i, 0);
         }
+    }
+
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+        requantize_bias(in_buffer, B, ldb, B_multi_stride);
 
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
         Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp
index 7376b5f..e84b58d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized_inline.hpp
@@ -219,12 +219,16 @@
         return get_col_sum_size() + (roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi));
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
         col_bias = reinterpret_cast<int32_t *>(in_buffer);
 
         for (unsigned int i=0; i<_nmulti; i++) {
             compute_col_sums(_qp, _Nsize, _Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize),  _Ksize, i, 0);
         }
+    }
+
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+        requantize_bias(in_buffer, B, ldb, B_multi_stride);
 
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
         Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 5639cb4..c75c320 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -923,7 +923,7 @@
         return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size();
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
         if (std::is_same<OutputStage, Requantize32>::value) {
             col_bias = reinterpret_cast<int32_t *>(in_buffer);
 
@@ -934,6 +934,10 @@
                 compute_col_sums(*qp_ptr, _Nsize, _Ksize * _Ksections, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize * _Ksections, i, 0);
             }
         }
+    }
+
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+        requantize_bias(in_buffer, B, ldb, B_multi_stride);
 
         // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index d4348be..f0b4e5d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -201,11 +201,11 @@
         return _buffer_per_multi * _args._nmulti * sizeof(To) + get_col_sum_size();
     }
 
-    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
         // Column sums go on the front of the pretransposed buffer in requantized cases.
         // We could optimize here in case we don't actually need to sum the columns, but this code is only run on setup.
         if (std::is_same<OutputStage, Requantize32>::value) {
-            col_bias = reinterpret_cast<int32_t *>(buffer);
+            col_bias = reinterpret_cast<int32_t *>(in_buffer);
 
             Requantize32 *qp_ptr = reinterpret_cast<Requantize32 *>(&_os);
 
@@ -213,6 +213,10 @@
                 compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _args._Nsize), _args._Ksize, i, 0);
             }
         }
+    }
+
+    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+        requantize_bias(buffer, B, ldb, B_multi_stride);
 
         // The actual transposed buffer goes after the column sums (if any)
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index 1e2a9ac..ce72703 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -179,13 +179,16 @@
         return _subgemm->get_B_pretransposed_array_size() + col_sum_size();
     }
 
+    void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+        _col_sums = reinterpret_cast<int32_t *>(in_buffer);
+        col_sums_pretransposed(B, ldb, B_multi_stride);
+    }
+
     void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
         _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride);
 
-        _col_sums = reinterpret_cast<int32_t *>(buffer);
-
-        col_sums_pretransposed(B, ldb, B_multi_stride);
+        requantize_bias(buffer, B, ldb, B_multi_stride);
     }
 
     void set_pretransposed_B_data(void *buffer) override {
diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp
index c471615..e441ddb 100644
--- a/src/core/TensorInfo.cpp
+++ b/src/core/TensorInfo.cpp
@@ -31,11 +31,11 @@
 
 #include <memory>
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 TensorInfo::TensorInfo()
     : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _dims_state(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true },
-      _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW)
+      _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW), _are_values_constant(true)
 {
 }
 
@@ -55,6 +55,7 @@
     _padding                       = info.padding();
     _quantization_info             = info.quantization_info();
     _data_layout                   = info.data_layout();
+    _are_values_constant           = info.are_values_constant();
 }
 
 TensorInfo::TensorInfo(Format format)
@@ -377,3 +378,4 @@
 
     return offset;
 }
+} // namespace arm_compute
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 378f104..ece9ca5 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -212,6 +212,9 @@
 
     /*** "Pretransposed" interface ***/
 
+    /* Compute col sums over all columns */
+    virtual void requantize_bias(void *, const To *, const int, const int) {};
+
     /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
     /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
     virtual void pretranspose_B_array(void *, const To *, const int, const int) {};
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index cafb348..d952724 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -312,9 +312,14 @@
 
     if(_aux_mem[Pretranspose].size > 0)
     {
-        // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
-        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), MemoryLifetime::Prepare, _reshaped_weights.total_size());
-        _aux_mem[ConvertedWeights]  = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+        // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
+        // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
+        _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric
+                                                                                     && !(biases->are_values_constant())) ?
+                                                 MemoryLifetime::Persistent :
+                                                 MemoryLifetime::Prepare,
+                                                 _reshaped_weights.total_size());
+        _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
     }
     else
     {
@@ -332,10 +337,9 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights, dst);
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
-    ARM_COMPUTE_RETURN_ERROR_ON(biases != nullptr && biases->num_dimensions() > 1);
     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!fc_info.constant_weights, "Non-constant weights are currently not supported");
+    ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
 
     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool is_fc_after_conv = true;
@@ -356,6 +360,19 @@
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = dst->dimension(1) > 1;
 
+    if(biases != nullptr)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+        if(is_data_type_quantized(src->data_type()))
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
+        }
+    }
+
     if(is_batched_fc_layer)
     {
         is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 97893b0..1dd6286 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -206,6 +206,7 @@
     std::vector<TypeInput>           _indirect_pad{};
     arm_gemm::ConvolutionParameters  _cp{};
     experimental::MemoryRequirements _aux_mem{ Count };
+    bool                             _B_pretranspose_required{ false };
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -391,6 +392,7 @@
         const size_t       B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
         _pretranspose_info                     = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
         _aux_mem[Pretranspose]                 = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
+        _B_pretranspose_required               = true;
     }
 
     // Handle indirect GEMM convolution
@@ -485,6 +487,35 @@
         in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
     }
 
+    // If necessary, run pretranspose every time if either weights or biases are non-constant
+    if((b && !b->info()->are_values_constant()) || (c && !c->info()->are_values_constant() && c->info()->data_type() == DataType::S32))
+    {
+        if(c && c->info()->data_type() == DataType::S32)
+        {
+            _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
+        }
+
+        // Pretranspose B if required
+        if(_B_pretranspose_required)
+        {
+            const int  ldb            = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+            const auto b_ptr          = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+            const int  multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+
+            CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
+            ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
+
+            if(b->info()->are_values_constant())
+            {
+                _gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+            }
+            else
+            {
+                _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+            }
+        }
+    }
+
     const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
 
     // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp
index 8b7e336..bd2fdda 100644
--- a/src/gpu/cl/operators/ClFullyConnected.cpp
+++ b/src/gpu/cl/operators/ClFullyConnected.cpp
@@ -169,8 +169,7 @@
                                          fc_info.fp_mixed_precision,      // fp_mixed_precision
                                          false,                           // fast_math
                                          true,                            // broadcast_bias
-                                         fc_info.activation_info,         // activation_info
-                                         fc_info.constant_weights);       // constant_weights
+                                         fc_info.activation_info);        // activation_info
 
     if(_is_quantized)
     {
@@ -333,7 +332,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
-    ARM_COMPUTE_RETURN_ERROR_ON(!fc_info.constant_weights && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
+    ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
 
     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
     bool is_fc_after_conv = true;
@@ -351,6 +350,19 @@
     const ITensorInfo *src_to_use     = src;
     const ITensorInfo *weights_to_use = weights;
 
+    if(biases != nullptr)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+        if(is_data_type_quantized(src->data_type()))
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
+        }
+    }
+
     // Check if we have a fully connected layer with batches
     const bool is_batched_fc_layer = dst->dimension(1) > 1;
     if(is_batched_fc_layer)
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index 625c057..292f531 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -574,7 +574,7 @@
 
     // Select GEMMType
     _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
-                                                gemm_info.constant_weights());
+                                                b->are_values_constant());
 
     const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
 
@@ -623,7 +623,7 @@
     {
         CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
     },
-    gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights());
+    gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
 
     const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
 
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index 413250f..5639fb4 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -290,6 +290,10 @@
 using NEFullyConnectedLayerFixture = FullyConnectedLayerValidationFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
 template <typename T>
 using NEFullyConnectedLayerMixedDataLayoutFixture = FullyConnectedLayerValidationFixture<Tensor, Accessor, NEFullyConnectedLayer, T, true>;
+template <typename T>
+using NEFullyConnectedLayerDynamicWeightsFixture = FullyConnectedWithDynamicWeightsFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
+template <typename T>
+using NEFullyConnectedLayerDynamicBiasFixture = FullyConnectedWithDynamicBiasFixture<Tensor, Accessor, NEFullyConnectedLayer, T>;
 
 TEST_SUITE(Float)
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -358,6 +362,11 @@
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32);
 }
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+                       framework::dataset::make("DataType", DataType::F32)),
+                       framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
+{
+}
 TEST_SUITE_END()
 TEST_SUITE_END()
 
@@ -413,6 +422,12 @@
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
 }
+
+FIXTURE_DATA_TEST_CASE(RunDynamicBias, NEFullyConnectedLayerDynamicBiasFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+                       framework::dataset::make("DataType", DataType::QASYMM8)),
+                       framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
+{
+}
 TEST_SUITE_END()
 TEST_SUITE(QASYMM8_SIGNED)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 7d76764..ccd9182 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -273,7 +273,7 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class FullyConnectedWithDynamicWeightsFixture : public framework::Fixture
+class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture
 {
 private:
     template <typename U>
@@ -289,6 +289,16 @@
             std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
             library->fill(tensor, distribution, i);
         }
+        else if(_data_type == DataType::QASYMM8)
+        {
+            std::uniform_int_distribution<uint8_t> distribution(0, 30);
+            library->fill(tensor, distribution, i);
+        }
+        else if(_data_type == DataType::S32)
+        {
+            std::uniform_int_distribution<int32_t> distribution(-50, 50);
+            library->fill(tensor, distribution, i);
+        }
         else
         {
             library->fill_tensor_uniform(tensor, i);
@@ -324,6 +334,11 @@
             constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f);
             validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
         }
+        else if(_data_type == DataType::QASYMM8)
+        {
+            constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
+            validate(AccessorType(target), ref, tolerance_qasymm8);
+        }
         else
         {
             validate(AccessorType(target), ref);
@@ -331,32 +346,51 @@
     }
 
 public:
+    using TDecay = typename std::decay<T>::type;
+    using TBias  = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type;
+
     template <typename...>
     void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
-               DataType data_type, ActivationLayerInfo activation_info)
+               DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias)
     {
         _data_type = data_type;
 
+        const bool is_quantized = is_data_type_quantized(data_type);
+
+        const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type;
+
+        const QuantizationInfo src_qinfo     = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo();
+        const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo();
+        const QuantizationInfo dst_qinfo     = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo();
+
         // Setup tensor meta-data
-        TensorInfo src_info(src_shape, 1, data_type);
+        const TensorInfo src_info(src_shape, 1, data_type, src_qinfo);
         _src.allocator()->init(src_info);
 
-        TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
-        TensorInfo  wei_info(tr_weights_shape, 1, data_type);
+        TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
+        if(!constant_weights)
+        {
+            const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
+            wei_info.set_tensor_shape(tr_weights_shape);
+        }
+        wei_info.set_are_values_constant(constant_weights);
         _weights.allocator()->init(wei_info);
 
-        TensorInfo bias_info(bias_shape, 1, data_type);
+        TensorInfo bias_info(bias_shape, 1, bias_data_type);
+        bias_info.set_are_values_constant(constant_bias);
         _bias.allocator()->init(bias_info);
 
-        TensorInfo dst_info(dst_shape, 1, data_type);
+        const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
         _dst.allocator()->init(dst_info);
 
         // Configure FC layer and mark the weights as non constant
         FullyConnectedLayerInfo fc_info;
-        fc_info.activation_info      = activation_info;
-        fc_info.are_weights_reshaped = true;
-        fc_info.transpose_weights    = false;
-        fc_info.constant_weights     = false;
+        fc_info.activation_info = activation_info;
+        if(!constant_weights)
+        {
+            fc_info.are_weights_reshaped = true;
+            fc_info.transpose_weights    = false;
+        }
         FunctionType fc;
         fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
 
@@ -369,29 +403,55 @@
         // Run multiple iterations with different inputs
         constexpr int num_iterations    = 5;
         int           randomizer_offset = 0;
+
+        // Create reference tensors
+        SimpleTensor<T>     src{ src_shape, data_type, 1, src_qinfo };
+        SimpleTensor<T>     weights{ weights_shape, data_type, 1, weights_qinfo };
+        SimpleTensor<TBias> bias{ bias_shape, bias_data_type };
+
+        // Fill weights and/or bias if they remain constant
+        if(constant_weights)
+        {
+            fill(AccessorType(_weights), 1);
+            fill(weights, 1);
+        }
+        if(constant_bias)
+        {
+            fill(AccessorType(_bias), 2);
+            fill(bias, 2);
+        }
+
         for(int i = 0; i < num_iterations; ++i)
         {
             // Run target
             {
                 fill(AccessorType(_src), randomizer_offset);
-                fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
-                fill(AccessorType(_bias), randomizer_offset + 2);
+                if(!constant_weights)
+                {
+                    fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+                }
+                if(!constant_bias)
+                {
+                    fill(AccessorType(_bias), randomizer_offset + 2);
+                }
 
                 fc.run();
             }
 
             // Run reference and compare
             {
-                SimpleTensor<T> src{ src_shape, data_type };
-                SimpleTensor<T> weights{ weights_shape, data_type };
-                SimpleTensor<T> bias{ bias_shape, data_type };
-
                 // Fill reference
                 fill(src, randomizer_offset);
-                fill(weights, randomizer_offset + 1);
-                fill(bias, randomizer_offset + 2);
+                if(!constant_weights)
+                {
+                    fill(weights, randomizer_offset + 1);
+                }
+                if(!constant_bias)
+                {
+                    fill(bias, randomizer_offset + 2);
+                }
 
-                auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape), activation_info);
+                auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape), activation_info, dst_qinfo);
 
                 // Validate
                 validate_with_tolerance(_dst, dst);
@@ -405,6 +465,32 @@
     TensorType _src{}, _weights{}, _bias{}, _dst{};
     DataType   _data_type{ DataType::UNKNOWN };
 };
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+    template <typename...>
+    void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
+               DataType data_type, ActivationLayerInfo activation_info)
+    {
+        FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
+                                                                                                  dst_shape, data_type, activation_info, false, true);
+    }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedWithDynamicBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+    template <typename...>
+    void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
+               DataType data_type, ActivationLayerInfo activation_info)
+    {
+        FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
+                                                                                                  dst_shape, data_type, activation_info, true, false);
+    }
+};
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 58ddb3f..248c973 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -1158,7 +1158,7 @@
     os << "retain_internal_weights=" << info.retain_internal_weights() << ",";
     os << "fp_mixed_precision=" << info.fp_mixed_precision() << ",";
     os << "broadcast_bias=" << info.broadcast_bias() << ",";
-    os << "pretranpose_B=" << info.pretranpose_B() << ",";
+    os << "pretranspose_B=" << info.pretranspose_B() << ",";
 
     return os;
 }