COMPMID-1406: Refactor gemm_interleaved to use our own types and scheduler

- Ported PrepareB kernel from gemm_interleave
- Ported TransformA feature from gemm_interleave
- Allocate reshaped a and b buffers
- Added memory_manager / memory_group
- MatrixMultiply kernel
- Interleave kernels execution.
- Fixed a few bugs: all nightly Convolution tests passing for threads=1
and threads=4
- Added Doxygen documentations and comments in the code
- Added support for all data types supported

Change-Id: Iffa1c09fda0bb9c61213bb83524d5a48e7ecb03c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141281
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index f17da7d..8ba620f 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -24,9 +24,13 @@
 #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
 
 #include "arm_compute/core/CPP/Validate.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
 #include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
 
 #include <arm_neon.h>
 
@@ -34,31 +38,96 @@
 {
 namespace
 {
-template <typename TypeInput, typename TypeOutput>
-std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                     std::shared_ptr<IMemoryManager> memory_manager)
+
 {
-    ARM_COMPUTE_UNUSED(method);
-    ARM_COMPUTE_UNUSED(a);
-    ARM_COMPUTE_UNUSED(b);
-    ARM_COMPUTE_UNUSED(d);
-    ARM_COMPUTE_UNUSED(alpha);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_UNUSED(pretranspose_hint);
-    return nullptr;
-}
-template <>
-std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
-{
-    ARM_COMPUTE_UNUSED(method);
-    ARM_COMPUTE_UNUSED(a);
-    ARM_COMPUTE_UNUSED(b);
-    ARM_COMPUTE_UNUSED(d);
-    ARM_COMPUTE_UNUSED(alpha);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
     switch(method)
     {
+        case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
+        {
+            if(!pretranspose_hint)
+            {
+                return nullptr;
+            }
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            function->configure(a, b, d, alpha, beta, pretranspose_hint);
+            return std::move(function);
+        }
+        default:
+            return nullptr;
+    }
+}
+
+template <typename TypeInput, typename TypeOutput>
+std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                           std::shared_ptr<IMemoryManager> memory_manager)
+{
+    ARM_COMPUTE_UNUSED(method);
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    ARM_COMPUTE_UNUSED(memory_manager);
+    return nullptr;
+}
+
 #ifdef __aarch64__
+template <>
+std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                            std::shared_ptr<IMemoryManager> memory_manager)
+{
+    switch(method)
+    {
+        case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+        {
+            if(!pretranspose_hint)
+            {
+                return nullptr;
+            }
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+            return std::move(function);
+        }
+        default:
+            return nullptr;
+    }
+    return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                              std::shared_ptr<IMemoryManager> memory_manager)
+{
+    switch(method)
+    {
+        case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+        {
+            if(!pretranspose_hint)
+            {
+                return nullptr;
+            }
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+            return std::move(function);
+        }
+        default:
+            return nullptr;
+    }
+    return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                         std::shared_ptr<IMemoryManager> memory_manager)
+{
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    ARM_COMPUTE_UNUSED(memory_manager);
+    switch(method)
+    {
         case arm_gemm::GemmMethod::GEMM_NATIVE:
         {
             auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
@@ -67,11 +136,11 @@
             function->configure(std::move(kernel));
             return std::move(function);
         }
-#endif /* __aarch64__ */
         default:
             return nullptr;
     }
 }
+#endif /* __aarch64__ */
 
 /** Fallback in case ACL doesn't have a function */
 template <typename TypeInput, typename TypeOutput>
@@ -173,11 +242,11 @@
         // Pretranspose B if required
         if(_gemm_kernel_asm->B_pretranspose_required())
         {
+            ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
             const int  ldb            = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
             const int  multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
 
-            ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
             _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
             _b->mark_as_unused();
         }
@@ -260,7 +329,7 @@
 
 template <typename TypeInput, typename TypeOutput>
 void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
-                                 ITensor *d, float alpha, float beta, bool pretranspose_hint)
+                                 ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
 {
     INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d);
     const CPUInfo               &ci          = NEScheduler::get().cpu_info();
@@ -269,7 +338,13 @@
     arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
 
     //Try to create an ACL function:
-    acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint);
+    acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+    // If the type agnostic factory failed to create an ACL function, try the specialised one:
+    if(acl_function == nullptr)
+    {
+        acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+    }
+    //If we still don't have an ACL function:
     if(acl_function == nullptr)
     {
         //Fallback onto arm_gemm function if ACL doesn't support this method.
@@ -282,7 +357,7 @@
 } //namespace
 
 NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
-    : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager))
+    : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
 {
 }
 
@@ -321,20 +396,20 @@
     switch(a->info()->data_type())
     {
         case DataType::F32:
-            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
-            create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
         case DataType::S8:
-            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
 #endif /* __aarch64__ */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default:
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
new file mode 100644
index 0000000..434723c
--- /dev/null
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+namespace arm_compute
+{
+NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager)
+    : _memory_group(std::move(memory_manager))
+{
+}
+void NEGEMMInterleavedWrapper::run()
+{
+    prepare();
+
+    _memory_group.acquire();
+    NEScheduler::get().run_workloads(_workloads);
+    _memory_group.release();
+}
+
+void NEGEMMInterleavedWrapper::prepare()
+{
+    if(!_is_prepared)
+    {
+        if(_pretranspose_b)
+        {
+            NEScheduler::get().schedule(_prepare_b.get(), Window::DimX);
+            _b->mark_as_unused();
+        }
+        else
+        {
+            _prepare_b->create_workloads(_b_workloads);
+        }
+        _transform_a->create_workloads(_a_workloads);
+        _matrix_multiply->create_workloads(_mm_workloads);
+
+        //Maximum number of workloads to create:
+        const unsigned int num_threads    = NEScheduler::get().num_threads();
+        const unsigned int max_iterations = num_threads == 1 ? 1 : num_threads * 4;
+        //Maximum number of iterations the parameters allow:
+        const unsigned int num_iterations = _batch_window.num_iterations_total();
+        // Keep the smallest of the two:
+        const unsigned int num_windows  = std::min(num_iterations, max_iterations);
+        const TensorShape  window_shape = _batch_window.shape();
+
+        // Create a 1D window to dynamically split the batch window:
+        Window win_1D;
+        win_1D.set(0, Window::Dimension(0, num_iterations));
+
+        // Create one workload for each sub-window:
+        for(unsigned int w = 0; w < num_windows; w++)
+        {
+            Window             win          = win_1D.split_window(0, w, num_windows);
+            const Coordinates  start_offset = index2coords(window_shape, win.x().start());
+            const Coordinates  end_offset   = index2coords(window_shape, win.x().end() - 1);
+            const unsigned int num_x_blocks = _block_walker.num_iterations(Window::DimX);
+
+            auto workload = [start_offset, end_offset, num_x_blocks, this](const ThreadInfo & info)
+            {
+                //For each block of rows in "M"
+                auto workload_mm = this->_mm_workloads.begin();
+                for(auto workload_a = this->_a_workloads.begin(); workload_a != this->_a_workloads.end(); workload_a++)
+                {
+                    // Transform one k_block from A:
+                    this->_transform_a->transform(*workload_a, info, this->_batch_window, start_offset, end_offset);
+                    // Then perform the matrix multiplication for each x block along N:
+                    for(unsigned int i = 0; i < num_x_blocks; i++)
+                    {
+                        ARM_COMPUTE_ERROR_ON(workload_mm == this->_mm_workloads.end());
+                        this->_matrix_multiply->transform(*workload_mm++, info, this->_batch_window, start_offset, end_offset);
+                    }
+                }
+            };
+            _workloads.push_back(workload);
+        }
+
+        _is_prepared = true;
+    }
+}
+
+namespace
+{
+// Factory to instantiate NEGEMMInterleavedPrepareBWrapperKernel:
+template <typename InputType, bool use_dot = false>
+std::unique_ptr<NEGEMMInterleavedPrepareBWrapperKernel> instantiate_prepareB(const ITensor *b, ITensor *transformed_b, const INEGEMMWrapperKernel::Params &params)
+{
+    auto prepare_b = support::cpp14::make_unique<NEGEMMInterleavedPrepareBWrapperKernelTemplate<InputType, use_dot>>();
+    prepare_b->configure(b, transformed_b, false, NEScheduler::get().cpu_info(), params);
+    return std::move(prepare_b);
+}
+
+// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
+template <typename InputType, bool use_dot = false>
+std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
+{
+    auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<InputType, use_dot>>();
+    transform_a->configure(a, transformed_a, false, block_walker, params);
+    return std::move(transform_a);
+}
+
+// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
+template <typename InputType, typename OutputType, bool use_dot = false>
+std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker,
+                                                                                    const BlockSizes &block_sizes, const INEGEMMWrapperKernel::Params &params, bool pretranspose_b, float alpha, float beta)
+{
+    auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<InputType, OutputType, use_dot>>();
+    matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, NEScheduler::get().num_threads());
+    return std::move(matrix_multiply);
+}
+} // namespace
+
+void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b, bool use_dot)
+{
+    _params         = INEGEMMWrapperKernel::extract_parameters(a, b, c);
+    _a              = a;
+    _b              = b;
+    _c              = c;
+    _pretranspose_b = pretranspose_b;
+
+    DataType input_type = a->info()->data_type();
+
+    // Forcing 128-byte alignment (required by 32-bit kernels)
+    const unsigned int alignment = 128;
+    _transformed_b.allocator()->init(TensorInfo{}, alignment);
+    _tmp_c.allocator()->init(TensorInfo{}, alignment);
+    if(!_pretranspose_b)
+    {
+        // If B is transposed at every iteration then transformed_B can be managed:
+        _memory_group.manage(&_transformed_b);
+    }
+    switch(input_type)
+    {
+        case DataType::F32:
+            _prepare_b = instantiate_prepareB<float>(_b, &_transformed_b, _params);
+            break;
+#ifdef __aarch64__
+        case DataType::U8:
+        case DataType::QASYMM8:
+            if(use_dot)
+            {
+                _prepare_b = instantiate_prepareB<uint8_t, true>(_b, &_transformed_b, _params);
+            }
+            else
+            {
+                _prepare_b = instantiate_prepareB<uint8_t, false>(_b, &_transformed_b, _params);
+            }
+            break;
+        case DataType::S8:
+            if(use_dot)
+            {
+                _prepare_b = instantiate_prepareB<int8_t, true>(_b, &_transformed_b, _params);
+            }
+            else
+            {
+                _prepare_b = instantiate_prepareB<int8_t, false>(_b, &_transformed_b, _params);
+            }
+            break;
+#endif /* __aarch64__ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F16:
+            _prepare_b = instantiate_prepareB<__fp16>(_b, &_transformed_b, _params);
+            break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+        default:
+            ARM_COMPUTE_ERROR("DataType not supported");
+            break;
+    }
+    ARM_COMPUTE_ERROR_ON(_prepare_b == nullptr);
+
+    _block_sizes = _prepare_b->block_sizes();
+
+    _block_walker.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_params.N, _block_sizes.x_block), _block_sizes.x_block));
+    _block_walker.set(Window::DimY, Window::Dimension(0, ceil_to_multiple(_params.K, _block_sizes.k_block), _block_sizes.k_block));
+    _block_walker.set(Window::DimZ, Window::Dimension(0, _params.multis));
+
+    _batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
+    _batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
+
+    _transformed_a.allocator()->init(TensorInfo(TensorShape{ _block_sizes.k_block, _block_sizes.m_round, _params.batches }, 1, input_type), alignment);
+    _memory_group.manage(&_transformed_a);
+    _memory_group.manage(&_tmp_c);
+
+    switch(input_type)
+    {
+        case DataType::F32:
+            _transform_a     = instantiate_transformA<float>(_a, &_transformed_a, _block_walker, _params);
+            _matrix_multiply = instantiate_matrix_multiply<float, float>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+            break;
+#ifdef __aarch64__
+        case DataType::U8:
+        case DataType::QASYMM8:
+            if(use_dot)
+            {
+                _transform_a     = instantiate_transformA<uint8_t, true>(_a, &_transformed_a, _block_walker, _params);
+                _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+            }
+            else
+            {
+                _transform_a     = instantiate_transformA<uint8_t, false>(_a, &_transformed_a, _block_walker, _params);
+                _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+            }
+            break;
+        case DataType::S8:
+            if(use_dot)
+            {
+                _transform_a     = instantiate_transformA<int8_t, true>(_a, &_transformed_a, _block_walker, _params);
+                _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+            }
+            else
+            {
+                _transform_a     = instantiate_transformA<int8_t, false>(_a, &_transformed_a, _block_walker, _params);
+                _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+            }
+            break;
+#endif /* __aarch64__ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F16:
+            _transform_a     = instantiate_transformA<__fp16>(_a, &_transformed_a, _block_walker, _params);
+            _matrix_multiply = instantiate_matrix_multiply<__fp16, __fp16>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
+            break;
+            break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+        default:
+            break;
+    }
+    ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
+    ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);
+    _transformed_a.allocator()->allocate();
+    _tmp_c.allocator()->allocate();
+    _transformed_b.allocator()->allocate();
+}
+} // namespace arm_compute