Fix dynamic weights for CPU connected layer

Resolves: COMPMID-5995
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I707b8918bebee7e70d4de5207ef555c806e7a305
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9405
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index af63015..70584a6 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -136,7 +136,7 @@
     }
     else
     {
-        GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+        GEMMInfo gemm_info;
         gemm_info.set_weight_format(weight_format);
         gemm_info.set_fixed_format(weight_format != WeightFormat::UNSPECIFIED);
         gemm_info.set_fast_math(enable_fast_math);
@@ -190,7 +190,7 @@
         const Status            status = get_gemmlowp_output_stage_info(&src_info, &weights_info, dst, act, gemmlowp_output_stage_info);
         ARM_COMPUTE_ERROR_ON(status.error_code() != ErrorCode::OK);
 
-        GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
+        GEMMInfo gemm_info;
         gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
         gemm_info.set_activation_info(act);
         gemm_info.set_fast_math(_enable_fast_math);
@@ -200,7 +200,7 @@
     else
     {
         // Configure matrix multiply kernel
-        GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
+        GEMMInfo gemm_info;
         gemm_info.set_activation_info(act);
         gemm_info.set_fast_math(_enable_fast_math);
         gemm_info.set_fixed_format(_fixed_format);
@@ -284,6 +284,8 @@
         // Reshape the weights
         _transpose_weights = std::make_unique<kernels::CpuTransposeKernel>();
         _transpose_weights->configure(weights, &_reshaped_weights);
+        _reshaped_weights.set_are_values_constant(weights->are_values_constant());
+
         weights_to_use     = &_reshaped_weights;
         _trans_weights_idx = AuxTensorIdx::TransposedWeights;
     }
@@ -297,6 +299,7 @@
                                     &_converted_weights,
                                     src->tensor_shape(),
                                     fc_info.weights_trained_layout);
+        _converted_weights.set_are_values_constant(weights_to_use->are_values_constant());
 
         weights_to_use            = &_converted_weights;
         _needs_weights_conversion = true;
@@ -364,7 +367,7 @@
 Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
                                        const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info, WeightsInfo weights_info)
 {
-    GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+    GEMMInfo gemm_info;
     gemm_info.set_activation_info(fc_info.activation_info);
     gemm_info.set_fast_math(fc_info.enable_fast_math);
     gemm_info.set_fixed_format(weights_info.weight_format() != WeightFormat::UNSPECIFIED);
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f914bce..b9d18c4 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -65,11 +65,13 @@
 
     const cpu::AsmGemmInfo asm_info      = init_assembly_metadata(gemm_info);
     const bool             is_c_bias     = beta == 1 && c != nullptr;
-    bool                   run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
+    bool                   run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) &&
+                                           (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
+                                           !(!b->are_values_constant() && b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
 
     // Check if we need to reshape the matrix B only on the first run
     _is_prepared                      = false;
-    _reshape_b_only_on_first_run      = gemm_info.reshape_b_only_on_first_run();
+    _reshape_b_only_on_first_run      = b->are_values_constant();
     _run_vector_matrix_multiplication = a->dimension(1) < 2;
     _run_alpha_scale                  = alpha != 1.f;
     _run_bias_addition                = is_c_bias;
@@ -211,7 +213,9 @@
 
     // Check if we need to run the optimized assembly kernel
     cpu::AsmGemmInfo asm_info      = init_assembly_metadata(gemm_info);
-    const bool       run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info));
+    const bool       run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info)) &&
+                                     (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
+                                     !(!b->are_values_constant() && b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
 
     if(!run_optimised)
     {
@@ -221,7 +225,7 @@
         // Check if the first input tensor is a vector.
         const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
         // Check if we need to reshape the matrix A and matrix B
-        const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
+        const bool run_interleave_transpose = !run_vector_matrix_multiplication && !b->are_values_constant();
 
         // Arguments used by GEMMReshapeInfo
         // If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
@@ -259,7 +263,7 @@
         auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
         ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
 
-        if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
+        if(is_c_bias)
         {
             ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, d, ConvertPolicy::SATURATE));
         }
@@ -294,7 +298,7 @@
     {
         // Pass c to asm dispatch only if it's the bias tensor
         ITensorPack asm_pack = tensors;
-        asm_pack.add_const_tensor(ACL_SRC_2, (_reshape_b_only_on_first_run) ? c : nullptr);
+        asm_pack.add_const_tensor(ACL_SRC_2, _run_bias_addition ? c : nullptr);
         _asm_glue->run(asm_pack);
         if(_run_alpha_scale)
         {
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 9bf6ed1..ebf2ebc 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -169,7 +169,7 @@
     {
         // Configure matrix multiply function
         _mm_gemm = std::make_unique<CpuGemm>();
-        _mm_gemm->configure(src, weights, biases, dst, 1.0f, 0.0f, gemm_info);
+        _mm_gemm->configure(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
         auto mm_mem_req = _mm_gemm->workspace();
         for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
         {
@@ -235,7 +235,7 @@
     else
     {
         // Perform validation step on Matrix multiply function
-        return CpuGemm::validate(src, weights, nullptr, dst, 1.0f, 0.0f, gemm_info);
+        return CpuGemm::validate(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
     }
 }
 
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index aec9da1..8ca128f 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -65,7 +65,6 @@
     asm_info.activation_info             = info.activation_info();
     asm_info.output_stage                = info.gemmlowp_output_stage();
     asm_info.fast_mode                   = info.fast_math();
-    asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run();
 
     return asm_info;
 }
@@ -120,7 +119,7 @@
     _a_offset                         = a->quantization_info().uniform().offset;
     _b_offset                         = b->quantization_info().uniform().offset;
     _run_vector_matrix_multiplication = a->dimension(1) < 2;
-    _reshape_b_only_on_first_run      = info.reshape_b_only_on_first_run();
+    _reshape_b_only_on_first_run      = b->are_values_constant();
     _is_prepared                      = false;
     _fused_assembly_path              = false;
     _flip_signedness                  = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && _reshape_b_only_on_first_run;
@@ -167,31 +166,34 @@
     // Initialize assembly kernel meta-data
     const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
 #ifdef __aarch64__
-    switch(a->data_type())
+    if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
     {
-        case DataType::QASYMM8:
-        case DataType::QASYMM8_SIGNED:
-        case DataType::U8:
-        case DataType::S8:
+        switch(a->data_type())
         {
-            if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+            case DataType::QASYMM8:
+            case DataType::QASYMM8_SIGNED:
+            case DataType::U8:
+            case DataType::S8:
             {
-                auto c_info_to_use = c == nullptr ? nullptr : c;
-                _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info);
-                _fused_assembly_path = _asm_glue->is_configured();
+                if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+                {
+                    auto c_info_to_use = c == nullptr ? nullptr : c;
+                    _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info);
+                    _fused_assembly_path = _asm_glue->is_configured();
+                }
+                else
+                {
+                    auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst);
+                    _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info);
+                }
+                _assembly_path = _asm_glue->is_configured();
+                break;
             }
-            else
+            default:
             {
-                auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst);
-                _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info);
+                ARM_COMPUTE_ERROR("Datatype not supported");
+                break;
             }
-            _assembly_path = _asm_glue->is_configured();
-            break;
-        }
-        default:
-        {
-            ARM_COMPUTE_ERROR("Datatype not supported");
-            break;
         }
     }
 #endif /* __aarch64__ */
@@ -371,14 +373,18 @@
     // Check if we need to run the optimized assembly kernel
     bool run_optimised             = false;
     bool run_optimised_requantized = false;
-    if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+
+    if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
     {
-        run_optimised             = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
-        run_optimised_requantized = run_optimised;
-    }
-    else
-    {
-        run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
+        if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+        {
+            run_optimised             = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
+            run_optimised_requantized = run_optimised;
+        }
+        else
+        {
+            run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
+        }
     }
 
     if(run_optimised)