Add Bias to MatMul Kernels and add support for use in Fully Connected Layer

Resolves: [COMPMID-6316]
Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Change-Id: I08e6bac9e6b46b76978da0dc6a48ccfe3dde5086
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9833
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp
index b7ba8b8..0be3f0f 100644
--- a/src/gpu/cl/operators/ClFullyConnected.cpp
+++ b/src/gpu/cl/operators/ClFullyConnected.cpp
@@ -113,22 +113,25 @@
 
 Status validate_mm(const ITensorInfo &src, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &dst, const FullyConnectedLayerInfo &fc_info)
 {
-    // If weights are dynamic, data is not batched, and bias is nullptr validate using matmul.
-    const bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
-    const bool use_matmul       = !weights.are_values_constant() && !weights_reshaped && !(dst.dimension(1) > 1) && (bias == nullptr);
+    // Note : If input is dynamic and data is not batched, use matmul, else use gemm
+    const bool transpose_weights = fc_info.transpose_weights ? !fc_info.are_weights_reshaped : false;
+    const bool use_matmul        = !weights.are_values_constant() && !(dst.dimension(1) > 1);
+    const bool use_dynamic_gemm  = !use_matmul && !weights.are_values_constant() && transpose_weights; // use dynamic gemm as fallback for matmul
+    const bool is_quantized      = is_data_type_quantized_asymmetric(src.data_type());
 
     if(use_matmul)
     {
-        MatMulInfo m_info{};
-        m_info.adj_rhs(fc_info.transpose_weights);
+        const MatMulInfo m_info = MatMulInfo().adj_rhs(transpose_weights);
 
-        // Note: Currently, shape is [M, B0, B1]
-        // LHS is reshaped here to match ClMatMul expectations of batch index in format - [M, 1, B0, B1, .. ]
-        TensorInfo lhs_to_use{ src };
-        lhs_to_use.set_tensor_shape(get_reshaped_matmul_tensor(src.tensor_shape()));
+        // Note: LHS is reshaped here to match ClMatMul expectations of batch index - From [M, B0, B1] to [M, 1, B0, B1]
+        TensorInfo lhs_to_use = src.clone()->set_tensor_shape(get_reshaped_matmul_tensor(src.tensor_shape()));
 
-        // Operator level validation.
-        ARM_COMPUTE_RETURN_ON_ERROR(ClMatMul::validate(&lhs_to_use, &weights, &dst, m_info, fc_info.activation_info));
+        const GPUTarget                                         gpu_target  = CLScheduler::get().target();
+        std::unique_ptr<cl_matmul::IClMatMulNativeKernelConfig> t           = cl_matmul::ClMatMulNativeKernelConfigurationFactory::create(gpu_target);
+        const MatMulKernelInfo                                  kernel_info = t->configure(&lhs_to_use, &weights, m_info);
+
+        return is_quantized ? kernels::ClMatMulLowpNativeKernel::validate(&lhs_to_use, &weights, bias, &dst, kernel_info, fc_info.activation_info) :
+               kernels::ClMatMulNativeKernel::validate(&lhs_to_use, &weights, bias, &dst, kernel_info, fc_info.activation_info);
     }
     else
     {
@@ -137,7 +140,7 @@
 
         const GEMMInfo &gemm_info = GEMMInfo(false,                           // is_a_reshaped
                                              false,                           // is_b_reshaped
-                                             true,                            // reshape_b_only_on_first_run
+                                             !use_dynamic_gemm,               // reshape_b_only_on_first_run
                                              0,                               // depth_output_gemm3d
                                              false,                           // reinterpret_input_as_3d
                                              fc_info.retain_internal_weights, // retain_internal_weights
@@ -147,7 +150,7 @@
                                              true,                            // broadcast_bias
                                              ActivationLayerInfo());          // activation_info
 
-        if(is_data_type_quantized_asymmetric(src.data_type()))
+        if(is_quantized)
         {
             const UniformQuantizationInfo iq_info = src.quantization_info().uniform();
             const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
@@ -191,35 +194,33 @@
 void ClFullyConnected::configure_mm(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst,
                                     const FullyConnectedLayerInfo &fc_info)
 {
-    // If weights are dynamic, configure matmul operator - else use gemm
+    // If weights are dynamic and matmul is supported use matmul, else use gemm
     if(_use_matmul)
     {
-        // Transpose RHS as _are_weights_reshaped == false when mat_mul is used.
-        const MatMulInfo mat_info = MatMulInfo().adj_rhs(fc_info.transpose_weights);
+        // Specify whether transpose weights is necessary in matmul info
+        const MatMulInfo mat_info = MatMulInfo().adj_rhs(_transpose_weights);
 
         // Note: MatMul does not need offset negation unlike gemm
         // 1. Change shape when calling matmul to fit batch expectations.
-        _lhs_to_use = *src->clone();
-        _lhs_to_use.set_tensor_shape(get_reshaped_matmul_tensor(_lhs_to_use.tensor_shape())); // Collapse all dims > 2 into final dimension.
-        _is_quantized = is_data_type_quantized_asymmetric(_lhs_to_use.data_type());
+        _lhs_to_use = src->clone()->set_tensor_shape(get_reshaped_matmul_tensor(_lhs_to_use.tensor_shape()));
 
-        // 2. Call kernel for matmul directly.
+        // 2. Use heuristics to get kernel info object
         const GPUTarget                                         gpu_target    = CLScheduler::get().target();
         std::unique_ptr<cl_matmul::IClMatMulNativeKernelConfig> kernel_config = cl_matmul::ClMatMulNativeKernelConfigurationFactory::create(gpu_target);
+        MatMulKernelInfo                                        kernel_info   = kernel_config->configure(src, weights, mat_info);
 
-        // Configure relevant matmul kernel
-        MatMulKernelInfo kernel_info = kernel_config->configure(src, weights, mat_info);
+        // 3. Configure relevant matmul kernel
         if(_is_quantized)
         {
             _matmul_lowp_native_kernel = std::make_unique<kernels::ClMatMulLowpNativeKernel>();
             _matmul_lowp_native_kernel->set_target(gpu_target);
-            _matmul_lowp_native_kernel->configure(compile_context, src, weights, dst, kernel_info, fc_info.activation_info);
+            _matmul_lowp_native_kernel->configure(compile_context, src, weights, bias, dst, kernel_info, fc_info.activation_info);
         }
         else
         {
             _matmul_native_kernel = std::make_unique<kernels::ClMatMulNativeKernel>();
             _matmul_native_kernel->set_target(gpu_target);
-            _matmul_native_kernel->configure(compile_context, src, weights, dst, kernel_info, fc_info.activation_info);
+            _matmul_native_kernel->configure(compile_context, src, weights, bias, dst, kernel_info, fc_info.activation_info);
         }
     }
     else
@@ -230,7 +231,7 @@
 
         const GEMMInfo &gemm_info = GEMMInfo(false,                           // is_a_reshaped
                                              false,                           // is_b_reshaped
-                                             !_dynamic_weights,               // reshape_b_only_on_first_run
+                                             !_dynamic_gemm,                  // reshape_b_only_on_first_run
                                              0,                               // depth_output_gemm3d
                                              false,                           // reinterpret_input_as_3d
                                              fc_info.retain_internal_weights, // retain_internal_weights
@@ -269,7 +270,8 @@
 void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst,
                                          const FullyConnectedLayerInfo &fc_info)
 {
-    ARM_COMPUTE_ERROR_ON((weights->dimension((_use_matmul) ? 0 : 1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
+    // MatMul fuses transpose operation, so we use the first dimension for comparison where appropriate.
+    ARM_COMPUTE_ERROR_ON((weights->dimension((_use_matmul && _transpose_weights) ? 0 : 1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
 
     // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
 
@@ -288,8 +290,8 @@
 void ClFullyConnected::configure_fc_fc(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst,
                                        const FullyConnectedLayerInfo &fc_info)
 {
-    // Compare first dimension when using matmul, as it performs transpose operation
-    ARM_COMPUTE_ERROR_ON(src->dimension(0) != weights->dimension((_use_matmul) ? 0 : 1));
+    // MatMul fuses transpose operation, so we use the first dimension for comparison where appropriate.
+    ARM_COMPUTE_ERROR_ON(src->dimension(0) != weights->dimension((_use_matmul && _transpose_weights) ? 0 : 1));
 
     // Configure matrix multiply kernel
     configure_mm(compile_context, src, weights, bias, dst, fc_info);
@@ -304,20 +306,18 @@
     ARM_COMPUTE_ERROR_THROW_ON(ClFullyConnected::validate(src, weights, biases, dst, fc_info));
     ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, fc_info);
 
-    _are_weights_converted = true;
-    _are_weights_reshaped  = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
-    _is_fc_after_conv      = true;
-    _is_quantized          = is_data_type_quantized_asymmetric(src->data_type());
-    _is_prepared           = fc_info.retain_internal_weights;
-    _weights_to_use        = TensorInfo(*weights);
-    _weights_to_use_idx    = ACL_SRC_1;
+    _transpose_weights  = fc_info.transpose_weights ? !fc_info.are_weights_reshaped : false;
+    _is_fc_after_conv   = true;
+    _is_quantized       = is_data_type_quantized_asymmetric(src->data_type());
+    _is_prepared        = fc_info.retain_internal_weights;
+    _weights_to_use     = TensorInfo(*weights);
+    _weights_to_use_idx = ACL_SRC_1;
 
     // When using dynamic weights - use matmul kernels.
-    // Note: We don't appear to support dynamic weights with pre-reshaped RHS.
-    // Note: No matmul with biases for the moment.
+    // Note: MatMul does not support broadcasting batch dimension, and therefore is disabled if fc is batched. Gemm is used as fallback.
     const bool is_batched_fc_layer = dst->dimension(1) > 1;
-    _dynamic_weights               = !weights->are_values_constant() && !_are_weights_reshaped;
-    _use_matmul                    = _dynamic_weights && !is_batched_fc_layer && (biases == nullptr);
+    _use_matmul                    = !weights->are_values_constant() && !is_batched_fc_layer;
+    _dynamic_gemm                  = !weights->are_values_constant() && _transpose_weights && !_use_matmul;
 
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
@@ -339,9 +339,8 @@
 
     ITensorInfo *weights_used = weights;
 
-    // Reshape weights if needed
-    // Not needed when matmul is in use -  MatMul has transpose RHS flags.
-    if(!_are_weights_reshaped && !_use_matmul)
+    // Reshape weights if needed - Not needed when matmul is in use as matmul fuses transpose op.
+    if(_transpose_weights && !_use_matmul)
     {
         // Reshape the weights
         _reshape_weights = std::make_unique<ClTranspose>();
@@ -361,9 +360,9 @@
                                     src->tensor_shape(),
                                     fc_info.weights_trained_layout);
 
-        weights_used           = &_converted_weights;
-        _weights_to_use_idx    = offset_int_vec(ConvertedWeights);
-        _are_weights_converted = false;
+        weights_used         = &_converted_weights;
+        _weights_to_use_idx  = offset_int_vec(ConvertedWeights);
+        _run_convert_weights = true;
     }
 
     if(_is_fc_after_conv)
@@ -398,11 +397,11 @@
             // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time
             _aux_mem[TransposedWeights] = MemoryInfo(
                                               offset_int_vec(TransposedWeights),
-                                              _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
+                                              _dynamic_gemm ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
                                               _reshaped_weights.total_size());
             _aux_mem[ConvertedWeights] = MemoryInfo(
                                              offset_int_vec(ConvertedWeights),
-                                             _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
+                                             _dynamic_gemm ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
                                              _converted_weights.total_size());
         }
         else
@@ -413,11 +412,11 @@
 
             _aux_mem[TransposedWeights] = MemoryInfo(
                                               offset_int_vec(TransposedWeights),
-                                              _dynamic_weights ? MemoryLifetime::Temporary : transposed_wei_lft,
+                                              _dynamic_gemm ? MemoryLifetime::Temporary : transposed_wei_lft,
                                               _reshaped_weights.total_size());
             _aux_mem[ConvertedWeights] = MemoryInfo(
                                              offset_int_vec(ConvertedWeights),
-                                             _dynamic_weights ? MemoryLifetime::Temporary : converted_wei_lft,
+                                             _dynamic_gemm ? MemoryLifetime::Temporary : converted_wei_lft,
                                              _converted_weights.total_size());
         }
     }
@@ -434,19 +433,17 @@
     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
 
-    const bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
-    bool       is_fc_after_conv = true;
+    const bool transpose_weights = fc_info.transpose_weights ? !fc_info.are_weights_reshaped : false;
+    bool       is_fc_after_conv  = true;
 
     // When using dynamic weights - use matmul kernels.
-    // Note: MatMul does not support broadcasting or biases so fallback with batched cases or when biases != nullptr.
-    // Note: Pre-Shaped RHS is a deprecated use case and is therefore not supported with matmul.
-    const bool dynamic_weights     = !weights->are_values_constant() && !weights_reshaped;
+    // Note: MatMul does not support broadcasting so fallback with batched cases.
     const bool is_batched_fc_layer = dst->dimension(1) > 1;
-    const bool use_matmul          = dynamic_weights && !is_batched_fc_layer && (biases == nullptr);
+    const bool use_matmul          = !weights->are_values_constant() && !is_batched_fc_layer;
 
     const ITensorInfo &flatten_src       = TensorInfo(src->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(src)).set_data_layout(DataLayout::NCHW));
     const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
-    const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
+    const ITensorInfo &converted_weights = transpose_weights ? TensorInfo(*reshaped_weights.clone()) : TensorInfo(weights->clone()->set_is_resizable(true).reset_padding());
 
     // With the Fully Connected layer we can have 4 different cases:
     //  1) Convolution layer -> Fully Connected layer without batches
@@ -482,7 +479,8 @@
         is_fc_after_conv = src->num_dimensions() > 1;
     }
 
-    if(!weights_reshaped && !use_matmul)
+    // Transpose kernel does not run when matmul is supported as matmul fuses transpose op.
+    if(transpose_weights && !use_matmul)
     {
         // Validate reshape weights kernel
         ARM_COMPUTE_RETURN_ON_ERROR(ClTranspose::validate(weights, &reshaped_weights));
@@ -502,14 +500,9 @@
     if(is_fc_after_conv)
     {
         // Fully Connected layer after a Convolution Layer without batches
-        if(use_matmul)
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(0) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
-        }
-        else
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
-        }
+        // K Index of matrix multiplication. MatMul performs transpose in kernel, so index is 0 when matmul and transpose enabled
+        const int weight_idx = (use_matmul && transpose_weights) ? 0 : 1;
+        ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(weight_idx) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
 
         // Validate flatten kernel
         ARM_COMPUTE_RETURN_ON_ERROR(ClFlatten::validate(src, &flatten_src));
@@ -518,7 +511,9 @@
     else
     {
         // Fully Connected layer after a Fully Connected Layer without batches
-        ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension((use_matmul) ? 0 : 1));
+        // K Index of matrix multiplication. MatMul performs transpose in kernel, so index is 0 when matmul and transpose enabled
+        const int weight_idx = (use_matmul && transpose_weights) ? 0 : 1;
+        ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(weight_idx));
     }
 
     // Validate matrix multiply kernel
@@ -533,7 +528,7 @@
 
 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
     ++_asrt_run_count;
-    ARM_COMPUTE_ERROR_ON(_dynamic_weights && _asrt_prepare_count != _asrt_run_count);
+    ARM_COMPUTE_ERROR_ON(_dynamic_gemm && _asrt_prepare_count != _asrt_run_count);
 #endif // ARM_COMPUTE_ASSERTS_ENABLED
 
     auto src = tensors.get_const_tensor(ACL_SRC_0);
@@ -584,11 +579,12 @@
 
 void ClFullyConnected::prepare(ITensorPack &tensors)
 {
-    if(!_is_prepared || _dynamic_weights)
+    // Note : Running prepare() each run when _use_matmul is true is unnecessary unless weights conversion is needed.
+    if(!_is_prepared || _dynamic_gemm || (_use_matmul && _run_convert_weights))
     {
 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
         ++_asrt_prepare_count;
-        ARM_COMPUTE_ERROR_ON(!_dynamic_weights && _asrt_prepare_count > 1);
+        ARM_COMPUTE_ERROR_ON(!_dynamic_gemm && !_use_matmul && _asrt_prepare_count > 1);
 #endif // ARM_COMPUTE_ASSERTS_ENABLED
 
         auto weights = tensors.get_const_tensor(ACL_SRC_1);
@@ -599,8 +595,8 @@
         // Pointer to current weights
         const ITensor *cur_weights = weights;
 
-        // Reshape of the weights if needed
-        if(!_are_weights_reshaped && !_use_matmul)
+        // Reshape weights if needed. Disabled when matmul kernels are enabled as matmul fuses transpose.
+        if(_transpose_weights && !_use_matmul)
         {
             // Run reshape weights kernel and mark weights as unused
             ITensorPack transpose_pack{ { ACL_SRC, weights }, { ACL_DST, reshaped_weights.get() } };
@@ -611,7 +607,7 @@
         }
 
         // Convert weights if needed
-        if(!_are_weights_converted)
+        if(_run_convert_weights)
         {
             ITensorPack convert_pack{ { ACL_SRC, cur_weights }, { ACL_DST, converted_weights.get() } };
             _convert_weights->run(convert_pack);
@@ -623,8 +619,8 @@
         ITensorPack gemm_pack = tensors;
         gemm_pack.add_const_tensor(ACL_SRC_1, cur_weights);
 
-        // Prepare GEMM prepare and release unused weights (If not using matmul)
-        if(!_use_matmul)
+        // Prepare GEMM prepare and release unused weights
+        if(_dynamic_gemm || !_use_matmul)
         {
             if(!_is_quantized)
             {