COMPMID-1419: Make NEGEMMAssemblyDispatch dynamically typed instead of templated

This makes it easier to integrate in GEMMLowpMatrixMultiplyCore

Change-Id: Ibf80803f016a2e6a24d943ffafb50b48f04ec545
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140868
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
index 1c9ecb0..382ef1c 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
@@ -35,7 +35,6 @@
 namespace arm_compute
 {
 /** Assembly kernel glue */
-template <typename TypeInput, typename TypeOutput>
 class NEGEMMAssemblyDispatch : public IFunction
 {
 public:
@@ -43,12 +42,21 @@
     NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
 
     /** Prevent instances of this class from being copy constructed */
-    NEGEMMAssemblyDispatch(const NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &) = delete;
+    NEGEMMAssemblyDispatch(const NEGEMMAssemblyDispatch &) = delete;
     /** Prevent instances of this class from being copied */
-    NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &operator=(const NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &) = delete;
-    NEGEMMAssemblyDispatch(NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &&) = default;
-    NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &operator=(NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &&) = default;
-    ~NEGEMMAssemblyDispatch() = default;
+    NEGEMMAssemblyDispatch &operator=(const NEGEMMAssemblyDispatch &) = delete;
+    NEGEMMAssemblyDispatch(NEGEMMAssemblyDispatch &&)                 = default;
+    NEGEMMAssemblyDispatch &operator=(NEGEMMAssemblyDispatch &&) = default;
+    ~NEGEMMAssemblyDispatch()                                    = default;
+
+    class IFallback
+    {
+    public:
+        virtual void run()                 = 0;
+        virtual void prepare()             = 0;
+        virtual bool is_configured() const = 0;
+        virtual ~IFallback()               = default;
+    };
 
 private:
     /** ACL Function */
@@ -68,53 +76,9 @@
      */
     bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
 
-    //Fallback: use arm_gemm's AssemblyGemm:
-    class Fallback
-    {
-#ifndef DOXYGEN_SKIP_THIS
-    public:
-        /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
-         *  The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
-         */
-        void run();
-        void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
-        void prepare();
-        bool is_configured() const;
-#endif /* DOXYGEN_SKIP_THIS */
-
-    private:
-        /** Allocate a workspace tensor.
-         *
-         * @param[in] workspace_size Size to allocate.
-         * @param[in] memory_group   Tensor memory group.
-         * @param[in] alignment      Workspace memory alignment.
-         */
-        void allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment);
-
-        /** Assembly Gemm kernel */
-        std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
-        /** Optimised NEON kernel */
-        std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
-        /** Input A */
-        const ITensor *_a
-        {
-            nullptr
-        };
-        /** Input B */
-        const ITensor *_b
-        {
-            nullptr
-        };
-        /** Output */
-        ITensor *_d{ nullptr };
-        /** GEMM workspace */
-        Tensor _workspace{};
-        /** Pre-transpose tensor */
-        Tensor _pretranspose{};
-        /** Prepared flag */
-        bool _is_prepared{ false };
-    } _arm_gemm;               /**< Fallback in case ACL doesn't have a function */
-    MemoryGroup _memory_group; /**< Function memory group */
+    /** Interface for the arm_gemm fallback */
+    std::unique_ptr<IFallback> _arm_gemm;
+    MemoryGroup                _memory_group; /**< Function memory group */
 public:
     /** If supported create an ACL function else fallback to the arm_gemm function.
      *
@@ -126,6 +90,19 @@
      * @param[in]  pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
      */
     void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
+
+    /** Indicates whether or not this function can be used to process the given parameters.
+     *
+     * @param[in] a                 Input tensor (Matrix A)
+     * @param[in] b                 Input tensor (Matrix B)
+     * @param[in] d                 Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+     * @param[in] alpha             Scalar multiplier to apply to AB matrix product.
+     * @param[in] beta              Scalar multiplier to apply to input D matrix before adding product.
+     * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+     *
+     * @return a status.
+     */
+    static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint);
     /** Was the function successfully configured ?
      *
      * @return True if the function is configured and ready to run
@@ -137,11 +114,5 @@
     void run() override;
 };
 
-/** Float 32 assembly dispatch kernel */
-using NEGEMMAssemblyDispatchF32 = NEGEMMAssemblyDispatch<float, float>;
-/** Uint 8 to Uint 32 assembly dispatch kernel */
-using NEGEMMAssemblyDispatchU8U32 = NEGEMMAssemblyDispatch<uint8_t, uint32_t>;
-/** Int 8 to Int 32 assembly dispatch kernel */
-using NEGEMMAssemblyDispatchS8S32 = NEGEMMAssemblyDispatch<int8_t, int32_t>;
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_NEGEMMASSEMBLYDISPATCH_H__ */