COMPMID-815: Updated NEWinogradLayer with the lastest code from Research.

Change-Id: I86d7f53b5f5d1dbc22078aea5c32b08a25d1f49e
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116634
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/winograd/alloc.hpp
index ef6f2b5..799e95d 100644
--- a/arm_compute/core/NEON/kernels/winograd/alloc.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/alloc.hpp
@@ -21,6 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #pragma once
 
 #ifdef ALLOC_ALIGN
diff --git a/arm_compute/core/NEON/kernels/winograd/arm.hpp b/arm_compute/core/NEON/kernels/winograd/arm.hpp
new file mode 100644
index 0000000..90e7828
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/arm.hpp
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/** Sets the macro __arm_any__ if compiling for Aarch32 or Aarch64.
+ *  Includes `arm_neon.h` if compiling for either architecture.
+ */
+
+#ifdef __arm__
+#define __arm_any__
+#endif  // __arm__
+
+#ifdef __aarch64__
+#define __arm_any__
+#endif  // __aarch64__
+
+#ifdef __arm_any__
+#include <arm_neon.h>
+#endif  // __arm_any__
diff --git a/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp
new file mode 100644
index 0000000..663b3c4
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+namespace winograd
+{
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+class BatchedBlockedGemm
+{
+  public:
+    /** Create a new batched blocked GEMM operator. */
+    BatchedBlockedGemm(
+      const unsigned int n_gemms,
+      const int M, const int K, const int N,
+      const int a_matrix_stride,
+      const int a_row_stride,
+      const int b_matrix_stride,
+      const int b_row_stride,
+      const int c_matrix_stride,
+      const int c_row_stride,
+      const TIn* const a_ptr,
+      const TIn* const b_ptr,
+      TOut* const c_ptr
+    );
+
+    BatchedBlockedGemm(const BatchedBlockedGemm&) = delete;
+    BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete;
+
+    /** Get a window of work performed by the operator. */
+    unsigned int get_window() const;
+
+    /** Perform a portion of the work of the operator. */
+    void run(const unsigned int start, const unsigned int stop);
+
+  private:
+    const unsigned int n_gemms;
+    const int M, N, K;
+    const int a_matrix_stride, a_row_stride;
+    const int b_matrix_stride, b_row_stride;
+    const int c_matrix_stride, c_row_stride;
+    const TIn* const a_ptr;
+    const TIn* const b_ptr;
+    TOut* const c_ptr;
+};
+
+}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/winograd/convolution.hpp b/arm_compute/core/NEON/kernels/winograd/convolution.hpp
new file mode 100644
index 0000000..2ab2597
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/convolution.hpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+enum PaddingType {
+  PADDING_SAME, PADDING_VALID
+};
diff --git a/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
new file mode 100644
index 0000000..725f6ca
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "convolution.hpp"
+#include "tensor.hpp"
+
+void direct_convolution(
+  const Tensor4D<Tensor4DShape, float>& input,
+  const Tensor4D<KernelShape, float>& kernel,
+  Tensor4D<Tensor4DShape, float>& output,
+  const PaddingType padding
+);
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp
new file mode 100644
index 0000000..e48d31b
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/gemm.hpp
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "utils.hpp"
+
+template <typename TIn, typename TOut>
+inline void Gemm(const TIn* const a, const TIn* const b, TOut *c,
+          const int M, const int K, const int N,
+          const int a_row_stride,
+          const int b_row_stride,
+          const int c_row_stride,
+          const bool a_transposed=false,
+          const bool b_transposed=false) {
+  // Array access methods
+  const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn {
+    return a[(!a_transposed) ? i*a_row_stride + j : i + j*M];
+  };
+
+  const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn {
+    return b[(!b_transposed) ? i*b_row_stride + j : i + j*N];
+  };
+
+  const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+    return c[i*c_row_stride + j];
+  };
+
+  // Perform the matrix multiplication
+  for (int i = 0; i < M; i++) {
+    for (int j = 0; j < N; j++) {
+      for (int k = 0; k < K; k++) {
+        C(i, j) += A(i, k) * B(k, j);
+      }
+    }
+  }
+}
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+inline void BlockedGemm(
+  const TIn* const a, const TIn* const b, TOut *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  // Array access methods
+  const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn {
+    return a[i*a_row_stride + j];
+  };
+
+  const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn {
+    return b[i*b_row_stride + j];
+  };
+
+  const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+    return c[i*c_row_stride + j];
+  };
+
+  const int M_BLOCKS = iceildiv(M, M_BLOCK);
+  const int N_BLOCKS = iceildiv(N, N_BLOCK);
+
+  // For each block of output rows
+  for (int mblock = 0; mblock < M_BLOCKS; mblock++) {
+    // For each block of output columns
+    for (int nblock = 0; nblock < N_BLOCKS; nblock++) {
+      // Create an appropriately sized block of accumulators
+      TOut accum[M_BLOCK][N_BLOCK];
+      for (int i = 0; i < M_BLOCK; i++) {
+        for (int j = 0; j < N_BLOCK; j++) {
+          accum[i][j] = static_cast<TOut>(0);
+        }
+      }
+
+      // Perform this portion of the matrix multiply
+      for (int k = 0; k < K; k++) {
+        // Load elements of A
+        TIn elems_a[M_BLOCK];
+        for (int i = 0; i < M_BLOCK; i++) {
+          elems_a[i] = A(mblock*M_BLOCK + i, k);
+        }
+
+        // Load elements of B
+        TIn elems_b[N_BLOCK];
+        for (int j = 0; j < N_BLOCK; j++) {
+          elems_b[j] = B(k, nblock*N_BLOCK + j);
+        }
+
+        // Perform the partial matrix multiply
+        for (int i = 0; i < M_BLOCK; i++) {
+          for (int j = 0; j < N_BLOCK; j++) {
+            accum[i][j] += elems_a[i] * elems_b[j];
+          }
+        }
+      }
+
+      // Store the partial product
+      for (int i = 0; i < M_BLOCK; i++) {
+        for (int j = 0; j < N_BLOCK; j++) {
+          C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j];
+        }
+      }
+    }
+  }
+}
+
+#include "gemm/a64_sgemm.hpp"
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
new file mode 100644
index 0000000..caeb48f
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
@@ -0,0 +1,355 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cassert>
+#include "../utils.hpp"
+
+#ifdef __aarch64__
+
+template <>
+inline void BlockedGemm<8, 12, float, float>(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  const int M_BLOCK = 8;
+  const int N_BLOCK = 12;
+
+  const int m_blocks = iceildiv(M, M_BLOCK);
+  const int n_blocks = iceildiv(N, N_BLOCK);
+
+  // For each block of output rows
+  for (int mblock = 0; mblock < m_blocks; mblock++) {
+    // For each block of output columns
+    for (int nblock = 0; nblock < n_blocks; nblock++) {
+      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+      const float *bptr = b + nblock*N_BLOCK;
+      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+      int k = K;
+
+      asm volatile (
+          // Create an 8x12 block of accumulators
+          " A_1 .req v27\n"
+          "sA_1 .req s27\n"
+          " A_2 .req v28\n"
+          "sA_2 .req s28\n"
+          " A_3 .req v29\n"
+          "sA_3 .req s29\n"
+          " A_4 .req v30\n"
+          "sA_4 .req s30\n"
+
+          " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n"
+          "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n"
+
+          " C_11 .req  v0\n" " C_12 .req  v1\n" " C_13 .req  v2\n"
+          " C_21 .req  v3\n" " C_22 .req  v4\n" " C_23 .req  v5\n"
+          " C_31 .req  v6\n" " C_32 .req  v7\n" " C_33 .req  v8\n"
+          " C_41 .req  v9\n" " C_42 .req v10\n" " C_43 .req v11\n"
+          " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n"
+          " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n"
+          " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n"
+          " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n"
+
+          "qC_11 .req  q0\n" "qC_12 .req  q1\n" "qC_13 .req  q2\n"
+          "qC_21 .req  q3\n" "qC_22 .req  q4\n" "qC_23 .req  q5\n"
+          "qC_31 .req  q6\n" "qC_32 .req  q7\n" "qC_33 .req  q8\n"
+          "qC_41 .req  q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n"
+          "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n"
+          "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n"
+          "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n"
+          "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n"
+
+          "aptr1 .req x17\n"
+          "aptr2 .req x18\n"
+          "aptr3 .req x19\n"
+          "aptr4 .req x20\n"
+          "aptr5 .req x21\n"
+          "aptr6 .req x22\n"
+          "aptr7 .req x23\n"
+
+          // Initialise accumulators with 0
+          // Initialise pointers
+          "movi C_11.4s, #0\n"
+          "add aptr1, %x[aptr], %x[a_row_stride]\n"
+          "movi C_12.4s, #0\n"
+          "add aptr2,    aptr1, %x[a_row_stride]\n"
+          "movi C_13.4s, #0\n"
+          "add aptr3,    aptr2, %x[a_row_stride]\n"
+          "movi C_21.4s, #0\n"
+          "add aptr4,    aptr3, %x[a_row_stride]\n"
+          "movi C_22.4s, #0\n"
+          "add aptr5,    aptr4, %x[a_row_stride]\n"
+          "movi C_23.4s, #0\n"
+          "add aptr6,    aptr5, %x[a_row_stride]\n"
+          "movi C_31.4s, #0\n"
+          "add aptr7,    aptr6, %x[a_row_stride]\n"
+          "movi C_32.4s, #0\n"
+          "ldr qB_1, [%x[bptr]]\n"
+          "movi C_33.4s, #0\n"
+          "ldr qB_2, [%x[bptr], #0x10]\n"
+          "movi C_41.4s, #0\n"
+          "prfm pldl1keep, [%x[bptr], #0x00]\n"
+          "movi C_42.4s, #0\n"
+          "prfm pldl1keep, [%x[bptr], #0x10]\n"
+          "movi C_43.4s, #0\n"
+          "prfm pldl1keep, [%x[bptr], #0x20]\n"
+          "movi C_51.4s, #0\n"
+          "prfm pldl1keep, [%x[aptr], #0x00]\n"
+          "movi C_52.4s, #0\n"
+          "prfm pldl1keep, [   aptr1, #0x00]\n"
+          "movi C_53.4s, #0\n"
+          "prfm pldl1keep, [   aptr2, #0x00]\n"
+          "movi C_61.4s, #0\n"
+          "prfm pldl1keep, [   aptr3, #0x00]\n"
+          "movi C_62.4s, #0\n"
+          "prfm pldl1keep, [   aptr4, #0x00]\n"
+          "movi C_63.4s, #0\n"
+          "prfm pldl1keep, [   aptr5, #0x00]\n"
+          "movi C_71.4s, #0\n"
+          "prfm pldl1keep, [   aptr6, #0x00]\n"
+          "movi C_72.4s, #0\n"
+          "prfm pldl1keep, [   aptr7, #0x00]\n"
+          "movi C_73.4s, #0\n"
+          "ldr sA_1, [%x[aptr]], #0x4\n"
+          "movi C_81.4s, #0\n"
+          "ldr sA_2, [   aptr1], #0x4\n"
+          "movi C_82.4s, #0\n"
+          "ldr sA_3, [   aptr2], #0x4\n"
+          "movi C_83.4s, #0\n"
+          "subs %x[k], %x[k], #1\n"
+          "beq 2f\n"
+
+          "1:"
+            "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+            "ldr qB_3, [%x[bptr], #0x20]\n"
+            "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+            "ldr sA_4, [   aptr3], #0x4\n"
+            "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+            "ldr sA_1, [   aptr4], #0x04\n"
+
+            "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+            "add %x[bptr], %x[bptr], %x[b_row_stride]\n"
+            "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+            "prfm pldl1keep, [   aptr3, #0x10]\n"
+            "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+            "ldr sA_2, [   aptr5], #0x04\n"
+
+            "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+            "prfm pldl1keep, [%x[bptr], #0x00]\n"
+            "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+            "prfm pldl1keep, [%x[bptr], #0x10]\n"
+            "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+            "ldr sA_3, [   aptr6], #0x04\n"
+
+            "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+            "prfm pldl1keep, [%x[bptr], #0x20]\n"
+            "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+            "prfm pldl1keep, [   aptr4, #0x10]\n"
+            "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+            "ldr sA_4, [   aptr7], #0x04\n"
+
+            "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+            "prfm pldl1keep, [   aptr5, #0x10]\n"
+            "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+            "prfm pldl1keep, [   aptr6, #0x10]\n"
+            "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+            "ldr sA_1, [%x[aptr]], #0x04\n"
+
+            "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+            "prfm pldl1keep, [   aptr7, #0x10]\n"
+            "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+            "subs %x[k], %x[k], #1\n"
+            "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+            "ldr sA_2, [   aptr1], #0x04\n"
+
+            "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+            "prfm pldl1keep, [%x[aptr], #0x10]\n"
+            "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+            "prfm pldl1keep, [   aptr1, #0x10]\n"
+            "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+            "ldr sA_3, [   aptr2], #0x04\n"
+
+            "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+            "prfm pldl1keep, [   aptr2, #0x10]\n"
+            "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+            "ldp qB_1, qB_2, [%x[bptr]]\n"
+            "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+            "bne 1b\n"
+
+          "2:"
+            "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+            "ldr qB_3, [%x[bptr], #0x20]\n"
+            "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+            "stp qC_11, qC_12, [%x[cptr]]\n"
+            "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+            "str qC_13, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+            "ldr sA_1, [   aptr4], #0x04\n"
+
+            "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+            "ldr sA_4, [   aptr3], #0x4\n"
+            "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+            "stp qC_21, qC_22, [%x[cptr]]\n"
+            "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+            "str qC_23, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+            "ldr sA_2, [   aptr5], #0x04\n"
+
+            "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+            "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+            "stp qC_31, qC_32, [%x[cptr]]\n"
+            "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+            "str qC_33, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+            "ldr sA_3, [   aptr6], #0x04\n"
+
+            "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+            "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+            "stp qC_41, qC_42, [%x[cptr]]\n"
+            "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+            "str qC_43, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+            "ldr sA_4, [   aptr7], #0x04\n"
+
+            "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+            "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+            "stp qC_51, qC_52, [%x[cptr]]\n"
+            "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+            "str qC_53, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+            "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+            "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+            "stp qC_61, qC_62, [%x[cptr]]\n"
+            "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+            "str qC_63, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+            "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+            "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+            "stp qC_71, qC_72, [%x[cptr]]\n"
+            "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+            "str qC_73, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+            "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+            "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+            "stp qC_81, qC_82, [%x[cptr]]\n"
+            "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+            "str qC_83, [%x[cptr], #0x20]\n"
+            "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+          // Clear aliases
+          ".unreq aptr1\n"
+          ".unreq aptr2\n"
+          ".unreq aptr3\n"
+          ".unreq aptr4\n"
+          ".unreq aptr5\n"
+          ".unreq aptr6\n"
+          ".unreq aptr7\n"
+
+          ".unreq  A_1\n" ".unreq  A_2\n" ".unreq  A_3\n" ".unreq  A_4\n"
+          ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n"
+
+          ".unreq  B_1\n" ".unreq  B_2\n" ".unreq  B_3\n"
+          ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n"
+
+          ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n"
+          ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n"
+          ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n"
+          ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n"
+          ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n"
+          ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n"
+          ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n"
+          ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n"
+
+          ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n"
+          ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n"
+          ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n"
+          ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n"
+          ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n"
+          ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n"
+          ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n"
+          ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n"
+          : [aptr] "+r" (aptr),
+            [bptr] "+r" (bptr),
+            [cptr] "+r" (cptr),
+            [k] "+r" (k)
+          : [a_row_stride] "r" (a_row_stride * sizeof(float)),
+            [b_row_stride] "r" (b_row_stride * sizeof(float)),
+            [c_row_stride] "r" (c_row_stride * sizeof(float))
+          : "cc", "memory",
+            "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+            "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+            "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
+            "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23"
+      );
+    }
+  }
+}
+
+/*****************************************************************************/
+/* 4x16 blocked GEMM with specialised tails
+ */
+#include "a64_sgemm_4x16.hpp"
+
+template <>
+inline void BlockedGemm<4, 16, float, float>(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  // Despatch based on tail of K
+  switch (K % 4) {
+    case 3:
+      sgemm_4x16_impl<3>(
+        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+      );
+      break;
+    case 2:
+      sgemm_4x16_impl<2>(
+        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+      );
+      break;
+    case 1:
+      sgemm_4x16_impl<1>(
+        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+      );
+      break;
+    case 0:
+      sgemm_4x16_impl<0>(
+        a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+      );
+      break;
+    default:
+      assert(false);
+  }
+}
+
+#endif  // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
new file mode 100644
index 0000000..5cd37de
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
@@ -0,0 +1,1446 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+template <const unsigned int tail>
+inline void sgemm_4x16_impl(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+);
+
+template <>
+inline void sgemm_4x16_impl<0>(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  const int TAIL_SIZE = 0;
+  const int M_BLOCK = 4;
+  const int N_BLOCK = 16;
+
+  const int m_blocks = iceildiv(M, M_BLOCK);
+  const int n_blocks = iceildiv(N, N_BLOCK);
+
+  // For each block of output rows
+  for (int mblock = 0; mblock < m_blocks; mblock++) {
+    // For each block of output columns
+    for (int nblock = 0; nblock < n_blocks; nblock++) {
+      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+      const float *bptr = b + nblock*N_BLOCK;
+      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+      int k = (K - TAIL_SIZE) / 4;
+
+      asm volatile(
+        "aptr2 .req X20\n"
+        "aptr3 .req X21\n"
+        "aptr4 .req X22\n"
+        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
+        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
+        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
+        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
+        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+        "vB1 .req v20\n" "qB1 .req q20\n"
+        "vB2 .req v21\n" "qB2 .req q21\n"
+        "vB3 .req v22\n" "qB3 .req q22\n"
+        "vB4 .req v23\n" "qB4 .req q23\n"
+
+        // Clear accumulators, initialise pointers
+        "movi vC11.4s, #0\n"
+        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+        "movi vC12.4s, #0\n"
+        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
+        "movi vC13.4s, #0\n"
+        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
+        "movi vC14.4s, #0\n"
+        "ldr qA1, [%x[aptr]], #0x10\n"
+        "movi vC21.4s, #0\n"
+        "ldr qA2, [   aptr2], #0x10\n"
+        "movi vC22.4s, #0\n"
+        "ldr qB1, [%x[bptr], #0x00]\n"
+        "movi vC23.4s, #0\n"
+        "ldr qB2, [%x[bptr], #0x10]\n"
+        "movi vC24.4s, #0\n"
+        "ldr qB3, [%x[bptr], #0x20]\n"
+        "movi vC31.4s, #0\n"
+        "movi vC32.4s, #0\n"
+        "movi vC33.4s, #0\n"
+        "movi vC34.4s, #0\n"
+        "movi vC41.4s, #0\n"
+        "movi vC42.4s, #0\n"
+        "movi vC43.4s, #0\n"
+        "movi vC44.4s, #0\n"
+        "subs %x[k], %x[k], #1\n"
+        "beq 2f\n"
+
+        "1:"  // Loop proper
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "subs %x[k], %x[k], #1\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+          "bne 1b\n"
+
+        "2:"  // Tail
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "stp qC11, qC12, [%x[cptr], #0x00]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "stp qC13, qC14, [%x[cptr], #0x20]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "stp qC21, qC22, [%x[cptr], #0x00]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "stp qC23, qC24, [%x[cptr], #0x20]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "stp qC31, qC32, [%x[cptr], #0x00]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "stp qC33, qC34, [%x[cptr], #0x20]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "stp qC41, qC42, [%x[cptr], #0x00]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+          "stp qC43, qC44, [%x[cptr], #0x20]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+        ".unreq vB4\n" ".unreq qB4\n"
+        ".unreq vB3\n" ".unreq qB3\n"
+        ".unreq vB2\n" ".unreq qB2\n"
+        ".unreq vB1\n" ".unreq qB1\n"
+        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+        ".unreq aptr2\n"
+        ".unreq aptr3\n"
+        ".unreq aptr4\n"
+
+        : [aptr] "+r" (aptr),
+          [bptr] "+r" (bptr),
+          [cptr] "+r" (cptr),
+          [k] "+r" (k)
+        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+        : "cc", "memory", "x20", "x21", "x22",
+          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+          "v21", "v22", "v23"
+      );
+    }
+  }
+}
+
+template <>
+inline void sgemm_4x16_impl<1>(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  const int TAIL_SIZE = 1;
+  const int M_BLOCK = 4;
+  const int N_BLOCK = 16;
+
+  const int m_blocks = iceildiv(M, M_BLOCK);
+  const int n_blocks = iceildiv(N, N_BLOCK);
+
+  // For each block of output rows
+  for (int mblock = 0; mblock < m_blocks; mblock++) {
+    // For each block of output columns
+    for (int nblock = 0; nblock < n_blocks; nblock++) {
+      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+      const float *bptr = b + nblock*N_BLOCK;
+      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+      int k = (K - TAIL_SIZE) / 4;
+
+      asm volatile(
+        "aptr2 .req X20\n"
+        "aptr3 .req X21\n"
+        "aptr4 .req X22\n"
+        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
+        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
+        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
+        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
+        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+        "vB1 .req v20\n" "qB1 .req q20\n"
+        "vB2 .req v21\n" "qB2 .req q21\n"
+        "vB3 .req v22\n" "qB3 .req q22\n"
+        "vB4 .req v23\n" "qB4 .req q23\n"
+
+        // Clear accumulators, initialise pointers
+        "movi vC11.4s, #0\n"
+        "ldr qB1, [%x[bptr], #0x00]\n"
+        "movi vC12.4s, #0\n"
+        "ldr qB2, [%x[bptr], #0x10]\n"
+        "movi vC13.4s, #0\n"
+        "ldr qB3, [%x[bptr], #0x20]\n"
+        "movi vC14.4s, #0\n"
+        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+        "movi vC21.4s, #0\n"
+        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
+        "movi vC22.4s, #0\n"
+        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
+        "movi vC23.4s, #0\n"
+        "cbnz %x[k], 3f\n"
+
+        // Prepare for tail in K
+        "movi vC24.4s, #0\n"
+        "ldr sA1, [%x[aptr]], #0x04\n"
+        "movi vC31.4s, #0\n"
+        "ldr sA2, [   aptr2], #0x04\n"
+        "movi vC32.4s, #0\n"
+        "movi vC33.4s, #0\n"
+        "movi vC34.4s, #0\n"
+        "movi vC41.4s, #0\n"
+        "movi vC42.4s, #0\n"
+        "movi vC43.4s, #0\n"
+        "movi vC44.4s, #0\n"
+        "b 2f\n"  // Jump to tail
+
+        "3:"  // Prepare for loop over K
+          "movi vC24.4s, #0\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "movi vC31.4s, #0\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "movi vC32.4s, #0\n"
+          "movi vC33.4s, #0\n"
+          "movi vC34.4s, #0\n"
+          "movi vC41.4s, #0\n"
+          "movi vC42.4s, #0\n"
+          "movi vC43.4s, #0\n"
+          "movi vC44.4s, #0\n"
+          "subs %x[k], %x[k], #1\n"
+          "beq 4f\n"
+
+        "1:"  // Loop proper
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "subs %x[k], %x[k], #1\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+          "bne 1b\n"
+
+        "4:"  // Tail iteration
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr sA1, [%x[aptr]], #0x04\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr sA2, [   aptr2], #0x04\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+        "2:"  // Common tail
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "stp qC11, qC12, [%x[cptr], #0x00]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "ldr sA3, [   aptr3], #0x04\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "stp qC13, qC14, [%x[cptr], #0x20]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "stp qC21, qC22, [%x[cptr], #0x00]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "ldr sA4, [   aptr4], #0x04\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "stp qC23, qC24, [%x[cptr], #0x20]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "stp qC31, qC32, [%x[cptr], #0x00]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "stp qC33, qC34, [%x[cptr], #0x20]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "stp qC41, qC42, [%x[cptr], #0x00]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+          "stp qC43, qC44, [%x[cptr], #0x20]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+        ".unreq vB4\n" ".unreq qB4\n"
+        ".unreq vB3\n" ".unreq qB3\n"
+        ".unreq vB2\n" ".unreq qB2\n"
+        ".unreq vB1\n" ".unreq qB1\n"
+        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+        ".unreq aptr2\n"
+        ".unreq aptr3\n"
+        ".unreq aptr4\n"
+
+        : [aptr] "+r" (aptr),
+          [bptr] "+r" (bptr),
+          [cptr] "+r" (cptr),
+          [k] "+r" (k)
+        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+        : "cc", "memory", "x20", "x21", "x22",
+          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+          "v21", "v22", "v23"
+      );
+    }
+  }
+}
+
+template <>
+inline void sgemm_4x16_impl<2>(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  const int TAIL_SIZE = 2;
+  const int M_BLOCK = 4;
+  const int N_BLOCK = 16;
+
+  const int m_blocks = iceildiv(M, M_BLOCK);
+  const int n_blocks = iceildiv(N, N_BLOCK);
+
+  // For each block of output rows
+  for (int mblock = 0; mblock < m_blocks; mblock++) {
+    // For each block of output columns
+    for (int nblock = 0; nblock < n_blocks; nblock++) {
+      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+      const float *bptr = b + nblock*N_BLOCK;
+      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+      int k = (K - TAIL_SIZE) / 4;
+
+      asm volatile(
+        "aptr2 .req X20\n"
+        "aptr3 .req X21\n"
+        "aptr4 .req X22\n"
+        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
+        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
+        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
+        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
+        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+        "vB1 .req v20\n" "qB1 .req q20\n"
+        "vB2 .req v21\n" "qB2 .req q21\n"
+        "vB3 .req v22\n" "qB3 .req q22\n"
+        "vB4 .req v23\n" "qB4 .req q23\n"
+
+        // Clear accumulators, initialise pointers
+        "movi vC11.4s, #0\n"
+        "ldr qB1, [%x[bptr], #0x00]\n"
+        "movi vC12.4s, #0\n"
+        "ldr qB2, [%x[bptr], #0x10]\n"
+        "movi vC13.4s, #0\n"
+        "ldr qB3, [%x[bptr], #0x20]\n"
+        "movi vC14.4s, #0\n"
+        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+        "movi vC21.4s, #0\n"
+        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
+        "movi vC22.4s, #0\n"
+        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
+        "movi vC23.4s, #0\n"
+        "cbnz %x[k], 3f\n"
+
+        // Prepare for tail in K
+        "movi vC24.4s, #0\n"
+        "ldr dA1, [%x[aptr]], #0x08\n"
+        "movi vC31.4s, #0\n"
+        "ldr dA2, [   aptr2], #0x08\n"
+        "movi vC32.4s, #0\n"
+        "movi vC33.4s, #0\n"
+        "movi vC34.4s, #0\n"
+        "movi vC41.4s, #0\n"
+        "movi vC42.4s, #0\n"
+        "movi vC43.4s, #0\n"
+        "movi vC44.4s, #0\n"
+        "b 2f\n"  // Jump to tail
+
+        "3:"  // Prepare for loop over K
+          "movi vC24.4s, #0\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "movi vC31.4s, #0\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "movi vC32.4s, #0\n"
+          "movi vC33.4s, #0\n"
+          "movi vC34.4s, #0\n"
+          "movi vC41.4s, #0\n"
+          "movi vC42.4s, #0\n"
+          "movi vC43.4s, #0\n"
+          "movi vC44.4s, #0\n"
+          "subs %x[k], %x[k], #1\n"
+          "beq 4f\n"
+
+        "1:"  // Loop proper
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "subs %x[k], %x[k], #1\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+          "bne 1b\n"
+
+        "4:"  // Tail iteration
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr dA1, [%x[aptr]], #0x08\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr dA2, [   aptr2], #0x08\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+        "2:"  // Common tail
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr dA3, [   aptr3], #0x08\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr dA4, [   aptr4], #0x08\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "stp qC11, qC12, [%x[cptr], #0x00]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "stp qC13, qC14, [%x[cptr], #0x20]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "stp qC21, qC22, [%x[cptr], #0x00]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "stp qC23, qC24, [%x[cptr], #0x20]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "stp qC31, qC32, [%x[cptr], #0x00]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "stp qC33, qC34, [%x[cptr], #0x20]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "stp qC41, qC42, [%x[cptr], #0x00]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+          "stp qC43, qC44, [%x[cptr], #0x20]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+        ".unreq vB4\n" ".unreq qB4\n"
+        ".unreq vB3\n" ".unreq qB3\n"
+        ".unreq vB2\n" ".unreq qB2\n"
+        ".unreq vB1\n" ".unreq qB1\n"
+        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+        ".unreq aptr2\n"
+        ".unreq aptr3\n"
+        ".unreq aptr4\n"
+
+        : [aptr] "+r" (aptr),
+          [bptr] "+r" (bptr),
+          [cptr] "+r" (cptr),
+          [k] "+r" (k)
+        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+        : "cc", "memory", "x20", "x21", "x22",
+          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+          "v21", "v22", "v23"
+      );
+    }
+  }
+}
+
+template <>
+inline void sgemm_4x16_impl<3>(
+  const float* const a, const float* const b, float *c,
+  const int M, const int K, const int N,
+  const int a_row_stride,
+  const int b_row_stride,
+  const int c_row_stride
+) {
+  const int TAIL_SIZE = 3;
+  const int M_BLOCK = 4;
+  const int N_BLOCK = 16;
+
+  const int m_blocks = iceildiv(M, M_BLOCK);
+  const int n_blocks = iceildiv(N, N_BLOCK);
+
+  // For each block of output rows
+  for (int mblock = 0; mblock < m_blocks; mblock++) {
+    // For each block of output columns
+    for (int nblock = 0; nblock < n_blocks; nblock++) {
+      const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+      const float *bptr = b + nblock*N_BLOCK;
+      float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+      int k = (K - TAIL_SIZE) / 4;
+
+      asm volatile(
+        "aptr2 .req X20\n"
+        "aptr3 .req X21\n"
+        "aptr4 .req X22\n"
+        "vC11 .req  v0\n" "vC12 .req  v1\n" "vC13 .req  v2\n" "vC14 .req  v3\n"
+        "qC11 .req  q0\n" "qC12 .req  q1\n" "qC13 .req  q2\n" "qC14 .req  q3\n"
+        "vC21 .req  v4\n" "vC22 .req  v5\n" "vC23 .req  v6\n" "vC24 .req  v7\n"
+        "qC21 .req  q4\n" "qC22 .req  q5\n" "qC23 .req  q6\n" "qC24 .req  q7\n"
+        "vC31 .req  v8\n" "vC32 .req  v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+        "qC31 .req  q8\n" "qC32 .req  q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+        "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+        "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+        "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+        "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+        "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+        "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+        "vB1 .req v20\n" "qB1 .req q20\n"
+        "vB2 .req v21\n" "qB2 .req q21\n"
+        "vB3 .req v22\n" "qB3 .req q22\n"
+        "vB4 .req v23\n" "qB4 .req q23\n"
+
+        // Clear accumulators, initialise pointers
+        "movi vC11.4s, #0\n"
+        "ldr qB1, [%x[bptr], #0x00]\n"
+        "movi vC12.4s, #0\n"
+        "ldr qB2, [%x[bptr], #0x10]\n"
+        "movi vC13.4s, #0\n"
+        "ldr qB3, [%x[bptr], #0x20]\n"
+        "movi vC14.4s, #0\n"
+        "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+        "movi vC21.4s, #0\n"
+        "add aptr3,    aptr2, %x[a_row_stride_bytes]\n"
+        "movi vC22.4s, #0\n"
+        "add aptr4,    aptr3, %x[a_row_stride_bytes]\n"
+        "movi vC23.4s, #0\n"
+        "cbnz %x[k], 3f\n"
+
+        // Prepare for tail in K
+        "movi vC24.4s, #0\n"
+        "ldr dA1, [%x[aptr]], #0x08\n"
+        "movi vC31.4s, #0\n"
+        "ldr dA2, [   aptr2], #0x08\n"
+        "movi vC32.4s, #0\n"
+        "movi vC33.4s, #0\n"
+        "movi vC34.4s, #0\n"
+        "movi vC41.4s, #0\n"
+        "movi vC42.4s, #0\n"
+        "movi vC43.4s, #0\n"
+        "movi vC44.4s, #0\n"
+        "b 2f\n"  // Jump to tail
+
+        "3:"  // Prepare for loop over K
+          "movi vC24.4s, #0\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "movi vC31.4s, #0\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "movi vC32.4s, #0\n"
+          "movi vC33.4s, #0\n"
+          "movi vC34.4s, #0\n"
+          "movi vC41.4s, #0\n"
+          "movi vC42.4s, #0\n"
+          "movi vC43.4s, #0\n"
+          "movi vC44.4s, #0\n"
+          "subs %x[k], %x[k], #1\n"
+          "beq 4f\n"
+
+        "1:"  // Loop proper
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "subs %x[k], %x[k], #1\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr qA1, [%x[aptr]], #0x10\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr qA2, [   aptr2], #0x10\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+          "bne 1b\n"
+
+        "4:"  // Tail iteration
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qA3, [   aptr3], #0x10\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr qA4, [   aptr4], #0x10\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+          "ldr dA1, [%x[aptr]], #0x08\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+          "ldr dA2, [   aptr2], #0x08\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+        "2:"  // Common tail
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr dA3, [   aptr3], #0x08\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "ldr dA4, [   aptr4], #0x08\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+          "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+          "ldr qB1, [%x[bptr], #0x00]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+          "ldr qB2, [%x[bptr], #0x10]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+          "ldr sA1, [%x[aptr]], #0x04\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+          "ldr sA2, [   aptr2], #0x04\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+          "ldr qB3, [%x[bptr], #0x20]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+          "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+          "ldr qB4, [%x[bptr], #0x30]\n"
+          "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+          "stp qC11, qC12, [%x[cptr], #0x00]\n"
+          "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+          "ldr sA3, [   aptr3], #0x04\n"
+          "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+          "stp qC13, qC14, [%x[cptr], #0x20]\n"
+          "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+          "stp qC21, qC22, [%x[cptr], #0x00]\n"
+          "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+          "ldr sA4, [   aptr4], #0x04\n"
+          "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+          "stp qC23, qC24, [%x[cptr], #0x20]\n"
+          "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+          "stp qC31, qC32, [%x[cptr], #0x00]\n"
+          "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+          "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+          "stp qC33, qC34, [%x[cptr], #0x20]\n"
+          "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+          "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+          "stp qC41, qC42, [%x[cptr], #0x00]\n"
+          "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+          "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+          "stp qC43, qC44, [%x[cptr], #0x20]\n"
+          "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+        ".unreq vB4\n" ".unreq qB4\n"
+        ".unreq vB3\n" ".unreq qB3\n"
+        ".unreq vB2\n" ".unreq qB2\n"
+        ".unreq vB1\n" ".unreq qB1\n"
+        ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+        ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+        ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+        ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+        ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+        ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+        ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+        ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+        ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+        ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+        ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+        ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+        ".unreq aptr2\n"
+        ".unreq aptr3\n"
+        ".unreq aptr4\n"
+
+        : [aptr] "+r" (aptr),
+          [bptr] "+r" (bptr),
+          [cptr] "+r" (cptr),
+          [k] "+r" (k)
+        : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+          [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+          [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+        : "cc", "memory", "x20", "x21", "x22",
+          "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+          "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+          "v21", "v22", "v23"
+      );
+    }
+  }
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h
new file mode 100644
index 0000000..0cdf742
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/perf.h
@@ -0,0 +1,9 @@
+#pragma once
+
+/* Prototypes from perf.c */
+
+void start_counter(int fd);
+long long get_counter(int fd);
+long long stop_counter(int fd);
+int open_instruction_counter(void);
+int open_cycle_counter(void);
diff --git a/arm_compute/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp
new file mode 100644
index 0000000..01fafa9
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/profiler.hpp
@@ -0,0 +1,326 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <cstdio>
+#include <map>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+#include "perf.h"
+#include <unistd.h>
+
+#ifdef CYCLE_PROFILING
+class EventIDContainer
+{
+  public:
+  EventIDContainer() : container_lock(), event_ids()
+  {
+  }
+
+  int get_event_id(const char *id)
+  {
+    std::lock_guard<std::mutex> lock(container_lock);
+    if (!event_ids.count(id)) {
+      event_ids.emplace(id, event_ids.size());
+    }
+    return event_ids[id];
+  }
+
+  unsigned int size() const
+  {
+    return event_ids.size();
+  }
+
+  auto begin()
+  {
+    return event_ids.begin();
+  }
+
+  auto end()
+  {
+    return event_ids.end();
+  }
+
+  private:
+  std::mutex container_lock;
+  std::map<const char *, int> event_ids;
+};
+
+
+class ThreadEventCounterContainer
+{
+  public:
+  ThreadEventCounterContainer() : container_lock(), thread_counter_fds()
+  {
+  }
+
+  int get_counter_fd()
+  {
+    const auto id = std::this_thread::get_id();
+    std::lock_guard<std::mutex> lock(container_lock);
+    if (!thread_counter_fds.count(id))
+    {
+      thread_counter_fds.emplace(id, open_cycle_counter());
+    }
+    return thread_counter_fds[id];
+  }
+
+  ~ThreadEventCounterContainer()
+  {
+    // Close all counter file descriptors
+    for (auto& fd : thread_counter_fds)
+    {
+      close(fd.second);
+    }
+  }
+
+  private:
+  std::mutex container_lock;
+  std::map<std::thread::id, int> thread_counter_fds;
+};
+#endif  // CYCLE_PROFILING
+
+
+class profiler {
+private:
+#ifdef CYCLE_PROFILING
+    struct ProfileEntry {
+      int event_id;
+      long int bytes_read, ops, bytes_written;
+      long int duration;
+    };
+
+    static const int maxevents = 10000;
+    ProfileEntry events[maxevents];
+    int currentevent;
+    std::mutex event_lock;
+
+    EventIDContainer event_ids;
+    ThreadEventCounterContainer thread_counter_fds;
+
+    int get_event_id(const char *id)
+    {
+      return event_ids.get_event_id(id);
+    }
+#endif  // CYCLE_PROFILING
+
+public:
+#ifdef CYCLE_PROFILING
+    profiler() :
+      currentevent(0),
+      event_lock(),
+      event_ids(),
+      thread_counter_fds()
+    {
+    }
+
+    ~profiler() {
+      std::lock_guard<std::mutex> lock_events(event_lock);
+
+        // Compute performance from recorded events
+        struct ProfileResult {
+          ProfileResult() : total_calls(0),
+                            total_duration(0),
+                            total_bytes_read(0),
+                            total_ops(0),
+                            total_bytes_written(0) {
+          }
+
+          void operator+=(const ProfileEntry &rhs) {
+            total_calls++;
+            total_duration += rhs.duration;
+            total_bytes_read += rhs.bytes_read;
+            total_ops += rhs.ops;
+            total_bytes_written = rhs.bytes_written;
+          }
+
+          float avg_duration(void) const {
+            return static_cast<float>(total_duration) /
+                   static_cast<float>(total_calls);
+          }
+
+          float bytes_read_per_cycle(void) const {
+            return static_cast<float>(total_bytes_read) /
+                   static_cast<float>(total_duration);
+          }
+
+          float ops_per_cycle(void) const {
+            return static_cast<float>(total_ops) /
+                   static_cast<float>(total_duration);
+          }
+
+          float bytes_written_per_cycle(void) const {
+            return static_cast<float>(total_bytes_written) /
+                   static_cast<float>(total_duration);
+          }
+
+          long int total_calls,
+                   total_duration,
+                   total_bytes_read,
+                   total_ops,
+                   total_bytes_written;
+        };
+
+        std::vector<ProfileResult> totals;
+        totals.resize(event_ids.size());
+        for (int i = 0; i < currentevent; i++) {
+          const auto &event = events[i];
+          totals[event.event_id] += event;
+        }
+
+        // Get the longest label
+        int len_label = 0;
+        for (const auto &kv : event_ids) {
+          len_label = std::max(len_label, static_cast<int>(strlen(kv.first)));
+        }
+
+        // Get the longest values for every other field
+        const auto get_length_of_field =
+          [totals] (const char *title, auto f, auto len) -> size_t {
+            size_t l = strlen(title);
+            for (const auto &v : totals) {
+              l = std::max(l, len(f(v)));
+            }
+            return l;
+        };
+
+        // Get the strlen for an int
+        const auto intlen = [] (long int x) -> size_t {
+          size_t len = 0;
+          do {
+            x /= 10;
+            len++;
+          } while (x);
+          return len;
+        };
+
+        // Get the strlen for a float
+        const auto floatlen = [] (const int precision) {
+          return [precision] (float x) {
+            size_t len = 0;
+
+            if (!std::isfinite(x)) {
+              return static_cast<size_t>(3);
+            }
+
+            do {
+              x /= 10.0f;
+              len++;
+            } while (x > 1.0f);
+            return len + 1 + precision;
+          };
+        };
+
+        const int len_calls = get_length_of_field(
+            "Calls", [] (const auto &v) {return v.total_calls;},
+            intlen
+        );
+        const int len_duration = get_length_of_field(
+            "Duration", [] (const auto &v) {return v.total_duration;},
+            intlen
+        );
+        const int len_average_duration = get_length_of_field(
+            "Average", [] (const auto &v) {return v.avg_duration();},
+            floatlen(2)
+        );
+        const int len_reads_per_cycle = get_length_of_field(
+            "Reads / cycle",
+            [] (const auto &v) {return v.bytes_read_per_cycle();},
+            floatlen(6)
+        );
+        const int len_ops_per_cycle = get_length_of_field(
+            "Ops / cycle",
+            [] (const auto &v) {return v.ops_per_cycle();},
+            floatlen(6)
+        );
+        const int len_writes_per_cycle = get_length_of_field(
+            "Writes / cycle",
+            [] (const auto &v) {return v.bytes_written_per_cycle();},
+            floatlen(6)
+        );
+
+        // Print header
+        printf(
+          "%*s    %*s    %*s    %*s    %*s    %*s    %*s\n",
+          len_label, "",
+          len_calls, "Calls",
+          len_duration, "Duration",
+          len_average_duration, "Average",
+          len_reads_per_cycle, "Reads / cycle",
+          len_ops_per_cycle, "Ops / cycle",
+          len_writes_per_cycle, "Writes / cycle"
+        );
+        for (const auto &kv : event_ids) {
+          const auto id = kv.second;
+          printf(
+            "%*s    %*ld    %*ld    %*.2f    %*.6f    %*.6f    %*.6f\n",
+            len_label, kv.first,
+            len_calls, totals[id].total_calls,
+            len_duration, totals[id].total_duration,
+            len_average_duration, totals[id].avg_duration(),
+            len_reads_per_cycle, totals[id].bytes_read_per_cycle(),
+            len_ops_per_cycle, totals[id].ops_per_cycle(),
+            len_writes_per_cycle, totals[id].bytes_written_per_cycle()
+          );
+        }
+        printf("\n");
+    }
+#endif  // CYCLE_PROFILING
+
+    template <typename T>
+    void operator() (const char * event,
+                     T func,
+                     long int bytes_read = 0,
+                     long int ops = 0,
+                     long int bytes_written = 0) {
+#ifdef CYCLE_PROFILING
+        if (currentevent==maxevents) {
+            func();
+        } else {
+            const auto countfd = thread_counter_fds.get_counter_fd();
+            start_counter(countfd);
+            func();
+            long long cycs = stop_counter(countfd);
+
+            // Store the profiling data
+            std::lock_guard<std::mutex> lock_events(event_lock);
+            events[currentevent++] = {
+              get_event_id(event), bytes_read, ops, bytes_written, cycs
+            };
+        }
+#else
+      (void) event;
+      (void) bytes_read;
+      (void) ops;
+      (void) bytes_written;
+      func();
+#endif  // CYCLE_PROFILING
+    }
+};
diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp
new file mode 100644
index 0000000..09e1457
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/shims.hpp
@@ -0,0 +1,747 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cstdint>
+#include "arm.hpp"
+
+namespace reorder {
+/** Re-order a tensor from NCHW format to NHWC.
+ *
+ * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
+ *
+ * @param[in] in Input tensor in NCHW format.
+ * @param[out] out Output tensor, to be written in NHWC format.
+ * @param n_batches Number of batches in the tensors.
+ * @param n_channels Number of channels in the tensors
+ * @param n_rows Height of the tensor
+ * @param n_cols Width of the tensor
+ * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_channels * in_channel_stride`.
+ * @param in_channel_stride Stride over channels in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
+ * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols`.
+ * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
+ * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols * out_col_stride`.
+ * @param out_col_stride Stride over columns in the output tensor. If `0` defaults to `n_channels`.
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+  const T* const in,
+  T* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride=0,
+  int in_channel_stride=0,
+  int in_row_stride=0,
+  int out_batch_stride=0,
+  int out_row_stride=0,
+  int out_col_stride=0
+);
+
+/** Re-order a tensor from NHWC format to NCHW.
+ *
+ * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
+ *
+ * @param[in] in Input tensor in NHWC format.
+ * @param[out] out Output tensor, to be written in NCHW format.
+ * @param n_batches Number of batches in the tensors.
+ * @param n_rows Height of the tensor
+ * @param n_cols Width of the tensor
+ * @param n_channels Number of channels in the tensors
+ * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
+ * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols * in_col_stride`.
+ * @param in_col_stride Stride over columns in the input tensor. If `0` defaults to `n_channels`.
+ * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_channels * out_channel_stride`.
+ * @param out_channel_stride Stride over channels in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
+ * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols`.
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+  const T* const in,  // Input data in NHWC form
+  T* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride=0,
+  int in_row_stride=0,
+  int in_col_stride=0,
+  int out_batch_stride=0,
+  int out_channel_stride=0,
+  int out_row_stride=0
+);
+
+/** Re-order a weight tensor from [Output feature map x Input feature map x
+ *  Height x Width] format to [Height x Width x Input feature map x Output
+ *  feature map] format.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+  const T* const in,  // Input in [Output x Input x Height x Width] form
+  T* const out,       // Output in [Height x Width x Input x Output] form
+  const int n_output_feature_maps,
+  const int n_input_feature_maps,
+  const int n_rows,
+  const int n_cols,
+  int in_output_feature_map_stride=0,
+  int in_input_feature_map_stride=0,
+  int in_row_stride=0,
+  int out_row_stride=0,
+  int out_col_stride=0,
+  int out_input_feature_map_stride=0
+);
+
+/** Re-order a weight tensor from [Height x Width x Input feature map x Output
+ *  feature map] format to [Output feature map x Input feature map x Height x
+ *  Width] format.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+  const T* const in,  // Input in [Height x Width x Input x Output] form
+  T* const out,       // Output in [Output x Input x Height x Width] form
+  const int n_rows,
+  const int n_cols,
+  const int n_input_feature_maps,
+  const int n_output_feature_maps,
+  int in_row_stride=0,
+  int in_col_stride=0,
+  int in_input_feature_map_stride=0,
+  int out_output_feature_map_stride=0,
+  int out_input_feature_map_stride=0,
+  int out_row_stride=0
+);
+
+/*****************************************************************************/
+/* 32-bit implementation : NCHW -> NHWC
+ */
+template <>
+inline void nchw_to_nhwc(
+  const int32_t* const in,
+  int32_t* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  typedef int32_t T;
+
+  // Fill in the stride values
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+  in_channel_stride = (in_channel_stride) ? in_channel_stride
+                                          : n_rows * in_row_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_channels * in_channel_stride;
+
+  out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_rows * out_row_stride;
+
+  // Perform the re-ordering
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_row = in_batch + i*in_row_stride;
+      T* const out_row = out_batch + i*out_row_stride;
+
+      int j = 0, j_remaining = n_cols;
+#ifdef __arm_any__
+      for (; j_remaining >= 4; j += 4, j_remaining -= 4)
+      {
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 4; c += 4, c_remaining -= 4)
+        {
+          // Read 4 channels worth of 4 columns, then zip to produce 4 columns
+          // worth of 4 channels.
+          int32x4_t channel_pixels[4];
+          channel_pixels[0] = vld1q_s32(in_row + (c + 0)*in_channel_stride + j);
+          channel_pixels[1] = vld1q_s32(in_row + (c + 1)*in_channel_stride + j);
+          channel_pixels[2] = vld1q_s32(in_row + (c + 2)*in_channel_stride + j);
+          channel_pixels[3] = vld1q_s32(in_row + (c + 3)*in_channel_stride + j);
+
+          const auto zip1 = vzipq_s32(channel_pixels[0], channel_pixels[2]);
+          const auto zip2 = vzipq_s32(channel_pixels[1], channel_pixels[3]);
+          const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
+          const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
+
+          vst1q_s32(out_row + (j + 0)*out_col_stride + c, out_0.val[0]);
+          vst1q_s32(out_row + (j + 1)*out_col_stride + c, out_0.val[1]);
+          vst1q_s32(out_row + (j + 2)*out_col_stride + c, out_1.val[0]);
+          vst1q_s32(out_row + (j + 3)*out_col_stride + c, out_1.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 4; _j++)
+          {
+            const T* const in_col = in_row + j + _j;
+            T* const out_col = out_row + (j + _j)*out_col_stride;
+            const T* const in_channel = in_col + c*in_channel_stride;
+            out_col[c] = *(in_channel);
+          }
+        }
+      }
+      for (; j_remaining >= 2; j += 2, j_remaining -= 2)
+      {
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 2; c += 2, c_remaining -= 2)
+        {
+          // Read 2 channels worth of 2 columns, then zip to produce 2 columns
+          // worth of 2 channels.
+          int32x2_t channel_pixels[2];
+          channel_pixels[0] = vld1_s32(in_row + (c + 0)*in_channel_stride + j);
+          channel_pixels[1] = vld1_s32(in_row + (c + 1)*in_channel_stride + j);
+
+          const auto output = vzip_s32(channel_pixels[0], channel_pixels[1]);
+
+          vst1_s32(out_row + (j + 0)*out_col_stride + c, output.val[0]);
+          vst1_s32(out_row + (j + 1)*out_col_stride + c, output.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 2; _j++)
+          {
+            const T* const in_col = in_row + j + _j;
+            T* const out_col = out_row + (j + _j)*out_col_stride;
+            const T* const in_channel = in_col + c*in_channel_stride;
+            out_col[c] = *(in_channel);
+          }
+        }
+      }
+#endif  // __arm_any__
+      for (; j_remaining; j++, j_remaining--)
+      {
+        const T* const in_col = in_row + j;
+        T* const out_col = out_row + j*out_col_stride;
+
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_col + c*in_channel_stride;
+          out_col[c] = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+template <>
+inline void nchw_to_nhwc(
+  const uint32_t* const in,
+  uint32_t* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  nchw_to_nhwc(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_channels, n_rows, n_cols,
+    in_batch_stride, in_channel_stride, in_row_stride,
+    out_batch_stride, out_row_stride, out_col_stride
+  );
+}
+
+template <>
+inline void nchw_to_nhwc(
+  const float* const in,
+  float* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  nchw_to_nhwc(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_channels, n_rows, n_cols,
+    in_batch_stride, in_channel_stride, in_row_stride,
+    out_batch_stride, out_row_stride, out_col_stride
+  );
+}
+
+/*****************************************************************************/
+/* Generic implementation : NCHW -> NHWC
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+  const T* const in,
+  T* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  // Fill in the stride values
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+  in_channel_stride = (in_channel_stride) ? in_channel_stride
+                                          : n_rows * in_row_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_channels * in_channel_stride;
+
+  out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_rows * out_row_stride;
+
+  // Perform the re-ordering
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_row = in_batch + i*in_row_stride;
+      T* const out_row = out_batch + i*out_row_stride;
+
+      for (int j = 0; j < n_cols; j++)
+      {
+        const T* const in_col = in_row + j;
+        T* const out_col = out_row + j*out_col_stride;
+
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_col + c*in_channel_stride;
+          out_col[c] = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+/*****************************************************************************/
+/* 32-bit implementation : NHWC -> NCHW
+ */
+template <>
+inline void nhwc_to_nchw(
+  const int32_t* const in,  // Input data in NHWC form
+  int32_t* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  typedef int32_t T;
+
+  // Fill in stride values
+  in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_rows * in_row_stride;
+
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+  out_channel_stride = (out_channel_stride) ? out_channel_stride
+                                            : n_rows * out_row_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_channels * out_channel_stride;
+
+  // Perform the re-ordering
+  // For every batch
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    // For every row
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_i = in_batch + i*in_row_stride;
+      T* const out_i = out_batch + i*out_row_stride;
+
+      // For every column, beginning with chunks of 4
+      int j = 0, j_remaining = n_cols;
+#ifdef __arm_any__
+      for (; j_remaining >= 4; j += 4, j_remaining -=4)
+      {
+        // For every channel, beginning with chunks of 4
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 4; c += 4, c_remaining -= 4)
+        {
+          // Read 4 columns worth of 4 channels then zip to produce 4 channels
+          // worth of 4 columns.
+          int32x4_t pixel_channels[4];
+          pixel_channels[0] = vld1q_s32(in_i + (j + 0)*in_col_stride + c);
+          pixel_channels[1] = vld1q_s32(in_i + (j + 1)*in_col_stride + c);
+          pixel_channels[2] = vld1q_s32(in_i + (j + 2)*in_col_stride + c);
+          pixel_channels[3] = vld1q_s32(in_i + (j + 3)*in_col_stride + c);
+
+          const auto zip1 = vzipq_s32(pixel_channels[0], pixel_channels[2]);
+          const auto zip2 = vzipq_s32(pixel_channels[1], pixel_channels[3]);
+          const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
+          const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
+
+          vst1q_s32(out_i + j + (c + 0)*out_channel_stride, out_0.val[0]);
+          vst1q_s32(out_i + j + (c + 1)*out_channel_stride, out_0.val[1]);
+          vst1q_s32(out_i + j + (c + 2)*out_channel_stride, out_1.val[0]);
+          vst1q_s32(out_i + j + (c + 3)*out_channel_stride, out_1.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 4; _j++)
+          {
+            const T* const in_j = in_i + (j + _j)*in_col_stride;
+            T* const out_j = out_i + (j + _j);
+
+            const T* const in_channel = in_j + c;
+            T* const out_channel = out_j + c*out_channel_stride;
+            *(out_channel) = *(in_channel);
+          }
+        }
+      }
+      for (; j_remaining >= 2; j += 2, j_remaining -=2)
+      {
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 2; c += 2, c_remaining -= 2)
+        {
+          // Read 2 columns worth of 2 channels then zip to produce 2 channels
+          // worth of 2 columns.
+          int32x2_t pixel_channels[2];
+          pixel_channels[0] = vld1_s32(in_i + (j + 0)*in_col_stride + c);
+          pixel_channels[1] = vld1_s32(in_i + (j + 1)*in_col_stride + c);
+
+          const auto output = vzip_s32(pixel_channels[0], pixel_channels[1]);
+
+          vst1_s32(out_i + j + (c + 0)*out_channel_stride, output.val[0]);
+          vst1_s32(out_i + j + (c + 1)*out_channel_stride, output.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 2; _j++)
+          {
+            const T* const in_j = in_i + (j + _j)*in_col_stride;
+            T* const out_j = out_i + (j + _j);
+
+            const T* const in_channel = in_j + c;
+            T* const out_channel = out_j + c*out_channel_stride;
+            *(out_channel) = *(in_channel);
+          }
+        }
+      }
+#endif  // __arm_any__
+      for (; j_remaining; j++, j_remaining--)
+      {
+        const T* const in_j = in_i + j*in_col_stride;
+        T* const out_j = out_i + j;
+
+        // For every channel
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_j + c;
+          T* const out_channel = out_j + c*out_channel_stride;
+          *(out_channel) = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+template <>
+inline void nhwc_to_nchw(
+  const uint32_t* const in,  // Input data in NHWC form
+  uint32_t* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  // Redirect to generic 32-bit implementation
+  nhwc_to_nchw(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_rows, n_cols, n_channels,
+    in_batch_stride, in_row_stride, in_col_stride,
+    out_batch_stride, out_channel_stride, out_row_stride
+  );
+}
+
+template <>
+inline void nhwc_to_nchw(
+  const float* const in,  // Input data in NHWC form
+  float* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  // Redirect to generic 32-bit implementation
+  nhwc_to_nchw(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_rows, n_cols, n_channels,
+    in_batch_stride, in_row_stride, in_col_stride,
+    out_batch_stride, out_channel_stride, out_row_stride
+  );
+}
+
+/*****************************************************************************/
+/* Generic implementation : NHWC -> NCHW
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+  const T* const in,  // Input data in NHWC form
+  T* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  // Fill in stride values
+  in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_rows * in_row_stride;
+
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+  out_channel_stride = (out_channel_stride) ? out_channel_stride
+                                            : n_rows * out_row_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_channels * out_channel_stride;
+
+  // Perform the re-ordering
+  // For every batch
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    // For every row
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_i = in_batch + i*in_row_stride;
+      T* const out_i = out_batch + i*out_row_stride;
+
+      // For every column
+      for (int j = 0; j < n_cols; j++)
+      {
+        const T* const in_j = in_i + j*in_col_stride;
+        T* const out_j = out_i + j;
+
+        // For every channel
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_j + c;
+          T* const out_channel = out_j + c*out_channel_stride;
+          *(out_channel) = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+  const T* const in,  // Input in [Output x Input x Height x Width] form
+  T* const out,       // Output in [Height x Width x Input x Output] form
+  const int n_output_feature_maps,
+  const int n_input_feature_maps,
+  const int n_rows,
+  const int n_cols,
+  int in_output_feature_map_stride,
+  int in_input_feature_map_stride,
+  int in_row_stride,
+  int out_row_stride,
+  int out_col_stride,
+  int out_input_feature_map_stride
+)
+{
+  // Fill in stride values
+  in_row_stride = (in_row_stride)
+    ? in_row_stride
+    : n_cols;
+  in_input_feature_map_stride = (in_input_feature_map_stride)
+    ? in_input_feature_map_stride
+    : n_rows * in_row_stride;
+  in_output_feature_map_stride = (in_output_feature_map_stride)
+    ? in_output_feature_map_stride
+    : n_input_feature_maps * in_input_feature_map_stride;
+
+  out_input_feature_map_stride = (out_input_feature_map_stride)
+    ? out_input_feature_map_stride
+    : n_output_feature_maps;
+  out_col_stride = (out_col_stride)
+    ? out_col_stride
+    : n_input_feature_maps * out_input_feature_map_stride;
+  out_row_stride = (out_row_stride)
+    ? out_row_stride
+    : n_cols * out_col_stride;
+
+  // Perform the re-ordering
+  for (int i = 0; i < n_rows; i++)
+  {
+    const T* const in_row = in + i * in_row_stride;
+    T* out_row = out + i * out_row_stride;
+
+    for (int j = 0; j < n_cols; j++)
+    {
+      const T* const in_col = in_row + j;
+      T* const out_col = out_row + j * out_col_stride;
+
+      for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+      {
+        const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+        T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+        for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+        {
+          const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride;
+          T* const out_ofm = out_ifm + ofm;
+          *(out_ofm) = *(in_ofm);
+        }
+      }
+    }
+  }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+  const T* const in,  // Input in [Height x Width x Input x Output] form
+  T* const out,       // Output in [Output x Input x Height x Width] form
+  const int n_rows,
+  const int n_cols,
+  const int n_input_feature_maps,
+  const int n_output_feature_maps,
+  int in_row_stride,
+  int in_col_stride,
+  int in_input_feature_map_stride,
+  int out_output_feature_map_stride,
+  int out_input_feature_map_stride,
+  int out_row_stride
+)
+{
+  // Fill in the stride values
+  in_input_feature_map_stride = (in_input_feature_map_stride)
+    ? in_input_feature_map_stride
+    : n_output_feature_maps;
+  in_col_stride = (in_col_stride)
+    ? in_col_stride
+    : n_input_feature_maps * in_input_feature_map_stride;
+  in_row_stride = (in_row_stride)
+    ? in_row_stride
+    : n_cols * in_col_stride;
+
+  out_row_stride = (out_row_stride)
+    ? out_row_stride
+    : n_cols;
+  out_input_feature_map_stride = (out_input_feature_map_stride)
+    ? out_input_feature_map_stride
+    : n_rows * out_row_stride;
+  out_output_feature_map_stride = (out_output_feature_map_stride)
+    ? out_output_feature_map_stride
+    : n_input_feature_maps * out_input_feature_map_stride;
+
+  // Perform the re-ordering
+  for (int i = 0; i < n_rows; i++)
+  {
+    const T* const in_row = in + i * in_row_stride;
+    T* const out_row = out + i * out_row_stride;
+
+    for (int j = 0; j < n_cols; j++)
+    {
+      const T* const in_col = in_row + j * in_col_stride;
+      T* const out_col = out_row + j;
+
+      for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+      {
+        const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+        T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+        for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+        {
+          const T* const in_ofm = in_ifm + ofm;
+          T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride;
+          *(out_ofm) = *(in_ofm);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace reorder
diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/winograd/tensor.hpp
index 70ef65d..6567eeb 100644
--- a/arm_compute/core/NEON/kernels/winograd/tensor.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/tensor.hpp
@@ -23,39 +23,44 @@
  */
 
 #pragma once
-#include <cstdio>
 #include <cstdlib>
 #include <random>
 
 #include "alloc.hpp"
 
-/*****************************************************************************/
-/* Padding definitions */
-enum PaddingType {
-  PADDING_SAME, PADDING_VALID
+enum TensorOrder
+{
+  NHWC,  ///< [Batch x Height x Width x Channels]
+  NCHW,  ///< [Batch x Channels x Height x Width]
 };
 
-/*****************************************************************************/
-/* Shape of a kernel */
-struct KernelShape {
-  int n_output_channels, n_rows, n_cols, n_input_channels;
+struct Tensor4DShape
+{
+  int n_batches, n_rows, n_cols, n_channels;
+  TensorOrder ordering;
 
-  int size(void) const {
-    return n_output_channels * n_rows * n_cols * n_input_channels;
+  // Create a new tensor with the default (NHWC) ordering
+  inline Tensor4DShape(
+    const int n_batches,
+    const int n_rows,
+    const int n_cols,
+    const int n_channels,
+    const TensorOrder ordering=NHWC
+  ) : n_batches(n_batches),
+      n_rows(n_rows),
+      n_cols(n_cols),
+      n_channels(n_channels),
+      ordering(ordering)
+  {
   }
-};
 
-struct Tensor4DShape {
-  int n_batches,
-      n_rows,
-      n_cols,
-      n_channels;
-
-  int size() const {
+  inline int size() const
+  {
     return n_batches * n_rows * n_cols * n_channels;
   }
 
-  bool TestEq(const Tensor4DShape& other) const {
+  inline bool TestEq(const Tensor4DShape& other) const
+  {
     return (n_batches == other.n_batches &&
             n_rows == other.n_rows &&
             n_cols == other.n_cols &&
@@ -63,148 +68,110 @@
   }
 };
 
+
+enum WeightOrder
+{
+  HWIO,  ///< [Height x Width x Input channels x Output channels]
+  OIHW,  ///< [Output channels x Input channels x Height x Width]
+};
+
+struct KernelShape
+{
+  int n_output_channels, n_rows, n_cols, n_input_channels;
+  WeightOrder ordering;
+
+  inline KernelShape(
+    const int n_output_channels,
+    const int n_rows,
+    const int n_cols,
+    const int n_input_channels,
+    const WeightOrder ordering=HWIO
+  ) : n_output_channels(n_output_channels),
+      n_rows(n_rows),
+      n_cols(n_cols),
+      n_input_channels(n_input_channels),
+      ordering(ordering)
+  {
+  }
+
+  inline int size(void) const
+  {
+    return n_output_channels * n_rows * n_cols * n_input_channels;
+  }
+};
+
+
 template <typename ShapeT, typename T>
-class Tensor4D final {
+class Tensor4D final
+{
   public:
     Tensor4D(ShapeT shape) :
-      _shape(shape),
-      _data(reinterpret_cast<T*>(ALLOCATE(size_bytes()))) {
+      shape(shape),
+      _data(reinterpret_cast<T*>(ALLOCATE(size_bytes())))
+    {
         Clear();
     }
 
+    Tensor4D(const Tensor4D<ShapeT, T>&) = delete;
+    Tensor4D operator=(const Tensor4D<ShapeT, T>&) = delete;
+
     ~Tensor4D() {
       free(_data);
     }
 
-    T* ptr() const {
+    inline T* ptr() const {
       return _data;
     }
 
-    const ShapeT& shape() const {
-      return _shape;
+    inline size_t size_bytes() const {
+      return shape.size() * sizeof(T);
     }
 
-    size_t size_bytes() const {
-      return _shape.size() * sizeof(T);
-    }
+    inline T& element(int, int, int, int) const;
 
-    bool TestEq(Tensor4D<ShapeT, T>& other) const;
-    T& element(int, int, int, int) const;
-    void Print() const;
-
-    void Clear() {
+    inline void Clear() {
       Fill(static_cast<T>(0));
     }
 
-    void Fill(T val) {
-      for (int i = 0; i < _shape.size(); i++)
+    inline void Fill(T val) {
+      for (int i = 0; i < shape.size(); i++)
         _data[i] = val;
     }
 
-    void TestPattern() {
-      for (int i = 0; i < _shape.size(); i++)
-        _data[i] = static_cast<T>(i);
-    }
-
-    void Rand(const int seed=2311) {
-      std::mt19937 gen(seed);
-      std::uniform_int_distribution<> dis(-50, +50);
-
-      for (int i = 0; i < _shape.size(); i++) {
-        _data[i] = static_cast<T>(dis(gen));
-      }
-    }
-    Tensor4D(const Tensor4D &) = delete;
-    /** Prevent instances of this class from being copied (As this class contains pointers) */
-    Tensor4D &operator=(const Tensor4D &) = delete;
-    /** Allow instances of this class to be moved */
-    Tensor4D(Tensor4D &&) = default;
-    /** Allow instances of this class to be moved */
-    Tensor4D &operator=(Tensor4D &&) = default;
-
+    const ShapeT shape;
 
   private:
-    const ShapeT _shape;
     T* const _data;
 };
 
 
 template <>
-inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const {
-  int index = ((n*_shape.n_rows + i)*_shape.n_cols + j)*_shape.n_channels + c;
+inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const
+{
+  int index;
+  if (shape.ordering == NHWC)
+  {
+    index = ((n*shape.n_rows + i)*shape.n_cols + j)*shape.n_channels + c;
+  }
+  else  // NCHW
+  {
+    index = ((n*shape.n_channels + c)*shape.n_rows + i)*shape.n_cols + j;
+  }
   return _data[index];
 }
 
 
 template <>
-inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const {
-  int index = ((i*_shape.n_cols + j)*_shape.n_input_channels + ic)*_shape.n_output_channels + oc;
+inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const
+{
+  int index;
+  if (shape.ordering == HWIO)
+  {
+    index = ((i*shape.n_cols + j)*shape.n_input_channels + ic)*shape.n_output_channels + oc;
+  }
+  else  // OIHW
+  {
+    index = ((oc*shape.n_input_channels + ic)*shape.n_rows + i)*shape.n_cols + j;
+  }
   return _data[index];
 }
-
-template <>
-inline bool Tensor4D<Tensor4DShape, float>::TestEq(Tensor4D<Tensor4DShape, float>& other) const {
-  // Test equivalence, printing errors
-  // First test the shapes are the same
-  if (!_shape.TestEq(other.shape())) {
-    printf("Tensors have different shapes.\n");
-    return false;
-  } else {
-    int incorrects = 0;
-
-    for (int n = 0; n < _shape.n_batches; n++) {
-      for (int i = 0; i < _shape.n_rows; i++) {
-        for (int j = 0; j < _shape.n_cols; j++) {
-          for (int c = 0; c < _shape.n_channels; c++) {
-            // Check elements for equivalence
-            const auto a = this->element(n, i, j, c);
-            const auto b = other.element(n, i, j, c);
-
-            if (a != b) {
-              printf("Difference at element {%d, %d, %d, %d}: %.3f != %.3f\n", n, i, j, c, a, b);
-
-              if (++incorrects > 100) {
-                printf("More than 100 incorrect values, stopping test.\n");
-                return false;
-              }
-            }
-          }
-        }
-      }
-    }
-
-    return incorrects == 0;
-  }
-}
-
-
-template <>
-inline void Tensor4D<Tensor4DShape, float>::Print() const {
-  for (int n = 0; n < _shape.n_batches; n++) {
-    for (int c = 0; c < _shape.n_channels; c++) {
-      for (int i = 0; i < _shape.n_rows; i++) {
-        for (int j = 0; j < _shape.n_cols; j++) {
-          printf("%5.2f ", element(n, i, j, c));
-        }
-        printf("\n");
-      }
-      printf("\n");
-    }
-  }
-}
-
-
-template <>
-inline void Tensor4D<KernelShape, float>::Print() const {
-  for (int oc = 0; oc < _shape.n_output_channels; oc++) {
-    for (int ic = 0; ic < _shape.n_input_channels; ic++) {
-      for (int i = 0; i < _shape.n_rows; i++) {
-        for (int j = 0; j < _shape.n_cols; j++) {
-          printf("%5.2f ", element(oc, i, j, ic));
-        }
-        printf("\n");
-      }
-      printf("\n");
-    }
-  }
-}
diff --git a/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp
new file mode 100644
index 0000000..68a5c6a
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "tensor.hpp"
+
+// Methods to print tensors and weights
+void PrintTensor(const Tensor4D<Tensor4DShape, float>& tensor);
+void PrintWeights(const Tensor4D<KernelShape, float>& weights);
+
+// Test the equivalence of two tensors
+bool CmpTensors(const Tensor4D<Tensor4DShape, float>& a,
+                const Tensor4D<Tensor4DShape, float>& b,
+                const float max_delta=0.0f);
+
+// Fill the tensor with a test pattern
+void TestPattern(Tensor4D<Tensor4DShape, float>& tensor);
+void TestPattern(Tensor4D<KernelShape, float>& weights);
+
+// Fill the tensor with random values
+void Randomise(Tensor4D<Tensor4DShape, float>& tensor, const int seed=0);
+void Randomise(Tensor4D<KernelShape, float>& weights, const int seed=0);
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp
new file mode 100644
index 0000000..39b4441
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "../winograd_gemm.hpp"
+
+namespace winograd
+{
+  /***************************************************************************/
+  /* Instance-less API */
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::execute(
+    const T *inptr,
+    const Tensor4DShape& input_shape,
+    const PaddingType padding_type,
+    const int tile_M,
+    const int tile_N,
+    T *outptr_base,
+    const int matrix_stride,
+    const int matrix_batch_stride,
+    const int matrix_row_stride
+  )
+  {
+    // Compute the padding required on each edge of the image
+    const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0;
+    const int pad_top = base_padding;
+    const int pad_left = base_padding;
+    const int tile_overlap = kernel_rows - 1;
+
+    // Compute striding values (assuming NHWC ordered data)
+    const int input_col_stride = input_shape.n_channels;
+    const int input_row_stride = input_shape.n_cols * input_col_stride;
+    const int input_batch_stride = input_shape.n_rows * input_row_stride;
+    const int output_col_stride = matrix_row_stride;
+    const int output_row_stride = tile_N * output_col_stride;
+
+    // Loop over batches
+    for (int batch = 0; batch < input_shape.n_batches; batch++)
+    {
+      // Pointer to the batch
+      const T* const input_base_batch = inptr + batch * input_batch_stride;
+      T* const outptr_base_batch = outptr_base + batch * matrix_batch_stride;
+
+      // Loop over rows of tiles
+      for (int tile_i = 0; tile_i < tile_M; tile_i++)
+      {
+        // Pointer to the row
+        const int row_offset = (tile_i == 0) ?
+          0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+        const T* const input_base_row = (
+          input_base_batch + ((inner_tile_rows - 2)*tile_i - row_offset)*input_row_stride
+        );
+        T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride;
+
+        // Padding (top + bottom) for the row
+        const int row_top = tile_i*(inner_tile_rows - tile_overlap) - pad_top;
+        const int row_bottom = row_top + inner_tile_rows;
+        const int row_pad_top = (tile_i == 0) ? pad_top : 0;
+        const int row_pad_bottom = (row_bottom <= input_shape.n_rows) ? 0 : row_bottom - input_shape.n_rows;
+
+        // Process the row
+        process_tile_row(
+          tile_N, input_shape.n_channels,
+          input_base_row, input_row_stride, input_col_stride,
+          outptr_base_row, matrix_stride, matrix_row_stride,
+          row_pad_top, pad_left, row_pad_bottom, input_shape.n_cols
+        );
+      }
+    }
+  }
+
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::process_tile_row(
+    const int tile_N,
+    int n_channels,
+    const T* const input_base,
+    const int input_row_stride,
+    const int input_col_stride,
+    T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    const int pad_top,
+    const int row_pad_left,
+    const int pad_bottom,
+    const int n_cols
+  )
+  {
+    constexpr int tile_overlap = kernel_cols - 1;
+
+    // Loop over columns of tiles
+    for (int tile_j = 0; tile_j < tile_N; tile_j++)
+    {
+      // Padding (left + right) for the tile
+      const int t_pad_left = (tile_j == 0) ? row_pad_left : 0;
+      const int t_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_left;
+      const int t_end = t_start + inner_tile_cols;
+      const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols;
+
+      // Get pointers into the inputs and outputs
+      const int col_offset = (tile_j == 0) ? 0 : row_pad_left;
+      const T* const input_base_col = (
+        input_base + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*input_col_stride
+      );
+      T* const outptr = matrix_base + tile_j*matrix_row_stride;
+
+      // Apply the specific tile processing function
+      tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right](
+        n_channels,
+        input_base_col,
+        input_row_stride,
+        input_col_stride,
+        outptr,
+        matrix_stride
+      );
+    }
+  }
+
+  /***************************************************************************/
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::InputTransform(
+    const T* const input,        /** Input tensor data */
+    const int n_batches,         /** Number of batches in input tensor. */
+    const int n_rows,            /** Number of rows in input tensor. */
+    const int n_cols,            /** Number of columns in input tensor. */
+    const int n_channels,        /** Number of channels in input tensor. */
+    const PaddingType padding,   /** Padding type. */
+    T* const output,             /** Base of output matrices. */
+    const int matrix_stride,     /** Stride between output matrices. */
+    const int matrix_row_stride  /** Stride within matrices. */
+  ) : _inptr(input), _outptr(output),
+      _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
+      _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+      _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)),
+      _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)),
+      _padding_type(padding)
+  {
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  unsigned int WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::get_window() const
+  {
+    // TODO When the input transform supports multithreading, return the total
+    // number of tile rows (allowing for multiple batches). For now we return 1
+    // to indicate that the activations must be transformed as a single block.
+    return 1;  // TODO _tiles_M * _n_batches;
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  void WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::run(
+    const unsigned int start, const unsigned int stop
+  )
+  {
+    // TODO When the input transform supports multithreading call execute for a
+    // portion of the tile rows.
+    (void) start;
+    (void) stop;
+
+    // For now, just do all of the work.
+    const Tensor4DShape input_shape = {
+      _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+    };
+    execute(
+      _inptr, input_shape, _padding_type, _tiles_M, _tiles_N, _outptr,
+      _matrix_stride, _matrix_row_stride * _tiles_M * _tiles_N, _matrix_row_stride
+    );
+  }
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp
new file mode 100644
index 0000000..4b54dfd
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "winograd_gemm.hpp"
+using namespace winograd;
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::WeightsTransform(
+  const T* const input,
+  T* const output,
+  const int matrix_stride,      /** Stride across matrices in the output. */
+  const int matrix_row_stride,  /** Stride across rows of the matrix. */
+  const int n_output_channels,
+  const int n_input_channels
+) : inptr(input), outptr(output),
+    matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride),
+    n_output_channels(n_output_channels), n_input_channels(n_input_channels)
+{
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+unsigned int WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::get_window() const
+{
+  // TODO When the weights transform supports multithreading, return the number
+  // of output channels. For now we return 1 to indicate that the weights must
+  // be transformed as a single block.
+  // return n_output_channels;
+  return 1;
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+void WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::run(
+  const unsigned int start, const unsigned int stop
+)
+{
+  // TODO When the weights transform supports multithreading call execute for a
+  // portion of the output channels.
+  (void) start;
+  (void) stop;
+
+  // For now, just do all of the work.
+  execute(
+    n_output_channels,
+    n_input_channels,
+    inptr,
+    outptr,
+    matrix_stride,
+    matrix_row_stride
+  );
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp
new file mode 100644
index 0000000..7fa5ee9
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "../winograd_gemm.hpp"
+
+namespace winograd
+{
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::execute(
+    const Tensor4DShape &output_shape,
+    const T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    T* const output
+  )
+  {
+    // Compute the number of tiles and hence the padding required on the bottom
+    // and right of the image.
+    const int tile_M = iceildiv(output_shape.n_rows, output_tile_rows);
+    const int tile_N = iceildiv(output_shape.n_cols, output_tile_cols);
+    const int pad_bottom = output_tile_rows*tile_M - output_shape.n_rows;
+    const int pad_right = output_tile_cols*tile_N - output_shape.n_cols;
+
+    const int matrix_tile_row_stride = tile_N * matrix_row_stride;
+    const int matrix_batch_stride = tile_M * matrix_tile_row_stride;
+    const int output_col_stride = output_shape.n_channels;
+    const int output_row_stride = output_shape.n_cols * output_col_stride;
+    const int output_batch_stride = output_shape.n_rows * output_row_stride;
+
+    // Perform the output transformation for each batch
+    for (int batch = 0; batch < output_shape.n_batches; batch++)
+    {
+      // Get batch offset for input and outputs.
+      const T* const matrix_batch = matrix_base + batch*matrix_batch_stride;
+      T* const outptr_batch = output + batch*output_batch_stride;
+
+      // Perform the output transformation for each row of the output tensor.
+      for (int tile_i = 0; tile_i < tile_M; tile_i++)
+      {
+        // Compute properties of this row of output tiles
+        const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom;
+        const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride;
+        T* const outptr_row = outptr_batch + output_tile_rows*tile_i*output_row_stride;
+
+        // Process the row
+        process_tile_row(
+          tile_N, output_shape.n_channels, matrix_tile_row, matrix_stride,
+          matrix_row_stride, outptr_row, output_row_stride,
+          output_col_stride, row_pad_bottom, pad_right
+        );
+      }
+    }
+  }
+
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::process_tile_row(
+    const int tile_N,
+    const int n_channels,
+    const T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    T* const output,
+    const int output_row_stride,
+    const int output_col_stride,
+    const int row_pad_bottom,
+    const int row_pad_right
+  )
+  {
+    // Loop over columns of tiles
+    for (int tile_j = 0; tile_j < tile_N; tile_j++)
+    {
+      // Properties of this tile
+      const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right;
+      const T* const matrix_row = matrix_base + tile_j * matrix_row_stride;
+      T* const outptr = output + output_tile_cols*tile_j*output_col_stride;
+
+      // Perform the output transformation
+      tile_fns[row_pad_bottom][tile_pad_right](
+        n_channels, matrix_row, matrix_stride,
+        outptr, output_row_stride, output_col_stride
+      );
+    }
+  }
+
+  template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+  template <typename T>
+  size_t WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::bytes_read(const Tensor4DShape &shape)
+  {
+    const int M = iceildiv(shape.n_rows, output_tile_rows) *
+                  iceildiv(shape.n_cols, output_tile_cols);
+    const int N = shape.n_channels;
+    return inner_tile_rows * inner_tile_cols * M * N * sizeof(T);
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  size_t WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::bytes_written(const Tensor4DShape &shape)
+  {
+    return shape.size() * sizeof(T);
+  }
+
+  template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+  template <typename T>
+  WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::OutputTransform(
+    const T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    T* const output,
+    const int n_batches,
+    const int n_rows,
+    const int n_cols,
+    const int n_channels
+  ) : _matrix_base(matrix_base), _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+      _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
+      _tile_M(iceildiv(n_rows, output_tile_rows)), _tile_N(iceildiv(n_cols, output_tile_cols))
+  {
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  unsigned int WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::get_window() const
+  {
+    // TODO When the output transform supports multithreading, return the total
+    // number of tile rows (allowing for multiple batches). For now we return 1
+    // to indicate that the activations must be transformed as a single block.
+    return 1;  // TODO _tile_M * _n_batches;
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  void WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::run(
+    const unsigned int start, const unsigned int stop
+  )
+  {
+    // TODO When the output transform supports multithreading call execute for a
+    // portion of the tile rows.
+    (void) start;
+    (void) stop;
+
+    // For now, just do all of the work.
+    const Tensor4DShape output_shape = {
+      _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+    };
+    execute(
+      output_shape, _matrix_base, _matrix_stride, _matrix_row_stride, _outptr
+    );
+  }
+}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp
new file mode 100644
index 0000000..d8b9c3b
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/utils.hpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+double TimeInUs(void);
+void PrintMatrix(const float* const m, const int M, const int N, const int row_stride);
+
+inline int iceildiv(const int a, const int b) {
+  return (a + b - 1) / b;
+}
+
+template <typename T>
+inline T roundup(const T a, const T b) {
+  return a + b - (a % b);
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
new file mode 100644
index 0000000..adca48a
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
@@ -0,0 +1,441 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "alloc.hpp"
+#include "convolution.hpp"
+#include "gemm.hpp"
+#include "profiler.hpp"
+#include "shims.hpp"
+#include "tensor.hpp"
+#include "utils.hpp"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+// Generic Winograd implementation using GEMM
+namespace winograd
+{
+
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class WinogradGEMM
+{
+  public:
+    // Information about the specific Winograd instance
+    static constexpr int output_tile_rows = OutputTileRows;
+    static constexpr int output_tile_cols = OutputTileCols;
+    static constexpr int kernel_rows = KernelRows;
+    static constexpr int kernel_cols = KernelCols;
+    static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;  // TODO Check
+    static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;  // TODO Check
+    static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
+
+    /** Transform weights from the spatial to the Winograd domain. */
+    template <typename T>
+    struct WeightsTransform
+    {
+      /** Get the bytes read during the transform. */
+      static inline size_t bytes_read(const KernelShape &shape)
+      {
+        return shape.size() * sizeof(T);
+      }
+
+      /** Get the bytes written during the transform. */
+      static inline size_t bytes_written(const KernelShape &shape)
+      {
+        const int inner_tile_size = inner_tile_rows * inner_tile_cols;
+        return (inner_tile_size * shape.n_input_channels *
+                shape.n_output_channels * sizeof(T));
+      }
+
+      /** Get the count of operations performed by the transform. */
+      static int ops_performed(const KernelShape &shape);
+
+      /** Apply the transform to a tensor. */
+      static void execute(
+        const int n_output_channels,
+        const int n_input_channels,
+        const T* const input,
+        T* const output,
+        const int matrix_stride,
+        const int matrix_row_stride
+      );
+
+      /** Create a WeightsTransform operator fixed on a given problem and set
+       * of pointers.
+       */
+      WeightsTransform(
+        const T* const input,
+        T* const output,
+        const int matrix_stride,       /** Stride across matrices in the output. */
+        const int matrix_row_stride,   /** Stride across rows of the matrix. */
+        const int n_output_channels,   /** Number of filters. */
+        const int n_input_channels     /** Number of channels in each filter. */
+      );
+
+      /** Get the window of work a given operator can perform. */
+      unsigned int get_window() const;
+
+      /** Perform work upon a window of the input. */
+      void run(const unsigned int start, const unsigned int stop);
+
+      private:
+        const T* const inptr;         /** Fixed pointer to input data. */
+        T* const outptr;              /** Fixed pointer to output memory. */
+        const int matrix_stride;      /** Stride between output matrices. */
+        const int matrix_row_stride;  /** Stride within output matrices. */
+        const int n_output_channels;  /** Number of filters. */
+        const int n_input_channels;   /** Number of channels in each filter. */
+    };
+
+    /** Transform input feature maps from the spatial to the Winograd domain.
+     */
+    template <typename T>
+    struct InputTransform
+    {
+      /** Get the bytes read during the transform. */
+      static size_t bytes_read(const Tensor4DShape &shape)
+      {
+        return shape.size() * sizeof(T);
+      }
+
+      /** Get the bytes written during the transform. */
+      static size_t bytes_written(const Tensor4DShape &shape)
+      {
+        const int M = iceildiv(shape.n_rows, inner_tile_rows) *
+                      iceildiv(shape.n_cols, inner_tile_cols);
+        const int K = shape.n_channels;
+        return inner_tile_rows * inner_tile_cols * M * K * sizeof(T);
+      }
+
+      /** Get the count of operations performed by the transform. */
+      static int ops_performed(const Tensor4DShape &shape);
+
+      /** Apply the transform to a tensor. */
+      static void execute(
+          const T *inptr,
+          const Tensor4DShape& input_shape,
+          const PaddingType padding_type,
+          const int tile_M,
+          const int tile_N,
+          T *outptr_base,
+          const int matrix_stride,
+          const int matrix_batch_stride,
+          const int matrix_row_stride
+      );
+
+      /***********************************************************************/
+      /** Create an InputTransform operator fixed on a given problem and set of
+       * pointers.
+       */
+      InputTransform(
+          const T* const input,        /** Input tensor data */
+          const int n_batches,         /** Number of batches in input tensor. */
+          const int n_rows,            /** Number of rows in input tensor. */
+          const int n_cols,            /** Number of columns in input tensor. */
+          const int n_channels,        /** Number of channels in input tensor. */
+          const PaddingType padding,   /** Padding type. */
+          T* const output,             /** Base of output matrices. */
+          const int matrix_stride,     /** Stride between output matrices. */
+          const int matrix_row_stride  /** Stride within matrices. */
+      );
+
+      /** Get the winodw of work a given operator can perform. */
+      unsigned int get_window() const;
+
+      /** Perform work upon a window of the input. */
+      void run(const unsigned int start, const unsigned int stop);
+      /***********************************************************************/
+
+      private:
+        static void process_tile_row(
+          const int tile_N,
+          int n_channels,
+          const T* const input_base,
+          const int input_row_stride,
+          const int input_col_stride,
+          T* const matrix_base,
+          const int matrix_stride,
+          const int matrix_row_stride,
+          const int row_pad_top,
+          const int row_pad_left,
+          const int row_pad_bottom,
+          const int row_pad_right
+        );
+
+        static constexpr int max_pad_bottom = inner_tile_rows - 1;
+        static constexpr int max_pad_right = inner_tile_cols - 1;
+
+        /** Process a single tile of the input tensor. */
+        template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+        static void process_tile(int, const T*, int, int, T*, int);
+
+        // Array of methods to transform tiles of the input tensor.
+        typedef void (*TileFn)(int, const T*, int, int, T*, int);
+        static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right];
+
+        /* Member values for instance-based API. */
+        const T* const _inptr;
+        T* const _outptr;
+        const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride,
+                  _matrix_row_stride, _tiles_M, _tiles_N;
+        const PaddingType _padding_type;
+    };
+
+    /** Transform output feature maps from the Winograd to the spatial domain.
+     */
+    template <typename T>
+    struct OutputTransform
+    {
+      /** Get the bytes read during the transform. */
+      static size_t bytes_read(const Tensor4DShape &shape);
+
+      /** Get the bytes written during the transform. */
+      static size_t bytes_written(const Tensor4DShape &shape);
+
+      /** Get the count of operations performed by the transform. */
+      static int ops_performed(const Tensor4DShape &shape);
+
+      /** Apply the transform to create a tensor. */
+      static void execute(
+        const Tensor4DShape &output_shape,
+        const T* const matrix_base,
+        const int matrix_stride,
+        const int matrix_row_stride,
+        T* const output
+      );
+
+      /***********************************************************************/
+      /** Create an OutputTransform operator fixed on a given problem and set
+       * of pointers.
+       */
+      OutputTransform(
+        const T* const matrix_base,   /** Pointer to base of matrices. */
+        const int matrix_stride,      /** Stride between matrices. */
+        const int matrix_row_stride,  /** Stride within a matrix. */
+        T* const output,              /** Pointer to output tensor. */
+        const int n_batches,          /** Number of batches in output tensor. */
+        const int n_rows,             /** Number of rows in output tensor. */
+        const int n_cols,             /** Number of columns in output tensor. */
+        const int n_channels          /** Number of channels in output tensor. */
+      );
+
+      /** Get the window of work a given operator can perform. */
+      unsigned int get_window() const;
+
+      /** Perform work upon a window of the input. */
+      void run(const unsigned int start, const unsigned int stop);
+      /***********************************************************************/
+
+      private:
+        static void process_tile_row(
+          const int tile_N,
+          const int n_channels,
+          const T* const matrix_base,
+          const int matrix_stride,
+          const int matrix_row_stride,
+          T* const output,
+          const int output_row_stride,
+          const int output_col_stride,
+          const int row_pad_bottom,
+          const int row_pad_right
+        );
+
+        // Limits on the amount of anti-padding to be applied
+        static constexpr int max_pad_bottom = output_tile_rows;
+        static constexpr int max_pad_right = output_tile_cols;
+
+        /** Prepare a single tile of the output tensor. */
+        template <int pad_bottom, int pad_right>
+        static void process_tile(int, const T*, int, T*, int, int);
+
+        // Array of methods to produce tiles of output tensor.
+        typedef void (*TileFn)(int, const T*, int, T*, int, int);
+        static const TileFn tile_fns[max_pad_bottom][max_pad_right];
+
+        /** Member constants for instances of the transform. */
+        const T* const _matrix_base;
+        const int _matrix_stride, _matrix_row_stride;
+        T* const _outptr;
+        const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N;
+    };
+
+    /** Perform a convolution.
+     */
+    template <typename TOut, typename TIn>
+    class Convolution
+    {
+      public:
+        // Information about the typed Winograd instance
+        typedef TOut OutputType;
+        typedef TIn InputType;
+
+        /** Create a new Winograd operator. */
+        Convolution(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding,
+          void *kernel_storage=NULL
+        );
+
+        Convolution(const Convolution&) = delete;
+        Convolution operator=(const Convolution&) = delete;
+
+        /** Create a new Winograd operator and initialise the weights. */
+        Convolution(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding,
+          const TIn* const kernel,
+          void *kernel_storage=NULL,
+          void *transform_working_space=NULL
+        );
+
+        /** Clean up a convolution engine. */
+        ~Convolution();
+
+        /** Transform the weights into the Winograd domain. */
+        template <typename WeightsTransform=WeightsTransform<TIn>>
+        void transform_weights(
+          const TIn* const kernel,
+          void *transform_working_space=NULL
+        );
+
+        /* Apply the Winograd operator to some input. */
+        void execute(
+          TOut* const output,
+          const TIn* const input,
+          void* working_space=NULL,
+          const int n_threads=1
+        );
+
+        /* Apply the Winograd operator to some input. */
+        void execute(
+          TOut* const output,
+          const TIn* const input,
+          const int n_threads
+        );
+
+        /** Get the output shape of a convolution. */
+        static Tensor4DShape get_output_shape(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &in_shape,
+          const PaddingType padding
+        );
+
+        /* Get the memory required to transform the kernel.
+         */
+        static size_t get_kernel_transform_working_size(const KernelShape &shape);
+
+        /** Get the memory required to store the kernel transformed into the
+         * Winograd domain.
+         */
+        static size_t get_kernel_storage_size(const KernelShape &shape);
+
+        /** Get the memory required to store the input tensor transformed into
+         * the Winograd domain.
+         */
+        static size_t get_input_storage_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /** Get the memory required to store the output tensor in the Winograd
+         * domain.
+         */
+        static size_t get_output_storage_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /** Get the memory required to apply a Winograd operator to some input.
+         */
+        static size_t get_working_space_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "input" matrix.
+         */
+        static size_t get_input_matrix_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        static int get_input_matrix_stride(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "output" matrix.
+         */
+        static size_t get_output_matrix_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        static int get_output_matrix_stride(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "kernel" matrix.
+         */
+        static size_t get_kernel_matrix_size(const KernelShape &shape);
+        static int get_kernel_matrix_stride(const KernelShape &shape);
+
+        static constexpr int M_BLOCK = 4;   /** Size of block used by GEMM. */
+        static constexpr int N_BLOCK = 16;  /** Size of block used by GEMM. */
+
+      private:
+        const KernelShape kernel_shape;  /** Shape of the kernel to be applied. */
+        TIn *kernel_matrices[N_GEMMS];   /** Pointers into the kernel matrices. */
+        const int kernel_matrix_row_stride;  /** Stride within the kernel matrices. */
+
+        const bool manage_kernel_storage;  /** Kernel storage is managed by the instance. */
+        void* const _kernel_storage;       /** Base pointer for kernel storage. */
+
+        const Tensor4DShape input_shape;  /** Shape of the input tensor. */
+        const PaddingType padding;        /** Padding applied by the operator. */
+
+        const Tensor4DShape output_shape;  /** Output shape produced by the operator. */
+
+        const int tile_rows;  /** Number of rows of tiles. */
+        const int tile_cols;  /** Number of columns of tiles. */
+        const int M, K, N;    /** Sizes of underlying fundamental matrix multiplications. */
+
+        profiler prof;
+    };
+};
+
+}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp
new file mode 100644
index 0000000..a3b3db4
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include <utility>
+
+#include "batched_blocked_gemm.hpp"
+#include "winograd_gemm.hpp"
+
+/** Example of how to construct an ACL-like interface.
+ *
+ * Use `get_weight_storage_size`, `get_input_storage_size` and
+ * `get_output_storage_size` to allocate memory for the convolution engine.
+ * Then create a `WinogradConvolutionLayer`.
+ *
+ * Initialise the weights using `weights_transform.run(...)`.
+ *
+ * For each inference:
+ *   1. Transform the inputs to the Winograd domain using `input_transform.run(...)`
+ *   2. Perform a number of GEMMs using `gemms.run(...)`
+ *   3. Transform the output to the spatial domain using `output_transform.run(...)`
+ */
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+          typename TIn, typename TOut>
+class WinogradConvolutionLayer
+{
+  private:
+    const KernelShape _kernel_shape;
+    const Tensor4DShape _input_shape;
+    const PaddingType _padding;
+    const Tensor4DShape _output_shape;
+    const int _n_output_rows, _n_output_cols;
+    const int _kernel_matrix_stride, _kernel_matrix_row_stride;
+    const int _input_matrix_stride, _input_matrix_row_stride;
+    const int _output_matrix_stride, _output_matrix_row_stride;
+    const int _tile_rows, _tile_cols;
+    const int _m, _k, _n;
+
+  public:
+    using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols>;
+    using WeightsTransform = typename WinogradBase::template WeightsTransform<TIn>;
+    using InputTransform = typename WinogradBase::template InputTransform<TIn>;
+    using WinogradConv = typename WinogradBase::template Convolution<TOut, TIn>;
+    using MultiGEMM = winograd::BatchedBlockedGemm<WinogradConv::M_BLOCK, WinogradConv::N_BLOCK, TIn, TOut>;
+    using OutputTransform = typename WinogradBase::template OutputTransform<TOut>;
+
+    /* Public member variables. */
+    WeightsTransform weights_transform;  /** Operator to transform weights to Winograd domain. */
+    InputTransform input_transform;      /** Operator to transform input to Winograd domain. */
+    MultiGEMM gemms;                     /** Operator to perform multiple GEMMs. */
+    OutputTransform output_transform;    /** Operator to transform output from Winograd domain. */
+
+    /** Determine how much memory (in units of TIn) to allocate for the
+     * transformed weights.
+     */
+    static unsigned int get_weight_storage_size(
+      const int n_output_channels,  /** Number of output feature maps. */
+      const int n_input_channels    /** Number of input feature maps. */
+    );
+
+    /** Determine how much memory (in units of TIn) to allocate for the
+     * transformed input.
+     */
+    static unsigned int get_input_storage_size(
+      const int n_batches,     /** Number of batches in the input tensor. */
+      const int n_channels,    /** Number of feature maps in the input tensor. */
+      const int n_rows,        /** Number of rows in each feature map. */
+      const int n_cols,        /** Number of columns in each feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Determine how much memory (in units of TOut) to allocate for the
+     * (Winograd domain) output.
+     */
+    static unsigned int get_output_storage_size(
+      const int n_batches,          /** Number of batches in the output tensor. */
+      const int n_rows,             /** Number of rows in each feature map of the input tensor. */
+      const int n_cols,             /** Number of columns in each feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Get the shape (rows, cols) of a feature map of the output tensor. */
+    static std::pair<int, int> get_output_feature_map_shape(
+      const int n_input_rows,  /** Number of rows in the input feature map. */
+      const int n_input_cols,  /** Number of columns in the input feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Create a new Winograd convolution layer.
+     */
+    WinogradConvolutionLayer(
+      const int n_batches,          /** Number of batches in the input and output tensors. */
+      const int n_input_channels,   /** Number of feature maps in a batch of the input tensor. */
+      const int n_input_rows,       /** Number of rows in a feature map of the input tensor. */
+      const int n_input_cols,       /** Number of columns in a feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+      const TIn* const weights,     /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+      TIn* const weights_storage,   /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+      const TIn* const input,       /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+      TIn* const winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+      TOut* const output,           /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+      TOut* const winograd_output   /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+    );
+};