COMPMID-2063: New Winograd implementation

Refactoring of winograd code reducing the size of the binaries
about 8X.

Change-Id: If8845bda324573e1a5cf436f354ac8603e88a92e
Signed-off-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/959
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Anthony Barbier <Anthony.barbier@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/arm_compute/core/NEON/kernels/convolution/common/padding.hpp b/arm_compute/core/NEON/kernels/convolution/common/padding.hpp
index 33f77d7..97b21e0 100644
--- a/arm_compute/core/NEON/kernels/convolution/common/padding.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/common/padding.hpp
@@ -71,4 +71,21 @@
     );
 };
 
+template <typename T>
+void crop_and_copy_tile(
+  unsigned int tile_rows,
+  unsigned int tile_cols,
+  unsigned int n_channels,
+  const T *inptr,
+  unsigned int in_row_stride,
+  unsigned int in_col_stride,
+  T *outptr,
+  unsigned int out_row_stride,
+  unsigned int out_col_stride,
+  unsigned int crop_top,
+  unsigned int crop_left,
+  unsigned int crop_bottom,
+  unsigned int crop_right
+);
+
 }
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
deleted file mode 100644
index 663b3c4..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-
-namespace winograd
-{
-
-template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
-class BatchedBlockedGemm
-{
-  public:
-    /** Create a new batched blocked GEMM operator. */
-    BatchedBlockedGemm(
-      const unsigned int n_gemms,
-      const int M, const int K, const int N,
-      const int a_matrix_stride,
-      const int a_row_stride,
-      const int b_matrix_stride,
-      const int b_row_stride,
-      const int c_matrix_stride,
-      const int c_row_stride,
-      const TIn* const a_ptr,
-      const TIn* const b_ptr,
-      TOut* const c_ptr
-    );
-
-    BatchedBlockedGemm(const BatchedBlockedGemm&) = delete;
-    BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete;
-
-    /** Get a window of work performed by the operator. */
-    unsigned int get_window() const;
-
-    /** Perform a portion of the work of the operator. */
-    void run(const unsigned int start, const unsigned int stop);
-
-  private:
-    const unsigned int n_gemms;
-    const int M, N, K;
-    const int a_matrix_stride, a_row_stride;
-    const int b_matrix_stride, b_row_stride;
-    const int c_matrix_stride, c_row_stride;
-    const TIn* const a_ptr;
-    const TIn* const b_ptr;
-    TOut* const c_ptr;
-};
-
-}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
deleted file mode 100644
index 6e06db3..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
-
-template <typename TIn, typename TOut>
-inline void Gemm(const TIn* const a, const TIn* const b, TOut *c,
-          const int M, const int K, const int N,
-          const int a_row_stride,
-          const int b_row_stride,
-          const int c_row_stride,
-          const bool a_transposed=false,
-          const bool b_transposed=false) {
-  // Array access methods
-  const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn {
-    return a[(!a_transposed) ? i*a_row_stride + j : i + j*M];
-  };
-
-  const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn {
-    return b[(!b_transposed) ? i*b_row_stride + j : i + j*N];
-  };
-
-  const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
-    return c[i*c_row_stride + j];
-  };
-
-  // Perform the matrix multiplication
-  for (int i = 0; i < M; i++) {
-    for (int j = 0; j < N; j++) {
-      for (int k = 0; k < K; k++) {
-        C(i, j) += A(i, k) * B(k, j);
-      }
-    }
-  }
-}
-
-template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
-inline void BlockedGemm(
-  const TIn* const a, const TIn* const b, TOut *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  // Array access methods
-  const auto A = [a, a_row_stride] (const int i, const int j) -> TIn {
-    return a[i*a_row_stride + j];
-  };
-
-  const auto B = [b, b_row_stride] (const int i, const int j) -> TIn {
-    return b[i*b_row_stride + j];
-  };
-
-  const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
-    return c[i*c_row_stride + j];
-  };
-
-  const int M_BLOCKS = iceildiv(M, M_BLOCK);
-  const int N_BLOCKS = iceildiv(N, N_BLOCK);
-
-  // For each block of output rows
-  for (int mblock = 0; mblock < M_BLOCKS; mblock++) {
-    // For each block of output columns
-    for (int nblock = 0; nblock < N_BLOCKS; nblock++) {
-      // Create an appropriately sized block of accumulators
-      TOut accum[M_BLOCK][N_BLOCK];
-      for (int i = 0; i < M_BLOCK; i++) {
-        for (int j = 0; j < N_BLOCK; j++) {
-          accum[i][j] = static_cast<TOut>(0);
-        }
-      }
-
-      // Perform this portion of the matrix multiply
-      for (int k = 0; k < K; k++) {
-        // Load elements of A
-        TIn elems_a[M_BLOCK];
-        for (int i = 0; i < M_BLOCK; i++) {
-          elems_a[i] = A(mblock*M_BLOCK + i, k);
-        }
-
-        // Load elements of B
-        TIn elems_b[N_BLOCK];
-        for (int j = 0; j < N_BLOCK; j++) {
-          elems_b[j] = B(k, nblock*N_BLOCK + j);
-        }
-
-        // Perform the partial matrix multiply
-        for (int i = 0; i < M_BLOCK; i++) {
-          for (int j = 0; j < N_BLOCK; j++) {
-            accum[i][j] += elems_a[i] * elems_b[j];
-          }
-        }
-      }
-
-      // Store the partial product
-      for (int i = 0; i < M_BLOCK; i++) {
-        for (int j = 0; j < N_BLOCK; j++) {
-          C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j];
-        }
-      }
-    }
-  }
-}
-
-#include "gemm/a64_sgemm.hpp"
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
deleted file mode 100644
index 8073cb1..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
+++ /dev/null
@@ -1,355 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-#include <cassert>
-#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
-
-#ifdef __aarch64__
-
-template <>
-inline void BlockedGemm<8, 12, float, float>(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  const int M_BLOCK = 8;
-  const int N_BLOCK = 12;
-
-  const int m_blocks = iceildiv(M, M_BLOCK);
-  const int n_blocks = iceildiv(N, N_BLOCK);
-
-  // For each block of output rows
-  for (int mblock = 0; mblock < m_blocks; mblock++) {
-    // For each block of output columns
-    for (int nblock = 0; nblock < n_blocks; nblock++) {
-      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
-      const float *bptr = b + nblock*N_BLOCK;
-      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
-      int k = K;
-
-      asm volatile (
-          // Create an 8x12 block of accumulators
-          " A_1 .req v27\n"
-          "sA_1 .req s27\n"
-          " A_2 .req v28\n"
-          "sA_2 .req s28\n"
-          " A_3 .req v29\n"
-          "sA_3 .req s29\n"
-          " A_4 .req v30\n"
-          "sA_4 .req s30\n"
-
-          " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n"
-          "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n"
-
-          " C_11 .req  v0\n" " C_12 .req  v1\n" " C_13 .req  v2\n"
-          " C_21 .req  v3\n" " C_22 .req  v4\n" " C_23 .req  v5\n"
-          " C_31 .req  v6\n" " C_32 .req  v7\n" " C_33 .req  v8\n"
-          " C_41 .req  v9\n" " C_42 .req v10\n" " C_43 .req v11\n"
-          " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n"
-          " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n"
-          " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n"
-          " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n"
-
-          "qC_11 .req  q0\n" "qC_12 .req  q1\n" "qC_13 .req  q2\n"
-          "qC_21 .req  q3\n" "qC_22 .req  q4\n" "qC_23 .req  q5\n"
-          "qC_31 .req  q6\n" "qC_32 .req  q7\n" "qC_33 .req  q8\n"
-          "qC_41 .req  q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n"
-          "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n"
-          "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n"
-          "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n"
-          "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n"
-
-          "aptr1 .req x17\n"
-          "aptr2 .req x18\n"
-          "aptr3 .req x19\n"
-          "aptr4 .req x20\n"
-          "aptr5 .req x21\n"
-          "aptr6 .req x22\n"
-          "aptr7 .req x23\n"
-
-          // Initialise accumulators with 0
-          // Initialise pointers
-          "movi C_11.4s, #0\n"
-          "add aptr1, %x[aptr], %x[a_row_stride]\n"
-          "movi C_12.4s, #0\n"
-          "add aptr2,    aptr1, %x[a_row_stride]\n"
-          "movi C_13.4s, #0\n"
-          "add aptr3,    aptr2, %x[a_row_stride]\n"
-          "movi C_21.4s, #0\n"
-          "add aptr4,    aptr3, %x[a_row_stride]\n"
-          "movi C_22.4s, #0\n"
-          "add aptr5,    aptr4, %x[a_row_stride]\n"
-          "movi C_23.4s, #0\n"
-          "add aptr6,    aptr5, %x[a_row_stride]\n"
-          "movi C_31.4s, #0\n"
-          "add aptr7,    aptr6, %x[a_row_stride]\n"
-          "movi C_32.4s, #0\n"
-          "ldr qB_1, [%x[bptr]]\n"
-          "movi C_33.4s, #0\n"
-          "ldr qB_2, [%x[bptr], #0x10]\n"
-          "movi C_41.4s, #0\n"
-          "prfm pldl1keep, [%x[bptr], #0x00]\n"
-          "movi C_42.4s, #0\n"
-          "prfm pldl1keep, [%x[bptr], #0x10]\n"
-          "movi C_43.4s, #0\n"
-          "prfm pldl1keep, [%x[bptr], #0x20]\n"
-          "movi C_51.4s, #0\n"
-          "prfm pldl1keep, [%x[aptr], #0x00]\n"
-          "movi C_52.4s, #0\n"
-          "prfm pldl1keep, [   aptr1, #0x00]\n"
-          "movi C_53.4s, #0\n"
-          "prfm pldl1keep, [   aptr2, #0x00]\n"
-          "movi C_61.4s, #0\n"
-          "prfm pldl1keep, [   aptr3, #0x00]\n"
-          "movi C_62.4s, #0\n"
-          "prfm pldl1keep, [   aptr4, #0x00]\n"
-          "movi C_63.4s, #0\n"
-          "prfm pldl1keep, [   aptr5, #0x00]\n"
-          "movi C_71.4s, #0\n"
-          "prfm pldl1keep, [   aptr6, #0x00]\n"
-          "movi C_72.4s, #0\n"
-          "prfm pldl1keep, [   aptr7, #0x00]\n"
-          "movi C_73.4s, #0\n"
-          "ldr sA_1, [%x[aptr]], #0x4\n"
-          "movi C_81.4s, #0\n"
-          "ldr sA_2, [   aptr1], #0x4\n"
-          "movi C_82.4s, #0\n"
-          "ldr sA_3, [   aptr2], #0x4\n"
-          "movi C_83.4s, #0\n"
-          "subs %x[k], %x[k], #1\n"
-          "beq 2f\n"
-
-          "1:"
-            "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
-            "ldr qB_3, [%x[bptr], #0x20]\n"
-            "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
-            "ldr sA_4, [   aptr3], #0x4\n"
-            "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
-            "ldr sA_1, [   aptr4], #0x04\n"
-
-            "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
-            "add %x[bptr], %x[bptr], %x[b_row_stride]\n"
-            "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
-            "prfm pldl1keep, [   aptr3, #0x10]\n"
-            "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
-            "ldr sA_2, [   aptr5], #0x04\n"
-
-            "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
-            "prfm pldl1keep, [%x[bptr], #0x00]\n"
-            "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
-            "prfm pldl1keep, [%x[bptr], #0x10]\n"
-            "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
-            "ldr sA_3, [   aptr6], #0x04\n"
-
-            "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
-            "prfm pldl1keep, [%x[bptr], #0x20]\n"
-            "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
-            "prfm pldl1keep, [   aptr4, #0x10]\n"
-            "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
-            "ldr sA_4, [   aptr7], #0x04\n"
-
-            "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
-            "prfm pldl1keep, [   aptr5, #0x10]\n"
-            "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
-            "prfm pldl1keep, [   aptr6, #0x10]\n"
-            "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
-            "ldr sA_1, [%x[aptr]], #0x04\n"
-
-            "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
-            "prfm pldl1keep, [   aptr7, #0x10]\n"
-            "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
-            "subs %x[k], %x[k], #1\n"
-            "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
-            "ldr sA_2, [   aptr1], #0x04\n"
-
-            "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
-            "prfm pldl1keep, [%x[aptr], #0x10]\n"
-            "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
-            "prfm pldl1keep, [   aptr1, #0x10]\n"
-            "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
-            "ldr sA_3, [   aptr2], #0x04\n"
-
-            "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
-            "prfm pldl1keep, [   aptr2, #0x10]\n"
-            "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
-            "ldp qB_1, qB_2, [%x[bptr]]\n"
-            "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
-            "bne 1b\n"
-
-          "2:"
-            "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
-            "ldr qB_3, [%x[bptr], #0x20]\n"
-            "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
-            "stp qC_11, qC_12, [%x[cptr]]\n"
-            "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
-            "str qC_13, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-            "ldr sA_1, [   aptr4], #0x04\n"
-
-            "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
-            "ldr sA_4, [   aptr3], #0x4\n"
-            "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
-            "stp qC_21, qC_22, [%x[cptr]]\n"
-            "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
-            "str qC_23, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-            "ldr sA_2, [   aptr5], #0x04\n"
-
-            "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
-            "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
-            "stp qC_31, qC_32, [%x[cptr]]\n"
-            "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
-            "str qC_33, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-            "ldr sA_3, [   aptr6], #0x04\n"
-
-            "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
-            "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
-            "stp qC_41, qC_42, [%x[cptr]]\n"
-            "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
-            "str qC_43, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-            "ldr sA_4, [   aptr7], #0x04\n"
-
-            "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
-            "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
-            "stp qC_51, qC_52, [%x[cptr]]\n"
-            "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
-            "str qC_53, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-
-            "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
-            "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
-            "stp qC_61, qC_62, [%x[cptr]]\n"
-            "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
-            "str qC_63, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-
-            "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
-            "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
-            "stp qC_71, qC_72, [%x[cptr]]\n"
-            "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
-            "str qC_73, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-
-            "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
-            "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
-            "stp qC_81, qC_82, [%x[cptr]]\n"
-            "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
-            "str qC_83, [%x[cptr], #0x20]\n"
-            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
-
-          // Clear aliases
-          ".unreq aptr1\n"
-          ".unreq aptr2\n"
-          ".unreq aptr3\n"
-          ".unreq aptr4\n"
-          ".unreq aptr5\n"
-          ".unreq aptr6\n"
-          ".unreq aptr7\n"
-
-          ".unreq  A_1\n" ".unreq  A_2\n" ".unreq  A_3\n" ".unreq  A_4\n"
-          ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n"
-
-          ".unreq  B_1\n" ".unreq  B_2\n" ".unreq  B_3\n"
-          ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n"
-
-          ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n"
-          ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n"
-          ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n"
-          ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n"
-          ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n"
-          ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n"
-          ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n"
-          ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n"
-
-          ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n"
-          ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n"
-          ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n"
-          ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n"
-          ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n"
-          ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n"
-          ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n"
-          ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n"
-          : [aptr] "+r" (aptr),
-            [bptr] "+r" (bptr),
-            [cptr] "+r" (cptr),
-            [k] "+r" (k)
-          : [a_row_stride] "r" (a_row_stride * sizeof(float)),
-            [b_row_stride] "r" (b_row_stride * sizeof(float)),
-            [c_row_stride] "r" (c_row_stride * sizeof(float))
-          : "cc", "memory",
-            "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-            "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-            "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
-            "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23"
-      );
-    }
-  }
-}
-
-/*****************************************************************************/
-/* 4x16 blocked GEMM with specialised tails
- */
-#include "a64_sgemm_4x16.hpp"
-
-template <>
-inline void BlockedGemm<4, 16, float, float>(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  // Despatch based on tail of K
-  switch (K % 4) {
-    case 3:
-      sgemm_4x16_impl<3>(
-        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
-      );
-      break;
-    case 2:
-      sgemm_4x16_impl<2>(
-        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
-      );
-      break;
-    case 1:
-      sgemm_4x16_impl<1>(
-        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
-      );
-      break;
-    case 0:
-      sgemm_4x16_impl<0>(
-        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
-      );
-      break;
-    default:
-      assert(false);
-  }
-}
-
-#endif  // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
deleted file mode 100644
index 5cd37de..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
+++ /dev/null
@@ -1,1446 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-template <const unsigned int tail>
-inline void sgemm_4x16_impl(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-);
-
-template <>
-inline void sgemm_4x16_impl<0>(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  const int TAIL_SIZE = 0;
-  const int M_BLOCK = 4;
-  const int N_BLOCK = 16;
-
-  const int m_blocks = iceildiv(M, M_BLOCK);
-  const int n_blocks = iceildiv(N, N_BLOCK);
-
-  // For each block of output rows
-  for (int mblock = 0; mblock < m_blocks; mblock++) {
-    // For each block of output columns
-    for (int nblock = 0; nblock < n_blocks; nblock++) {
-      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
-      const float *bptr = b + nblock*N_BLOCK;
-      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
-      int k = (K - TAIL_SIZE) / 4;
-
-      asm volatile(
-        "aptr2 .req X20\n"
-        "aptr3 .req X21\n"
-        "aptr4 .req X22\n"
-        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
-        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
-        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
-        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
-        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
-        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
-        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
-        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
-        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
-        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
-        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
-        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
-        "vB1 .req v20\n" "qB1 .req q20\n"
-        "vB2 .req v21\n" "qB2 .req q21\n"
-        "vB3 .req v22\n" "qB3 .req q22\n"
-        "vB4 .req v23\n" "qB4 .req q23\n"
-
-        // Clear accumulators, initialise pointers
-        "movi vC11.4s, #0\n"
-        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
-        "movi vC12.4s, #0\n"
-        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
-        "movi vC13.4s, #0\n"
-        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
-        "movi vC14.4s, #0\n"
-        "ldr qA1, [%x[aptr]], #0x10\n"
-        "movi vC21.4s, #0\n"
-        "ldr qA2, [   aptr2], #0x10\n"
-        "movi vC22.4s, #0\n"
-        "ldr qB1, [%x[bptr], #0x00]\n"
-        "movi vC23.4s, #0\n"
-        "ldr qB2, [%x[bptr], #0x10]\n"
-        "movi vC24.4s, #0\n"
-        "ldr qB3, [%x[bptr], #0x20]\n"
-        "movi vC31.4s, #0\n"
-        "movi vC32.4s, #0\n"
-        "movi vC33.4s, #0\n"
-        "movi vC34.4s, #0\n"
-        "movi vC41.4s, #0\n"
-        "movi vC42.4s, #0\n"
-        "movi vC43.4s, #0\n"
-        "movi vC44.4s, #0\n"
-        "subs %x[k], %x[k], #1\n"
-        "beq 2f\n"
-
-        "1:"  // Loop proper
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "subs %x[k], %x[k], #1\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-          "bne 1b\n"
-
-        "2:"  // Tail
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "stp qC11, qC12, [%x[cptr], #0x00]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "stp qC13, qC14, [%x[cptr], #0x20]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "stp qC21, qC22, [%x[cptr], #0x00]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "stp qC23, qC24, [%x[cptr], #0x20]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "stp qC31, qC32, [%x[cptr], #0x00]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "stp qC33, qC34, [%x[cptr], #0x20]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "stp qC41, qC42, [%x[cptr], #0x00]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-          "stp qC43, qC44, [%x[cptr], #0x20]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-
-        ".unreq vB4\n" ".unreq qB4\n"
-        ".unreq vB3\n" ".unreq qB3\n"
-        ".unreq vB2\n" ".unreq qB2\n"
-        ".unreq vB1\n" ".unreq qB1\n"
-        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
-        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
-        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
-        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
-        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
-        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
-        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
-        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
-        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
-        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
-        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
-        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
-        ".unreq aptr2\n"
-        ".unreq aptr3\n"
-        ".unreq aptr4\n"
-
-        : [aptr] "+r" (aptr),
-          [bptr] "+r" (bptr),
-          [cptr] "+r" (cptr),
-          [k] "+r" (k)
-        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
-          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
-          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
-        : "cc", "memory", "x20", "x21", "x22",
-          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
-          "v21", "v22", "v23"
-      );
-    }
-  }
-}
-
-template <>
-inline void sgemm_4x16_impl<1>(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  const int TAIL_SIZE = 1;
-  const int M_BLOCK = 4;
-  const int N_BLOCK = 16;
-
-  const int m_blocks = iceildiv(M, M_BLOCK);
-  const int n_blocks = iceildiv(N, N_BLOCK);
-
-  // For each block of output rows
-  for (int mblock = 0; mblock < m_blocks; mblock++) {
-    // For each block of output columns
-    for (int nblock = 0; nblock < n_blocks; nblock++) {
-      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
-      const float *bptr = b + nblock*N_BLOCK;
-      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
-      int k = (K - TAIL_SIZE) / 4;
-
-      asm volatile(
-        "aptr2 .req X20\n"
-        "aptr3 .req X21\n"
-        "aptr4 .req X22\n"
-        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
-        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
-        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
-        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
-        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
-        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
-        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
-        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
-        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
-        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
-        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
-        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
-        "vB1 .req v20\n" "qB1 .req q20\n"
-        "vB2 .req v21\n" "qB2 .req q21\n"
-        "vB3 .req v22\n" "qB3 .req q22\n"
-        "vB4 .req v23\n" "qB4 .req q23\n"
-
-        // Clear accumulators, initialise pointers
-        "movi vC11.4s, #0\n"
-        "ldr qB1, [%x[bptr], #0x00]\n"
-        "movi vC12.4s, #0\n"
-        "ldr qB2, [%x[bptr], #0x10]\n"
-        "movi vC13.4s, #0\n"
-        "ldr qB3, [%x[bptr], #0x20]\n"
-        "movi vC14.4s, #0\n"
-        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
-        "movi vC21.4s, #0\n"
-        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
-        "movi vC22.4s, #0\n"
-        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
-        "movi vC23.4s, #0\n"
-        "cbnz %x[k], 3f\n"
-
-        // Prepare for tail in K
-        "movi vC24.4s, #0\n"
-        "ldr sA1, [%x[aptr]], #0x04\n"
-        "movi vC31.4s, #0\n"
-        "ldr sA2, [   aptr2], #0x04\n"
-        "movi vC32.4s, #0\n"
-        "movi vC33.4s, #0\n"
-        "movi vC34.4s, #0\n"
-        "movi vC41.4s, #0\n"
-        "movi vC42.4s, #0\n"
-        "movi vC43.4s, #0\n"
-        "movi vC44.4s, #0\n"
-        "b 2f\n"  // Jump to tail
-
-        "3:"  // Prepare for loop over K
-          "movi vC24.4s, #0\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "movi vC31.4s, #0\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "movi vC32.4s, #0\n"
-          "movi vC33.4s, #0\n"
-          "movi vC34.4s, #0\n"
-          "movi vC41.4s, #0\n"
-          "movi vC42.4s, #0\n"
-          "movi vC43.4s, #0\n"
-          "movi vC44.4s, #0\n"
-          "subs %x[k], %x[k], #1\n"
-          "beq 4f\n"
-
-        "1:"  // Loop proper
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "subs %x[k], %x[k], #1\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-          "bne 1b\n"
-
-        "4:"  // Tail iteration
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr sA1, [%x[aptr]], #0x04\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr sA2, [   aptr2], #0x04\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-
-        "2:"  // Common tail
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "stp qC11, qC12, [%x[cptr], #0x00]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "ldr sA3, [   aptr3], #0x04\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "stp qC13, qC14, [%x[cptr], #0x20]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "stp qC21, qC22, [%x[cptr], #0x00]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "ldr sA4, [   aptr4], #0x04\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "stp qC23, qC24, [%x[cptr], #0x20]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "stp qC31, qC32, [%x[cptr], #0x00]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "stp qC33, qC34, [%x[cptr], #0x20]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "stp qC41, qC42, [%x[cptr], #0x00]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-          "stp qC43, qC44, [%x[cptr], #0x20]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-
-        ".unreq vB4\n" ".unreq qB4\n"
-        ".unreq vB3\n" ".unreq qB3\n"
-        ".unreq vB2\n" ".unreq qB2\n"
-        ".unreq vB1\n" ".unreq qB1\n"
-        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
-        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
-        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
-        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
-        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
-        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
-        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
-        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
-        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
-        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
-        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
-        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
-        ".unreq aptr2\n"
-        ".unreq aptr3\n"
-        ".unreq aptr4\n"
-
-        : [aptr] "+r" (aptr),
-          [bptr] "+r" (bptr),
-          [cptr] "+r" (cptr),
-          [k] "+r" (k)
-        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
-          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
-          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
-        : "cc", "memory", "x20", "x21", "x22",
-          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
-          "v21", "v22", "v23"
-      );
-    }
-  }
-}
-
-template <>
-inline void sgemm_4x16_impl<2>(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  const int TAIL_SIZE = 2;
-  const int M_BLOCK = 4;
-  const int N_BLOCK = 16;
-
-  const int m_blocks = iceildiv(M, M_BLOCK);
-  const int n_blocks = iceildiv(N, N_BLOCK);
-
-  // For each block of output rows
-  for (int mblock = 0; mblock < m_blocks; mblock++) {
-    // For each block of output columns
-    for (int nblock = 0; nblock < n_blocks; nblock++) {
-      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
-      const float *bptr = b + nblock*N_BLOCK;
-      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
-      int k = (K - TAIL_SIZE) / 4;
-
-      asm volatile(
-        "aptr2 .req X20\n"
-        "aptr3 .req X21\n"
-        "aptr4 .req X22\n"
-        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
-        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
-        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
-        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
-        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
-        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
-        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
-        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
-        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
-        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
-        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
-        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
-        "vB1 .req v20\n" "qB1 .req q20\n"
-        "vB2 .req v21\n" "qB2 .req q21\n"
-        "vB3 .req v22\n" "qB3 .req q22\n"
-        "vB4 .req v23\n" "qB4 .req q23\n"
-
-        // Clear accumulators, initialise pointers
-        "movi vC11.4s, #0\n"
-        "ldr qB1, [%x[bptr], #0x00]\n"
-        "movi vC12.4s, #0\n"
-        "ldr qB2, [%x[bptr], #0x10]\n"
-        "movi vC13.4s, #0\n"
-        "ldr qB3, [%x[bptr], #0x20]\n"
-        "movi vC14.4s, #0\n"
-        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
-        "movi vC21.4s, #0\n"
-        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
-        "movi vC22.4s, #0\n"
-        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
-        "movi vC23.4s, #0\n"
-        "cbnz %x[k], 3f\n"
-
-        // Prepare for tail in K
-        "movi vC24.4s, #0\n"
-        "ldr dA1, [%x[aptr]], #0x08\n"
-        "movi vC31.4s, #0\n"
-        "ldr dA2, [   aptr2], #0x08\n"
-        "movi vC32.4s, #0\n"
-        "movi vC33.4s, #0\n"
-        "movi vC34.4s, #0\n"
-        "movi vC41.4s, #0\n"
-        "movi vC42.4s, #0\n"
-        "movi vC43.4s, #0\n"
-        "movi vC44.4s, #0\n"
-        "b 2f\n"  // Jump to tail
-
-        "3:"  // Prepare for loop over K
-          "movi vC24.4s, #0\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "movi vC31.4s, #0\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "movi vC32.4s, #0\n"
-          "movi vC33.4s, #0\n"
-          "movi vC34.4s, #0\n"
-          "movi vC41.4s, #0\n"
-          "movi vC42.4s, #0\n"
-          "movi vC43.4s, #0\n"
-          "movi vC44.4s, #0\n"
-          "subs %x[k], %x[k], #1\n"
-          "beq 4f\n"
-
-        "1:"  // Loop proper
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "subs %x[k], %x[k], #1\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-          "bne 1b\n"
-
-        "4:"  // Tail iteration
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr dA1, [%x[aptr]], #0x08\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr dA2, [   aptr2], #0x08\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-
-        "2:"  // Common tail
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr dA3, [   aptr3], #0x08\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr dA4, [   aptr4], #0x08\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "stp qC11, qC12, [%x[cptr], #0x00]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "stp qC13, qC14, [%x[cptr], #0x20]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "stp qC21, qC22, [%x[cptr], #0x00]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "stp qC23, qC24, [%x[cptr], #0x20]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "stp qC31, qC32, [%x[cptr], #0x00]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "stp qC33, qC34, [%x[cptr], #0x20]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "stp qC41, qC42, [%x[cptr], #0x00]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-          "stp qC43, qC44, [%x[cptr], #0x20]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-
-        ".unreq vB4\n" ".unreq qB4\n"
-        ".unreq vB3\n" ".unreq qB3\n"
-        ".unreq vB2\n" ".unreq qB2\n"
-        ".unreq vB1\n" ".unreq qB1\n"
-        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
-        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
-        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
-        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
-        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
-        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
-        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
-        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
-        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
-        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
-        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
-        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
-        ".unreq aptr2\n"
-        ".unreq aptr3\n"
-        ".unreq aptr4\n"
-
-        : [aptr] "+r" (aptr),
-          [bptr] "+r" (bptr),
-          [cptr] "+r" (cptr),
-          [k] "+r" (k)
-        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
-          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
-          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
-        : "cc", "memory", "x20", "x21", "x22",
-          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
-          "v21", "v22", "v23"
-      );
-    }
-  }
-}
-
-template <>
-inline void sgemm_4x16_impl<3>(
-  const float* const a, const float* const b, float *c,
-  const int M, const int K, const int N,
-  const int a_row_stride,
-  const int b_row_stride,
-  const int c_row_stride
-) {
-  const int TAIL_SIZE = 3;
-  const int M_BLOCK = 4;
-  const int N_BLOCK = 16;
-
-  const int m_blocks = iceildiv(M, M_BLOCK);
-  const int n_blocks = iceildiv(N, N_BLOCK);
-
-  // For each block of output rows
-  for (int mblock = 0; mblock < m_blocks; mblock++) {
-    // For each block of output columns
-    for (int nblock = 0; nblock < n_blocks; nblock++) {
-      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
-      const float *bptr = b + nblock*N_BLOCK;
-      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
-      int k = (K - TAIL_SIZE) / 4;
-
-      asm volatile(
-        "aptr2 .req X20\n"
-        "aptr3 .req X21\n"
-        "aptr4 .req X22\n"
-        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
-        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
-        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
-        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
-        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
-        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
-        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
-        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
-        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
-        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
-        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
-        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
-        "vB1 .req v20\n" "qB1 .req q20\n"
-        "vB2 .req v21\n" "qB2 .req q21\n"
-        "vB3 .req v22\n" "qB3 .req q22\n"
-        "vB4 .req v23\n" "qB4 .req q23\n"
-
-        // Clear accumulators, initialise pointers
-        "movi vC11.4s, #0\n"
-        "ldr qB1, [%x[bptr], #0x00]\n"
-        "movi vC12.4s, #0\n"
-        "ldr qB2, [%x[bptr], #0x10]\n"
-        "movi vC13.4s, #0\n"
-        "ldr qB3, [%x[bptr], #0x20]\n"
-        "movi vC14.4s, #0\n"
-        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
-        "movi vC21.4s, #0\n"
-        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
-        "movi vC22.4s, #0\n"
-        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
-        "movi vC23.4s, #0\n"
-        "cbnz %x[k], 3f\n"
-
-        // Prepare for tail in K
-        "movi vC24.4s, #0\n"
-        "ldr dA1, [%x[aptr]], #0x08\n"
-        "movi vC31.4s, #0\n"
-        "ldr dA2, [   aptr2], #0x08\n"
-        "movi vC32.4s, #0\n"
-        "movi vC33.4s, #0\n"
-        "movi vC34.4s, #0\n"
-        "movi vC41.4s, #0\n"
-        "movi vC42.4s, #0\n"
-        "movi vC43.4s, #0\n"
-        "movi vC44.4s, #0\n"
-        "b 2f\n"  // Jump to tail
-
-        "3:"  // Prepare for loop over K
-          "movi vC24.4s, #0\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "movi vC31.4s, #0\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "movi vC32.4s, #0\n"
-          "movi vC33.4s, #0\n"
-          "movi vC34.4s, #0\n"
-          "movi vC41.4s, #0\n"
-          "movi vC42.4s, #0\n"
-          "movi vC43.4s, #0\n"
-          "movi vC44.4s, #0\n"
-          "subs %x[k], %x[k], #1\n"
-          "beq 4f\n"
-
-        "1:"  // Loop proper
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "subs %x[k], %x[k], #1\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr qA1, [%x[aptr]], #0x10\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr qA2, [   aptr2], #0x10\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-          "bne 1b\n"
-
-        "4:"  // Tail iteration
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qA3, [   aptr3], #0x10\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr qA4, [   aptr4], #0x10\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr dA1, [%x[aptr]], #0x08\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr dA2, [   aptr2], #0x08\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
-
-        "2:"  // Common tail
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr dA3, [   aptr3], #0x08\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr dA4, [   aptr4], #0x08\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
-          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
-          "ldr qB1, [%x[bptr], #0x00]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
-          "ldr qB2, [%x[bptr], #0x10]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "ldr sA1, [%x[aptr]], #0x04\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "ldr sA2, [   aptr2], #0x04\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
-          "ldr qB3, [%x[bptr], #0x20]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
-
-          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr qB4, [%x[bptr], #0x30]\n"
-          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
-          "stp qC11, qC12, [%x[cptr], #0x00]\n"
-          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "ldr sA3, [   aptr3], #0x04\n"
-          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
-          "stp qC13, qC14, [%x[cptr], #0x20]\n"
-          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
-          "stp qC21, qC22, [%x[cptr], #0x00]\n"
-          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "ldr sA4, [   aptr4], #0x04\n"
-          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
-          "stp qC23, qC24, [%x[cptr], #0x20]\n"
-          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
-          "stp qC31, qC32, [%x[cptr], #0x00]\n"
-          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
-          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
-          "stp qC33, qC34, [%x[cptr], #0x20]\n"
-          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
-          "stp qC41, qC42, [%x[cptr], #0x00]\n"
-          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
-          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
-          "stp qC43, qC44, [%x[cptr], #0x20]\n"
-          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
-
-        ".unreq vB4\n" ".unreq qB4\n"
-        ".unreq vB3\n" ".unreq qB3\n"
-        ".unreq vB2\n" ".unreq qB2\n"
-        ".unreq vB1\n" ".unreq qB1\n"
-        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
-        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
-        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
-        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
-        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
-        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
-        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
-        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
-        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
-        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
-        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
-        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
-        ".unreq aptr2\n"
-        ".unreq aptr3\n"
-        ".unreq aptr4\n"
-
-        : [aptr] "+r" (aptr),
-          [bptr] "+r" (bptr),
-          [cptr] "+r" (cptr),
-          [k] "+r" (k)
-        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
-          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
-          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
-        : "cc", "memory", "x20", "x21", "x22",
-          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
-          "v21", "v22", "v23"
-      );
-    }
-  }
-}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
deleted file mode 100644
index b813bbb..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
+++ /dev/null
@@ -1,349 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-#include "../winograd_gemm.hpp"
-
-namespace winograd
-{
-  /***************************************************************************/
-  /* Instance-less API */
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void InputTransformImpl<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::execute(
-    const T* const input,        /** Input tensor data */
-    const int n_batches,         /** Number of batches in input tensor. */
-    const int in_batch_stride,   /** Stride between batches of the input. */
-    const int n_rows,            /** Number of rows in input tensor. */
-    const int in_row_stride,     /** Stride between rows of the input. */
-    const int n_cols,            /** Number of columns in input tensor. */
-    const int in_col_stride,     /** Stride between columns of the input. */
-    const int n_channels,        /** Number of channels in input tensor. */
-    const PaddingType padding,   /** Padding type. */
-    const int tile_M,
-    const int tile_N,
-    T* const output,             /** Base of output matrices. */
-    const int matrix_stride,     /** Stride between output matrices. */
-    const int matrix_batch_stride,  /** Stride between batches within the matrix. */
-    const int matrix_row_stride  /** Stride within matrices. */
-  )
-  {
-    // Compute the padding required on each edge of the image
-    const int pad_top = (padding == PADDING_SAME) ? (KernelRows - 1) / 2 : 0;
-    const int pad_left = (padding == PADDING_SAME) ? (KernelCols - 1) / 2 : 0;
-
-    // Compute striding values (assuming NHWC ordered data)
-    const int output_col_stride = matrix_row_stride;
-    const int output_row_stride = tile_N * output_col_stride;
-
-    // Loop over batches
-    for (int batch = 0; batch < n_batches; batch++)
-    {
-      // Pointer to the batch
-      const T* const input_base_batch = input + batch * in_batch_stride;
-      T* const outptr_base_batch = output + batch * matrix_batch_stride;
-
-      // Loop over rows of tiles
-      for (int tile_i = 0; tile_i < tile_M; tile_i++)
-      {
-        // Padding (top + bottom) for the row
-        const int row_top = tile_i*(InnerTileRows - overlap_rows) - pad_top;
-        const int row_bottom = row_top + InnerTileRows;
-        const int row_pad_top = std::max(0, pad_top - tile_i*(InnerTileRows - overlap_rows));
-        const int row_pad_bottom = (row_bottom <= n_rows) ? 0 : row_bottom - n_rows;
-
-        // Pointer to the row
-        const int row_offset = std::min(0, row_pad_top - pad_top);
-        const T* const input_base_row = (
-          input_base_batch + ((InnerTileRows - overlap_rows)*tile_i + row_offset)*in_row_stride
-        );
-        T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride;
-
-        // Process the row
-        process_tile_row(
-          tile_N, n_channels,
-          input_base_row, in_row_stride, in_col_stride,
-          outptr_base_row, matrix_stride, matrix_row_stride,
-          row_pad_top, pad_left, row_pad_bottom, n_cols
-        );
-      }
-    }
-  }
-
-
-  template <int KernelRows, int InnerTileRows, typename T>
-  void InputTransformImpl<KernelRows, 1, InnerTileRows, 1, T>::execute(
-    const T* const input,        /** Input tensor data */
-    const int n_batches,         /** Number of batches in input tensor. */
-    const int in_batch_stride,   /** Stride between batches of the input. */
-    const int n_rows,            /** Number of rows in input tensor. */
-    const int in_row_stride,     /** Stride between rows of the input. */
-    const int n_cols,            /** Number of columns in input tensor. */
-    const int in_col_stride,     /** Stride between columns of the input. */
-    const int n_channels,        /** Number of channels in input tensor. */
-    const PaddingType padding,   /** Padding type. */
-    const int tile_M,
-    const int tile_N,
-    T* const output,             /** Base of output matrices. */
-    const int matrix_stride,     /** Stride between output matrices. */
-    const int matrix_batch_stride,  /** Stride between batches within the matrix. */
-    const int matrix_row_stride  /** Stride within matrices. */
-  )
-  {
-    // If an Nx1 kernel then transpose and redirect to the 1xN implementation
-    InputTransformImpl<1, KernelRows, 1, InnerTileRows, T>::execute(
-      input,
-      n_batches, in_batch_stride,
-      n_cols, in_col_stride,
-      n_rows, in_row_stride,
-      n_channels, padding,
-      tile_N, tile_M,
-      output, matrix_stride, matrix_batch_stride, matrix_row_stride
-    );
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void InputTransformImpl<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::process_tile_row(
-    const int tile_N,
-    int n_channels,
-    const T* const input_base,
-    const int input_row_stride,
-    const int input_col_stride,
-    T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    const int pad_top,
-    const int row_pad_left,
-    const int pad_bottom,
-    const int n_cols
-  )
-  {
-    // Loop over columns of tiles
-    for (int tile_j = 0; tile_j < tile_N; tile_j++)
-    {
-      // Padding (left + right) for the tile
-      const int t_start = tile_j*(InnerTileCols - overlap_cols) - row_pad_left;
-      const int t_end = t_start + InnerTileCols;
-      const int t_pad_left = std::max(0, row_pad_left - tile_j*(InnerTileCols - overlap_cols));
-      const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols;
-
-      // Get pointers into the inputs and outputs
-      const int col_offset = std::min(0, t_pad_left - row_pad_left);
-      const T* const input_base_col = (
-        input_base + ((InnerTileCols - overlap_cols)*tile_j + col_offset)*input_col_stride
-      );
-      T* const outptr = matrix_base + tile_j*matrix_row_stride;
-
-      // Apply the specific tile processing function
-      const typename Tiles::TileFn tilefn = Tiles::get_tile_specialization(
-        pad_top, t_pad_left, pad_bottom, t_pad_right
-      );
-
-      tilefn(
-        n_channels,
-        input_base_col, input_row_stride, input_col_stride,
-        outptr, matrix_stride,
-        pad_top, t_pad_left, pad_bottom, t_pad_right
-      );
-    }
-  }
-
-  /***************************************************************************/
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  InputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::InputTransform(
-    const T* const input,        /** Input tensor data */
-    const int n_batches,         /** Number of batches in input tensor. */
-    const int n_rows,            /** Number of rows in input tensor. */
-    const int n_cols,            /** Number of columns in input tensor. */
-    const int n_channels,        /** Number of channels in input tensor. */
-    const PaddingType padding,   /** Padding type. */
-    T* const output,             /** Base of output matrices. */
-    const int matrix_stride,     /** Stride between output matrices. */
-    const int matrix_row_stride, /** Stride within matrices. */
-    const int in_batch_stride,   /** Stride between input batches. */
-    const int in_row_stride,     /** Stride between input rows. */
-    const int in_col_stride      /** Stride between input columns. */
-  ) : _inptr(input), _outptr(output),
-      _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
-      _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
-      _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - KernelRows + 1,
-                        InnerTileRows - KernelRows + 1)),
-      _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - KernelCols + 1,
-                        InnerTileCols - KernelCols + 1)),
-      _in_col_stride(in_col_stride ? in_col_stride : n_channels),
-      _in_row_stride(in_row_stride ? in_row_stride : n_cols * _in_col_stride),
-      _in_batch_stride(in_batch_stride ? in_batch_stride : n_rows * _in_row_stride),
-      _padding_type(padding)
-  {
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  unsigned int InputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::get_window() const
-  {
-    // The final window includes the tail, all other windows will be a multiple
-    // of the window block in size.
-    return iceildiv(_n_channels, WINDOW_BLOCK);
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void InputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::run(
-    const unsigned int start, const unsigned int stop
-  )
-  {
-    if (start >= get_window())
-    {
-      return;
-    }
-
-    // Determine the window of work to perform
-    const unsigned int start_channel = start * WINDOW_BLOCK;
-    const unsigned int stop_channel = std::min<const unsigned int>(
-      stop * WINDOW_BLOCK, _n_channels
-    );
-    const unsigned int n_channels = stop_channel - start_channel;
-
-    // Perform the work
-    execute(
-      _inptr + start_channel,
-      _n_batches, _in_batch_stride,
-      _n_rows, _in_row_stride,
-      _n_cols, _in_col_stride,
-      n_channels,
-      _padding_type,
-      _tiles_M,
-      _tiles_N,
-      _outptr + start_channel,
-      _matrix_stride,
-      _matrix_row_stride * _tiles_M * _tiles_N,
-      _matrix_row_stride
-    );
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void InputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::execute(
-    const T* const input,        /** Input tensor data */
-    const int n_batches,         /** Number of batches in input tensor. */
-    const int in_batch_stride,   /** Stride between batches of the input. */
-    const int n_rows,            /** Number of rows in input tensor. */
-    const int in_row_stride,     /** Stride between rows of the input. */
-    const int n_cols,            /** Number of columns in input tensor. */
-    const int in_col_stride,     /** Stride between columns of the input. */
-    const int n_channels,        /** Number of channels in input tensor. */
-    const PaddingType padding,   /** Padding type. */
-    const int tile_M,
-    const int tile_N,
-    T* const output,             /** Base of output matrices. */
-    const int matrix_stride,     /** Stride between output matrices. */
-    const int matrix_batch_stride,  /** Stride between batches within the matrix. */
-    const int matrix_row_stride  /** Stride within matrices. */
-  )
-  {
-    Transform::execute(
-      input, n_batches, in_batch_stride, n_rows, in_row_stride, n_cols,
-      in_col_stride, n_channels, padding, tile_M, tile_N, output,
-      matrix_stride, matrix_batch_stride, matrix_row_stride
-    );
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  typename InputTransformImplTiles<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::TileFn
-    InputTransformImplTiles<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::
-      get_tile_specialization(
-        const int pad_top,
-        const int pad_left,
-        const int pad_bottom,
-        const int pad_right
-      )
-  {
-    if (!(pad_top || pad_left || pad_bottom || pad_right))
-    {
-      // No padding, return unpadded specialisation
-      return tilefn_unpadded;
-    }
-    else if (pad_top && !(pad_left || pad_bottom || pad_right))
-    {
-      // Top padding only
-      const int index = (pad_top - min_pad_top) / (InnerTileRows - overlap_rows);
-      return tilefn_top_padded[index];
-    }
-    else if (!(pad_top) && pad_left && !(pad_bottom || pad_right))
-    {
-      // Left padding only
-      const int index = (pad_left - min_pad_left) / (InnerTileCols - overlap_cols);
-      return tilefn_left_padded[index];
-    }
-    else if (!(pad_top || pad_left) && pad_bottom && !(pad_right))
-    {
-      // Bottom padding only
-      return tilefn_bottom_padded[pad_bottom - 1];
-    }
-    else if (!(pad_top || pad_left || pad_bottom) && pad_right)
-    {
-      // Right padding only
-      return tilefn_right_padded[pad_right - 1];
-    }
-    else
-    {
-      // Combination of paddings, return an unspecialised method
-      return tilefn_generic;
-    }
-  }
-
-  template <int KernelCols, int InnerTileCols, typename T>
-  typename InputTransformImplTiles<1, KernelCols, 1, InnerTileCols, T>::TileFn
-    InputTransformImplTiles<1, KernelCols, 1, InnerTileCols, T>::
-      get_tile_specialization(
-        const int pad_top,
-        const int pad_left,
-        const int pad_bottom,
-        const int pad_right
-      )
-  {
-    (void) pad_top;
-    (void) pad_bottom;
-
-    if (!(pad_left || pad_right))
-    {
-      // No padding, return unpadded specialisation
-      return tilefn_unpadded;
-    }
-    else if (pad_left && !pad_right)
-    {
-      // Left padding only
-      const int index = (pad_left - min_pad_left) / (InnerTileCols - overlap_cols);
-      return tilefn_left_padded[index];
-    }
-    else if (!pad_left && pad_right)
-    {
-      // Right padding only
-      return tilefn_right_padded[pad_right - 1];
-    }
-    else
-    {
-      // Combination of paddings, return an unspecialised method
-      return tilefn_generic;
-    }
-  }
-}
-
-
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
deleted file mode 100644
index bad3ef2..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
-using namespace winograd;
-
-
-template <int otr, int otc, int kr, int kc>
-template <typename T>
-WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::WeightsTransform(
-  const T* const input,
-  T* const output,
-  const int matrix_stride,      /** Stride across matrices in the output. */
-  const int matrix_row_stride,  /** Stride across rows of the matrix. */
-  const int n_output_channels,
-  const int n_input_channels
-) : inptr(input), outptr(output),
-    matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride),
-    n_output_channels(n_output_channels), n_input_channels(n_input_channels)
-{
-}
-
-
-template <int otr, int otc, int kr, int kc>
-template <typename T>
-unsigned int WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::get_window() const
-{
-  // TODO When the weights transform supports multithreading, return the number
-  // of output channels. For now we return 1 to indicate that the weights must
-  // be transformed as a single block.
-  // return n_output_channels;
-  return 1;
-}
-
-
-template <int otr, int otc, int kr, int kc>
-template <typename T>
-void WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::run(
-  const unsigned int start, const unsigned int stop
-)
-{
-  // TODO When the weights transform supports multithreading call execute for a
-  // portion of the output channels.
-  (void) start;
-  (void) stop;
-
-  // For now, just do all of the work.
-  execute(
-    n_output_channels,
-    n_input_channels,
-    inptr,
-    outptr,
-    matrix_stride,
-    matrix_row_stride
-  );
-}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
deleted file mode 100644
index 77cd9de..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
+++ /dev/null
@@ -1,278 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-#include "../winograd_gemm.hpp"
-
-namespace winograd
-{
-/***************************************************************************/
-  /* Instance-less API */
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void OutputTransformImpl<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::execute(
-    const int n_batches,
-    const int output_batch_stride,
-    const int n_rows,
-    const int output_row_stride,
-    const int n_cols,
-    const int output_col_stride,
-    const int n_channels,
-    const T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    const T* const biases,
-    T* const output
-  )
-  {
-    // Compute the number of tiles and hence the padding required on the bottom
-    // and right of the image.
-    const int tile_M = iceildiv(n_rows, OutputTileRows);
-    const int tile_N = iceildiv(n_cols, OutputTileCols);
-    const int pad_bottom = OutputTileRows*tile_M - n_rows;
-    const int pad_right = OutputTileCols*tile_N - n_cols;
-
-    const int matrix_tile_row_stride = tile_N * matrix_row_stride;
-    const int matrix_batch_stride = tile_M * matrix_tile_row_stride;
-
-    // Perform the output transformation for each batch
-    for (int batch = 0; batch < n_batches; batch++)
-    {
-      // Get batch offset for input and outputs.
-      const T* const matrix_batch = matrix_base + batch*matrix_batch_stride;
-      T* const outptr_batch = output + batch*output_batch_stride;
-
-      // Perform the output transformation for each row of the output tensor.
-      for (int tile_i = 0; tile_i < tile_M; tile_i++)
-      {
-        // Compute properties of this row of output tiles
-        const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom;
-        const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride;
-        T* const outptr_row = outptr_batch + OutputTileRows*tile_i*output_row_stride;
-
-        // Process the row
-        process_tile_row(
-          tile_N, n_channels, matrix_tile_row, matrix_stride,
-          matrix_row_stride, biases,
-          outptr_row, output_row_stride, output_col_stride, row_pad_bottom,
-          pad_right
-        );
-      }
-    }
-  }
-
-template <int KernelRows, int InnerTileRows, typename T>
-  void OutputTransformImpl<KernelRows, 1, InnerTileRows, 1, T>::execute(
-    const int n_batches,
-    const int output_batch_stride,
-    const int n_rows,
-    const int output_row_stride,
-    const int n_cols,
-    const int output_col_stride,
-    const int n_channels,
-    const T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    const T* const biases,
-    T* const output
-  )
-  {
-    // If an Nx1 kernel then transpose and redirect to the 1xN implementation.
-    OutputTransformImpl<1, KernelRows, 1, InnerTileRows, T>::execute(
-        n_batches,
-        output_batch_stride,
-        n_cols, output_col_stride,
-        n_rows, output_row_stride,
-        n_channels,
-        matrix_base, matrix_stride, matrix_row_stride,
-        biases, output
-      );
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void OutputTransformImpl<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::process_tile_row(
-    const int tile_N,
-    const int n_channels,
-    const T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    const T* const biases,
-    T* const output,
-    const int output_row_stride,
-    const int output_col_stride,
-    const int row_pad_bottom,
-    const int row_pad_right
-  )
-  {
-    // Loop over columns of tiles
-    for (int tile_j = 0; tile_j < tile_N; tile_j++)
-    {
-      // Properties of this tile
-      const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right;
-      const T* const matrix_row = matrix_base + tile_j * matrix_row_stride;
-      T* const outptr = output + OutputTileCols *tile_j*output_col_stride;
-
-      // Perform the output transformation
-      const typename Tiles::TileFn tilefn = Tiles::get_tile_specialization(row_pad_bottom, tile_pad_right);
-      tilefn(
-        n_channels, matrix_row, matrix_stride, biases,
-        outptr, output_row_stride, output_col_stride,
-        row_pad_bottom, tile_pad_right
-      );
-    }
-  }
-
-/***************************************************************************/
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  OutputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::OutputTransform(
-    const T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    const T* const biases,
-    T* const output,
-    const int n_batches,
-    const int n_rows,
-    const int n_cols,
-    const int n_channels,
-    const int out_batch_stride,
-    const int out_row_stride,
-    const int out_col_stride
-  ) : _matrix_base(matrix_base), _biases(biases),
-      _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
-      _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols),
-      _n_channels(n_channels), _tile_M(iceildiv(n_rows, OutputTileRows)),
-      _tile_N(iceildiv(n_cols, OutputTileCols)),
-      _out_col_stride(out_col_stride ? out_col_stride : n_channels),
-      _out_row_stride(out_row_stride ? out_row_stride : n_cols * _out_col_stride),
-      _out_batch_stride(out_batch_stride ? out_batch_stride : n_rows * _out_row_stride)
-  {
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  unsigned int OutputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::get_window() const
-  {
-    // The final window includes the tail, all other windows will be a multiple
-    // of the window block in size.
-    return iceildiv(_n_channels, WINDOW_BLOCK);
-  }
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void OutputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::run(
-    const unsigned int start, const unsigned int stop
-  )
-  {
-    if (start >= get_window())
-    {
-      return;
-    }
-
-    // Determine the window of work to perform
-    const unsigned int start_channel = start * WINDOW_BLOCK;
-    const unsigned int stop_channel = std::min<const unsigned int>(
-      stop * WINDOW_BLOCK, _n_channels
-    );
-    const unsigned int n_channels = stop_channel - start_channel;
-
-    execute(
-      _n_batches,
-      _out_batch_stride,
-      _n_rows,
-      _out_row_stride,
-      _n_cols,
-      _out_col_stride,
-      n_channels,
-      _matrix_base + start_channel,
-      _matrix_stride,
-      _matrix_row_stride,
-      (_biases != nullptr) ? _biases + start_channel : nullptr,
-      _outptr + start_channel
-    );
-  }
-
- template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  void OutputTransform<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::execute(
-    const int n_batches,
-    const int out_batch_stride,
-    const int n_rows,
-    const int out_row_stride,
-    const int n_cols,
-    const int out_col_stride,
-    const int n_channels,
-    const T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    const T* const biases,
-    T* const output
-  )
-  {
-    Transform::execute(
-      n_batches, out_batch_stride,
-      n_rows, out_row_stride,
-      n_cols, out_col_stride, n_channels,
-      matrix_base, matrix_stride, matrix_row_stride,
-      biases, output
-    );
-  }
-
-  template <int KernelCols, int InnerTileCols, typename T>
-  typename OutputTransformImplTiles<1, KernelCols, 1, InnerTileCols, T>::TileFn
-    OutputTransformImplTiles<1, KernelCols, 1, InnerTileCols, T>::
-      get_tile_specialization(const int pad_bottom, const int pad_right)
-  {
-    (void) pad_bottom;
-
-    if (!pad_right)
-    {
-      // No padding, return unpadded specialisation
-      return tilefn_unpadded;
-    }
-    else
-    {
-      return tilefn_right_padded[pad_right - 1];
-    }
-  }
-
-  template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-  typename OutputTransformImplTiles<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::TileFn
-    OutputTransformImplTiles<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>::
-      get_tile_specialization(const int pad_bottom, const int pad_right)
-  {
-    if (!(pad_bottom || pad_right))
-    {
-      // No padding, return unpadded specialisation
-      return tilefn_unpadded;
-    }
-    else if (pad_bottom && !pad_right)
-    {
-      return tilefn_bottom_padded[pad_bottom - 1];
-    }
-    else if (!pad_bottom && pad_right)
-    {
-      return tilefn_right_padded[pad_right - 1];
-    }
-    else
-    {
-      return tilefn_generic;
-    }
-  }
-}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp
new file mode 100644
index 0000000..183c9c1
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd.hpp
@@ -0,0 +1,610 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "convolution.hpp"
+#include "tensor.hpp"
+#include "utils.hpp"
+
+namespace winograd
+{
+
+class ITransform
+{
+  public:
+    virtual ~ITransform() = default;
+
+    /**
+     * Get the working space required to perform the transformation.
+     *
+     * Note, the working space is only required when performing the
+     * transformation - hence it can be reused whenever the transformation is
+     * not running.
+     *
+     * @param nthreads The greatest number of threads that will be used to execute the transform.
+     * @return Size of working space required in bytes.
+     */
+    virtual size_t get_working_space_size(unsigned int nthreads=1) const = 0;
+
+    /**
+     * Set the working space to be used by the transformation.
+     *
+     * Note, the working space is only required when performing the
+     * transformation - hence it can be reused whenever the transformation is
+     * not running.
+     *
+     * @param Pointer to the working space.
+     */
+    virtual void set_working_space(void *buffer) = 0;
+
+    /**
+     * Get the window of work a given operator can perform.
+     */
+    virtual unsigned int get_window() const = 0;
+
+    /**
+     * Perform work upon a window of the transform.
+     */
+    virtual void run(unsigned int start, unsigned int stop, unsigned int threadid=0) = 0;
+};
+
+class IInputTransform : public ITransform
+{
+  public:
+    virtual ~IInputTransform() = default;
+
+    /**
+     * Set the pointer to the (NHWC-ordered) tensor to be transformed.
+     */
+    virtual void set_input_tensor(const void *input) = 0;
+
+    /**
+     * Set the pointer to the (NHWC-ordered) tensor to be transformed.
+     * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
+     */
+    virtual void set_input_tensor(const void *input, int col_stride) = 0;
+
+    /**
+     * Set the pointer to the (NHWC-ordered) tensor to be transformed.
+     * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
+     * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
+     */
+    virtual void set_input_tensor(const void *input, int row_stride, int col_stride) = 0;
+
+    /**
+     * Set the pointer to the (NHWC-ordered) tensor to be transformed.
+     * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes).
+     * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
+     * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
+     */
+    virtual void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) = 0;
+
+    /**
+     * Set pointers to the matrices written by the transform.
+     * @param matrices Pointer to the start of the first matrix representing the transformed input.
+     * @param inter_matrix_stride Stride (in elements) between matrices.
+     * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
+     */
+    virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
+};
+
+class IOutputTransform : public ITransform
+{
+  public:
+    virtual ~IOutputTransform() = default;
+
+    /**
+     * Set pointers to the matrices written by the transform.
+     * @param matrices Pointer to the start of the first matrix representing the input to the transform.
+     * @param inter_matrix_stride Stride (in elements) between matrices.
+     * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
+     */
+    virtual void set_input_matrices(const void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
+
+    /**
+     * Set pointer to the bias tensor (can be ignored or called with nullptr for no bias.
+     */
+    virtual void set_bias(const void *bias=nullptr) = 0;
+
+    /**
+     * Set pointer to the output tensor produced by the transform.
+     */
+    virtual void set_output_tensor(void *output) = 0;
+
+    /**
+     * Set pointer to the output tensor produced by the transform.
+     * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
+     */
+    virtual void set_output_tensor(void *output, int col_stride) = 0;
+
+    /**
+     * Set pointer to the output tensor produced by the transform.
+     * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
+     * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
+     */
+    virtual void set_output_tensor(void *output, int row_stride, int col_stride) = 0;
+
+    /**
+     * Set pointer to the output tensor produced by the transform.
+     * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes).
+     * @param row_stride Stride between rows of the tensor, measured in elements (not bytes).
+     * @param col_stride Stride between columns of the tensor, measured in elements (not bytes).
+     */
+    virtual void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) = 0;
+};
+
+class IWeightTransform : public ITransform
+{
+  public:
+    virtual ~IWeightTransform() = default;
+
+    /** Set pointer to the weight tensor read by the transform. */
+    virtual void set_weight_tensor(const void *weights) = 0;
+
+    /**
+     * Set pointers to the matrices written by the transform.
+     * @param matrices Pointer to the start of the first matrix representing the transformed input.
+     * @param inter_matrix_stride Stride (in elements) between matrices.
+     * @param matrix_row_stride Stride (in elements) between the rows within a single matrix.
+     */
+    virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0;
+};
+
+enum class WinogradRoots
+{
+  Integers,
+};
+
+template <int InnerTileRows, int InnerTileCols, typename TIn, typename TOut, WinogradRoots Roots>
+class InputTransform : public IInputTransform
+{
+  public:
+    /** Create an InputTransform operator fixed on a given problem and set of
+     * pointers.
+     */
+    InputTransform(
+        int kernel_rows,     /**< Number of rows in the kernel */
+        int kernel_cols,     /**< Number of columns in the kernel */
+        int n_batches,       /**< Number of batches in input tensor. */
+        int n_rows,          /**< Number of rows in input tensor. */
+        int n_cols,          /**< Number of columns in input tensor. */
+        int n_channels,      /**< Number of channels in input tensor. */
+        int padding_top,     /**< Padding to apply to the top of the image. */
+        int padding_left,    /**< Padding to apply to the left of the image. */
+        int padding_bottom,  /**< Padding to apply to the bottom of the image. */
+        int padding_right    /**< Padding to apply to the right of the image. */
+    );
+
+    InputTransform(InputTransform&) = delete;
+    InputTransform operator=(InputTransform&) = delete;
+
+    /** Set pointers to the input tensor read by the transform. */
+    void set_input_tensor(const void *input) override;
+    void set_input_tensor(const void *input, int col_stride) override;
+    void set_input_tensor(const void *input, int row_stride, int col_stride) override;
+    void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override;
+
+    /** Set pointers to the matrices written by the transform. */
+    void set_output_matrices(void *matrices, int iter_matrix_stride, int matrix_row_stride) override;
+
+    /** Get the working space required to perform the transformation. */
+    size_t get_working_space_size(unsigned int nthreads=1) const override;
+    void set_working_space(void *buffer) override;
+
+    /** Get the window of work a given operator can perform. */
+    unsigned int get_window() const override;
+    static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
+
+    /** Perform work upon a window of the input. */
+    void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
+
+  protected:
+    const int _n_batches, _n_rows, _n_cols, _n_channels;
+
+  private:
+    void transform_unpadded_tile(
+      unsigned int threadid,
+      int n_channels,
+      TOut *outptr,
+      const TIn *inptr
+    );
+
+    void transform_padded_tile(
+      unsigned int threadid,
+      int n_channels,
+      TOut *outptr,
+      const TIn *inptr,
+      int padding_top,
+      int padding_left,
+      int padding_bottom,
+      int padding_right
+    );
+    
+    /* Tile implementation */
+    static void transform_tile(
+      int n_channels,         /** @param[in] Number of channels in the tensor. */
+      const TIn* inptr_base,  /** @param[in] Pointer to the base of the input tile. */
+      int input_row_stride,   /** @param[in] Stride between rows of the input tensor. */
+      int input_col_stride,   /** @param[in] Stride between columns of the input tensor. */
+      TOut* mptr_base,        /** @param[out] Base pointer to transformed input matrices. */
+      int matrix_stride       /** @param[in] Stride between matrices in the input space. */
+    );
+
+    /** Get the working space for a thread. */
+    void * get_working_space(unsigned int threadid) const;
+
+    const TIn* _inptr;
+    TOut* _outptr;
+
+    const int _overlap_rows, _overlap_cols;
+    const int _padding_top, _padding_left, _padding_bottom, _padding_right;
+    const int _tiles_M, _tiles_N;
+    int _matrix_stride, _matrix_row_stride, _matrix_batch_stride;
+    int _in_col_stride, _in_row_stride, _in_batch_stride;
+
+    const int _working_space_col_stride, _working_space_row_stride;
+    TIn *_working_space;
+};
+
+template <int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots>
+class InputTransform<InnerTileRows, 1, TIn, TOut, Roots> :
+  public InputTransform<1, InnerTileRows, TIn, TOut, Roots>
+{
+  using Base = InputTransform<1, InnerTileRows, TIn, TOut, Roots>;
+
+  public:
+    InputTransform(
+      int kernel_rows,     /**< Number of rows in the kernel. */
+      int kernel_cols,     /**< Number of columns in the kernel. */
+      int n_batches,       /**< Number of batches in input tensor. */
+      int n_rows,          /**< Number of rows in input tensor. */
+      int n_cols,          /**< Number of columns in input tensor. */
+      int n_channels,      /**< Number of channels in input tensor. */
+      int padding_top,     /**< Padding to apply to the top of the image. */
+      int padding_left,    /**< Padding to apply to the left of the image. */
+      int padding_bottom,  /**< Padding to apply to the bottom of the image. */
+      int padding_right    /**< Padding to apply to the right of the image. */
+    );
+
+    /** Set pointers to the input tensor read by the transform. */
+    void set_input_tensor(const void *input) override;
+    void set_input_tensor(const void *input, int col_stride) override;
+    void set_input_tensor(const void *input, int row_stride, int col_stride) override;
+    void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override;
+};
+
+template <
+  int KernelRows, int KernelCols,
+  int InnerTileRows, int InnerTileCols,
+  typename TIn, typename TOut,
+  WinogradRoots Roots
+>
+class OutputTransform : public IOutputTransform
+{
+  public:
+    OutputTransform(
+      int n_batches,  /**< Number of batches in output tensor. */
+      int n_rows,     /**< Number of rows in output tensor. */
+      int n_cols,     /**< Number of columns in output tensor. */
+      int n_channels  /**< Number of channels in output tensor. */
+    );
+
+    OutputTransform(OutputTransform&) = delete;
+    OutputTransform operator=(OutputTransform&) = delete;
+
+    /** Set pointers to the matrices read by the transform. */
+    void set_input_matrices(const void *matrices, int iter_matrix_stride, int matrix_row_stride) override;
+
+    /** Set pointer to the bias tensor (can be ignored or called with nullptr for no bias */
+    void set_bias(const void *bias=nullptr) override;
+
+    /** Set pointers to the output tensor written by the transform. */
+    void set_output_tensor(void *output) override;
+    void set_output_tensor(void *output, int col_stride) override;
+    void set_output_tensor(void *output, int row_stride, int col_stride) override;
+    void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override;
+
+    /** Get the working space required to perform the transformation. */
+    size_t get_working_space_size(unsigned int nthreads=1) const override;
+    void set_working_space(void *buffer) override;
+
+    /** Get the window of work a given operator can perform. */
+    unsigned int get_window() const override;
+    static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
+
+    /** Perform work upon a window of the input. */
+    void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
+
+  protected:
+    static constexpr int inner_tile_rows = InnerTileRows;
+    static constexpr int inner_tile_cols = InnerTileCols;
+    static constexpr int output_tile_rows = InnerTileRows - KernelRows + 1;
+    static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1;
+
+    const int _n_batches, _n_rows, _n_cols, _n_channels;
+
+  private:
+    void transform_uncropped_tile(
+      unsigned int threadid,
+      int n_channels,
+      TOut *outptr,
+      const TIn *inptr,
+      const TOut *biases
+    );
+
+    void transform_cropped_tile(
+      unsigned int threadid,
+      int n_channels,
+      TOut *outptr,
+      const TIn *inptr,
+      const TOut *biases,
+      int pad_bottom,
+      int pad_right
+    );
+
+    /** Implementation of the tile transformation method. */
+    static void transform_tile(
+      int n_channels,
+      const TIn* matrix_base,
+      int matrix_stride,
+      const TOut* biases,
+      TOut* output,
+      int output_row_stride,
+      int output_col_stride
+    );
+
+    /** Get the working space for a thread. */
+    void * get_working_space(unsigned int threadid) const;
+
+    const TIn* _matrix_base;
+    const TOut* _biases;
+    int _matrix_stride, _matrix_row_stride, _matrix_batch_stride;
+    TOut* _outptr;
+    const int _tiles_M, _tiles_N;
+    int _out_col_stride, _out_row_stride, _out_batch_stride;
+
+    const int _working_space_col_stride, _working_space_row_stride;
+    TOut *_working_space;
+};
+
+template <
+  int KernelRows,
+  int InnerTileRows,
+  typename TIn, typename TOut,
+  WinogradRoots Roots
+>
+class OutputTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
+  public OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>
+{
+  using Base = OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>;
+
+  public:
+    OutputTransform(
+      int n_batches,  /**< Number of batches in output tensor. */
+      int n_rows,     /**< Number of rows in output tensor. */
+      int n_cols,     /**< Number of columns in output tensor. */
+      int n_channels  /**< Number of channels in output tensor. */
+    );
+
+    /** Set pointers to the output tensor written by the transform. */
+    void set_output_tensor(void *output) override;
+    void set_output_tensor(void *output, int col_stride) override;
+    void set_output_tensor(void *output, int row_stride, int col_stride) override;
+    void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override;
+};
+
+template <
+  int KernelRows, int KernelCols,
+  int InnerTileRows, int InnerTileCols,
+  typename TIn, typename TOut,
+  WinogradRoots Roots
+>
+class WeightTransform : public IWeightTransform
+{
+  public:
+    WeightTransform(
+      int n_output_channels,  /**< Number of output channels in the kernel. */
+      int n_input_channels    /**< Number of input channels in the kernel. */
+    );
+
+    WeightTransform(WeightTransform&) = delete;
+    WeightTransform operator=(WeightTransform&) = delete;
+
+    /** Set pointer to the weight tensor read by the transform. */
+    void set_weight_tensor(const void *weights) override;
+
+    /** Set pointer to the matrices written by the transform. */
+    void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) override;
+
+    /** Get the working space required to perform the transformation. */
+    size_t get_working_space_size(unsigned int nthreads=1) const override;
+    void set_working_space(void *buffer) override;
+
+    /** Get the window of work a given operator can perform. */
+    unsigned int get_window() const override;
+    static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
+
+    /** Perform work upon a window of the input. */
+    void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override;
+
+  protected:
+    static const int kernel_rows = KernelRows;
+    static const int kernel_cols = KernelCols;
+    static const int inner_tile_rows = InnerTileRows;
+    static const int inner_tile_cols = InnerTileCols;
+
+  private:
+    /** Apply the transform to a tensor. */
+    static void execute(
+      int n_output_channels,
+      int n_input_channels,
+      const TIn* input,
+      TOut* output,
+      int matrix_stride,
+      int matrix_row_stride
+    );
+
+    const int _n_output_channels, _n_input_channels;
+    TOut *_matrices;
+    int _matrix_stride, _matrix_row_stride;
+    const TIn *_weights;
+};
+
+template <int KernelRows, int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots>
+class WeightTransform<KernelRows, 1, InnerTileRows, 1, TIn, TOut, Roots> :
+  public WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>
+{
+  public:
+    using WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::WeightTransform;
+};
+
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols, WinogradRoots Roots>
+class WinogradGEMM
+{
+  public:
+    // Information about the specific Winograd instance
+    static constexpr int output_tile_rows = OutputTileRows;
+    static constexpr int output_tile_cols = OutputTileCols;
+    static constexpr int kernel_rows = KernelRows;
+    static constexpr int kernel_cols = KernelCols;
+    static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;
+    static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;
+    static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
+
+    /** Transform weights from the spatial to the Winograd domain. */
+    template <typename TIn, typename TOut>
+    using WeightsTransform = WeightTransform<
+      KernelRows, KernelCols, inner_tile_rows, inner_tile_cols,
+      TIn, TOut, Roots
+    >;
+
+    /** Transform input feature maps from the spatial to the Winograd domain.
+     */
+    template <typename TIn, typename TOut>
+    using InputTransform = InputTransform<
+      inner_tile_rows, inner_tile_cols, TIn, TOut, Roots
+    >;
+
+    /** Transform output feature maps from the Winograd to the spatial domain.
+     */
+    template <typename TIn, typename TOut>
+    using OutputTransform = OutputTransform<
+      KernelRows, KernelCols, inner_tile_rows, inner_tile_cols,
+      TIn, TOut, Roots
+    >;
+
+    /** Perform a convolution.
+     */
+    template <typename TOut, typename TIn, typename TInGEMM=TIn, typename TOutGEMM=TOut>
+    class Convolution
+    {
+      public:
+        // Information about the typed Winograd instance
+        typedef TOut OutputType;
+        typedef TOutGEMM GemmOutputType;
+        typedef TInGEMM GemmInputType;
+        typedef TIn InputType;
+
+        /** Get the output shape of a convolution. */
+        static Tensor4DShape get_output_shape(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &in_shape,
+          const PaddingType padding
+        );
+
+        /* Get the memory required to transform the kernel.
+         */
+        static size_t get_kernel_transform_working_size(const KernelShape &shape);
+
+        /** Get the memory required to store the kernel transformed into the
+         * Winograd domain.
+         */
+        static size_t get_kernel_storage_size(const KernelShape &shape);
+
+        /** Get the memory required to store the input tensor transformed into
+         * the Winograd domain.
+         */
+        static size_t get_input_storage_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /** Get the memory required to store the output tensor in the Winograd
+         * domain.
+         */
+        static size_t get_output_storage_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /** Get the memory required to apply a Winograd operator to some input.
+         */
+        static size_t get_working_space_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "input" matrix.
+         */
+        static size_t get_input_matrix_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        static int get_input_matrix_stride(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "output" matrix.
+         */
+        static size_t get_output_matrix_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        static int get_output_matrix_stride(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "kernel" matrix.
+         */
+        static size_t get_kernel_matrix_size(const KernelShape &shape);
+        static int get_kernel_matrix_stride(const KernelShape &shape);
+
+        static constexpr int M_BLOCK = 4;   /** Size of block used by GEMM. */
+        static constexpr int N_BLOCK = 16;  /** Size of block used by GEMM. */
+    };
+};
+
+}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
deleted file mode 100644
index 71b5fd5..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
+++ /dev/null
@@ -1,226 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-
-#include "arm_compute/core/NEON/kernels/convolution/common/alloc.hpp"
-#include "arm_compute/core/NEON/kernels/convolution/common/convolution.hpp"
-#include "gemm.hpp"
-#include "arm_compute/core/NEON/kernels/convolution/common/shims.hpp"
-#include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp"
-#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
-#include "winograd_input_transform.hpp"
-#include "winograd_output_transform.hpp"
-
-#include <thread>
-#include <utility>
-#include <vector>
-
-// Generic Winograd implementation using GEMM
-namespace winograd
-{
-
-template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
-class WinogradGEMM
-{
-  public:
-    // Information about the specific Winograd instance
-    static constexpr int output_tile_rows = OutputTileRows;
-    static constexpr int output_tile_cols = OutputTileCols;
-    static constexpr int kernel_rows = KernelRows;
-    static constexpr int kernel_cols = KernelCols;
-    static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;
-    static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;
-    static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
-
-    /** Transform weights from the spatial to the Winograd domain. */
-    template <typename T>
-    struct WeightsTransform
-    {
-      /** Get the bytes read during the transform. */
-      static inline size_t bytes_read(const KernelShape &shape)
-      {
-        return shape.size() * sizeof(T);
-      }
-
-      /** Get the bytes written during the transform. */
-      static inline size_t bytes_written(const KernelShape &shape)
-      {
-        const int inner_tile_size = inner_tile_rows * inner_tile_cols;
-        return (inner_tile_size * shape.n_input_channels *
-                shape.n_output_channels * sizeof(T));
-      }
-
-      /** Get the count of operations performed by the transform. */
-      static int ops_performed(const KernelShape &shape);
-
-      /** Apply the transform to a tensor. */
-      static void execute(
-        const int n_output_channels,
-        const int n_input_channels,
-        const T* const input,
-        T* const output,
-        const int matrix_stride,
-        const int matrix_row_stride
-      );
-
-      /** Create a WeightsTransform operator fixed on a given problem and set
-       * of pointers.
-       */
-      WeightsTransform(
-        const T* const input,
-        T* const output,
-        const int matrix_stride,       /** Stride across matrices in the output. */
-        const int matrix_row_stride,   /** Stride across rows of the matrix. */
-        const int n_output_channels,   /** Number of filters. */
-        const int n_input_channels     /** Number of channels in each filter. */
-      );
-
-      /** Get the window of work a given operator can perform. */
-      unsigned int get_window() const;
-
-      /** Perform work upon a window of the input. */
-      void run(const unsigned int start, const unsigned int stop);
-
-      private:
-        const T* const inptr;         /** Fixed pointer to input data. */
-        T* const outptr;              /** Fixed pointer to output memory. */
-        const int matrix_stride;      /** Stride between output matrices. */
-        const int matrix_row_stride;  /** Stride within output matrices. */
-        const int n_output_channels;  /** Number of filters. */
-        const int n_input_channels;   /** Number of channels in each filter. */
-    };
-
-    /** Transform input feature maps from the spatial to the Winograd domain.
-     */
-    template <typename T>
-    using InputTransform = InputTransform<
-      KernelRows, KernelCols,
-      (OutputTileRows + KernelRows - 1),
-      (OutputTileCols + KernelCols - 1),
-      T
-    >;
-
-    /** Transform output feature maps from the Winograd to the spatial domain.
-     */
-    template <typename T>
-     using OutputTransform = OutputTransform<
-      KernelRows, KernelCols,
-      (OutputTileRows + KernelRows - 1),
-      (OutputTileCols + KernelCols - 1),
-      T
-    >;
-
-
-    /** Perform a convolution.
-     */
-    template <typename TOut, typename TIn>
-    class Convolution
-    {
-      public:
-        // Information about the typed Winograd instance
-        typedef TOut OutputType;
-        typedef TIn InputType;
-
-        /** Get the output shape of a convolution. */
-        static Tensor4DShape get_output_shape(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &in_shape,
-          const PaddingType padding
-        );
-
-        /* Get the memory required to transform the kernel.
-         */
-        static size_t get_kernel_transform_working_size(const KernelShape &shape);
-
-        /** Get the memory required to store the kernel transformed into the
-         * Winograd domain.
-         */
-        static size_t get_kernel_storage_size(const KernelShape &shape);
-
-        /** Get the memory required to store the input tensor transformed into
-         * the Winograd domain.
-         */
-        static size_t get_input_storage_size(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        /** Get the memory required to store the output tensor in the Winograd
-         * domain.
-         */
-        static size_t get_output_storage_size(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        /** Get the memory required to apply a Winograd operator to some input.
-         */
-        static size_t get_working_space_size(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        /* Get the memory required by a single "input" matrix.
-         */
-        static size_t get_input_matrix_size(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        static int get_input_matrix_stride(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        /* Get the memory required by a single "output" matrix.
-         */
-        static size_t get_output_matrix_size(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        static int get_output_matrix_stride(
-          const KernelShape &kernel_shape,
-          const Tensor4DShape &input_shape,
-          const PaddingType padding_type
-        );
-
-        /* Get the memory required by a single "kernel" matrix.
-         */
-        static size_t get_kernel_matrix_size(const KernelShape &shape);
-        static int get_kernel_matrix_stride(const KernelShape &shape);
-
-        static constexpr int M_BLOCK = 4;   /** Size of block used by GEMM. */
-        static constexpr int N_BLOCK = 16;  /** Size of block used by GEMM. */
-    };
-};
-
-}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_input_transform.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_input_transform.hpp
deleted file mode 100644
index 995554d..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_input_transform.hpp
+++ /dev/null
@@ -1,271 +0,0 @@
-/*
- * Copyright (c) 2018 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-
-namespace winograd
-{
-
-namespace
-{
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-class InputTransformImplTiles
-{
-  public:
-    /** Method to transform a tile of the input tensor into the Winograd domain. */
-    typedef void (*TileFn)(
-      const int n_channels,        /** @param[in] Number of channels in the tensor. */
-      const T* const inptr_base,   /** @param[in] Pointer to the base of the input tile. */
-      const int input_row_stride,  /** @param[in] Stride between rows of the input tensor. */
-      const int input_col_stride,  /** @param[in] Stride between columns of the input tensor. */
-      T* const mptr_base,          /** @param[out] Base pointer to transformed input matrices. */
-      const int matrix_stride,     /** @param[in] Stride between matrices in the input space. */
-      const int _pad_top,          /** @param[in] Top padding for unspecialised tiles. */
-      const int _pad_left,         /** @param[in] Left padding for unspecialised tiles. */
-      const int _pad_bottom,       /** @param[in] Bottom padding for unspecialised tiles. */
-      const int _pad_right         /** @param[in] Right padding for unspecialised tiles. */
-    );
-
-    static TileFn get_tile_specialization(
-      const int pad_top,
-      const int pad_left,
-      const int pad_bottom,
-      const int pad_right
-    );
-
-    // Tile overlaps
-    static constexpr int overlap_rows = KernelRows - 1;
-    static constexpr int overlap_cols = KernelCols - 1;
-
-  private:
-
-    // Maximum padding and number of distinct paddings
-    static constexpr int max_pad_top = KernelRows / 2;
-    static constexpr int min_pad_top = KernelRows % (InnerTileRows - overlap_rows);
-    static constexpr int n_pad_top = iceildiv(max_pad_top, InnerTileRows - overlap_rows);
-
-    static constexpr int max_pad_left = KernelCols / 2;
-    static constexpr int min_pad_left = KernelCols % (InnerTileCols - overlap_cols);
-    static constexpr int n_pad_left = iceildiv(max_pad_left, InnerTileCols - overlap_cols);
-
-    static constexpr int n_pad_bottom = InnerTileRows;
-    static constexpr int n_pad_right = InnerTileCols;
-
-    // Pointers to methods implementing a generically padded tile and a totally unpadded tile.
-    static const TileFn tilefn_generic;   /** Generic tile processing function. */
-    static const TileFn tilefn_unpadded;  /** Tile processor for unpadded tiles. */
-
-    // Arrays of methods covering tiles which are padded only on a single side.
-    static const TileFn tilefn_top_padded[n_pad_top];
-    static const TileFn tilefn_left_padded[n_pad_left];
-    static const TileFn tilefn_bottom_padded[n_pad_bottom];
-    static const TileFn tilefn_right_padded[n_pad_right];
-};
-
-
-template < int KernelCols, int InnerTileCols, typename T>
-class InputTransformImplTiles<1, KernelCols, 1, InnerTileCols, T>
-{
-  public:
-    /** Method to transform a tile of the input tensor into the Winograd domain. */
-    typedef void (*TileFn)(
-      const int n_channels,        /** @param[in] Number of channels in the tensor. */
-      const T* const inptr_base,   /** @param[in] Pointer to the base of the input tile. */
-      const int input_row_stride,  /** @param[in] Stride between rows of the input tensor. */
-      const int input_col_stride,  /** @param[in] Stride between columns of the input tensor. */
-      T* const mptr_base,          /** @param[out] Base pointer to transformed input matrices. */
-      const int matrix_stride,     /** @param[in] Stride between matrices in the input space. */
-      const int _pad_top,          /** @param[in] Top padding for unspecialised tiles. */
-      const int _pad_left,         /** @param[in] Left padding for unspecialised tiles. */
-      const int _pad_bottom,       /** @param[in] Bottom padding for unspecialised tiles. */
-      const int _pad_right         /** @param[in] Right padding for unspecialised tiles. */
-    );
-
-    static TileFn get_tile_specialization(
-      const int pad_top,
-      const int pad_left,
-      const int pad_bottom,
-      const int pad_right
-    );
-
-    // Tile overlaps
-    static constexpr int overlap_rows = 0;
-    static constexpr int overlap_cols = KernelCols - 1;
-
-  private:
-    // Maximum padding and number of distinct paddings
-    static constexpr int max_pad_left = KernelCols / 2;
-    static constexpr int min_pad_left = KernelCols % (InnerTileCols - overlap_cols);
-    static constexpr int n_pad_left = iceildiv(max_pad_left, InnerTileCols - overlap_cols);
-
-    static constexpr int n_pad_right = InnerTileCols;
-
-    // Pointers to methods implementing a generically padded tile and a totally unpadded tile.
-    static const TileFn tilefn_generic;   /** Generic tile processing function. */
-    static const TileFn tilefn_unpadded;  /** Tile processor for unpadded tiles. */
-
-    // Arrays of methods covering tiles which are padded only on a single side.
-    static const TileFn tilefn_left_padded[n_pad_left];
-    static const TileFn tilefn_right_padded[n_pad_right];
-};
-
-
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-class InputTransformImpl
-{
-  public:
-    /** Apply the transform to a tensor. */
-    static void execute(
-        const T* const input,        /** Input tensor data */
-        const int n_batches,         /** Number of batches in input tensor. */
-        const int in_batch_stride,   /** Stride between batches of the input. */
-        const int n_rows,            /** Number of rows in input tensor. */
-        const int in_row_stride,     /** Stride between rows of the input. */
-        const int n_cols,            /** Number of columns in input tensor. */
-        const int in_col_stride,     /** Stride between columns of the input. */
-        const int n_channels,        /** Number of channels in input tensor. */
-        const PaddingType padding,   /** Padding type. */
-        const int tile_M,
-        const int tile_N,
-        T* const output,             /** Base of output matrices. */
-        const int matrix_stride,     /** Stride between output matrices. */
-        const int matrix_batch_stride,  /** Stride between batches within the matrix. */
-        const int matrix_row_stride  /** Stride within matrices. */
-    );
-
-  private:
-    static void process_tile_row(
-      const int tile_N,
-      int n_channels,
-      const T* const input_base,
-      const int input_row_stride,
-      const int input_col_stride,
-      T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      const int row_pad_top,
-      const int row_pad_left,
-      const int row_pad_bottom,
-      const int n_cols
-    );
-
-    using Tiles = InputTransformImplTiles<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>;
-
-    static constexpr int overlap_rows = Tiles::overlap_rows;
-    static constexpr int overlap_cols = Tiles::overlap_cols;
-
-
-    };
-
-
-template <int KernelRows, int InnerTileRows, typename T>
-class InputTransformImpl<KernelRows, 1, InnerTileRows, 1, T>
-{
-  public:
-    /** Apply the transform to a tensor. */
-    static void execute(
-        const T* const input,        /** Input tensor data */
-        const int n_batches,         /** Number of batches in input tensor. */
-        const int in_batch_stride,   /** Stride between batches of the input. */
-        const int n_rows,            /** Number of rows in input tensor. */
-        const int in_row_stride,     /** Stride between rows of the input. */
-        const int n_cols,            /** Number of columns in input tensor. */
-        const int in_col_stride,     /** Stride between columns of the input. */
-        const int n_channels,        /** Number of channels in input tensor. */
-        const PaddingType padding,   /** Padding type. */
-        const int tile_M,
-        const int tile_N,
-        T* const output,             /** Base of output matrices. */
-        const int matrix_stride,     /** Stride between output matrices. */
-        const int matrix_batch_stride,  /** Stride between batches within the matrix. */
-        const int matrix_row_stride  /** Stride within matrices. */
-    );
-};
-
-}  // namespace (anonymous)
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-class InputTransform
-{
-  public:
-  /***********************************************************************/
-  /** Create an InputTransform operator fixed on a given problem and set of
-   * pointers.
-   */
-  InputTransform(
-      const T* const input,        /** Input tensor data */
-      const int n_batches,         /** Number of batches in input tensor. */
-      const int n_rows,            /** Number of rows in input tensor. */
-      const int n_cols,            /** Number of columns in input tensor. */
-      const int n_channels,        /** Number of channels in input tensor. */
-      const PaddingType padding,   /** Padding type. */
-      T* const output,             /** Base of output matrices. */
-      const int matrix_stride,     /** Stride between output matrices. */
-      const int matrix_row_stride, /** Stride within matrices. */
-      const int in_batch_stride=0, /** Stride between input batches. */
-      const int in_row_stride=0,   /** Stride between input rows. */
-      const int in_col_stride=0    /** Stride between input columns. */
-  );
-
-  /** Get the window of work a given operator can perform. */
-  unsigned int get_window() const;
-  static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
-
-  /** Perform work upon a window of the input. */
-  void run(const unsigned int start, const unsigned int stop);
-
-  /** Apply the transform to a tensor. */
-  static void execute(
-      const T* const input,        /** Input tensor data */
-      const int n_batches,         /** Number of batches in input tensor. */
-      const int in_batch_stride,   /** Stride between batches of the input. */
-      const int n_rows,            /** Number of rows in input tensor. */
-      const int in_row_stride,     /** Stride between rows of the input. */
-      const int n_cols,            /** Number of columns in input tensor. */
-      const int in_col_stride,     /** Stride between columns of the input. */
-      const int n_channels,        /** Number of channels in input tensor. */
-      const PaddingType padding,   /** Padding type. */
-      const int tile_M,
-      const int tile_N,
-      T* const output,             /** Base of output matrices. */
-      const int matrix_stride,     /** Stride between output matrices. */
-      const int matrix_batch_stride,  /** Stride between batches within the matrix. */
-      const int matrix_row_stride  /** Stride within matrices. */
-  );
-
-  protected:
-    using Transform = InputTransformImpl<KernelRows, KernelCols, InnerTileRows, InnerTileCols, T>;
-
-    /* Member values for instance-based API. */
-    const T* const _inptr;
-    T* const _outptr;
-    const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride,
-              _matrix_row_stride, _tiles_M, _tiles_N;
-    const int _in_col_stride, _in_row_stride, _in_batch_stride;
-    const PaddingType _padding_type;
-};
-
-}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp
new file mode 100644
index 0000000..9d418be
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_layer.hpp
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include <utility>
+
+#include "arm_gemm_local.hpp"
+#include "arm_gemm.hpp"
+#include "winograd.hpp"
+
+namespace winograd
+{
+
+
+class IWinogradConvolutionLayer
+{
+  public:
+    virtual ~IWinogradConvolutionLayer() = default;
+
+    virtual unsigned int weight_transform_get_window(void) const = 0;
+    virtual void weight_transform_run(unsigned int start, unsigned int stop) = 0;
+
+    virtual ITransform& input_transform(void) = 0; // Expose the input transform
+    virtual ITransform& output_transform(void) = 0;  // Expose the output transform
+    virtual arm_gemm::IGemmCommon *gemm(void) = 0;  // Expose the underlying GEMM
+};
+
+/** Example of how to construct an ACL-like interface.
+ *
+ * Use `get_weight_storage_size`, `get_input_storage_size` and
+ * `get_output_storage_size` to allocate memory for the convolution engine.
+ * Then create a `WinogradConvolutionLayer`.
+ *
+ * Initialise the weights using `weights_transform.run(...)`.
+ *
+ * For each inference:
+ *   1. Transform the inputs to the Winograd domain using `input_transform.run(...)`
+ *   2. Perform a number of GEMMs using `gemms.run(...)`
+ *   3. Transform the output to the spatial domain using `output_transform.run(...)`
+ */
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+          typename TIn, typename TInGEMM, typename TOutGEMM, typename TOut,
+          WinogradRoots Roots>
+class WinogradConvolutionLayer : public IWinogradConvolutionLayer
+{
+  private:
+    static constexpr int InnerTileRows = OutputTileRows + KernelRows - 1;
+    static constexpr int InnerTileCols = OutputTileCols + KernelCols - 1;
+    static constexpr int N_GEMMS = InnerTileRows * InnerTileCols;
+
+    const KernelShape _kernel_shape;
+    const Tensor4DShape _input_shape;
+    const PaddingType _padding;
+    const Tensor4DShape _output_shape;
+    const int _n_output_rows, _n_output_cols;
+    const int _kernel_matrix_stride, _kernel_matrix_row_stride;
+    const int _input_matrix_stride, _input_matrix_row_stride;
+    const int _output_matrix_stride, _output_matrix_row_stride;
+    const int _tile_rows, _tile_cols;
+    const int _m, _k, _n;
+
+  public:
+    using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, Roots>;
+    using WeightsTransform = typename WinogradBase::template WeightsTransform<TIn, TInGEMM>;
+    using InputTransform = typename WinogradBase::template InputTransform<TIn, TInGEMM>;
+    using WinogradConv = typename WinogradBase::template Convolution<TOut, TIn, TInGEMM, TOutGEMM>;
+    using OutputTransform = typename WinogradBase::template OutputTransform<TOutGEMM, TOut>;
+
+    /* Public member variables. */
+    WeightsTransform weights_transform;  /** Operator to transform weights to Winograd domain. */
+    InputTransform _input_transform;      /** Operator to transform input to Winograd domain. */
+    arm_gemm::UniqueGemmCommon<TInGEMM, TOutGEMM> gemms;    /** Operator to perform multiple GEMMs. */
+    OutputTransform _output_transform;    /** Operator to transform output from Winograd domain. */
+
+    /** Determine how much memory (in units of TIn) to allocate for the
+     * transformed weights.
+     */
+    static unsigned int get_weight_storage_size(
+      const int n_output_channels,  /** Number of output feature maps. */
+      const int n_input_channels    /** Number of input feature maps. */
+    );
+
+    static unsigned int get_weight_stride(
+      const int n_output_channels,  /** Number of output feature maps. */
+      const int n_input_channels    /** Number of input feature maps. */
+    );
+
+    static unsigned int get_weight_multi_stride(
+      const int n_output_channels,  /** Number of output feature maps. */
+      const int n_input_channels    /** Number of input feature maps. */
+    );
+
+    /** Determine how much memory (in units of TIn) to allocate for the
+     * transformed input.
+     */
+    static unsigned int get_input_storage_size(
+      const int n_batches,     /** Number of batches in the input tensor. */
+      const int n_channels,    /** Number of feature maps in the input tensor. */
+      const int n_rows,        /** Number of rows in each feature map. */
+      const int n_cols,        /** Number of columns in each feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Get the row stride for the A matrix in the Winograd domain. */
+    static unsigned int get_input_stride(
+      const int n_batches,     /** Number of batches in the input tensor. */
+      const int n_channels,    /** Number of feature maps in the input tensor. */
+      const int n_rows,        /** Number of rows in each feature map. */
+      const int n_cols,        /** Number of columns in each feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Get the stride between A matrices in the Winograd domain. */
+    static unsigned int get_input_multi_stride(
+      const int n_batches,     /** Number of batches in the input tensor. */
+      const int n_channels,    /** Number of feature maps in the input tensor. */
+      const int n_rows,        /** Number of rows in each feature map. */
+      const int n_cols,        /** Number of columns in each feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Determine how much memory (in units of TOut) to allocate for the
+     * (Winograd domain) output.
+     */
+    static unsigned int get_output_storage_size(
+      const int n_batches,          /** Number of batches in the output tensor. */
+      const int n_rows,             /** Number of rows in each feature map of the input tensor. */
+      const int n_cols,             /** Number of columns in each feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    static unsigned int get_output_stride(
+      const int n_batches,          /** Number of batches in the output tensor. */
+      const int n_rows,             /** Number of rows in each feature map of the input tensor. */
+      const int n_cols,             /** Number of columns in each feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    static unsigned int get_output_multi_stride(
+      const int n_batches,          /** Number of batches in the output tensor. */
+      const int n_rows,             /** Number of rows in each feature map of the input tensor. */
+      const int n_cols,             /** Number of columns in each feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Get the shape (rows, cols) of a feature map of the output tensor. */
+    static std::pair<int, int> get_output_feature_map_shape(
+      const int n_input_rows,  /** Number of rows in the input feature map. */
+      const int n_input_cols,  /** Number of columns in the input feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Create a new Winograd convolution layer.
+     */
+    WinogradConvolutionLayer(
+      const arm_gemm::CPUInfo &cpuinfo,       /** Describes CPU properties. */
+      const int n_threads,          /** Maximum number of threads used to execute the convolution. */
+      const int n_batches,          /** Number of batches in the input and output tensors. */
+      const int n_input_channels,   /** Number of feature maps in a batch of the input tensor. */
+      const int n_input_rows,       /** Number of rows in a feature map of the input tensor. */
+      const int n_input_cols,       /** Number of columns in a feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+      const TIn* const weights,     /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+      TInGEMM* const weights_storage,  /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+      const TIn* const input,       /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+      TInGEMM* const winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+      const TOut* const biases,     /** Pointer to biases vector. Pass nullptr if no bias is provided. */
+      TOut* const output,           /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+      TOutGEMM* const winograd_output,  /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+      const bool pretranspose_B=true,         /** Hint that the B matrix can be pretransposed. */
+      arm_gemm::GemmConfig *gemm_cfg=nullptr  /** Pointer to GEMM configuration. */
+    );
+
+    /* Utility methods for interacting with the layer. */
+    unsigned int weight_transform_get_window(void) const;
+    void weight_transform_run(const unsigned int start, const unsigned int stop);
+
+    ITransform& input_transform(void);
+    ITransform& output_transform(void);
+
+    /* Get a pointer to the GEMM underlying the Winograd transform. */
+    arm_gemm::IGemmCommon *gemm(void);
+};
+
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_output_transform.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_output_transform.hpp
deleted file mode 100644
index 07a0b86..0000000
--- a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_output_transform.hpp
+++ /dev/null
@@ -1,232 +0,0 @@
-/*
- * Copyright (c) 2018 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-
-namespace winograd
-{
-
-
-namespace
-{
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-class OutputTransformImplTiles
-{
-  public:
-    typedef void (*TileFn)(
-      const int n_channels,         /** @param[in] Number of channels in output tensor */
-      const T* const matrix_base,   /** @param[in] Base pointer to Winograd output matrices. */
-      const int matrix_stride,      /** @param[in] Stride between matrices in the output space. */
-      const T* const biases,        /** @param[in] Pointer to bias vector (may be nullptr). */
-      T* const output,              /** @param[out] Pointer to output tensor. */
-      const int output_row_stride,  /** @param[in] Stride across rows of the output tensor. */
-      const int output_col_stride,  /** @param[in] Stride between columns of the output tensor. */
-      const int _pad_bottom,        /** @param[in] Bottom padding for unspecialised tiles. */
-      const int _pad_right          /** @param[in] Right padding for unspecialised tiles. */
-    );
-
-    static TileFn get_tile_specialization(
-      const int pad_bottom,
-      const int pad_right
-    );
-
-    static constexpr unsigned int OutputTileRows = InnerTileRows - KernelRows + 1;
-    static constexpr unsigned int OutputTileCols = InnerTileCols - KernelCols + 1;
-
-  private:
-    static constexpr unsigned int n_pad_bottom = OutputTileRows - 1;
-    static constexpr unsigned int n_pad_right = OutputTileCols - 1;
-
-    static const TileFn tilefn_generic;   /** Generic tile processing function. */
-    static const TileFn tilefn_unpadded;  /** Tile processor for unpadded tiles. */
-    static const TileFn tilefn_bottom_padded[n_pad_bottom];  /** Bottom padding only. */
-    static const TileFn tilefn_right_padded[n_pad_right];    /** Right padding only. */
-};
-
-template <int KernelCols, int InnerTileCols, typename T>
-class OutputTransformImplTiles<1, KernelCols, 1, InnerTileCols, T>
-{
-  public:
-    typedef void (*TileFn)(
-      const int n_channels,         /** @param[in] Number of channels in output tensor */
-      const T* const matrix_base,   /** @param[in] Base pointer to Winograd output matrices. */
-      const int matrix_stride,      /** @param[in] Stride between matrices in the output space. */
-      const T* const biases,        /** @param[in] Pointer to bias vector (may be nullptr). */
-      T* const output,              /** @param[out] Pointer to output tensor. */
-      const int output_row_stride,  /** @param[in] Stride across rows of the output tensor. */
-      const int output_col_stride,  /** @param[in] Stride between columns of the output tensor. */
-      const int _pad_bottom,        /** @param[in] Bottom padding for unspecialised tiles. */
-      const int _pad_right          /** @param[in] Right padding for unspecialised tiles. */
-    );
-
-    static TileFn get_tile_specialization(
-      const int pad_bottom,
-      const int pad_right
-    );
-
-    static constexpr unsigned int OutputTileRows = 1;
-    static constexpr unsigned int OutputTileCols = InnerTileCols - KernelCols + 1;
-
-  private:
-    static constexpr unsigned int n_pad_right = OutputTileCols - 1;
-
-    static const TileFn tilefn_unpadded;  /** Tile processor for unpadded tiles. */
-    static const TileFn tilefn_right_padded[n_pad_right];    /** Right padding only. */
-};
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-class OutputTransformImpl
-{
-  private:
-    static void process_tile_row(
-      const int tile_N,
-      const int n_channels,
-      const T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      const T* const biases,
-      T* const output,
-      const int output_row_stride,
-      const int output_col_stride,
-      const int row_pad_bottom,
-      const int row_pad_right
-    );
-
-    using Tiles = OutputTransformImplTiles<
-      KernelRows, KernelCols, InnerTileRows, InnerTileCols, T
-    >;
-
-  public:
-    /** Apply the output transform to a tensor. */
-    static void execute(
-      const int n_batches,
-      const int out_batch_stride,
-      const int n_rows,
-      const int out_row_stride,
-      const int n_cols,
-      const int out_col_stride,
-      const int n_channels,
-      const T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      const T* const biases,
-      T* const output
-    );
-
-    static constexpr unsigned int OutputTileRows = Tiles::OutputTileRows;
-    static constexpr unsigned int OutputTileCols = Tiles::OutputTileCols;
-};
-
-template <int KernelRows, int InnerTileRows, typename T>
-class OutputTransformImpl<KernelRows, 1, InnerTileRows, 1, T>
-{
-  public:
-    /** Apply the output transform to a tensor. */
-    static void execute(
-      const int n_batches,
-      const int out_batch_stride,
-      const int n_rows,
-      const int out_row_stride,
-      const int n_cols,
-      const int out_col_stride,
-      const int n_channels,
-      const T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      const T* const biases,
-      T* const output
-    );
-
-    static constexpr unsigned int OutputTileRows = InnerTileRows - KernelRows + 1;
-    static constexpr unsigned int OutputTileCols = 1;
-};
-
-}  // namespace (anonymous)
-
-template <int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename T>
-class OutputTransform
-{
-  public:
-    /***********************************************************************/
-    /** Create an OutputTransform operator fixed on a given problem and set
-     * of pointers.
-     */
-    OutputTransform(
-      const T* const matrix_base,   /** Pointer to base of matrices. */
-      const int matrix_stride,      /** Stride between matrices. */
-      const int matrix_row_stride,  /** Stride within a matrix. */
-      const T* const biases,        /** Pointer to biases vector. */
-      T* const output,              /** Pointer to output tensor. */
-      const int n_batches,          /** Number of batches in output tensor. */
-      const int n_rows,             /** Number of rows in output tensor. */
-      const int n_cols,             /** Number of columns in output tensor. */
-      const int n_channels,         /** Number of channels in output tensor. */
-      const int out_batch_stride=0, /** Output batch stride. */
-      const int out_row_stride=0,   /** Output row stride. */
-      const int out_col_stride=0    /** Output column stride. */
-    );
-
-    /** Get the window of work a given operator can perform. */
-    unsigned int get_window() const;
-    static constexpr unsigned int WINDOW_BLOCK = 16;  // Base size of window
-
-    /** Perform work upon a window of the input. */
-    void run(const unsigned int start, const unsigned int stop);
-
-    /** Apply the transform to create a tensor. */
-    static void execute(
-      const int n_batches,
-      const int out_batch_stride,
-      const int n_rows,
-      const int out_row_stride,
-      const int n_cols,
-      const int out_col_stride,
-      const int n_channels,
-      const T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      const T* const biases,
-      T* const output
-    );
-
-  private:
-    using Transform = OutputTransformImpl<
-      KernelRows, KernelCols, InnerTileRows, InnerTileCols, T
-    >;
-
-    static constexpr unsigned int OutputTileRows = Transform::OutputTileRows;
-    static constexpr unsigned int OutputTileCols = Transform::OutputTileCols;
-
-    /** Member constants for instances of the transform. */
-    const T* const _matrix_base;
-    const T* const _biases;
-    const int _matrix_stride, _matrix_row_stride;
-    T* const _outptr;
-    const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N;
-    const int _out_col_stride, _out_row_stride, _out_batch_stride;
-};
-
-}  // namespace winograd
-