[ONCPUML-1451] Add matmul kernel to enable bf16 to bf16 operations via PyTorch® autocast() function

The full range of tests must be added with [MLINFSW-482] epic due to the lack of reordering kernels implemented in Acl.

Co-Authored-By: David Mansell <David.Mansell@arm.com>
Change-Id: I820d316295a1ec94fdc89c37e4144a268f914c36
Signed-off-by: Renato Arantes <renato.arantes@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11169
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp
index 8908712..f68ae98 100644
--- a/src/cpu/operators/CpuMatMul.cpp
+++ b/src/cpu/operators/CpuMatMul.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -102,8 +102,8 @@
                            const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8,
-                                                         DataType::QASYMM8_SIGNED);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::BFLOAT16,
+                                                         DataType::QASYMM8, DataType::QASYMM8_SIGNED);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic.");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic.");
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
@@ -120,6 +120,7 @@
     auto gemm_info            = AsmGemmInfo();
     gemm_info.activation_info = act_info;
     gemm_info.fast_mode       = settings.fast_math();
+    gemm_info.fixed_format    = settings.fixed_format();
 
     // Validate and then permute a/b
     if (adj_lhs)
@@ -157,6 +158,14 @@
                                                                    gemm_info.activation_info, gemm_info.output_stage));
     }
 
+    if (gemm_info.fixed_format)
+    {
+        gemm_info.weight_format                          = WeightFormat::ANY;
+        arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY;
+        ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, lhs_to_use,
+                                                                               rhs_to_use, nullptr, dst, gemm_info));
+    }
+
     cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info);
 
     return Status{};
@@ -221,6 +230,7 @@
     // Fill AsmGemmInfo class object before configuration
     _gemm_info.activation_info = act_info;
     _gemm_info.fast_mode       = settings.fast_math();
+    _gemm_info.fixed_format    = settings.fixed_format();
     _gemm_info.negated_offsets = false;
 
     lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use;
@@ -233,6 +243,18 @@
                                        _gemm_info.output_stage);
     }
 
+    if (_gemm_info.fixed_format)
+    {
+        _gemm_info.weight_format                         = WeightFormat::ANY;
+        arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY;
+        ARM_COMPUTE_ERROR_THROW_ON(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, &lhs_to_use,
+                                                                              &rhs_to_use, nullptr, dst, _gemm_info));
+        // Set gemm weights info to the one returned by has_opt_impl
+        _gemm_info.weight_format = expected_weight_format;
+        // has_opt_impl may return a non fast math kernel, even if we requested one
+        _gemm_info.fast_mode = arm_compute::is_fixed_format_fast_math(expected_weight_format);
+    }
+
     // Configure Asm Kernel
     _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
     _asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use,