COMPMID-3520: Move ndrange.hpp header from arm_gemm to assembly

Change-Id: I6352a520ce38230cdfbad346b176cb659ab242a7
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3327
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/assembly/Helpers.cpp b/src/core/NEON/kernels/assembly/Helpers.cpp
index 93ea6c8..5990505 100644
--- a/src/core/NEON/kernels/assembly/Helpers.cpp
+++ b/src/core/NEON/kernels/assembly/Helpers.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -22,7 +22,7 @@
  * SOFTWARE.
  */
 
-#include "arm_compute/core/NEON/kernels/assembly/Helpers.h"
+#include "src/core/NEON/kernels/assembly/Helpers.h"
 
 namespace arm_compute
 {
diff --git a/src/core/NEON/kernels/assembly/Helpers.h b/src/core/NEON/kernels/assembly/Helpers.h
new file mode 100644
index 0000000..09c0446
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/Helpers.h
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2018-2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_ASSEMBLY_HELPERS_H
+#define ARM_COMPUTE_ASSEMBLY_HELPERS_H
+
+#include "arm_compute/core/CPP/CPPTypes.h"
+#include "arm_compute/core/Utils.h"
+
+#include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h"
+#include "arm_gemm.hpp"
+
+namespace arm_compute
+{
+/** Block sizes to use to break the M, N, K dimension */
+struct BlockSizes
+{
+    unsigned int k_block{ 0 };             /**< Block size alon the K dimension */
+    unsigned int x_block{ 0 };             /**< Block size along the N (x) dimension */
+    unsigned int m_round{ 0 };             /**< Block size along the M dimension (Must be a multiple of strategy_out_height) */
+    unsigned int strategy_out_height{ 0 }; /**< Number of rows (M) processed by the selected strategy */
+};
+
+/** Extracts the kernel description of the selected kernel by the GEMM backend heuristics
+ *
+ * @param[in] input_type        Data type of the input tensor.
+ * @param[in] ci                CPU information.
+ * @param[in] num_threads       Maximum number of threads that might be used for the calculations.
+ * @param[in] p                 M, N, K sizes.
+ * @param[in] activation        Activation struct
+ * @param[in] pretranspose_hint Is B also pretransposed ?
+ *
+ * @return Kernel description that the assembly heuristics picked for the given configuration
+ */
+arm_gemm::KernelDescription get_gemm_info(DataType                            input_type,
+                                          const CPUInfo                      &ci,
+                                          const unsigned int                  num_threads,
+                                          const INEGEMMWrapperKernel::Params &p,
+                                          arm_gemm::Activation                activation,
+                                          bool                                pretranspose_hint);
+
+/** Calculate the recommended block sizes to use based on the CPU cache sizes and the strategy which will be used
+ *
+ * @param[in] ci CPU information.
+ * @param[in] M  M dimension.
+ * @param[in] N  N dimension.
+ * @param[in] K  K dimension.
+ *
+ * @return Recommeded block sizes to use for the given M, N, K dimensions.
+ */
+template <typename strategy>
+BlockSizes calculate_block_sizes(const CPUInfo &ci, unsigned int M, unsigned int N, unsigned int K)
+{
+    BlockSizes bs;
+
+    using Toi = typename strategy::operand_type;
+
+    const unsigned int L1_size = ci.get_L1_cache_size();
+    const unsigned int L2_size = ci.get_L2_cache_size();
+
+    // Work out blocking parameters
+
+    // k_block: Find out how much of the larger array can be loaded into half the cache.
+    // This should account for associative caches.
+    bs.k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
+
+    // Needs to be (at least a single) multiple of the K unroll level.
+    bs.k_block /= strategy::k_unroll();
+    bs.k_block = std::max(bs.k_block, 1U) * strategy::k_unroll();
+
+    // Now tune to presented problem size; this is how many blocks we need.
+    int num_k_blocks = DIV_CEIL(K, bs.k_block);
+
+    // So divide the space equally into that many blocks.
+    bs.k_block = DIV_CEIL(K, num_k_blocks);
+
+    // And round UP to the K unroll level required.
+    bs.k_block = ceil_to_multiple(bs.k_block, strategy::k_unroll());
+
+    // x_block: Work out how many rows (of length k_block) will fit in the L2
+    // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
+    bs.x_block = (((L2_size * 9) / 10) - (bs.k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / (sizeof(Toi) * bs.k_block);
+
+    // Needs to be (at least a single) multiple of the kernel output width.
+    bs.x_block /= strategy::out_width();
+    bs.x_block = std::max(bs.x_block, 1U) * strategy::out_width();
+
+    // And tune to the presented problem size.
+    int num_x_blocks = DIV_CEIL(N, bs.x_block);
+    bs.x_block       = DIV_CEIL(N, num_x_blocks);
+
+    bs.x_block = ceil_to_multiple(bs.x_block, strategy::out_width());
+
+    // Work out the rounded size of M - needed for some buffers.
+    bs.m_round             = ceil_to_multiple(M, strategy::out_height());
+    bs.strategy_out_height = strategy::out_height();
+
+    return bs;
+}
+
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_ASSEMBLY_HELPERS_H */
diff --git a/src/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h b/src/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
new file mode 100644
index 0000000..2d3d805
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2018-2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
+#define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
+
+#include "arm_compute/core/NEON/INEKernel.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_gemm_compute_iface.hpp"
+
+#include "gemm_common.hpp"
+
+namespace arm_compute
+{
+class ITensor;
+
+/** This class is a wrapper for the assembly kernels.
+  *
+  * Some kernels were written in assembly and highly optimised for specific CPUs like A53 or A55.
+  * This class works as a wrapper for these assembly kernels. The arm compute library creates an instance
+  * of NEGEMMAssemblyWrapperKernel and other auxiliary data structures to execute a single assembly kernel
+  * in the context of an NEFunctions.
+  *
+  * The type T is the type of the actual kernel implemented in assembly which is of type
+  *         template<typename To, typename Tr> class GemmCommon
+  *
+  *
+  */
+template <typename TypeInput, typename TypeOutput>
+class NEGEMMAssemblyWrapperKernel final : public INEKernel
+{
+public:
+    /** Constructor
+     */
+    NEGEMMAssemblyWrapperKernel()
+        : _kernel(nullptr), _name("NEGEMMAssemblyWrapperKernel")
+    {
+    }
+
+    NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &)  = delete;
+    NEGEMMAssemblyWrapperKernel(NEGEMMAssemblyWrapperKernel &&) = default;
+    NEGEMMAssemblyWrapperKernel &operator=(NEGEMMAssemblyWrapperKernel &) = delete;
+
+    const char *name() const override
+    {
+        return _name.c_str();
+    }
+
+    void run(const Window &window, const ThreadInfo &info) override
+    {
+        ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
+        ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+
+        auto win = arm_gemm::to_ndcoord(window);
+
+        arm_gemm::ndcoord_t thread_locator{};
+
+        _kernel->execute(win, thread_locator, info.thread_id);
+    }
+
+    // Inherited methods overridden:
+    void run_nd(const Window &window, const ThreadInfo &info, const Window &thread_locator) override
+    {
+        ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(_kernel)));
+        ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+
+        //convert between arm_compute and arm_gemm types
+        auto ndc_win = arm_gemm::to_ndcoord(window);
+        auto ndc_tlc = arm_gemm::to_ndcoord(thread_locator);
+
+        _kernel->execute(ndc_win, ndc_tlc, info.thread_id);
+    }
+
+    /** Initialise the kernel's input and output.
+     *
+     * @param[in] kernel      Pointer to an assembly kernel implementation.
+     * @param[in] num_threads Number of concurrent threads which will execute the kernel.
+     */
+    void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
+    {
+        ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
+        _kernel = kernel;
+
+        Window win = to_window(kernel->get_window_size());
+
+        INEKernel::configure(win);
+
+        if(!kernel_name_tag.empty())
+        {
+            _name += "/" + kernel_name_tag;
+        }
+    }
+
+private:
+    arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
+    std::string _name;
+};
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H */
diff --git a/src/core/NEON/kernels/assembly/arm_gemm.hpp b/src/core/NEON/kernels/assembly/arm_gemm.hpp
new file mode 100644
index 0000000..7723224
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -0,0 +1,176 @@
+/*
+ * Copyright (c) 2018-2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include <memory>
+#include <cstring>
+
+#include "arm_gemm_local.hpp"
+#include "gemm_common.hpp"
+
+namespace arm_gemm {
+
+enum class GemmMethod
+{
+    DEFAULT,
+    GEMV_BATCHED,
+    GEMV_PRETRANSPOSED,
+    GEMV_NATIVE_TRANSPOSED,
+    GEMM_NATIVE,
+    GEMM_HYBRID,
+    GEMM_INTERLEAVED,
+    GEMM_INTERLEAVED_2D,
+    QUANTIZE_WRAPPER,
+    GEMM_HYBRID_QUANTIZED
+};
+
+struct KernelDescription
+{
+    GemmMethod   method      = GemmMethod::DEFAULT;
+    std::string  name        = "";
+    bool         is_default  = false;
+
+    KernelDescription(GemmMethod m, std::string n, bool d=false) : method(m), name(n), is_default(d) { }
+    KernelDescription() noexcept  { }
+};
+
+struct GemmConfig
+{
+    GemmMethod   method           = GemmMethod::DEFAULT;
+    std::string  filter           = "";
+    unsigned int inner_block_size = 0;
+    unsigned int outer_block_size = 0;
+
+    GemmConfig(GemmMethod method) : method(method) { }
+    GemmConfig() { }
+};
+
+struct Activation
+{
+    enum class Type {
+        None,
+        ReLU,
+        BoundedReLU
+    };
+
+    Type    type;
+    float   param1;
+    float   param2;
+
+    Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f) : type(type), param1(p1), param2(p2) { }
+};
+
+struct GemmArgs
+{
+public:
+    const CPUInfo    *_ci;
+    unsigned int      _Msize;
+    unsigned int      _Nsize;
+    unsigned int      _Ksize;
+    unsigned int      _nbatches;
+    unsigned int      _nmulti;
+    bool              _trA;
+    bool              _trB;
+    Activation        _act;
+    int               _maxthreads;
+    bool              _pretransposed_hint;
+    const GemmConfig *_cfg;
+
+    GemmArgs(const CPUInfo *ci, const unsigned int M, const unsigned int N,
+             const unsigned int K, const unsigned int nbatches,
+             const unsigned int nmulti, const bool trA, const bool trB,
+             Activation act, const int maxthreads,
+             const bool pretransposed_hint, const GemmConfig *cfg=nullptr ) :
+             _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
+             _trA(trA), _trB(trB), _act(act), _maxthreads(maxthreads),
+             _pretransposed_hint(pretransposed_hint), _cfg(cfg)
+    {
+    }
+};
+
+struct Requantize32
+{
+public:
+    const int32_t  *bias = nullptr;
+    size_t          bias_multi_stride = 0;
+    int32_t         a_offset = 0;
+    int32_t         b_offset = 0;
+    int32_t         c_offset = 0;
+    bool            per_channel_requant = false;
+    int32_t         per_layer_shift = 0;
+    int32_t         per_layer_mul = 0;
+    const int32_t  *per_channel_shifts = nullptr;
+    const int32_t  *per_channel_muls = nullptr;
+    int32_t         minval = 0;
+    int32_t         maxval = 0;
+
+    Requantize32() = default;
+
+    // Constructor for per-tensor quantization
+    Requantize32(const int32_t *bias, size_t bias_multi_stride,
+                 int32_t a_offset, int32_t b_offset, int32_t c_offset,
+                 int32_t requant_shift, int32_t requant_mul,
+                 int32_t minv, int32_t maxv) :
+        bias(bias), bias_multi_stride(bias_multi_stride),
+        a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
+        per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul),
+        minval(minv), maxval(maxv)
+    {
+    }
+
+    // Constructor for per-channel quantization
+    Requantize32(const int32_t *bias, size_t bias_multi_stride,
+                 int32_t a_offset, int32_t b_offset, int32_t c_offset,
+                 const int32_t *requant_shifts, const int32_t *requant_muls,
+                 int32_t minv, int32_t maxv) :
+        bias(bias), bias_multi_stride(bias_multi_stride),
+        a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
+        per_channel_requant(true), per_channel_shifts(requant_shifts), per_channel_muls(requant_muls),
+        minval(minv), maxval(maxv)
+    {
+    }
+};
+
+struct Nothing
+{
+};
+
+template<typename Top, typename Tret>
+using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >;
+
+/* Low level API calls.
+ * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
+
+/* get_gemm_method(): Given the templated types and provided parameters,
+ * which is the preferred method to implement this GEMM?  */
+template<typename Top, typename Tret, class OutputStage = Nothing>
+KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & ={});
+
+template<typename Top, typename Tret, class OutputStage = Nothing>
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & ={});
+
+template<typename Top, typename Tret, class OutputStage = Nothing>
+std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & ={});
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp b/src/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp
new file mode 100644
index 0000000..ab3a67c
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include "arm_compute/core/Window.h"
+#include "arm_compute/core/Dimensions.h"
+
+#include "ndrange.hpp"
+
+#include <cassert>
+
+/* This file contains mapping between integral types used in arm_compute and arm_gemm
+ * These two codebases both require a degree of separation for the sake of modularity
+ * so maintain their own types which represent similar information.
+ */
+
+namespace arm_gemm {
+
+//we want to unify the maximum number of dimensions used beween arm_gemm and arm compute library
+constexpr std::size_t ndrange_max =
+    arm_compute::Dimensions<unsigned int>::num_max_dimensions;
+
+using ndrange_t=NDRange<ndrange_max>;
+using ndcoord_t=NDCoordinate<ndrange_max>;
+
+/* Converts an `arm_gemm::ndrange_t` to a `arm_compute::Window`
+ *
+ * As `NDRange<T>` does not not encode start positions, we specify
+ * the start to be zero in the produced `arm_compute::Window`
+ *
+ * @param [ndr] the `arm_gemm::ndrange_t` we wish to convert into a `arm_compute::Window`
+ * @returns an `arm_compute::Window` representing the same dimensional ranges as `ndr`
+ */
+inline arm_compute::Window to_window(const ndrange_t& ndr) {
+    arm_compute::Window win;
+
+    for(unsigned int i = 0; i!=ndrange_max; ++i) {
+        //populate the window with the dimensions of the NDRange
+        win.set(i, arm_compute::Window::Dimension(0, ndr.get_size(i)));
+    }
+
+    return win;
+}
+
+/*
+ * Converts an `arm_gemm::ndcoord_t` to a `arm_compute::Window`
+ *
+ * @param [ndc] the `arm_gemm::ndcoord_t` we wish to convert into a `arm_compute::Window`
+ * @returns an `arm_compute::Window` representing the same dimensional ranges as `ndc`
+ */
+inline arm_compute::Window to_window(const ndcoord_t& ndc) {
+    arm_compute::Window win;
+
+    for(unsigned int i = 0; i!=ndrange_max; ++i) {
+        const auto start = ndc.get_position(i);
+        const auto size  = ndc.get_size(i);
+        const auto stop  = start + size;
+
+        //populate the window with the dimensions of the NDRange
+        win.set(i, arm_compute::Window::Dimension(start, stop));
+    }
+
+    return win;
+}
+
+/** Convert an `arm_compute::Window` to an `arm_gemm::NDRange` of the same max dimensions
+ *
+ * It should be noted that `arm_compute::Window` specifies a `start()` and an `end()`
+ * where as `arm_gemm::ndrange_t` only has a size, as a result we store the delta between the range
+ *
+ * @param [win] the `arm_compute::Window` we want to convert to `arm_gemm::ndrange_t`
+ * @return the resultant ndrange_t
+ */
+inline ndrange_t to_ndrange(const arm_compute::Window& win) {
+    return {
+        static_cast<unsigned int>(win[0].end() - win[0].start()),
+        static_cast<unsigned int>(win[1].end() - win[1].start()),
+        static_cast<unsigned int>(win[2].end() - win[2].start()),
+        static_cast<unsigned int>(win[3].end() - win[3].start()),
+        static_cast<unsigned int>(win[4].end() - win[4].start()),
+        static_cast<unsigned int>(win[5].end() - win[5].start())
+    };
+}
+
+/** Convert an `arm_compute::Window` to an `arm_gemm::NDCoord` of the same max dimensions
+ *
+ * @param [win] the `arm_compute::Window` we want to convert to `arm_gemm::ndcoord_t`
+ * @return the resultant ndcoord_t
+ */
+inline ndcoord_t to_ndcoord(const arm_compute::Window& win) {
+    return {
+        { static_cast<unsigned int>(win[0].start()), static_cast<unsigned int>(win[0].end() - win[0].start()) },
+        { static_cast<unsigned int>(win[1].start()), static_cast<unsigned int>(win[1].end() - win[1].start()) },
+        { static_cast<unsigned int>(win[2].start()), static_cast<unsigned int>(win[2].end() - win[2].start()) },
+        { static_cast<unsigned int>(win[3].start()), static_cast<unsigned int>(win[3].end() - win[3].start()) },
+        { static_cast<unsigned int>(win[4].start()), static_cast<unsigned int>(win[4].end() - win[4].start()) },
+        { static_cast<unsigned int>(win[5].start()), static_cast<unsigned int>(win[5].end() - win[5].start()) }
+    };
+}
+
+} //namespace arm_gemm
diff --git a/src/core/NEON/kernels/assembly/gemm_common.hpp b/src/core/NEON/kernels/assembly/gemm_common.hpp
new file mode 100644
index 0000000..a44b774
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/gemm_common.hpp
@@ -0,0 +1,201 @@
+/*
+ * Copyright (c) 2017-2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include "arm_gemm_compute_iface.hpp"
+
+#include <cstddef>
+#include <cassert>
+
+#define UNUSED(x)   (void)(x)
+
+namespace arm_gemm {
+
+// Abstract class for the GEMM/GEMV functions.
+//
+// GEMM implementations may be "native" (never require any input
+// permutation), "pretransposed" (require permutation up-front) or require
+// working space (permute as they go along).  This interface should support
+// all of them.
+
+// The real GemmCommon class is templated based on the operand and return
+// type.  This is an interface class which is independent of those types.
+class IGemmCommon {
+public:
+    /* Pass in the pointers to the arrays to be operated on and their
+     * strides.  This "generic" version uses void *s, the preferred version
+     * is the one provided by templated GemmCommon (below) which takes
+     * appropriately typed pointers.  If B is pretransposed (see below) then
+     * the settings for B here are ignored.
+     */
+    virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+                                    const void *B, const int ldb, /* batches share B */     const int B_multi_stride,
+                                          void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+                                    const void *bias, /* no row or batch stride needed */   const int bias_multi_stride) = 0;
+
+    /** @returns an ndrange containing ranges of the compute space which can be
+     * broken up and parallelised over
+     */
+    virtual ndrange_t get_window_size() const = 0;
+
+    /* The maximum thread count is specified when the GEMM is created.  Some
+     * implementations need to know how many threads will actually run in
+     * order to work properly.
+     *
+     * In some cases, after creating the GEMM the number of threads needs to
+     * be reduced (e.g. not enough work to split across threads).  This
+     * method allows the number of actual threads to be run to be set (must
+     * be equal or lower).
+     *
+     * This has an empty default implementation, as GEMMs which don't care
+     * about thread count can safely ignore this.
+     */
+    virtual void set_nthreads(int) { };
+
+    /* Whether this GEMM can be dynamically scheduled or not. */
+    virtual bool supports_dynamic_scheduling() const { return false; }
+
+    /** Main execute member fucntion
+     * @param [in] work_range     specifies the range of work we want to be computed, total range defined by get_window_size()
+     * @param [in] thread_locator where are we inside of the thread space
+     * @naram [in] threadid       a unique threadid
+     */
+    virtual void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) = 0;
+
+    /*** Working space interface (optional) ***/
+    /* Total number of bytes of temporary working space needed.  If zero, it's not necessary to call set_working_space(). */
+    virtual size_t get_working_size() const { return 0; }
+    /* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */
+    virtual void set_working_space(void *) { };
+
+    /*** "Pretransposed" interface (optional) ***/
+    /* Is this object set up for pretranspose?  If so, pretranspose_array() needs to be called before execute(); */
+    virtual bool B_is_pretransposed() const { return false; }
+    /* Does pretranspose still need to be done? */
+    virtual bool B_pretranspose_required() const { return false; }
+    /* Total number of bytes of space needed for pretransposed arrays. */
+    virtual size_t get_B_pretransposed_array_size() const { return 0; }
+    /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
+    /* The "real" version of this depends on the templated operand type (see below).  */
+    virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
+    /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
+    virtual void set_pretransposed_B_data(void *) { }
+
+    /*** "Quantized bias" interface (optional) ***/
+    /* Set the bias vector for quantized GEMMs */
+    virtual void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride)
+    {
+        UNUSED(bias);
+        UNUSED(bias_multi_stride);
+    }
+
+    // Destructor
+    virtual ~IGemmCommon() { }
+};
+
+/* "Real" GemmCommon class which is templated on the operand and return types.
+ *
+ * In addition to correctly typed versions of the functions that operate on
+ * operand and return data, this class provides a default implementation of
+ * 'set_arrays' to capture the provided arguments in protected class
+ * members, as essentially any implementation will need these.
+ */
+template<typename To, typename Tr>
+class GemmCommon : public IGemmCommon {
+protected:
+    const To *_Aptr=nullptr;
+    int _lda=0;
+    int _A_batch_stride=0;
+    int _A_multi_stride=0;
+    const To *_Bptr=nullptr;
+    int _ldb=0;
+    int _B_multi_stride=0;
+    Tr *_Cptr=nullptr;
+    int _ldc=0;
+    int _C_batch_stride=0;
+    int _C_multi_stride=0;
+    const Tr *_bias=nullptr;
+    int _bias_multi_stride=0;
+
+public:
+    /* Pass in the pointers to the arrays to be operated on and their
+     * strides (templated version with appropriate types). */
+    virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+                            const To *B, const int ldb, /* batches share B */     const int B_multi_stride,
+                                  Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+                            const Tr *bias, /* no row or batch stride needed */   const int bias_multi_stride) {
+        _Aptr = A;
+        _lda = lda;
+        _A_batch_stride = A_batch_stride;
+        _A_multi_stride = A_multi_stride;
+        _Bptr = B;
+        _ldb = ldb;
+        _B_multi_stride = B_multi_stride;
+        _Cptr = C;
+        _ldc = ldc;
+        _C_batch_stride = C_batch_stride;
+        _C_multi_stride = C_multi_stride;
+        _bias = bias;
+        _bias_multi_stride = bias_multi_stride;
+    }
+
+    /* Implementation of the void * overload which casts its arguments to the appropriate type. */
+    void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+                            const void *B, const int ldb, /* batches share B */     const int B_multi_stride,
+                                  void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+                            const void *bias, /* no row or batch stride needed */   const int bias_multi_stride) override {
+        set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride,
+                   static_cast<const To *>(B), ldb, B_multi_stride,
+                   static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
+                   static_cast<const Tr *>(bias), bias_multi_stride);
+    }
+
+    /*** "Pretransposed" interface ***/
+
+    /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
+    /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
+    virtual void pretranspose_B_array(void *, const To *, const int, const int) { };
+
+    /* Implementation of the void * overload which casts its arguments to the appropriate type. */
+    void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override {
+        pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
+    }
+};
+
+template<typename GemmKernel>
+inline
+int unsigned get_total_window_size(const GemmKernel& kernel)
+{
+    auto window=kernel.get_window_size();
+
+    unsigned int total = 1;
+    for(unsigned i = 0; i != arm_gemm::ndrange_max; ++i)
+    {
+        total *= window.get_size(i);
+    }
+
+    return total;
+}
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/assembly/ndrange.hpp b/src/core/NEON/kernels/assembly/ndrange.hpp
new file mode 100644
index 0000000..d082a3e
--- /dev/null
+++ b/src/core/NEON/kernels/assembly/ndrange.hpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2019-2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include <array>
+#include <algorithm>
+#include <initializer_list>
+
+#include <cassert>
+
+namespace arm_gemm {
+
+template<unsigned int D>
+class NDRange {
+private:
+    std::array<unsigned int, D> m_sizes {};
+    std::array<unsigned int, D> m_totalsizes {};
+
+    class NDRangeIterator {
+    private:
+        const NDRange &m_parent;
+        unsigned int m_pos = 0;
+        unsigned int m_end = 0;
+
+    public:
+        NDRangeIterator(const NDRange &p, unsigned int s, unsigned int e) : m_parent(p), m_pos(s), m_end(e) { }
+
+        bool done() const {
+            return (m_pos >= m_end);
+        }
+
+        unsigned int dim(unsigned int d) const {
+            unsigned int r = m_pos;
+
+            if (d < (D - 1)) {
+                r %= m_parent.m_totalsizes[d];
+            }
+
+            if (d > 0) {
+                r /= m_parent.m_totalsizes[d-1];
+            }
+
+            return r;
+        }
+
+        bool next_dim0() {
+            m_pos++;
+
+            return !done();
+        }
+
+        bool next_dim1() {
+            m_pos += m_parent.m_sizes[0] - dim(0);
+
+            return !done();
+        }
+
+        unsigned int dim0_max() const {
+            unsigned int offset = std::min(m_end - m_pos, m_parent.m_sizes[0] - dim(0));
+
+            return dim(0) + offset;
+        }
+    };
+
+public:
+    NDRange& operator=(const NDRange& rhs)=default;
+    NDRange(const NDRange& rhs)           =default;
+
+    template <typename... T>
+    NDRange(T... ts)
+    : m_sizes{ts...}
+    {
+        unsigned int t=1;
+
+        for (unsigned int i=0; i<D; i++) {
+            t *= m_sizes[i];
+
+            m_totalsizes[i] = t;
+        }
+    }
+
+    NDRange(const std::array<unsigned int, D>& n)
+    : m_sizes(n)
+    {
+        unsigned int t=1;
+
+        for (unsigned int i=0; i<D; i++) {
+            t *= m_sizes[i];
+
+            m_totalsizes[i] = t;
+        }
+    }
+
+    NDRangeIterator iterator(unsigned int start, unsigned int end) const {
+        return NDRangeIterator(*this, start, end);
+    }
+
+    unsigned int total_size() const {
+        return m_totalsizes[D - 1];
+    }
+
+    unsigned int get_size(unsigned int v) const {
+        return m_sizes[v];
+    }
+};
+
+/** NDCoordinate builds upon a range, but specifies a starting position
+ * in addition to a size which it inherits from NDRange
+ */
+template<unsigned int N>
+class NDCoordinate : public NDRange<N> {
+    using int_t     =unsigned int;
+    using ndrange_t = NDRange<N>;
+
+    std::array<int_t, N> m_positions {};
+public:
+    NDCoordinate& operator=(const NDCoordinate& rhs)=default;
+    NDCoordinate(const NDCoordinate& rhs)           =default;
+    NDCoordinate(const std::initializer_list<std::pair<int_t, int_t>>& list)
+    {
+        std::array<int_t, N> sizes{};
+
+        std::size_t i = 0;
+        for(auto& p : list) {
+            m_positions[i]= p.first;
+            sizes[i++]    = p.second;
+        }
+
+        //update the parents sizes
+        static_cast<ndrange_t&>(*this) = ndrange_t(sizes);
+    }
+
+    int_t get_position(int_t d) const {
+        assert(d < m_positions.size());
+        return m_positions[d];
+    }
+
+    void set_position(int_t d, int_t v) {
+        assert(d < size(m_positions));
+        assert(v < ndrange_t::get_size(d));
+
+        m_positions[d] = v;
+    }
+
+    int_t get_position_end(int_t d) const {
+        return get_position(d) + NDRange<N>::get_size(d);
+    }
+}; //class NDCoordinate
+
+/** @returns the number of dimensions in the NDRange which have none-1 values
+ * IE there is actual work in these dimensions that can be broken up
+ */
+template<unsigned int N>
+std::size_t ndrange_popcount(const NDRange<N>& ndr) {
+    std::size_t count = 0;
+
+    for(unsigned int d = 0; d != N; ++d) {
+        if(ndr.get_size(d) != 1)
+            ++count;
+    }
+    return count;
+}
+
+} // namespace arm_gemm