[ONCPUML-1451] Add matmul kernel to enable bf16 to bf16 operations via PyTorch® autocast() function

The full range of tests must be added with [MLINFSW-482] epic due to the lack of reordering kernels implemented in Acl.

Co-Authored-By: David Mansell <David.Mansell@arm.com>
Change-Id: I820d316295a1ec94fdc89c37e4144a268f914c36
Signed-off-by: Renato Arantes <renato.arantes@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11169
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 58ee68f..efe2a7a 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -581,7 +581,6 @@
             // Fixed format kernels need no pretranspose.
             ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
                 assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
-
             const int  ldb            = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
                                                                      b_to_use->info()->offset_first_element_in_bytes());
@@ -857,6 +856,7 @@
     arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
     arm_gemm::GemmArgs     args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
                                 info.fixed_format, info.fast_mode, &cfg);
+
     // TODO: Incorporate info.transpose_b COMPMID-6595
     switch (a->data_type())
     {
@@ -900,9 +900,18 @@
 #if defined(ARM_COMPUTE_ENABLE_BF16)
         case DataType::BFLOAT16:
         {
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(
-                !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
-                "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+            if (d->data_type() == DataType::BFLOAT16)
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+                    !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+                    "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
+            }
+            else
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+                    !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+                    "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+            }
             break;
         }
 #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
@@ -958,8 +967,9 @@
                                     "Only F32 output supported for F32 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16,
                                     "Only F16 output supported for F16 input");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32,
-                                    "Only F32 output supported for BFLOAT16 input");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 &&
+                                        (d->data_type() != DataType::F32 && d->data_type() != DataType::BFLOAT16),
+                                    "Only F32/BFLOAT16 output supported for BFLOAT16 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32,
                                     "Only U32 output supported for U8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
@@ -1030,7 +1040,14 @@
 #endif /* __aarch64__ */
 #if defined(ARM_COMPUTE_ENABLE_BF16)
         case DataType::BFLOAT16:
-            create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+            if (d->data_type() == DataType::BFLOAT16)
+            {
+                create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+            }
+            else
+            {
+                create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+            }
             break;
 #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC