[arm_gemm] Use static validate to find arm_gemm kernels.

The static method `CpuGemmAssemblyDispatch::validate` should look into
the list of the available kernels to make sure the one requested by
the user was found.

Formatting changes in the files touched by the patch have been
automatically inserted by the formatting script.

Resolves: ONCPUML-840

Change-Id: Icd650a30e142284a942c64f8a2b72441ee7b3f4e
Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7375
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index f4af587..dd72fb5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -144,6 +144,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index a502262..42f4528 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -108,6 +108,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 8b855ab..69a2803 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -232,6 +232,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 4d7f798..cb3ff7a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -236,6 +236,12 @@
 }
 
 template<typename Top, typename Tret, class OutputStage>
+bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
+    const GemmImplementation<Top, Tret, OutputStage> *impl;
+    return find_implementation<Top, Tret, OutputStage>(args, os, impl);
+}
+
+template<typename Top, typename Tret, class OutputStage>
 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
     const GemmImplementation<Top, Tret, OutputStage> *impl;
 
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index d650116..3915861 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -56,6 +56,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index a113455..0c68e4d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -159,6 +159,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 1532816..6b813c7 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -230,6 +230,7 @@
 }
 
 template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index a80766b..95139c2 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -197,6 +197,7 @@
 }
 
 template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index d459df8..20cee55 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -56,6 +56,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index f2d46d5..a2d2cc8 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -157,6 +157,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index e38cc09..200e04f 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -187,4 +187,7 @@
 template <typename Top, typename Tret, class OutputStage = Nothing>
 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
 
+template <typename Top, typename Tret, class OutputStage = Nothing>
+bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {});
+
 } // namespace arm_gemm
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 657f3b8..496b55e 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@
                                                                                             const std::vector<int32_t> &multipliers);
 
     // Inherited methods overridden:
-    void run(ITensorPack &tensors) override;
-    void prepare(ITensorPack &tensors) override;
+    void                             run(ITensorPack &tensors) override;
+    void                             prepare(ITensorPack &tensors) override;
     bool                             is_configured() const override;
     experimental::MemoryRequirements workspace() const override;
 
@@ -203,12 +203,12 @@
     /** Indirect buffer */
     std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
     std::unique_ptr<const TypeInput *, free_delete>        _indirect_buf{};
-    std::vector<TypeInput>           _indirect_pad{};
-    arm_gemm::ConvolutionParameters  _cp{};
-    experimental::MemoryRequirements _aux_mem{ Count };
-    bool                             _B_pretranspose_required{ false };
-    bool                             _is_b_constant{ true };
-    bool                             _is_c_constant{ true };
+    std::vector<TypeInput>                                 _indirect_pad{};
+    arm_gemm::ConvolutionParameters                        _cp{};
+    experimental::MemoryRequirements                       _aux_mem{ Count };
+    bool                                                   _B_pretranspose_required{ false };
+    bool                                                   _is_b_constant{ true };
+    bool                                                   _is_c_constant{ true };
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -635,6 +635,72 @@
 {
 }
 
+Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+    ARM_COMPUTE_UNUSED(c);
+    arm_gemm::Activation act         = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
+    Params               p           = extract_parameters(a, b, d, info);
+    const CPUInfo       &ci          = NEScheduler::get().cpu_info();
+    unsigned int         num_threads = NEScheduler::get().num_threads();
+
+    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode);
+    switch(a->data_type())
+    {
+        case DataType::F32:
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})),
+                                            "We could not find an optimized kernel for F32 input");
+            break;
+#ifdef __aarch64__
+        case DataType::U8:
+        case DataType::QASYMM8:
+            if(d->data_type() == DataType::S32)
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})),
+                                                "We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
+            }
+            else
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})),
+                                                "We could not find an optimized kernel for U8 input and U8 output");
+            }
+            break;
+        case DataType::S8:
+        case DataType::QASYMM8_SIGNED:
+            if(d->data_type() == DataType::S32)
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})),
+                                                "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
+            }
+            else
+            {
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})),
+                                                "We could not find an optimized kernel for S8 input and S32 output");
+            }
+            break;
+#endif /* __aarch64__ */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+        case DataType::BFLOAT16:
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat, float, arm_gemm::Nothing>(args, {})),
+                                            "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+            break;
+        }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F16:
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})),
+                                            "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+            break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+        default:
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
+            break;
+    }
+
+    return Status{};
+}
+
 Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
 {
     ARM_COMPUTE_UNUSED(c, info);
@@ -663,7 +729,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
-    return Status{};
+    return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info);
 }
 
 bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index a50f363..74359ee 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -68,11 +68,11 @@
     class IFallback
     {
     public:
-        virtual void run(ITensorPack &tensors)                         = 0;
-        virtual void prepare(ITensorPack &tensors)                     = 0;
-        virtual experimental::MemoryRequirements workspace() const     = 0;
-        virtual bool                             is_configured() const = 0;
-        virtual ~IFallback()                                           = default;
+        virtual void                             run(ITensorPack &tensors)     = 0;
+        virtual void                             prepare(ITensorPack &tensors) = 0;
+        virtual experimental::MemoryRequirements workspace() const             = 0;
+        virtual bool                             is_configured() const         = 0;
+        virtual ~IFallback()                                                   = default;
     };
 
 public:
@@ -97,6 +97,18 @@
      * @return a status.
      */
     static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+
+    /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+     *
+     * @param[in] a    Input tensor info (Matrix A)
+     * @param[in] b    Input tensor info (Matrix B)
+     * @param[in] c    Input tensor info (Matrix C) used to pass the bias for quantized calculations
+     * @param[in] d    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in] info GEMM meta-data
+     *
+     * @return a status.
+     */
+    static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
     /** Checks if activation is supported by the gemm assembly dispatcher
      *
      * @param[in] activation Activation to check
@@ -111,8 +123,8 @@
     bool is_configured() const;
 
     // Inherited methods overridden:
-    void prepare(ITensorPack &tensors) override;
-    void run(ITensorPack &tensors) override;
+    void                             prepare(ITensorPack &tensors) override;
+    void                             run(ITensorPack &tensors) override;
     experimental::MemoryRequirements workspace() const override;
 
 private: