COMPMID-815: Updated NEWinogradLayer with the lastest code from Research.

Change-Id: I86d7f53b5f5d1dbc22078aea5c32b08a25d1f49e
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116634
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/SConscript b/SConscript
index c7779ca..c9f6d08 100644
--- a/SConscript
+++ b/SConscript
@@ -175,6 +175,11 @@
     core_files += Glob('src/core/NEON/*.cpp')
     core_files += Glob('src/core/NEON/kernels/*.cpp')
 
+    # build winograd sources for either v7a / v8a
+    core_files += Glob('src/core/NEON/kernels/winograd/*.cpp')
+    core_files += Glob('src/core/NEON/kernels/winograd/transforms/*.cpp')
+    arm_compute_env.Append(CPPPATH = ["arm_compute/core/NEON/kernels/winograd/"])
+
     if env['arch'] == "armv7a":
         core_files += Glob('src/core/NEON/kernels/arm32/*.cpp')
 
diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
index 73b7e8d..9526192 100644
--- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,6 +25,7 @@
 #define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__
 
 #include "arm_compute/core/NEON/INEKernel.h"
+#include "arm_compute/core/NEON/kernels/winograd/convolution.hpp"
 #include "arm_compute/core/NEON/kernels/winograd/tensor.hpp"
 
 namespace arm_compute
@@ -36,11 +37,25 @@
 {
 public:
     friend class NEWinogradLayerKernel;
-    Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage);
+    Winograd3x3F32(
+        const int          n_batches,         /** Number of batches in the input and output tensors. */
+        const int          n_input_channels,  /** Number of feature maps in a batch of the input tensor. */
+        const int          n_input_rows,      /** Number of rows in a feature map of the input tensor. */
+        const int          n_input_cols,      /** Number of columns in a feature map of the input tensor. */
+        const int          n_output_channels, /** Number of feature maps in the output tensor. */
+        const bool         same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+        const float *const weights,           /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+        float *const       weights_storage,   /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+        const float *const input,             /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+        float *const       winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+        float *const       output,            /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+        float *const       winograd_output    /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+    );
+
     ~Winograd3x3F32();
-    void transform_weights(const void *const kernel, void *transform_working_space);
-    void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space);
-    void reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output);
+    void transform_weights();
+    void transform_input();
+    void transform_output();
 
 private:
     class Private;
@@ -75,15 +90,29 @@
 
     /* Get the memory required to instantiate a new Winograd operator.
        */
-    static size_t get_kernel_storage_size(const KernelShape &shape);
+    static size_t get_weight_storage_size(
+        const int n_output_channels, /** Number of output feature maps. */
+        const int n_input_channels   /** Number of input feature maps. */
+    );
 
-    /* Get the memory required to apply a Winograd operator to some input.
-       */
-    static size_t get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding);
+    static unsigned int get_input_storage_size(
+        const int  n_batches,   /** Number of batches in the input tensor. */
+        const int  n_channels,  /** Number of feature maps in the input tensor. */
+        const int  n_rows,      /** Number of rows in each feature map. */
+        const int  n_cols,      /** Number of columns in each feature map. */
+        const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */
+    );
 
-    /* Get the memory required to transform the kernel.
-       */
-    static size_t get_kernel_transform_working_size(const KernelShape &shape);
+    /** Determine how much memory (in units of TOut) to allocate for the
+     * (Winograd domain) output.
+     */
+    static unsigned int get_output_storage_size(
+        const int  n_batches,         /** Number of batches in the output tensor. */
+        const int  n_rows,            /** Number of rows in each feature map of the input tensor. */
+        const int  n_cols,            /** Number of columns in each feature map of the input tensor. */
+        const int  n_output_channels, /** Number of feature maps in the output tensor. */
+        const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+    );
 
 protected:
     Winograd3x3F32 *_convolver;
diff --git a/arm_compute/core/NEON/kernels/winograd/alloc.hpp b/arm_compute/core/NEON/kernels/winograd/alloc.hpp
index ef6f2b5..799e95d 100644
--- a/arm_compute/core/NEON/kernels/winograd/alloc.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/alloc.hpp
@@ -21,6 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #pragma once
 
 #ifdef ALLOC_ALIGN
diff --git a/src/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/arm.hpp
similarity index 77%
copy from src/core/NEON/kernels/winograd/transforms.hpp
copy to arm_compute/core/NEON/kernels/winograd/arm.hpp
index 8546ee9..90e7828 100644
--- a/src/core/NEON/kernels/winograd/transforms.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/arm.hpp
@@ -22,8 +22,18 @@
  * SOFTWARE.
  */
 
-#pragma once
+/** Sets the macro __arm_any__ if compiling for Aarch32 or Aarch64.
+ *  Includes `arm_neon.h` if compiling for either architecture.
+ */
 
-#include "transforms/input_2x2_3x3.hpp"
-#include "transforms/kernel_2x2_3x3.hpp"
-#include "transforms/output_2x2_3x3.hpp"
+#ifdef __arm__
+#define __arm_any__
+#endif  // __arm__
+
+#ifdef __aarch64__
+#define __arm_any__
+#endif  // __aarch64__
+
+#ifdef __arm_any__
+#include <arm_neon.h>
+#endif  // __arm_any__
diff --git a/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp
new file mode 100644
index 0000000..663b3c4
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/batched_blocked_gemm.hpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+namespace winograd
+{
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+class BatchedBlockedGemm
+{
+  public:
+    /** Create a new batched blocked GEMM operator. */
+    BatchedBlockedGemm(
+      const unsigned int n_gemms,
+      const int M, const int K, const int N,
+      const int a_matrix_stride,
+      const int a_row_stride,
+      const int b_matrix_stride,
+      const int b_row_stride,
+      const int c_matrix_stride,
+      const int c_row_stride,
+      const TIn* const a_ptr,
+      const TIn* const b_ptr,
+      TOut* const c_ptr
+    );
+
+    BatchedBlockedGemm(const BatchedBlockedGemm&) = delete;
+    BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete;
+
+    /** Get a window of work performed by the operator. */
+    unsigned int get_window() const;
+
+    /** Perform a portion of the work of the operator. */
+    void run(const unsigned int start, const unsigned int stop);
+
+  private:
+    const unsigned int n_gemms;
+    const int M, N, K;
+    const int a_matrix_stride, a_row_stride;
+    const int b_matrix_stride, b_row_stride;
+    const int c_matrix_stride, c_row_stride;
+    const TIn* const a_ptr;
+    const TIn* const b_ptr;
+    TOut* const c_ptr;
+};
+
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/convolution.hpp
similarity index 89%
rename from src/core/NEON/kernels/winograd/transforms.hpp
rename to arm_compute/core/NEON/kernels/winograd/convolution.hpp
index 8546ee9..2ab2597 100644
--- a/src/core/NEON/kernels/winograd/transforms.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/convolution.hpp
@@ -24,6 +24,6 @@
 
 #pragma once
 
-#include "transforms/input_2x2_3x3.hpp"
-#include "transforms/kernel_2x2_3x3.hpp"
-#include "transforms/output_2x2_3x3.hpp"
+enum PaddingType {
+  PADDING_SAME, PADDING_VALID
+};
diff --git a/src/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
similarity index 82%
copy from src/core/NEON/kernels/winograd/transforms.hpp
copy to arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
index 8546ee9..725f6ca 100644
--- a/src/core/NEON/kernels/winograd/transforms.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/direct_convolution.hpp
@@ -23,7 +23,12 @@
  */
 
 #pragma once
+#include "convolution.hpp"
+#include "tensor.hpp"
 
-#include "transforms/input_2x2_3x3.hpp"
-#include "transforms/kernel_2x2_3x3.hpp"
-#include "transforms/output_2x2_3x3.hpp"
+void direct_convolution(
+  const Tensor4D<Tensor4DShape, float>& input,
+  const Tensor4D<KernelShape, float>& kernel,
+  Tensor4D<Tensor4DShape, float>& output,
+  const PaddingType padding
+);
diff --git a/src/core/NEON/kernels/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm.hpp
similarity index 93%
rename from src/core/NEON/kernels/winograd/gemm.hpp
rename to arm_compute/core/NEON/kernels/winograd/gemm.hpp
index 111e196..e48d31b 100644
--- a/src/core/NEON/kernels/winograd/gemm.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/gemm.hpp
@@ -1,4 +1,3 @@
-
 /*
  * Copyright (c) 2017 ARM Limited.
  *
@@ -22,11 +21,12 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #pragma once
 #include "utils.hpp"
 
 template <typename TIn, typename TOut>
-void Gemm(const TIn* const a, const TIn* const b, TOut *c,
+inline void Gemm(const TIn* const a, const TIn* const b, TOut *c,
           const int M, const int K, const int N,
           const int a_row_stride,
           const int b_row_stride,
@@ -57,7 +57,7 @@
 }
 
 template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
-void BlockedGemm(
+inline void BlockedGemm(
   const TIn* const a, const TIn* const b, TOut *c,
   const int M, const int K, const int N,
   const int a_row_stride,
@@ -65,11 +65,11 @@
   const int c_row_stride
 ) {
   // Array access methods
-  const auto A = [a, a_row_stride] (const int i, const int j) -> TIn {
+  const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn {
     return a[i*a_row_stride + j];
   };
 
-  const auto B = [b, b_row_stride] (const int i, const int j) -> TIn {
+  const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn {
     return b[i*b_row_stride + j];
   };
 
diff --git a/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
similarity index 99%
rename from src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
rename to arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
index e1b7488..caeb48f 100644
--- a/src/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm.hpp
@@ -21,6 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #pragma once
 #include <cassert>
 #include "../utils.hpp"
@@ -347,8 +348,7 @@
       );
       break;
     default:
-      assert(0);
-      break;
+      assert(false);
   }
 }
 
diff --git a/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
similarity index 98%
rename from src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
rename to arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
index e74610e..5cd37de 100644
--- a/src/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/gemm/a64_sgemm_4x16.hpp
@@ -21,6 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 template <const unsigned int tail>
 inline void sgemm_4x16_impl(
   const float* const a, const float* const b, float *c,
@@ -604,9 +605,9 @@
           "ldr qB2, [%x[bptr], #0x10]\n"
           "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
           "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr sA1, [%x[aptr]], #0x10\n"
+          "ldr sA1, [%x[aptr]], #0x04\n"
           "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr sA2, [   aptr2], #0x10\n"
+          "ldr sA2, [   aptr2], #0x04\n"
           "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
           "ldr qB3, [%x[bptr], #0x20]\n"
           "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
@@ -617,7 +618,7 @@
           "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
           "stp qC11, qC12, [%x[cptr], #0x00]\n"
           "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "ldr sA3, [   aptr3], #0x10\n"
+          "ldr sA3, [   aptr3], #0x04\n"
           "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
           "stp qC13, qC14, [%x[cptr], #0x20]\n"
           "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
@@ -625,7 +626,7 @@
           "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
           "stp qC21, qC22, [%x[cptr], #0x00]\n"
           "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "ldr sA4, [   aptr4], #0x10\n"
+          "ldr sA4, [   aptr4], #0x04\n"
           "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
           "stp qC23, qC24, [%x[cptr], #0x20]\n"
           "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
@@ -951,18 +952,18 @@
           "ldr qB2, [%x[bptr], #0x10]\n"
           "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
           "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr dA1, [%x[aptr]], #0x10\n"
+          "ldr dA1, [%x[aptr]], #0x08\n"
           "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr dA2, [   aptr2], #0x10\n"
+          "ldr dA2, [   aptr2], #0x08\n"
           "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
           "ldr qB3, [%x[bptr], #0x20]\n"
           "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
 
         "2:"  // Common tail
           "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr dA3, [   aptr3], #0x10\n"
+          "ldr dA3, [   aptr3], #0x08\n"
           "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr dA4, [   aptr4], #0x10\n"
+          "ldr dA4, [   aptr4], #0x08\n"
           "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
           "ldr qB4, [%x[bptr], #0x30]\n"
           "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
@@ -1320,18 +1321,18 @@
           "ldr qB2, [%x[bptr], #0x10]\n"
           "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
           "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
-          "ldr dA1, [%x[aptr]], #0x10\n"
+          "ldr dA1, [%x[aptr]], #0x08\n"
           "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
-          "ldr dA2, [   aptr2], #0x10\n"
+          "ldr dA2, [   aptr2], #0x08\n"
           "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
           "ldr qB3, [%x[bptr], #0x20]\n"
           "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
 
         "2:"  // Common tail
           "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
-          "ldr dA3, [   aptr3], #0x10\n"
+          "ldr dA3, [   aptr3], #0x08\n"
           "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
-          "ldr dA4, [   aptr4], #0x10\n"
+          "ldr dA4, [   aptr4], #0x08\n"
           "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
           "ldr qB4, [%x[bptr], #0x30]\n"
           "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
@@ -1369,9 +1370,9 @@
           "ldr qB2, [%x[bptr], #0x10]\n"
           "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
           "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
-          "ldr sA1, [%x[aptr]], #0x10\n"
+          "ldr sA1, [%x[aptr]], #0x04\n"
           "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
-          "ldr sA2, [   aptr2], #0x10\n"
+          "ldr sA2, [   aptr2], #0x04\n"
           "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
           "ldr qB3, [%x[bptr], #0x20]\n"
           "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
@@ -1381,7 +1382,7 @@
           "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
           "stp qC11, qC12, [%x[cptr], #0x00]\n"
           "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
-          "ldr sA3, [   aptr3], #0x10\n"
+          "ldr sA3, [   aptr3], #0x04\n"
           "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
           "stp qC13, qC14, [%x[cptr], #0x20]\n"
           "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
@@ -1389,7 +1390,7 @@
           "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
           "stp qC21, qC22, [%x[cptr], #0x00]\n"
           "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
-          "ldr sA4, [   aptr4], #0x10\n"
+          "ldr sA4, [   aptr4], #0x04\n"
           "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
           "stp qC23, qC24, [%x[cptr], #0x20]\n"
           "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
diff --git a/arm_compute/core/NEON/kernels/winograd/perf.h b/arm_compute/core/NEON/kernels/winograd/perf.h
new file mode 100644
index 0000000..0cdf742
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/perf.h
@@ -0,0 +1,9 @@
+#pragma once
+
+/* Prototypes from perf.c */
+
+void start_counter(int fd);
+long long get_counter(int fd);
+long long stop_counter(int fd);
+int open_instruction_counter(void);
+int open_cycle_counter(void);
diff --git a/src/core/NEON/kernels/winograd/profiler.hpp b/arm_compute/core/NEON/kernels/winograd/profiler.hpp
similarity index 79%
rename from src/core/NEON/kernels/winograd/profiler.hpp
rename to arm_compute/core/NEON/kernels/winograd/profiler.hpp
index 143192b..01fafa9 100644
--- a/src/core/NEON/kernels/winograd/profiler.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/profiler.hpp
@@ -1,4 +1,3 @@
-
 /*
  * Copyright (c) 2017 ARM Limited.
  *
@@ -22,6 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #pragma once
 
 #include <algorithm>
@@ -29,11 +29,85 @@
 #include <cstring>
 #include <cstdio>
 #include <map>
+#include <mutex>
+#include <thread>
 #include <vector>
 
 #include "perf.h"
 #include <unistd.h>
 
+#ifdef CYCLE_PROFILING
+class EventIDContainer
+{
+  public:
+  EventIDContainer() : container_lock(), event_ids()
+  {
+  }
+
+  int get_event_id(const char *id)
+  {
+    std::lock_guard<std::mutex> lock(container_lock);
+    if (!event_ids.count(id)) {
+      event_ids.emplace(id, event_ids.size());
+    }
+    return event_ids[id];
+  }
+
+  unsigned int size() const
+  {
+    return event_ids.size();
+  }
+
+  auto begin()
+  {
+    return event_ids.begin();
+  }
+
+  auto end()
+  {
+    return event_ids.end();
+  }
+
+  private:
+  std::mutex container_lock;
+  std::map<const char *, int> event_ids;
+};
+
+
+class ThreadEventCounterContainer
+{
+  public:
+  ThreadEventCounterContainer() : container_lock(), thread_counter_fds()
+  {
+  }
+
+  int get_counter_fd()
+  {
+    const auto id = std::this_thread::get_id();
+    std::lock_guard<std::mutex> lock(container_lock);
+    if (!thread_counter_fds.count(id))
+    {
+      thread_counter_fds.emplace(id, open_cycle_counter());
+    }
+    return thread_counter_fds[id];
+  }
+
+  ~ThreadEventCounterContainer()
+  {
+    // Close all counter file descriptors
+    for (auto& fd : thread_counter_fds)
+    {
+      close(fd.second);
+    }
+  }
+
+  private:
+  std::mutex container_lock;
+  std::map<std::thread::id, int> thread_counter_fds;
+};
+#endif  // CYCLE_PROFILING
+
+
 class profiler {
 private:
 #ifdef CYCLE_PROFILING
@@ -46,27 +120,29 @@
     static const int maxevents = 10000;
     ProfileEntry events[maxevents];
     int currentevent;
-    int countfd;
+    std::mutex event_lock;
 
-    std::map<const char *, int> event_ids;
+    EventIDContainer event_ids;
+    ThreadEventCounterContainer thread_counter_fds;
 
-    int get_event_id(const char *id) {
-      if (!event_ids.count(id)) {
-        event_ids.emplace(id, event_ids.size());
-      }
-      return event_ids[id];
+    int get_event_id(const char *id)
+    {
+      return event_ids.get_event_id(id);
     }
 #endif  // CYCLE_PROFILING
 
 public:
 #ifdef CYCLE_PROFILING
-    profiler() {
-        currentevent = 0;
-        countfd = open_cycle_counter();
+    profiler() :
+      currentevent(0),
+      event_lock(),
+      event_ids(),
+      thread_counter_fds()
+    {
     }
 
     ~profiler() {
-        close(countfd);
+      std::lock_guard<std::mutex> lock_events(event_lock);
 
         // Compute performance from recorded events
         struct ProfileResult {
@@ -228,16 +304,22 @@
         if (currentevent==maxevents) {
             func();
         } else {
+            const auto countfd = thread_counter_fds.get_counter_fd();
             start_counter(countfd);
             func();
             long long cycs = stop_counter(countfd);
 
             // Store the profiling data
+            std::lock_guard<std::mutex> lock_events(event_lock);
             events[currentevent++] = {
               get_event_id(event), bytes_read, ops, bytes_written, cycs
             };
         }
 #else
+      (void) event;
+      (void) bytes_read;
+      (void) ops;
+      (void) bytes_written;
       func();
 #endif  // CYCLE_PROFILING
     }
diff --git a/arm_compute/core/NEON/kernels/winograd/shims.hpp b/arm_compute/core/NEON/kernels/winograd/shims.hpp
new file mode 100644
index 0000000..09e1457
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/shims.hpp
@@ -0,0 +1,747 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cstdint>
+#include "arm.hpp"
+
+namespace reorder {
+/** Re-order a tensor from NCHW format to NHWC.
+ *
+ * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
+ *
+ * @param[in] in Input tensor in NCHW format.
+ * @param[out] out Output tensor, to be written in NHWC format.
+ * @param n_batches Number of batches in the tensors.
+ * @param n_channels Number of channels in the tensors
+ * @param n_rows Height of the tensor
+ * @param n_cols Width of the tensor
+ * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_channels * in_channel_stride`.
+ * @param in_channel_stride Stride over channels in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
+ * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols`.
+ * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
+ * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols * out_col_stride`.
+ * @param out_col_stride Stride over columns in the output tensor. If `0` defaults to `n_channels`.
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+  const T* const in,
+  T* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride=0,
+  int in_channel_stride=0,
+  int in_row_stride=0,
+  int out_batch_stride=0,
+  int out_row_stride=0,
+  int out_col_stride=0
+);
+
+/** Re-order a tensor from NHWC format to NCHW.
+ *
+ * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
+ *
+ * @param[in] in Input tensor in NHWC format.
+ * @param[out] out Output tensor, to be written in NCHW format.
+ * @param n_batches Number of batches in the tensors.
+ * @param n_rows Height of the tensor
+ * @param n_cols Width of the tensor
+ * @param n_channels Number of channels in the tensors
+ * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
+ * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols * in_col_stride`.
+ * @param in_col_stride Stride over columns in the input tensor. If `0` defaults to `n_channels`.
+ * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_channels * out_channel_stride`.
+ * @param out_channel_stride Stride over channels in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
+ * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols`.
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+  const T* const in,  // Input data in NHWC form
+  T* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride=0,
+  int in_row_stride=0,
+  int in_col_stride=0,
+  int out_batch_stride=0,
+  int out_channel_stride=0,
+  int out_row_stride=0
+);
+
+/** Re-order a weight tensor from [Output feature map x Input feature map x
+ *  Height x Width] format to [Height x Width x Input feature map x Output
+ *  feature map] format.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+  const T* const in,  // Input in [Output x Input x Height x Width] form
+  T* const out,       // Output in [Height x Width x Input x Output] form
+  const int n_output_feature_maps,
+  const int n_input_feature_maps,
+  const int n_rows,
+  const int n_cols,
+  int in_output_feature_map_stride=0,
+  int in_input_feature_map_stride=0,
+  int in_row_stride=0,
+  int out_row_stride=0,
+  int out_col_stride=0,
+  int out_input_feature_map_stride=0
+);
+
+/** Re-order a weight tensor from [Height x Width x Input feature map x Output
+ *  feature map] format to [Output feature map x Input feature map x Height x
+ *  Width] format.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+  const T* const in,  // Input in [Height x Width x Input x Output] form
+  T* const out,       // Output in [Output x Input x Height x Width] form
+  const int n_rows,
+  const int n_cols,
+  const int n_input_feature_maps,
+  const int n_output_feature_maps,
+  int in_row_stride=0,
+  int in_col_stride=0,
+  int in_input_feature_map_stride=0,
+  int out_output_feature_map_stride=0,
+  int out_input_feature_map_stride=0,
+  int out_row_stride=0
+);
+
+/*****************************************************************************/
+/* 32-bit implementation : NCHW -> NHWC
+ */
+template <>
+inline void nchw_to_nhwc(
+  const int32_t* const in,
+  int32_t* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  typedef int32_t T;
+
+  // Fill in the stride values
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+  in_channel_stride = (in_channel_stride) ? in_channel_stride
+                                          : n_rows * in_row_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_channels * in_channel_stride;
+
+  out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_rows * out_row_stride;
+
+  // Perform the re-ordering
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_row = in_batch + i*in_row_stride;
+      T* const out_row = out_batch + i*out_row_stride;
+
+      int j = 0, j_remaining = n_cols;
+#ifdef __arm_any__
+      for (; j_remaining >= 4; j += 4, j_remaining -= 4)
+      {
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 4; c += 4, c_remaining -= 4)
+        {
+          // Read 4 channels worth of 4 columns, then zip to produce 4 columns
+          // worth of 4 channels.
+          int32x4_t channel_pixels[4];
+          channel_pixels[0] = vld1q_s32(in_row + (c + 0)*in_channel_stride + j);
+          channel_pixels[1] = vld1q_s32(in_row + (c + 1)*in_channel_stride + j);
+          channel_pixels[2] = vld1q_s32(in_row + (c + 2)*in_channel_stride + j);
+          channel_pixels[3] = vld1q_s32(in_row + (c + 3)*in_channel_stride + j);
+
+          const auto zip1 = vzipq_s32(channel_pixels[0], channel_pixels[2]);
+          const auto zip2 = vzipq_s32(channel_pixels[1], channel_pixels[3]);
+          const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
+          const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
+
+          vst1q_s32(out_row + (j + 0)*out_col_stride + c, out_0.val[0]);
+          vst1q_s32(out_row + (j + 1)*out_col_stride + c, out_0.val[1]);
+          vst1q_s32(out_row + (j + 2)*out_col_stride + c, out_1.val[0]);
+          vst1q_s32(out_row + (j + 3)*out_col_stride + c, out_1.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 4; _j++)
+          {
+            const T* const in_col = in_row + j + _j;
+            T* const out_col = out_row + (j + _j)*out_col_stride;
+            const T* const in_channel = in_col + c*in_channel_stride;
+            out_col[c] = *(in_channel);
+          }
+        }
+      }
+      for (; j_remaining >= 2; j += 2, j_remaining -= 2)
+      {
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 2; c += 2, c_remaining -= 2)
+        {
+          // Read 2 channels worth of 2 columns, then zip to produce 2 columns
+          // worth of 2 channels.
+          int32x2_t channel_pixels[2];
+          channel_pixels[0] = vld1_s32(in_row + (c + 0)*in_channel_stride + j);
+          channel_pixels[1] = vld1_s32(in_row + (c + 1)*in_channel_stride + j);
+
+          const auto output = vzip_s32(channel_pixels[0], channel_pixels[1]);
+
+          vst1_s32(out_row + (j + 0)*out_col_stride + c, output.val[0]);
+          vst1_s32(out_row + (j + 1)*out_col_stride + c, output.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 2; _j++)
+          {
+            const T* const in_col = in_row + j + _j;
+            T* const out_col = out_row + (j + _j)*out_col_stride;
+            const T* const in_channel = in_col + c*in_channel_stride;
+            out_col[c] = *(in_channel);
+          }
+        }
+      }
+#endif  // __arm_any__
+      for (; j_remaining; j++, j_remaining--)
+      {
+        const T* const in_col = in_row + j;
+        T* const out_col = out_row + j*out_col_stride;
+
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_col + c*in_channel_stride;
+          out_col[c] = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+template <>
+inline void nchw_to_nhwc(
+  const uint32_t* const in,
+  uint32_t* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  nchw_to_nhwc(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_channels, n_rows, n_cols,
+    in_batch_stride, in_channel_stride, in_row_stride,
+    out_batch_stride, out_row_stride, out_col_stride
+  );
+}
+
+template <>
+inline void nchw_to_nhwc(
+  const float* const in,
+  float* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  nchw_to_nhwc(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_channels, n_rows, n_cols,
+    in_batch_stride, in_channel_stride, in_row_stride,
+    out_batch_stride, out_row_stride, out_col_stride
+  );
+}
+
+/*****************************************************************************/
+/* Generic implementation : NCHW -> NHWC
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+  const T* const in,
+  T* const out,
+  const int n_batches,
+  const int n_channels,
+  const int n_rows,
+  const int n_cols,
+  int in_batch_stride,
+  int in_channel_stride,
+  int in_row_stride,
+  int out_batch_stride,
+  int out_row_stride,
+  int out_col_stride
+)
+{
+  // Fill in the stride values
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+  in_channel_stride = (in_channel_stride) ? in_channel_stride
+                                          : n_rows * in_row_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_channels * in_channel_stride;
+
+  out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_rows * out_row_stride;
+
+  // Perform the re-ordering
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_row = in_batch + i*in_row_stride;
+      T* const out_row = out_batch + i*out_row_stride;
+
+      for (int j = 0; j < n_cols; j++)
+      {
+        const T* const in_col = in_row + j;
+        T* const out_col = out_row + j*out_col_stride;
+
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_col + c*in_channel_stride;
+          out_col[c] = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+/*****************************************************************************/
+/* 32-bit implementation : NHWC -> NCHW
+ */
+template <>
+inline void nhwc_to_nchw(
+  const int32_t* const in,  // Input data in NHWC form
+  int32_t* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  typedef int32_t T;
+
+  // Fill in stride values
+  in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_rows * in_row_stride;
+
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+  out_channel_stride = (out_channel_stride) ? out_channel_stride
+                                            : n_rows * out_row_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_channels * out_channel_stride;
+
+  // Perform the re-ordering
+  // For every batch
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    // For every row
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_i = in_batch + i*in_row_stride;
+      T* const out_i = out_batch + i*out_row_stride;
+
+      // For every column, beginning with chunks of 4
+      int j = 0, j_remaining = n_cols;
+#ifdef __arm_any__
+      for (; j_remaining >= 4; j += 4, j_remaining -=4)
+      {
+        // For every channel, beginning with chunks of 4
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 4; c += 4, c_remaining -= 4)
+        {
+          // Read 4 columns worth of 4 channels then zip to produce 4 channels
+          // worth of 4 columns.
+          int32x4_t pixel_channels[4];
+          pixel_channels[0] = vld1q_s32(in_i + (j + 0)*in_col_stride + c);
+          pixel_channels[1] = vld1q_s32(in_i + (j + 1)*in_col_stride + c);
+          pixel_channels[2] = vld1q_s32(in_i + (j + 2)*in_col_stride + c);
+          pixel_channels[3] = vld1q_s32(in_i + (j + 3)*in_col_stride + c);
+
+          const auto zip1 = vzipq_s32(pixel_channels[0], pixel_channels[2]);
+          const auto zip2 = vzipq_s32(pixel_channels[1], pixel_channels[3]);
+          const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
+          const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
+
+          vst1q_s32(out_i + j + (c + 0)*out_channel_stride, out_0.val[0]);
+          vst1q_s32(out_i + j + (c + 1)*out_channel_stride, out_0.val[1]);
+          vst1q_s32(out_i + j + (c + 2)*out_channel_stride, out_1.val[0]);
+          vst1q_s32(out_i + j + (c + 3)*out_channel_stride, out_1.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 4; _j++)
+          {
+            const T* const in_j = in_i + (j + _j)*in_col_stride;
+            T* const out_j = out_i + (j + _j);
+
+            const T* const in_channel = in_j + c;
+            T* const out_channel = out_j + c*out_channel_stride;
+            *(out_channel) = *(in_channel);
+          }
+        }
+      }
+      for (; j_remaining >= 2; j += 2, j_remaining -=2)
+      {
+        int c = 0, c_remaining = n_channels;
+        for (; c_remaining >= 2; c += 2, c_remaining -= 2)
+        {
+          // Read 2 columns worth of 2 channels then zip to produce 2 channels
+          // worth of 2 columns.
+          int32x2_t pixel_channels[2];
+          pixel_channels[0] = vld1_s32(in_i + (j + 0)*in_col_stride + c);
+          pixel_channels[1] = vld1_s32(in_i + (j + 1)*in_col_stride + c);
+
+          const auto output = vzip_s32(pixel_channels[0], pixel_channels[1]);
+
+          vst1_s32(out_i + j + (c + 0)*out_channel_stride, output.val[0]);
+          vst1_s32(out_i + j + (c + 1)*out_channel_stride, output.val[1]);
+        }
+        for (; c_remaining; c++, c_remaining--)
+        {
+          for (int _j = 0; _j < 2; _j++)
+          {
+            const T* const in_j = in_i + (j + _j)*in_col_stride;
+            T* const out_j = out_i + (j + _j);
+
+            const T* const in_channel = in_j + c;
+            T* const out_channel = out_j + c*out_channel_stride;
+            *(out_channel) = *(in_channel);
+          }
+        }
+      }
+#endif  // __arm_any__
+      for (; j_remaining; j++, j_remaining--)
+      {
+        const T* const in_j = in_i + j*in_col_stride;
+        T* const out_j = out_i + j;
+
+        // For every channel
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_j + c;
+          T* const out_channel = out_j + c*out_channel_stride;
+          *(out_channel) = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+template <>
+inline void nhwc_to_nchw(
+  const uint32_t* const in,  // Input data in NHWC form
+  uint32_t* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  // Redirect to generic 32-bit implementation
+  nhwc_to_nchw(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_rows, n_cols, n_channels,
+    in_batch_stride, in_row_stride, in_col_stride,
+    out_batch_stride, out_channel_stride, out_row_stride
+  );
+}
+
+template <>
+inline void nhwc_to_nchw(
+  const float* const in,  // Input data in NHWC form
+  float* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  // Redirect to generic 32-bit implementation
+  nhwc_to_nchw(
+    reinterpret_cast<const int32_t*>(in),
+    reinterpret_cast<int32_t*>(out),
+    n_batches, n_rows, n_cols, n_channels,
+    in_batch_stride, in_row_stride, in_col_stride,
+    out_batch_stride, out_channel_stride, out_row_stride
+  );
+}
+
+/*****************************************************************************/
+/* Generic implementation : NHWC -> NCHW
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+  const T* const in,  // Input data in NHWC form
+  T* const out,       // Output data in NCHW form
+  const int n_batches,
+  const int n_rows,
+  const int n_cols,
+  const int n_channels,
+  int in_batch_stride,
+  int in_row_stride,
+  int in_col_stride,
+  int out_batch_stride,
+  int out_channel_stride,
+  int out_row_stride
+)
+{
+  // Fill in stride values
+  in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+  in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+  in_batch_stride = (in_batch_stride) ? in_batch_stride
+                                      : n_rows * in_row_stride;
+
+  out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+  out_channel_stride = (out_channel_stride) ? out_channel_stride
+                                            : n_rows * out_row_stride;
+  out_batch_stride = (out_batch_stride) ? out_batch_stride
+                                        : n_channels * out_channel_stride;
+
+  // Perform the re-ordering
+  // For every batch
+  for (int n = 0; n < n_batches; n++)
+  {
+    const T* const in_batch = in + n*in_batch_stride;
+    T* const out_batch = out + n*out_batch_stride;
+
+    // For every row
+    for (int i = 0; i < n_rows; i++)
+    {
+      const T* const in_i = in_batch + i*in_row_stride;
+      T* const out_i = out_batch + i*out_row_stride;
+
+      // For every column
+      for (int j = 0; j < n_cols; j++)
+      {
+        const T* const in_j = in_i + j*in_col_stride;
+        T* const out_j = out_i + j;
+
+        // For every channel
+        for (int c = 0; c < n_channels; c++)
+        {
+          const T* const in_channel = in_j + c;
+          T* const out_channel = out_j + c*out_channel_stride;
+          *(out_channel) = *(in_channel);
+        }
+      }
+    }
+  }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+  const T* const in,  // Input in [Output x Input x Height x Width] form
+  T* const out,       // Output in [Height x Width x Input x Output] form
+  const int n_output_feature_maps,
+  const int n_input_feature_maps,
+  const int n_rows,
+  const int n_cols,
+  int in_output_feature_map_stride,
+  int in_input_feature_map_stride,
+  int in_row_stride,
+  int out_row_stride,
+  int out_col_stride,
+  int out_input_feature_map_stride
+)
+{
+  // Fill in stride values
+  in_row_stride = (in_row_stride)
+    ? in_row_stride
+    : n_cols;
+  in_input_feature_map_stride = (in_input_feature_map_stride)
+    ? in_input_feature_map_stride
+    : n_rows * in_row_stride;
+  in_output_feature_map_stride = (in_output_feature_map_stride)
+    ? in_output_feature_map_stride
+    : n_input_feature_maps * in_input_feature_map_stride;
+
+  out_input_feature_map_stride = (out_input_feature_map_stride)
+    ? out_input_feature_map_stride
+    : n_output_feature_maps;
+  out_col_stride = (out_col_stride)
+    ? out_col_stride
+    : n_input_feature_maps * out_input_feature_map_stride;
+  out_row_stride = (out_row_stride)
+    ? out_row_stride
+    : n_cols * out_col_stride;
+
+  // Perform the re-ordering
+  for (int i = 0; i < n_rows; i++)
+  {
+    const T* const in_row = in + i * in_row_stride;
+    T* out_row = out + i * out_row_stride;
+
+    for (int j = 0; j < n_cols; j++)
+    {
+      const T* const in_col = in_row + j;
+      T* const out_col = out_row + j * out_col_stride;
+
+      for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+      {
+        const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+        T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+        for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+        {
+          const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride;
+          T* const out_ofm = out_ifm + ofm;
+          *(out_ofm) = *(in_ofm);
+        }
+      }
+    }
+  }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+  const T* const in,  // Input in [Height x Width x Input x Output] form
+  T* const out,       // Output in [Output x Input x Height x Width] form
+  const int n_rows,
+  const int n_cols,
+  const int n_input_feature_maps,
+  const int n_output_feature_maps,
+  int in_row_stride,
+  int in_col_stride,
+  int in_input_feature_map_stride,
+  int out_output_feature_map_stride,
+  int out_input_feature_map_stride,
+  int out_row_stride
+)
+{
+  // Fill in the stride values
+  in_input_feature_map_stride = (in_input_feature_map_stride)
+    ? in_input_feature_map_stride
+    : n_output_feature_maps;
+  in_col_stride = (in_col_stride)
+    ? in_col_stride
+    : n_input_feature_maps * in_input_feature_map_stride;
+  in_row_stride = (in_row_stride)
+    ? in_row_stride
+    : n_cols * in_col_stride;
+
+  out_row_stride = (out_row_stride)
+    ? out_row_stride
+    : n_cols;
+  out_input_feature_map_stride = (out_input_feature_map_stride)
+    ? out_input_feature_map_stride
+    : n_rows * out_row_stride;
+  out_output_feature_map_stride = (out_output_feature_map_stride)
+    ? out_output_feature_map_stride
+    : n_input_feature_maps * out_input_feature_map_stride;
+
+  // Perform the re-ordering
+  for (int i = 0; i < n_rows; i++)
+  {
+    const T* const in_row = in + i * in_row_stride;
+    T* const out_row = out + i * out_row_stride;
+
+    for (int j = 0; j < n_cols; j++)
+    {
+      const T* const in_col = in_row + j * in_col_stride;
+      T* const out_col = out_row + j;
+
+      for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+      {
+        const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+        T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+        for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+        {
+          const T* const in_ofm = in_ifm + ofm;
+          T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride;
+          *(out_ofm) = *(in_ofm);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace reorder
diff --git a/arm_compute/core/NEON/kernels/winograd/tensor.hpp b/arm_compute/core/NEON/kernels/winograd/tensor.hpp
index 70ef65d..6567eeb 100644
--- a/arm_compute/core/NEON/kernels/winograd/tensor.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/tensor.hpp
@@ -23,39 +23,44 @@
  */
 
 #pragma once
-#include <cstdio>
 #include <cstdlib>
 #include <random>
 
 #include "alloc.hpp"
 
-/*****************************************************************************/
-/* Padding definitions */
-enum PaddingType {
-  PADDING_SAME, PADDING_VALID
+enum TensorOrder
+{
+  NHWC,  ///< [Batch x Height x Width x Channels]
+  NCHW,  ///< [Batch x Channels x Height x Width]
 };
 
-/*****************************************************************************/
-/* Shape of a kernel */
-struct KernelShape {
-  int n_output_channels, n_rows, n_cols, n_input_channels;
+struct Tensor4DShape
+{
+  int n_batches, n_rows, n_cols, n_channels;
+  TensorOrder ordering;
 
-  int size(void) const {
-    return n_output_channels * n_rows * n_cols * n_input_channels;
+  // Create a new tensor with the default (NHWC) ordering
+  inline Tensor4DShape(
+    const int n_batches,
+    const int n_rows,
+    const int n_cols,
+    const int n_channels,
+    const TensorOrder ordering=NHWC
+  ) : n_batches(n_batches),
+      n_rows(n_rows),
+      n_cols(n_cols),
+      n_channels(n_channels),
+      ordering(ordering)
+  {
   }
-};
 
-struct Tensor4DShape {
-  int n_batches,
-      n_rows,
-      n_cols,
-      n_channels;
-
-  int size() const {
+  inline int size() const
+  {
     return n_batches * n_rows * n_cols * n_channels;
   }
 
-  bool TestEq(const Tensor4DShape& other) const {
+  inline bool TestEq(const Tensor4DShape& other) const
+  {
     return (n_batches == other.n_batches &&
             n_rows == other.n_rows &&
             n_cols == other.n_cols &&
@@ -63,148 +68,110 @@
   }
 };
 
+
+enum WeightOrder
+{
+  HWIO,  ///< [Height x Width x Input channels x Output channels]
+  OIHW,  ///< [Output channels x Input channels x Height x Width]
+};
+
+struct KernelShape
+{
+  int n_output_channels, n_rows, n_cols, n_input_channels;
+  WeightOrder ordering;
+
+  inline KernelShape(
+    const int n_output_channels,
+    const int n_rows,
+    const int n_cols,
+    const int n_input_channels,
+    const WeightOrder ordering=HWIO
+  ) : n_output_channels(n_output_channels),
+      n_rows(n_rows),
+      n_cols(n_cols),
+      n_input_channels(n_input_channels),
+      ordering(ordering)
+  {
+  }
+
+  inline int size(void) const
+  {
+    return n_output_channels * n_rows * n_cols * n_input_channels;
+  }
+};
+
+
 template <typename ShapeT, typename T>
-class Tensor4D final {
+class Tensor4D final
+{
   public:
     Tensor4D(ShapeT shape) :
-      _shape(shape),
-      _data(reinterpret_cast<T*>(ALLOCATE(size_bytes()))) {
+      shape(shape),
+      _data(reinterpret_cast<T*>(ALLOCATE(size_bytes())))
+    {
         Clear();
     }
 
+    Tensor4D(const Tensor4D<ShapeT, T>&) = delete;
+    Tensor4D operator=(const Tensor4D<ShapeT, T>&) = delete;
+
     ~Tensor4D() {
       free(_data);
     }
 
-    T* ptr() const {
+    inline T* ptr() const {
       return _data;
     }
 
-    const ShapeT& shape() const {
-      return _shape;
+    inline size_t size_bytes() const {
+      return shape.size() * sizeof(T);
     }
 
-    size_t size_bytes() const {
-      return _shape.size() * sizeof(T);
-    }
+    inline T& element(int, int, int, int) const;
 
-    bool TestEq(Tensor4D<ShapeT, T>& other) const;
-    T& element(int, int, int, int) const;
-    void Print() const;
-
-    void Clear() {
+    inline void Clear() {
       Fill(static_cast<T>(0));
     }
 
-    void Fill(T val) {
-      for (int i = 0; i < _shape.size(); i++)
+    inline void Fill(T val) {
+      for (int i = 0; i < shape.size(); i++)
         _data[i] = val;
     }
 
-    void TestPattern() {
-      for (int i = 0; i < _shape.size(); i++)
-        _data[i] = static_cast<T>(i);
-    }
-
-    void Rand(const int seed=2311) {
-      std::mt19937 gen(seed);
-      std::uniform_int_distribution<> dis(-50, +50);
-
-      for (int i = 0; i < _shape.size(); i++) {
-        _data[i] = static_cast<T>(dis(gen));
-      }
-    }
-    Tensor4D(const Tensor4D &) = delete;
-    /** Prevent instances of this class from being copied (As this class contains pointers) */
-    Tensor4D &operator=(const Tensor4D &) = delete;
-    /** Allow instances of this class to be moved */
-    Tensor4D(Tensor4D &&) = default;
-    /** Allow instances of this class to be moved */
-    Tensor4D &operator=(Tensor4D &&) = default;
-
+    const ShapeT shape;
 
   private:
-    const ShapeT _shape;
     T* const _data;
 };
 
 
 template <>
-inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const {
-  int index = ((n*_shape.n_rows + i)*_shape.n_cols + j)*_shape.n_channels + c;
+inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const
+{
+  int index;
+  if (shape.ordering == NHWC)
+  {
+    index = ((n*shape.n_rows + i)*shape.n_cols + j)*shape.n_channels + c;
+  }
+  else  // NCHW
+  {
+    index = ((n*shape.n_channels + c)*shape.n_rows + i)*shape.n_cols + j;
+  }
   return _data[index];
 }
 
 
 template <>
-inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const {
-  int index = ((i*_shape.n_cols + j)*_shape.n_input_channels + ic)*_shape.n_output_channels + oc;
+inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const
+{
+  int index;
+  if (shape.ordering == HWIO)
+  {
+    index = ((i*shape.n_cols + j)*shape.n_input_channels + ic)*shape.n_output_channels + oc;
+  }
+  else  // OIHW
+  {
+    index = ((oc*shape.n_input_channels + ic)*shape.n_rows + i)*shape.n_cols + j;
+  }
   return _data[index];
 }
-
-template <>
-inline bool Tensor4D<Tensor4DShape, float>::TestEq(Tensor4D<Tensor4DShape, float>& other) const {
-  // Test equivalence, printing errors
-  // First test the shapes are the same
-  if (!_shape.TestEq(other.shape())) {
-    printf("Tensors have different shapes.\n");
-    return false;
-  } else {
-    int incorrects = 0;
-
-    for (int n = 0; n < _shape.n_batches; n++) {
-      for (int i = 0; i < _shape.n_rows; i++) {
-        for (int j = 0; j < _shape.n_cols; j++) {
-          for (int c = 0; c < _shape.n_channels; c++) {
-            // Check elements for equivalence
-            const auto a = this->element(n, i, j, c);
-            const auto b = other.element(n, i, j, c);
-
-            if (a != b) {
-              printf("Difference at element {%d, %d, %d, %d}: %.3f != %.3f\n", n, i, j, c, a, b);
-
-              if (++incorrects > 100) {
-                printf("More than 100 incorrect values, stopping test.\n");
-                return false;
-              }
-            }
-          }
-        }
-      }
-    }
-
-    return incorrects == 0;
-  }
-}
-
-
-template <>
-inline void Tensor4D<Tensor4DShape, float>::Print() const {
-  for (int n = 0; n < _shape.n_batches; n++) {
-    for (int c = 0; c < _shape.n_channels; c++) {
-      for (int i = 0; i < _shape.n_rows; i++) {
-        for (int j = 0; j < _shape.n_cols; j++) {
-          printf("%5.2f ", element(n, i, j, c));
-        }
-        printf("\n");
-      }
-      printf("\n");
-    }
-  }
-}
-
-
-template <>
-inline void Tensor4D<KernelShape, float>::Print() const {
-  for (int oc = 0; oc < _shape.n_output_channels; oc++) {
-    for (int ic = 0; ic < _shape.n_input_channels; ic++) {
-      for (int i = 0; i < _shape.n_rows; i++) {
-        for (int j = 0; j < _shape.n_cols; j++) {
-          printf("%5.2f ", element(oc, i, j, ic));
-        }
-        printf("\n");
-      }
-      printf("\n");
-    }
-  }
-}
diff --git a/src/core/NEON/kernels/winograd/utils.hpp b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp
similarity index 61%
copy from src/core/NEON/kernels/winograd/utils.hpp
copy to arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp
index 14e709f..68a5c6a 100644
--- a/src/core/NEON/kernels/winograd/utils.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/tensor_utils.hpp
@@ -1,4 +1,3 @@
-
 /*
  * Copyright (c) 2017 ARM Limited.
  *
@@ -22,34 +21,23 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #pragma once
-#include <ctime>
+#include "tensor.hpp"
 
-inline double TimeInUs(void) {
-#ifdef CYCLE_PROFILING
-  timespec t;
-  clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t);
-  return 1e6*t.tv_sec + 1e-3*t.tv_nsec;
-#else
-  return 0;
-#endif
-}
+// Methods to print tensors and weights
+void PrintTensor(const Tensor4D<Tensor4DShape, float>& tensor);
+void PrintWeights(const Tensor4D<KernelShape, float>& weights);
 
-inline int iceildiv(const int a, const int b) {
-  return (a + b - 1) / b;
-}
+// Test the equivalence of two tensors
+bool CmpTensors(const Tensor4D<Tensor4DShape, float>& a,
+                const Tensor4D<Tensor4DShape, float>& b,
+                const float max_delta=0.0f);
 
-template <typename T>
-inline T roundup(const T a, const T b) {
-  return a + b - (a % b);
-}
+// Fill the tensor with a test pattern
+void TestPattern(Tensor4D<Tensor4DShape, float>& tensor);
+void TestPattern(Tensor4D<KernelShape, float>& weights);
 
-inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) {
-  for (int i = 0; i < M; i++) {
-    for (int j = 0; j < N; j++) {
-      printf("%.3f ", m[i*row_stride + j]);
-    }
-    printf("\n");
-  }
-  printf("\n");
-}
+// Fill the tensor with random values
+void Randomise(Tensor4D<Tensor4DShape, float>& tensor, const int seed=0);
+void Randomise(Tensor4D<KernelShape, float>& weights, const int seed=0);
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp
new file mode 100644
index 0000000..39b4441
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/input.hpp
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "../winograd_gemm.hpp"
+
+namespace winograd
+{
+  /***************************************************************************/
+  /* Instance-less API */
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::execute(
+    const T *inptr,
+    const Tensor4DShape& input_shape,
+    const PaddingType padding_type,
+    const int tile_M,
+    const int tile_N,
+    T *outptr_base,
+    const int matrix_stride,
+    const int matrix_batch_stride,
+    const int matrix_row_stride
+  )
+  {
+    // Compute the padding required on each edge of the image
+    const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0;
+    const int pad_top = base_padding;
+    const int pad_left = base_padding;
+    const int tile_overlap = kernel_rows - 1;
+
+    // Compute striding values (assuming NHWC ordered data)
+    const int input_col_stride = input_shape.n_channels;
+    const int input_row_stride = input_shape.n_cols * input_col_stride;
+    const int input_batch_stride = input_shape.n_rows * input_row_stride;
+    const int output_col_stride = matrix_row_stride;
+    const int output_row_stride = tile_N * output_col_stride;
+
+    // Loop over batches
+    for (int batch = 0; batch < input_shape.n_batches; batch++)
+    {
+      // Pointer to the batch
+      const T* const input_base_batch = inptr + batch * input_batch_stride;
+      T* const outptr_base_batch = outptr_base + batch * matrix_batch_stride;
+
+      // Loop over rows of tiles
+      for (int tile_i = 0; tile_i < tile_M; tile_i++)
+      {
+        // Pointer to the row
+        const int row_offset = (tile_i == 0) ?
+          0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+        const T* const input_base_row = (
+          input_base_batch + ((inner_tile_rows - 2)*tile_i - row_offset)*input_row_stride
+        );
+        T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride;
+
+        // Padding (top + bottom) for the row
+        const int row_top = tile_i*(inner_tile_rows - tile_overlap) - pad_top;
+        const int row_bottom = row_top + inner_tile_rows;
+        const int row_pad_top = (tile_i == 0) ? pad_top : 0;
+        const int row_pad_bottom = (row_bottom <= input_shape.n_rows) ? 0 : row_bottom - input_shape.n_rows;
+
+        // Process the row
+        process_tile_row(
+          tile_N, input_shape.n_channels,
+          input_base_row, input_row_stride, input_col_stride,
+          outptr_base_row, matrix_stride, matrix_row_stride,
+          row_pad_top, pad_left, row_pad_bottom, input_shape.n_cols
+        );
+      }
+    }
+  }
+
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::process_tile_row(
+    const int tile_N,
+    int n_channels,
+    const T* const input_base,
+    const int input_row_stride,
+    const int input_col_stride,
+    T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    const int pad_top,
+    const int row_pad_left,
+    const int pad_bottom,
+    const int n_cols
+  )
+  {
+    constexpr int tile_overlap = kernel_cols - 1;
+
+    // Loop over columns of tiles
+    for (int tile_j = 0; tile_j < tile_N; tile_j++)
+    {
+      // Padding (left + right) for the tile
+      const int t_pad_left = (tile_j == 0) ? row_pad_left : 0;
+      const int t_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_left;
+      const int t_end = t_start + inner_tile_cols;
+      const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols;
+
+      // Get pointers into the inputs and outputs
+      const int col_offset = (tile_j == 0) ? 0 : row_pad_left;
+      const T* const input_base_col = (
+        input_base + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*input_col_stride
+      );
+      T* const outptr = matrix_base + tile_j*matrix_row_stride;
+
+      // Apply the specific tile processing function
+      tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right](
+        n_channels,
+        input_base_col,
+        input_row_stride,
+        input_col_stride,
+        outptr,
+        matrix_stride
+      );
+    }
+  }
+
+  /***************************************************************************/
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::InputTransform(
+    const T* const input,        /** Input tensor data */
+    const int n_batches,         /** Number of batches in input tensor. */
+    const int n_rows,            /** Number of rows in input tensor. */
+    const int n_cols,            /** Number of columns in input tensor. */
+    const int n_channels,        /** Number of channels in input tensor. */
+    const PaddingType padding,   /** Padding type. */
+    T* const output,             /** Base of output matrices. */
+    const int matrix_stride,     /** Stride between output matrices. */
+    const int matrix_row_stride  /** Stride within matrices. */
+  ) : _inptr(input), _outptr(output),
+      _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
+      _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+      _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)),
+      _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)),
+      _padding_type(padding)
+  {
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  unsigned int WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::get_window() const
+  {
+    // TODO When the input transform supports multithreading, return the total
+    // number of tile rows (allowing for multiple batches). For now we return 1
+    // to indicate that the activations must be transformed as a single block.
+    return 1;  // TODO _tiles_M * _n_batches;
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  void WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::run(
+    const unsigned int start, const unsigned int stop
+  )
+  {
+    // TODO When the input transform supports multithreading call execute for a
+    // portion of the tile rows.
+    (void) start;
+    (void) stop;
+
+    // For now, just do all of the work.
+    const Tensor4DShape input_shape = {
+      _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+    };
+    execute(
+      _inptr, input_shape, _padding_type, _tiles_M, _tiles_N, _outptr,
+      _matrix_stride, _matrix_row_stride * _tiles_M * _tiles_N, _matrix_row_stride
+    );
+  }
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp
new file mode 100644
index 0000000..4b54dfd
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel.hpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "winograd_gemm.hpp"
+using namespace winograd;
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::WeightsTransform(
+  const T* const input,
+  T* const output,
+  const int matrix_stride,      /** Stride across matrices in the output. */
+  const int matrix_row_stride,  /** Stride across rows of the matrix. */
+  const int n_output_channels,
+  const int n_input_channels
+) : inptr(input), outptr(output),
+    matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride),
+    n_output_channels(n_output_channels), n_input_channels(n_input_channels)
+{
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+unsigned int WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::get_window() const
+{
+  // TODO When the weights transform supports multithreading, return the number
+  // of output channels. For now we return 1 to indicate that the weights must
+  // be transformed as a single block.
+  // return n_output_channels;
+  return 1;
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+void WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::run(
+  const unsigned int start, const unsigned int stop
+)
+{
+  // TODO When the weights transform supports multithreading call execute for a
+  // portion of the output channels.
+  (void) start;
+  (void) stop;
+
+  // For now, just do all of the work.
+  execute(
+    n_output_channels,
+    n_input_channels,
+    inptr,
+    outptr,
+    matrix_stride,
+    matrix_row_stride
+  );
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp
new file mode 100644
index 0000000..7fa5ee9
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/output.hpp
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "../winograd_gemm.hpp"
+
+namespace winograd
+{
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::execute(
+    const Tensor4DShape &output_shape,
+    const T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    T* const output
+  )
+  {
+    // Compute the number of tiles and hence the padding required on the bottom
+    // and right of the image.
+    const int tile_M = iceildiv(output_shape.n_rows, output_tile_rows);
+    const int tile_N = iceildiv(output_shape.n_cols, output_tile_cols);
+    const int pad_bottom = output_tile_rows*tile_M - output_shape.n_rows;
+    const int pad_right = output_tile_cols*tile_N - output_shape.n_cols;
+
+    const int matrix_tile_row_stride = tile_N * matrix_row_stride;
+    const int matrix_batch_stride = tile_M * matrix_tile_row_stride;
+    const int output_col_stride = output_shape.n_channels;
+    const int output_row_stride = output_shape.n_cols * output_col_stride;
+    const int output_batch_stride = output_shape.n_rows * output_row_stride;
+
+    // Perform the output transformation for each batch
+    for (int batch = 0; batch < output_shape.n_batches; batch++)
+    {
+      // Get batch offset for input and outputs.
+      const T* const matrix_batch = matrix_base + batch*matrix_batch_stride;
+      T* const outptr_batch = output + batch*output_batch_stride;
+
+      // Perform the output transformation for each row of the output tensor.
+      for (int tile_i = 0; tile_i < tile_M; tile_i++)
+      {
+        // Compute properties of this row of output tiles
+        const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom;
+        const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride;
+        T* const outptr_row = outptr_batch + output_tile_rows*tile_i*output_row_stride;
+
+        // Process the row
+        process_tile_row(
+          tile_N, output_shape.n_channels, matrix_tile_row, matrix_stride,
+          matrix_row_stride, outptr_row, output_row_stride,
+          output_col_stride, row_pad_bottom, pad_right
+        );
+      }
+    }
+  }
+
+  template <int output_tile_rows, int output_tile_cols,
+            int kernel_rows, int kernel_cols>
+  template <typename T>
+  void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::process_tile_row(
+    const int tile_N,
+    const int n_channels,
+    const T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    T* const output,
+    const int output_row_stride,
+    const int output_col_stride,
+    const int row_pad_bottom,
+    const int row_pad_right
+  )
+  {
+    // Loop over columns of tiles
+    for (int tile_j = 0; tile_j < tile_N; tile_j++)
+    {
+      // Properties of this tile
+      const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right;
+      const T* const matrix_row = matrix_base + tile_j * matrix_row_stride;
+      T* const outptr = output + output_tile_cols*tile_j*output_col_stride;
+
+      // Perform the output transformation
+      tile_fns[row_pad_bottom][tile_pad_right](
+        n_channels, matrix_row, matrix_stride,
+        outptr, output_row_stride, output_col_stride
+      );
+    }
+  }
+
+  template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+  template <typename T>
+  size_t WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::bytes_read(const Tensor4DShape &shape)
+  {
+    const int M = iceildiv(shape.n_rows, output_tile_rows) *
+                  iceildiv(shape.n_cols, output_tile_cols);
+    const int N = shape.n_channels;
+    return inner_tile_rows * inner_tile_cols * M * N * sizeof(T);
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  size_t WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::bytes_written(const Tensor4DShape &shape)
+  {
+    return shape.size() * sizeof(T);
+  }
+
+  template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+  template <typename T>
+  WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::OutputTransform(
+    const T* const matrix_base,
+    const int matrix_stride,
+    const int matrix_row_stride,
+    T* const output,
+    const int n_batches,
+    const int n_rows,
+    const int n_cols,
+    const int n_channels
+  ) : _matrix_base(matrix_base), _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+      _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
+      _tile_M(iceildiv(n_rows, output_tile_rows)), _tile_N(iceildiv(n_cols, output_tile_cols))
+  {
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  unsigned int WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::get_window() const
+  {
+    // TODO When the output transform supports multithreading, return the total
+    // number of tile rows (allowing for multiple batches). For now we return 1
+    // to indicate that the activations must be transformed as a single block.
+    return 1;  // TODO _tile_M * _n_batches;
+  }
+
+  template <int otr, int otc, int kr, int kc>
+  template <typename T>
+  void WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::run(
+    const unsigned int start, const unsigned int stop
+  )
+  {
+    // TODO When the output transform supports multithreading call execute for a
+    // portion of the tile rows.
+    (void) start;
+    (void) stop;
+
+    // For now, just do all of the work.
+    const Tensor4DShape output_shape = {
+      _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+    };
+    execute(
+      output_shape, _matrix_base, _matrix_stride, _matrix_row_stride, _outptr
+    );
+  }
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms.hpp b/arm_compute/core/NEON/kernels/winograd/utils.hpp
similarity index 80%
copy from src/core/NEON/kernels/winograd/transforms.hpp
copy to arm_compute/core/NEON/kernels/winograd/utils.hpp
index 8546ee9..d8b9c3b 100644
--- a/src/core/NEON/kernels/winograd/transforms.hpp
+++ b/arm_compute/core/NEON/kernels/winograd/utils.hpp
@@ -24,6 +24,14 @@
 
 #pragma once
 
-#include "transforms/input_2x2_3x3.hpp"
-#include "transforms/kernel_2x2_3x3.hpp"
-#include "transforms/output_2x2_3x3.hpp"
+double TimeInUs(void);
+void PrintMatrix(const float* const m, const int M, const int N, const int row_stride);
+
+inline int iceildiv(const int a, const int b) {
+  return (a + b - 1) / b;
+}
+
+template <typename T>
+inline T roundup(const T a, const T b) {
+  return a + b - (a % b);
+}
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
new file mode 100644
index 0000000..adca48a
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp
@@ -0,0 +1,441 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "alloc.hpp"
+#include "convolution.hpp"
+#include "gemm.hpp"
+#include "profiler.hpp"
+#include "shims.hpp"
+#include "tensor.hpp"
+#include "utils.hpp"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+// Generic Winograd implementation using GEMM
+namespace winograd
+{
+
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class WinogradGEMM
+{
+  public:
+    // Information about the specific Winograd instance
+    static constexpr int output_tile_rows = OutputTileRows;
+    static constexpr int output_tile_cols = OutputTileCols;
+    static constexpr int kernel_rows = KernelRows;
+    static constexpr int kernel_cols = KernelCols;
+    static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1;  // TODO Check
+    static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1;  // TODO Check
+    static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
+
+    /** Transform weights from the spatial to the Winograd domain. */
+    template <typename T>
+    struct WeightsTransform
+    {
+      /** Get the bytes read during the transform. */
+      static inline size_t bytes_read(const KernelShape &shape)
+      {
+        return shape.size() * sizeof(T);
+      }
+
+      /** Get the bytes written during the transform. */
+      static inline size_t bytes_written(const KernelShape &shape)
+      {
+        const int inner_tile_size = inner_tile_rows * inner_tile_cols;
+        return (inner_tile_size * shape.n_input_channels *
+                shape.n_output_channels * sizeof(T));
+      }
+
+      /** Get the count of operations performed by the transform. */
+      static int ops_performed(const KernelShape &shape);
+
+      /** Apply the transform to a tensor. */
+      static void execute(
+        const int n_output_channels,
+        const int n_input_channels,
+        const T* const input,
+        T* const output,
+        const int matrix_stride,
+        const int matrix_row_stride
+      );
+
+      /** Create a WeightsTransform operator fixed on a given problem and set
+       * of pointers.
+       */
+      WeightsTransform(
+        const T* const input,
+        T* const output,
+        const int matrix_stride,       /** Stride across matrices in the output. */
+        const int matrix_row_stride,   /** Stride across rows of the matrix. */
+        const int n_output_channels,   /** Number of filters. */
+        const int n_input_channels     /** Number of channels in each filter. */
+      );
+
+      /** Get the window of work a given operator can perform. */
+      unsigned int get_window() const;
+
+      /** Perform work upon a window of the input. */
+      void run(const unsigned int start, const unsigned int stop);
+
+      private:
+        const T* const inptr;         /** Fixed pointer to input data. */
+        T* const outptr;              /** Fixed pointer to output memory. */
+        const int matrix_stride;      /** Stride between output matrices. */
+        const int matrix_row_stride;  /** Stride within output matrices. */
+        const int n_output_channels;  /** Number of filters. */
+        const int n_input_channels;   /** Number of channels in each filter. */
+    };
+
+    /** Transform input feature maps from the spatial to the Winograd domain.
+     */
+    template <typename T>
+    struct InputTransform
+    {
+      /** Get the bytes read during the transform. */
+      static size_t bytes_read(const Tensor4DShape &shape)
+      {
+        return shape.size() * sizeof(T);
+      }
+
+      /** Get the bytes written during the transform. */
+      static size_t bytes_written(const Tensor4DShape &shape)
+      {
+        const int M = iceildiv(shape.n_rows, inner_tile_rows) *
+                      iceildiv(shape.n_cols, inner_tile_cols);
+        const int K = shape.n_channels;
+        return inner_tile_rows * inner_tile_cols * M * K * sizeof(T);
+      }
+
+      /** Get the count of operations performed by the transform. */
+      static int ops_performed(const Tensor4DShape &shape);
+
+      /** Apply the transform to a tensor. */
+      static void execute(
+          const T *inptr,
+          const Tensor4DShape& input_shape,
+          const PaddingType padding_type,
+          const int tile_M,
+          const int tile_N,
+          T *outptr_base,
+          const int matrix_stride,
+          const int matrix_batch_stride,
+          const int matrix_row_stride
+      );
+
+      /***********************************************************************/
+      /** Create an InputTransform operator fixed on a given problem and set of
+       * pointers.
+       */
+      InputTransform(
+          const T* const input,        /** Input tensor data */
+          const int n_batches,         /** Number of batches in input tensor. */
+          const int n_rows,            /** Number of rows in input tensor. */
+          const int n_cols,            /** Number of columns in input tensor. */
+          const int n_channels,        /** Number of channels in input tensor. */
+          const PaddingType padding,   /** Padding type. */
+          T* const output,             /** Base of output matrices. */
+          const int matrix_stride,     /** Stride between output matrices. */
+          const int matrix_row_stride  /** Stride within matrices. */
+      );
+
+      /** Get the winodw of work a given operator can perform. */
+      unsigned int get_window() const;
+
+      /** Perform work upon a window of the input. */
+      void run(const unsigned int start, const unsigned int stop);
+      /***********************************************************************/
+
+      private:
+        static void process_tile_row(
+          const int tile_N,
+          int n_channels,
+          const T* const input_base,
+          const int input_row_stride,
+          const int input_col_stride,
+          T* const matrix_base,
+          const int matrix_stride,
+          const int matrix_row_stride,
+          const int row_pad_top,
+          const int row_pad_left,
+          const int row_pad_bottom,
+          const int row_pad_right
+        );
+
+        static constexpr int max_pad_bottom = inner_tile_rows - 1;
+        static constexpr int max_pad_right = inner_tile_cols - 1;
+
+        /** Process a single tile of the input tensor. */
+        template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+        static void process_tile(int, const T*, int, int, T*, int);
+
+        // Array of methods to transform tiles of the input tensor.
+        typedef void (*TileFn)(int, const T*, int, int, T*, int);
+        static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right];
+
+        /* Member values for instance-based API. */
+        const T* const _inptr;
+        T* const _outptr;
+        const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride,
+                  _matrix_row_stride, _tiles_M, _tiles_N;
+        const PaddingType _padding_type;
+    };
+
+    /** Transform output feature maps from the Winograd to the spatial domain.
+     */
+    template <typename T>
+    struct OutputTransform
+    {
+      /** Get the bytes read during the transform. */
+      static size_t bytes_read(const Tensor4DShape &shape);
+
+      /** Get the bytes written during the transform. */
+      static size_t bytes_written(const Tensor4DShape &shape);
+
+      /** Get the count of operations performed by the transform. */
+      static int ops_performed(const Tensor4DShape &shape);
+
+      /** Apply the transform to create a tensor. */
+      static void execute(
+        const Tensor4DShape &output_shape,
+        const T* const matrix_base,
+        const int matrix_stride,
+        const int matrix_row_stride,
+        T* const output
+      );
+
+      /***********************************************************************/
+      /** Create an OutputTransform operator fixed on a given problem and set
+       * of pointers.
+       */
+      OutputTransform(
+        const T* const matrix_base,   /** Pointer to base of matrices. */
+        const int matrix_stride,      /** Stride between matrices. */
+        const int matrix_row_stride,  /** Stride within a matrix. */
+        T* const output,              /** Pointer to output tensor. */
+        const int n_batches,          /** Number of batches in output tensor. */
+        const int n_rows,             /** Number of rows in output tensor. */
+        const int n_cols,             /** Number of columns in output tensor. */
+        const int n_channels          /** Number of channels in output tensor. */
+      );
+
+      /** Get the window of work a given operator can perform. */
+      unsigned int get_window() const;
+
+      /** Perform work upon a window of the input. */
+      void run(const unsigned int start, const unsigned int stop);
+      /***********************************************************************/
+
+      private:
+        static void process_tile_row(
+          const int tile_N,
+          const int n_channels,
+          const T* const matrix_base,
+          const int matrix_stride,
+          const int matrix_row_stride,
+          T* const output,
+          const int output_row_stride,
+          const int output_col_stride,
+          const int row_pad_bottom,
+          const int row_pad_right
+        );
+
+        // Limits on the amount of anti-padding to be applied
+        static constexpr int max_pad_bottom = output_tile_rows;
+        static constexpr int max_pad_right = output_tile_cols;
+
+        /** Prepare a single tile of the output tensor. */
+        template <int pad_bottom, int pad_right>
+        static void process_tile(int, const T*, int, T*, int, int);
+
+        // Array of methods to produce tiles of output tensor.
+        typedef void (*TileFn)(int, const T*, int, T*, int, int);
+        static const TileFn tile_fns[max_pad_bottom][max_pad_right];
+
+        /** Member constants for instances of the transform. */
+        const T* const _matrix_base;
+        const int _matrix_stride, _matrix_row_stride;
+        T* const _outptr;
+        const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N;
+    };
+
+    /** Perform a convolution.
+     */
+    template <typename TOut, typename TIn>
+    class Convolution
+    {
+      public:
+        // Information about the typed Winograd instance
+        typedef TOut OutputType;
+        typedef TIn InputType;
+
+        /** Create a new Winograd operator. */
+        Convolution(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding,
+          void *kernel_storage=NULL
+        );
+
+        Convolution(const Convolution&) = delete;
+        Convolution operator=(const Convolution&) = delete;
+
+        /** Create a new Winograd operator and initialise the weights. */
+        Convolution(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding,
+          const TIn* const kernel,
+          void *kernel_storage=NULL,
+          void *transform_working_space=NULL
+        );
+
+        /** Clean up a convolution engine. */
+        ~Convolution();
+
+        /** Transform the weights into the Winograd domain. */
+        template <typename WeightsTransform=WeightsTransform<TIn>>
+        void transform_weights(
+          const TIn* const kernel,
+          void *transform_working_space=NULL
+        );
+
+        /* Apply the Winograd operator to some input. */
+        void execute(
+          TOut* const output,
+          const TIn* const input,
+          void* working_space=NULL,
+          const int n_threads=1
+        );
+
+        /* Apply the Winograd operator to some input. */
+        void execute(
+          TOut* const output,
+          const TIn* const input,
+          const int n_threads
+        );
+
+        /** Get the output shape of a convolution. */
+        static Tensor4DShape get_output_shape(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &in_shape,
+          const PaddingType padding
+        );
+
+        /* Get the memory required to transform the kernel.
+         */
+        static size_t get_kernel_transform_working_size(const KernelShape &shape);
+
+        /** Get the memory required to store the kernel transformed into the
+         * Winograd domain.
+         */
+        static size_t get_kernel_storage_size(const KernelShape &shape);
+
+        /** Get the memory required to store the input tensor transformed into
+         * the Winograd domain.
+         */
+        static size_t get_input_storage_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /** Get the memory required to store the output tensor in the Winograd
+         * domain.
+         */
+        static size_t get_output_storage_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /** Get the memory required to apply a Winograd operator to some input.
+         */
+        static size_t get_working_space_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "input" matrix.
+         */
+        static size_t get_input_matrix_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        static int get_input_matrix_stride(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "output" matrix.
+         */
+        static size_t get_output_matrix_size(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        static int get_output_matrix_stride(
+          const KernelShape &kernel_shape,
+          const Tensor4DShape &input_shape,
+          const PaddingType padding_type
+        );
+
+        /* Get the memory required by a single "kernel" matrix.
+         */
+        static size_t get_kernel_matrix_size(const KernelShape &shape);
+        static int get_kernel_matrix_stride(const KernelShape &shape);
+
+        static constexpr int M_BLOCK = 4;   /** Size of block used by GEMM. */
+        static constexpr int N_BLOCK = 16;  /** Size of block used by GEMM. */
+
+      private:
+        const KernelShape kernel_shape;  /** Shape of the kernel to be applied. */
+        TIn *kernel_matrices[N_GEMMS];   /** Pointers into the kernel matrices. */
+        const int kernel_matrix_row_stride;  /** Stride within the kernel matrices. */
+
+        const bool manage_kernel_storage;  /** Kernel storage is managed by the instance. */
+        void* const _kernel_storage;       /** Base pointer for kernel storage. */
+
+        const Tensor4DShape input_shape;  /** Shape of the input tensor. */
+        const PaddingType padding;        /** Padding applied by the operator. */
+
+        const Tensor4DShape output_shape;  /** Output shape produced by the operator. */
+
+        const int tile_rows;  /** Number of rows of tiles. */
+        const int tile_cols;  /** Number of columns of tiles. */
+        const int M, K, N;    /** Sizes of underlying fundamental matrix multiplications. */
+
+        profiler prof;
+    };
+};
+
+}  // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp
new file mode 100644
index 0000000..a3b3db4
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include <utility>
+
+#include "batched_blocked_gemm.hpp"
+#include "winograd_gemm.hpp"
+
+/** Example of how to construct an ACL-like interface.
+ *
+ * Use `get_weight_storage_size`, `get_input_storage_size` and
+ * `get_output_storage_size` to allocate memory for the convolution engine.
+ * Then create a `WinogradConvolutionLayer`.
+ *
+ * Initialise the weights using `weights_transform.run(...)`.
+ *
+ * For each inference:
+ *   1. Transform the inputs to the Winograd domain using `input_transform.run(...)`
+ *   2. Perform a number of GEMMs using `gemms.run(...)`
+ *   3. Transform the output to the spatial domain using `output_transform.run(...)`
+ */
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+          typename TIn, typename TOut>
+class WinogradConvolutionLayer
+{
+  private:
+    const KernelShape _kernel_shape;
+    const Tensor4DShape _input_shape;
+    const PaddingType _padding;
+    const Tensor4DShape _output_shape;
+    const int _n_output_rows, _n_output_cols;
+    const int _kernel_matrix_stride, _kernel_matrix_row_stride;
+    const int _input_matrix_stride, _input_matrix_row_stride;
+    const int _output_matrix_stride, _output_matrix_row_stride;
+    const int _tile_rows, _tile_cols;
+    const int _m, _k, _n;
+
+  public:
+    using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols>;
+    using WeightsTransform = typename WinogradBase::template WeightsTransform<TIn>;
+    using InputTransform = typename WinogradBase::template InputTransform<TIn>;
+    using WinogradConv = typename WinogradBase::template Convolution<TOut, TIn>;
+    using MultiGEMM = winograd::BatchedBlockedGemm<WinogradConv::M_BLOCK, WinogradConv::N_BLOCK, TIn, TOut>;
+    using OutputTransform = typename WinogradBase::template OutputTransform<TOut>;
+
+    /* Public member variables. */
+    WeightsTransform weights_transform;  /** Operator to transform weights to Winograd domain. */
+    InputTransform input_transform;      /** Operator to transform input to Winograd domain. */
+    MultiGEMM gemms;                     /** Operator to perform multiple GEMMs. */
+    OutputTransform output_transform;    /** Operator to transform output from Winograd domain. */
+
+    /** Determine how much memory (in units of TIn) to allocate for the
+     * transformed weights.
+     */
+    static unsigned int get_weight_storage_size(
+      const int n_output_channels,  /** Number of output feature maps. */
+      const int n_input_channels    /** Number of input feature maps. */
+    );
+
+    /** Determine how much memory (in units of TIn) to allocate for the
+     * transformed input.
+     */
+    static unsigned int get_input_storage_size(
+      const int n_batches,     /** Number of batches in the input tensor. */
+      const int n_channels,    /** Number of feature maps in the input tensor. */
+      const int n_rows,        /** Number of rows in each feature map. */
+      const int n_cols,        /** Number of columns in each feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Determine how much memory (in units of TOut) to allocate for the
+     * (Winograd domain) output.
+     */
+    static unsigned int get_output_storage_size(
+      const int n_batches,          /** Number of batches in the output tensor. */
+      const int n_rows,             /** Number of rows in each feature map of the input tensor. */
+      const int n_cols,             /** Number of columns in each feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Get the shape (rows, cols) of a feature map of the output tensor. */
+    static std::pair<int, int> get_output_feature_map_shape(
+      const int n_input_rows,  /** Number of rows in the input feature map. */
+      const int n_input_cols,  /** Number of columns in the input feature map. */
+      const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+    );
+
+    /** Create a new Winograd convolution layer.
+     */
+    WinogradConvolutionLayer(
+      const int n_batches,          /** Number of batches in the input and output tensors. */
+      const int n_input_channels,   /** Number of feature maps in a batch of the input tensor. */
+      const int n_input_rows,       /** Number of rows in a feature map of the input tensor. */
+      const int n_input_cols,       /** Number of columns in a feature map of the input tensor. */
+      const int n_output_channels,  /** Number of feature maps in the output tensor. */
+      const bool same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+      const TIn* const weights,     /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+      TIn* const weights_storage,   /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+      const TIn* const input,       /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+      TIn* const winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+      TOut* const output,           /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+      TOut* const winograd_output   /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+    );
+};
diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
index 6fecf08..60cdc97 100644
--- a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -73,7 +73,8 @@
     CPPPermute                      _permute_input;
     CPPPermute                      _permute_weights;
     CPPPermute                      _permute_output;
-    Tensor                          _workspace;
+    Tensor                          _input_workspace;
+    Tensor                          _output_workspace;
     Tensor                          _kernel_storage;
     Tensor                          _input_nhwc;
     Tensor                          _output_nhwc;
diff --git a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
index d17630a..24d72ed 100644
--- a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,11 +29,11 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "support/ToolchainSupport.h"
 
-#include "src/core/NEON/kernels/winograd/winograd_gemm.hpp"
+#include "arm_compute/core/NEON/kernels/winograd/winograd_layer.hpp"
 
 namespace
 {
-using T = winograd::Winograd2x2_3x3GEMM<float, float>;
+using T = WinogradConvolutionLayer<2, 2, 3, 3, float, float>;
 } // namespace
 
 namespace arm_compute
@@ -41,11 +41,23 @@
 class Winograd3x3F32::Private
 {
 public:
-    Private(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage)
-        : convolver(kernel_shape, input_shape, padding_type, kernel_storage)
+    Private(
+        const int          n_batches,         /** Number of batches in the input and output tensors. */
+        const int          n_input_channels,  /** Number of feature maps in a batch of the input tensor. */
+        const int          n_input_rows,      /** Number of rows in a feature map of the input tensor. */
+        const int          n_input_cols,      /** Number of columns in a feature map of the input tensor. */
+        const int          n_output_channels, /** Number of feature maps in the output tensor. */
+        const bool         same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+        const float *const weights,           /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+        float *const       weights_storage,   /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+        const float *const input,             /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+        float *const       winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+        float *const       output,            /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+        float *const       winograd_output    /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+    )
+        : convolver(n_batches, n_input_channels, n_input_rows, n_input_cols, n_output_channels, same_padding, weights, weights_storage, input, winograd_input, output, winograd_output)
     {
     }
-
     T convolver;
 };
 
@@ -53,46 +65,62 @@
 {
 }
 
-void Winograd3x3F32::transform_weights(const void *const kernel, void *transform_working_space)
+void Winograd3x3F32::transform_output()
 {
-    _pimpl->convolver.transform_weights(reinterpret_cast<const float *>(kernel), transform_working_space);
+    auto win = _pimpl->convolver.output_transform.get_window();
+    _pimpl->convolver.output_transform.run(0, win);
 }
 
-void Winograd3x3F32::reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space)
+void Winograd3x3F32::transform_input()
 {
-    _pimpl->convolver.reshape_input(input_shape, padding_type, reinterpret_cast<const float *>(input), working_space);
+    auto win = _pimpl->convolver.input_transform.get_window();
+    _pimpl->convolver.input_transform.run(0, win);
 }
 
-void Winograd3x3F32::reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output)
+void Winograd3x3F32::transform_weights()
 {
-#if defined(__aarch64__)
-    _pimpl->convolver.reshape_output(input_shape, padding_type, reinterpret_cast<float *const>(output));
-#else  /* __aarch64__ */
-    ARM_COMPUTE_UNUSED(input_shape);
-    ARM_COMPUTE_UNUSED(padding_type);
-    ARM_COMPUTE_UNUSED(output);
-    ARM_COMPUTE_ERROR("Not implemented");
-#endif /* __aarch64__ */
+    auto win = _pimpl->convolver.weights_transform.get_window();
+    _pimpl->convolver.weights_transform.run(0, win);
 }
 
-Winograd3x3F32::Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage)
-    : _pimpl(support::cpp14::make_unique<Private>(kernel_shape, input_shape, padding_type, kernel_storage))
+Winograd3x3F32::Winograd3x3F32(
+    const int          n_batches,         /** Number of batches in the input and output tensors. */
+    const int          n_input_channels,  /** Number of feature maps in a batch of the input tensor. */
+    const int          n_input_rows,      /** Number of rows in a feature map of the input tensor. */
+    const int          n_input_cols,      /** Number of columns in a feature map of the input tensor. */
+    const int          n_output_channels, /** Number of feature maps in the output tensor. */
+    const bool         same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+    const float *const weights,           /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+    float *const       weights_storage,   /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+    const float *const input,             /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+    float *const       winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+    float *const       output,            /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+    float *const       winograd_output    /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+)
+    : _pimpl(support::cpp14::make_unique<Private>(n_batches, n_input_channels, n_input_rows, n_input_cols, n_output_channels, same_padding, weights, weights_storage, input, winograd_input, output,
+                                                  winograd_output))
 {
 }
 
-size_t NEWinogradLayerKernel::get_kernel_storage_size(const KernelShape &shape)
+unsigned int NEWinogradLayerKernel::get_input_storage_size(const int n_batches, const int n_channels, const int n_rows, const int n_cols, const bool same_padding)
 {
-    return T::get_kernel_storage_size(shape);
+    return T::get_input_storage_size(n_batches, n_channels, n_rows, n_cols, same_padding);
 }
 
-size_t NEWinogradLayerKernel::get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding)
+unsigned int NEWinogradLayerKernel::get_output_storage_size(
+    const int  n_batches,         /** Number of batches in the output tensor. */
+    const int  n_rows,            /** Number of rows in each feature map of the input tensor. */
+    const int  n_cols,            /** Number of columns in each feature map of the input tensor. */
+    const int  n_output_channels, /** Number of feature maps in the output tensor. */
+    const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+)
 {
-    return T::get_working_space_size(input_shape, k_shape, padding);
+    return T::get_output_storage_size(n_batches, n_rows, n_cols, n_output_channels, same_padding);
 }
 
-size_t NEWinogradLayerKernel::get_kernel_transform_working_size(const KernelShape &shape)
+size_t NEWinogradLayerKernel::get_weight_storage_size(const int n_output_channels, const int n_input_channels)
 {
-    return T::get_kernel_transform_working_size(shape);
+    return T::get_weight_storage_size(n_output_channels, n_input_channels);
 }
 
 NEWinogradLayerKernel::NEWinogradLayerKernel()
@@ -105,7 +133,8 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(convolver);
     _convolver = convolver;
     Window win;
-    win.set(Window::DimX, Window::Dimension(0, 15, 1));
+    auto   win_last = _convolver->_pimpl->convolver.gemms.get_window();
+    win.set(Window::DimX, Window::Dimension(0, win_last, 1));
     INEKernel::configure(win);
 }
 
@@ -115,6 +144,6 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     const size_t first_gemm = window.x().start();
     const size_t last_gemm  = window.x().end();
-    _convolver->_pimpl->convolver.execute(first_gemm, last_gemm);
+    _convolver->_pimpl->convolver.gemms.run(first_gemm, last_gemm);
 }
 } // namespace arm_compute
diff --git a/src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp b/src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp
new file mode 100644
index 0000000..52c2db8
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/batched_blocked_gemm.cpp
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "batched_blocked_gemm.hpp"
+#include "gemm.hpp"
+using namespace winograd;
+
+template <const int MB, const int NB, typename TIn, typename TOut>
+BatchedBlockedGemm<MB, NB, TIn, TOut>::BatchedBlockedGemm(
+  const unsigned int n_gemms,
+  const int M, const int K, const int N,
+  const int a_matrix_stride,
+  const int a_row_stride,
+  const int b_matrix_stride,
+  const int b_row_stride,
+  const int c_matrix_stride,
+  const int c_row_stride,
+  const TIn* const a_ptr,
+  const TIn* const b_ptr,
+  TOut* const c_ptr
+) : n_gemms(n_gemms), M(M), N(N), K(K),
+    a_matrix_stride(a_matrix_stride),
+    a_row_stride(a_row_stride),
+    b_matrix_stride(b_matrix_stride),
+    b_row_stride(b_row_stride),
+    c_matrix_stride(c_matrix_stride),
+    c_row_stride(c_row_stride),
+    a_ptr(a_ptr), b_ptr(b_ptr), c_ptr(c_ptr)
+{
+}
+
+template <const int MBlock, const int NBlock, typename TIn, typename TOut>
+unsigned int BatchedBlockedGemm<MBlock, NBlock, TIn, TOut>::get_window() const
+{
+  return n_gemms;
+}
+
+template <const int MBlock, const int NBlock, typename TIn, typename TOut>
+void BatchedBlockedGemm<MBlock, NBlock, TIn, TOut>::run(
+  const unsigned int start, const unsigned int stop
+)
+{
+  // Perform the specified GEMMs
+  for (unsigned int i = start; i < stop; i++)
+  {
+    // Get pointers to the relevant matrices
+    const TIn* const mtr_a = a_ptr + i*a_matrix_stride;
+    const TIn* const mtr_b = b_ptr + i*b_matrix_stride;
+    TOut* const mtr_c = c_ptr + i*c_matrix_stride;
+
+    // Perform the GEMM
+    BlockedGemm<MBlock, NBlock, TIn, TOut>(
+      mtr_a, mtr_b, mtr_c, M, K, N,
+      a_row_stride, b_row_stride, c_row_stride
+    );
+  }
+}
+
+template class winograd::BatchedBlockedGemm<4, 16, float, float>;
+
diff --git a/src/core/NEON/kernels/winograd/perf.h b/src/core/NEON/kernels/winograd/perf.h
deleted file mode 100644
index 11fb0c4..0000000
--- a/src/core/NEON/kernels/winograd/perf.h
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-/* Prototypes from perf.c */
-
-void start_counter(int fd);
-long long get_counter(int fd);
-long long stop_counter(int fd);
-int open_instruction_counter(void);
-int open_cycle_counter(void);
diff --git a/src/core/NEON/kernels/winograd/shims.hpp b/src/core/NEON/kernels/winograd/shims.hpp
deleted file mode 100644
index 249e575..0000000
--- a/src/core/NEON/kernels/winograd/shims.hpp
+++ /dev/null
@@ -1,319 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#pragma once
-
-/** Re-order a weight tensor from [Output feature map x Input feature map x
- *  Height x Width] format to [Height x Width x Input feature map x Output
- *  feature map] format.
- */
-template <typename T>
-inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
-  const T* const in,  // Input in [Output x Input x Height x Width] form
-  T* const out,       // Output in [Height x Width x Input x Output] form
-  const int n_output_feature_maps,
-  const int n_input_feature_maps,
-  const int n_rows,
-  const int n_cols,
-  int in_output_feature_map_stride=0,
-  int in_input_feature_map_stride=0,
-  int in_row_stride=0,
-  int out_row_stride=0,
-  int out_col_stride=0,
-  int out_input_feature_map_stride=0
-);
-
-/** Re-order a weight tensor from [Height x Width x Input feature map x Output
- *  feature map] format to [Output feature map x Input feature map x Height x
- *  Width] format.
- */
-template <typename T>
-inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
-  const T* const in,  // Input in [Height x Width x Input x Output] form
-  T* const out,       // Output in [Output x Input x Height x Width] form
-  const int n_rows,
-  const int n_cols,
-  const int n_input_feature_maps,
-  const int n_output_feature_maps,
-  int in_row_stride=0,
-  int in_col_stride=0,
-  int in_input_feature_map_stride=0,
-  int out_output_feature_map_stride=0,
-  int out_input_feature_map_stride=0,
-  int out_row_stride=0
-);
-
-
-/* Re-order a tensor from NCHW format to NHWC.
- */
-template <typename T>
-inline void nchw_to_nhwc(
-  const T* const in,
-  T* const out,
-  const int n_batches,
-  const int n_channels,
-  const int n_rows,
-  const int n_cols,
-  int in_batch_stride=0,
-  int in_channel_stride=0,
-  int in_row_stride=0,
-  int out_batch_stride=0,
-  int out_row_stride=0,
-  int out_col_stride=0
-)
-{
-  // Fill in the stride values
-  in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
-  in_channel_stride = (in_channel_stride) ? in_channel_stride
-                                          : n_rows * in_row_stride;
-  in_batch_stride = (in_batch_stride) ? in_batch_stride
-                                      : n_channels * in_channel_stride;
-
-  out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
-  out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
-  out_batch_stride = (out_batch_stride) ? out_batch_stride
-                                        : n_rows * out_row_stride;
-
-  // Perform the re-ordering
-  for (int n = 0; n < n_batches; n++)
-  {
-    const T* const in_batch = in + n*in_batch_stride;
-    T* const out_batch = out + n*out_batch_stride;
-
-    for (int i = 0; i < n_rows; i++)
-    {
-      const T* const in_row = in_batch + i*in_row_stride;
-      T* const out_row = out_batch + i*out_row_stride;
-
-      for (int j = 0; j < n_cols; j++)
-      {
-        const T* const in_col = in_row + j;
-        T* const out_col = out_row + j*out_col_stride;
-
-        for (int c = 0; c < n_channels; c++)
-        {
-          const T* const in_channel = in_col + c*in_channel_stride;
-          out_col[c] = *(in_channel);
-        }
-      }
-    }
-  }
-}
-
-/* Re-order a tensor from NHWC format to NCHW.
- */
-template <typename T>
-inline void nhwc_to_nchw(
-  const T* const in,  // Input data in NHWC form
-  T* const out,       // Output data in NCHW form
-  const int n_batches,
-  const int n_rows,
-  const int n_cols,
-  const int n_channels,
-  int in_batch_stride=0,
-  int in_row_stride=0,
-  int in_col_stride=0,
-  int out_batch_stride=0,
-  int out_channel_stride=0,
-  int out_row_stride=0
-)
-{
-  // Fill in stride values
-  in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
-  in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
-  in_batch_stride = (in_batch_stride) ? in_batch_stride
-                                      : n_rows * in_row_stride;
-
-  out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
-  out_channel_stride = (out_channel_stride) ? out_channel_stride
-                                            : n_rows * out_row_stride;
-  out_batch_stride = (out_batch_stride) ? out_batch_stride
-                                        : n_channels * out_channel_stride;
-
-  // Perform the re-ordering
-  // For every batch
-  for (int n = 0; n < n_batches; n++)
-  {
-    const T* const in_batch = in + n*in_batch_stride;
-    T* const out_batch = out + n*out_batch_stride;
-
-    // For every row
-    for (int i = 0; i < n_rows; i++)
-    {
-      const T* const in_i = in_batch + i*in_row_stride;
-      T* const out_i = out_batch + i*out_row_stride;
-
-      // For every column
-      for (int j = 0; j < n_cols; j++)
-      {
-        const T* const in_j = in_i + j*in_col_stride;
-        T* const out_j = out_i + j;
-
-        // For every channel
-        for (int c = 0; c < n_channels; c++)
-        {
-          const T* const in_channel = in_j + c;
-          T* const out_channel = out_j + c*out_channel_stride;
-          *(out_channel) = *(in_channel);
-        }
-      }
-    }
-  }
-}
-
-
-/*****************************************************************************/
-/* Generic weight re-order implementation.
- */
-template <typename T>
-inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
-  const T* const in,  // Input in [Output x Input x Height x Width] form
-  T* const out,       // Output in [Height x Width x Input x Output] form
-  const int n_output_feature_maps,
-  const int n_input_feature_maps,
-  const int n_rows,
-  const int n_cols,
-  int in_output_feature_map_stride,
-  int in_input_feature_map_stride,
-  int in_row_stride,
-  int out_row_stride,
-  int out_col_stride,
-  int out_input_feature_map_stride
-)
-{
-  // Fill in stride values
-  in_row_stride = (in_row_stride)
-    ? in_row_stride
-    : n_cols;
-  in_input_feature_map_stride = (in_input_feature_map_stride)
-    ? in_input_feature_map_stride
-    : n_rows * in_row_stride;
-  in_output_feature_map_stride = (in_output_feature_map_stride)
-    ? in_output_feature_map_stride
-    : n_input_feature_maps * in_input_feature_map_stride;
-
-  out_input_feature_map_stride = (out_input_feature_map_stride)
-    ? out_input_feature_map_stride
-    : n_output_feature_maps;
-  out_col_stride = (out_col_stride)
-    ? out_col_stride
-    : n_input_feature_maps * out_input_feature_map_stride;
-  out_row_stride = (out_row_stride)
-    ? out_row_stride
-    : n_cols * out_col_stride;
-
-  // Perform the re-ordering
-  for (int i = 0; i < n_rows; i++)
-  {
-    const T* const in_row = in + i * in_row_stride;
-    T* out_row = out + i * out_row_stride;
-
-    for (int j = 0; j < n_cols; j++)
-    {
-      const T* const in_col = in_row + j;
-      T* const out_col = out_row + j * out_col_stride;
-
-      for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
-      {
-        const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
-        T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
-
-        for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
-        {
-          const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride;
-          T* const out_ofm = out_ifm + ofm;
-          *(out_ofm) = *(in_ofm);
-        }
-      }
-    }
-  }
-}
-
-/*****************************************************************************/
-/* Generic weight re-order implementation.
- */
-template <typename T>
-inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
-  const T* const in,  // Input in [Height x Width x Input x Output] form
-  T* const out,       // Output in [Output x Input x Height x Width] form
-  const int n_rows,
-  const int n_cols,
-  const int n_input_feature_maps,
-  const int n_output_feature_maps,
-  int in_row_stride,
-  int in_col_stride,
-  int in_input_feature_map_stride,
-  int out_output_feature_map_stride,
-  int out_input_feature_map_stride,
-  int out_row_stride
-)
-{
-  // Fill in the stride values
-  in_input_feature_map_stride = (in_input_feature_map_stride)
-    ? in_input_feature_map_stride
-    : n_output_feature_maps;
-  in_col_stride = (in_col_stride)
-    ? in_col_stride
-    : n_input_feature_maps * in_input_feature_map_stride;
-  in_row_stride = (in_row_stride)
-    ? in_row_stride
-    : n_cols * in_col_stride;
-
-  out_row_stride = (out_row_stride)
-    ? out_row_stride
-    : n_cols;
-  out_input_feature_map_stride = (out_input_feature_map_stride)
-    ? out_input_feature_map_stride
-    : n_rows * out_row_stride;
-  out_output_feature_map_stride = (out_output_feature_map_stride)
-    ? out_output_feature_map_stride
-    : n_input_feature_maps * out_input_feature_map_stride;
-
-  // Perform the re-ordering
-  for (int i = 0; i < n_rows; i++)
-  {
-    const T* const in_row = in + i * in_row_stride;
-    T* const out_row = out + i * out_row_stride;
-
-    for (int j = 0; j < n_cols; j++)
-    {
-      const T* const in_col = in_row + j * in_col_stride;
-      T* const out_col = out_row + j;
-
-      for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
-      {
-        const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
-        T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
-
-        for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
-        {
-          const T* const in_ofm = in_ifm + ofm;
-          T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride;
-          *(out_ofm) = *(in_ofm);
-        }
-      }
-    }
-  }
-}
-
diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp
deleted file mode 100644
index ca8d012..0000000
--- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3.hpp
+++ /dev/null
@@ -1,639 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp"
-
-
-namespace winograd {
-  /* Transform an input tensor into the Winograd domain.
-   */
-  template <typename T>
-  struct Winograd2x2_3x3GemmInput {
-    static void execute(
-        const T *inptr,
-        const Tensor4DShape& input_shape,
-        const PaddingType padding_type,
-        const int tile_M,
-        const int tile_N,
-        T *outptr_base,
-        const int matrix_stride,
-        const int matrix_batch_stride,
-        const int matrix_row_stride
-    );
-
-    static size_t bytes_read(const Tensor4DShape &input_shape,
-                           const Tensor4DShape &output_shape) {
-      const int tile_rows = iceildiv(output_shape.n_rows, 2);
-      const int tile_cols = iceildiv(output_shape.n_cols, 2);
-      return input_shape.n_batches * tile_rows * (16 + 8*(tile_cols - 1)) * input_shape.n_channels * sizeof(T);
-    }
-
-    static int flops_performed(const Tensor4DShape &input_shape,
-                                const Tensor4DShape &output_shape) {
-      const int tile_rows = iceildiv(output_shape.n_rows, 2);
-      const int tile_cols = iceildiv(output_shape.n_cols, 2);
-      return input_shape.n_batches * tile_rows * (32 + 24*(tile_cols - 1)) * input_shape.n_channels;
-    }
-
-    static size_t bytes_written(const Tensor4DShape &input_shape,
-                              const Tensor4DShape &output_shape) {
-      const int tile_rows = iceildiv(output_shape.n_rows, 2);
-      const int tile_cols = iceildiv(output_shape.n_cols, 2);
-      const int M = input_shape.n_batches * tile_rows * tile_cols;
-      return 16 * M * input_shape.n_channels * sizeof(T);
-    }
-
-    protected:
-    template <const PaddingType padding, const int pad_bottom, const int pad_right>
-    static void process_tile_tensor(
-        const int tile_M,      // Number of rows of tiles
-        const int tile_N,      // Number of columns of tiles
-        int n_channels,  // Number of input channels
-        const T* const input,  // Base input pointer (appropriate to batch and channel)
-        const int input_row_stride,  // Stride between rows of the input
-        const int input_col_stride,  // Stride between columns of the input
-        T* const matrix,              // 1st output matrix (appropriate to batch and channel)
-        const int matrix_stride,      // Stride between matrices
-        const int matrix_row_stride   // Stride between rows of the output matrix
-    );
-
-    template <const int pad_top, const int pad_left,
-              const int pad_bottom, const int pad_right,
-              const int proc_channels>
-    static void process_tile_row(
-        const int tile_N,      // Number of tiles in the row
-        const T* const input,  // Base input pointer (appropriate to batch, channel and row)
-        const int input_row_stride,  // Stride between rows of the input
-        const int input_col_stride,  // Stride between columns of the input
-        T* const matrix,              // 1st output matrix (appropriate to batch, channel and row)
-        const int matrix_stride,      // Stride between matrices
-        const int matrix_row_stride   // Stride between rows of the output matrix
-    );
-  };
-
-  template <typename T>
-  struct Winograd2x2_3x3GemmInputChannelwise {
-    static void execute(
-        const T *inptr,
-        const Tensor4DShape& input_shape,
-        const PaddingType padding_type,
-        const int tile_M,
-        const int tile_N,
-        T *outptr_base,
-        const int matrix_stride,
-        const int matrix_batch_stride,
-        const int matrix_row_stride
-    );
-
-    static size_t bytes_read(const Tensor4DShape &input_shape,
-                           const Tensor4DShape &output_shape) {
-      // We read as many bytes as we write
-      return bytes_written(input_shape, output_shape);
-    }
-
-    static int flops_performed(const Tensor4DShape &input_shape,
-                                const Tensor4DShape &output_shape) {
-      const int tile_rows = iceildiv(output_shape.n_rows, 2);
-      const int tile_cols = iceildiv(output_shape.n_cols, 2);
-      return input_shape.n_batches * tile_rows * 32 * tile_cols * input_shape.n_channels;
-    }
-
-    static size_t bytes_written(const Tensor4DShape &input_shape,
-                              const Tensor4DShape &output_shape) {
-      return winograd::Winograd2x2_3x3GemmInput<T>::bytes_written(input_shape, output_shape);
-    }
-
-    protected:
-    typedef void (*tilefunc)(int, const T*, int, int, T*, int);
-    template <const int pad_top,
-              const int pad_left,
-              const int pad_bottom,
-              const int pad_right>
-    static void process_tile(
-        int n_channels,  // Number of channels in the tile
-        const T* const input_base,
-        const int input_row_stride,
-        const int input_col_stride,
-        T* const matrix_base,
-        const int matrix_stride
-    );
-
-    private:
-    template <const int pad_top,
-              const int pad_left,
-              const int pad_bottom,
-              const int pad_right,
-              const int proc_channels>
-    static void _process_tile(
-        int &n_channels, const T* &inptr,
-        const int input_row_stride, const int input_col_stride,
-        T* &outptr, const int matrix_stride
-    );
-  };
-}
-
-/*****************************************************************************/
-// Include specialised implementations here
-#include "input_2x2_3x3/a64_float.hpp"
-#include "input_2x2_3x3/a64_float_channelwise.hpp"
-/*****************************************************************************/
-
-/*****************************************************************************/
-template <typename T>
-void winograd::Winograd2x2_3x3GemmInput<T>::execute(
-    const T *inptr_base,
-    const Tensor4DShape& input_shape,
-    const PaddingType padding_type,
-    const int tile_M,
-    const int tile_N,
-    T *outptr_base,
-    const int matrix_stride,
-    const int matrix_batch_stride,
-    const int matrix_row_stride
-) {
-  // Select an appropriate matrix processing method for the shape and padding
-  // of the input tensor.
-  typedef void (*tensorfunc)(int, int, int, const T*, int, int, T*, int, int);
-  const auto process_tensor = [&padding_type, &input_shape] () -> tensorfunc {
-    if (padding_type == PADDING_VALID) {
-      const int pad_bottom = input_shape.n_rows % 2;
-      const int pad_right = input_shape.n_cols % 2;
-
-      if (pad_bottom == 0 && pad_right == 0) {
-        return process_tile_tensor<PADDING_VALID, 0, 0>;
-      } else if (pad_bottom == 0 && pad_right == 1) {
-        return process_tile_tensor<PADDING_VALID, 0, 1>;
-      } else if (pad_bottom == 1 && pad_right == 0) {
-        return process_tile_tensor<PADDING_VALID, 1, 0>;
-      } else if (pad_bottom == 1 && pad_right == 1) {
-        return process_tile_tensor<PADDING_VALID, 1, 1>;
-      }
-    } else {  // PADDING_SAME
-      const int pad_bottom = 1 + input_shape.n_rows % 2;
-      const int pad_right = 1 + input_shape.n_cols % 2;
-
-      if (pad_bottom == 1 && pad_right == 1) {
-        return process_tile_tensor<PADDING_SAME, 1, 1>;
-      } else if (pad_bottom == 1 && pad_right == 2) {
-        return process_tile_tensor<PADDING_SAME, 1, 2>;
-      } else if (pad_bottom == 2 && pad_right == 1) {
-        return process_tile_tensor<PADDING_SAME, 2, 1>;
-      } else if (pad_bottom == 2 && pad_right == 2) {
-        return process_tile_tensor<PADDING_SAME, 2, 2>;
-      }
-    }
-
-    printf("%s::%u Uncovered case.\n", __FILE__, __LINE__);
-    exit(-1);
-    return NULL;  // No function found
-  } ();
-
-  // Compute strides
-  const int input_row_stride = input_shape.n_cols * input_shape.n_channels;
-  const int input_col_stride = input_shape.n_channels;
-
-  // Process each batch of the tensor in turn.
-  for (int batch = 0; batch < input_shape.n_batches; batch++) {
-    // Work out pointers
-    const T *inptr = inptr_base + (batch * input_shape.n_rows *
-                                   input_shape.n_cols * input_shape.n_channels);
-    T *outptr = outptr_base + batch * matrix_batch_stride;
-
-    // Delegate doing the actual work
-    process_tensor(
-      tile_M, tile_N, input_shape.n_channels,
-      inptr, input_row_stride, input_col_stride,
-      outptr, matrix_stride, matrix_row_stride
-    );
-  }
-}
-
-/*****************************************************************************/
-template <typename T>
-template <const PaddingType padding, const int pad_bottom, const int pad_right>
-void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_tensor(
-    const int tile_M,      // Number of rows of tiles
-    const int tile_N,      // Number of columns of tiles
-    int n_channels,  // Number of input channels
-    const T* const input,  // Base input pointer (appropriate to batch and channel)
-    const int input_row_stride,  // Stride between rows of the input
-    const int input_col_stride,  // Stride between columns of the input
-    T* const matrix,              // 1st output matrix (appropriate to batch and channel)
-    const int matrix_stride,      // Stride between matrices
-    const int matrix_row_stride   // Stride between rows of the output matrix
-) {
-  // Base row processing functions
-  typedef void (*rowfunc)(int, const T*, int, int, T*, int, int);
-  const rowfunc process_top_row[3] = {
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, 0, pad_right, 1>
-      : process_tile_row<1, 1, 0, pad_right, 1>,
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, 0, pad_right, 2>
-      : process_tile_row<1, 1, 0, pad_right, 2>,
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, 0, pad_right, 4>
-      : process_tile_row<1, 1, 0, pad_right, 4>,
-  };
-  const rowfunc process_middle_row[3] = {
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, 0, pad_right, 1>
-      : process_tile_row<0, 1, 0, pad_right, 1>,
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, 0, pad_right, 2>
-      : process_tile_row<0, 1, 0, pad_right, 2>,
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, 0, pad_right, 4>
-      : process_tile_row<0, 1, 0, pad_right, 4>,
-  };
-  const rowfunc process_bottom_row[3] = {
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, pad_bottom, pad_right, 1>
-      : process_tile_row<0, 1, pad_bottom, pad_right, 1>,
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, pad_bottom, pad_right, 2>
-      : process_tile_row<0, 1, pad_bottom, pad_right, 2>,
-    (padding == PADDING_VALID)
-      ? process_tile_row<0, 0, pad_bottom, pad_right, 4>
-      : process_tile_row<0, 1, pad_bottom, pad_right, 4>,
-  };
-
-  // Method to get an input pointer for the given tile row
-  const auto get_inptr = [&input, &input_row_stride] (const int tile_i) {
-    if (padding == PADDING_VALID) {
-      return input + 2 * tile_i * input_row_stride;
-    } else {
-      return input + (2 * tile_i - (tile_i ? 1 : 0)) * input_row_stride;
-    }
-  };
-
-  // Wrapper to process a row of tiles, covering all channels.
-  const auto process_row =
-    [tile_N, input_row_stride, input_col_stride, matrix_stride, matrix_row_stride, n_channels]
-    (const rowfunc f[3], const T *inptr, T *outptr) {
-      int rem_channels = n_channels;
-
-      // While there remain channels to process continue to process the
-      // row.
-      for (; rem_channels >= 4; rem_channels -= 4, inptr += 4, outptr += 4) {
-        f[2](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
-      }
-      for (; rem_channels >= 2; rem_channels -= 2, inptr += 2, outptr += 2) {
-        f[1](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
-      }
-      if (rem_channels) {
-        f[0](tile_N, inptr, input_row_stride, input_col_stride, outptr, matrix_stride, matrix_row_stride);
-      }
-  };
-
-  // Process all rows of tiles in the tensor
-  for (int tile_i = 0; tile_i < tile_M; tile_i++) {
-    T* const m_row = matrix + tile_i * tile_N * matrix_row_stride;
-    const T *row_inptr = get_inptr(tile_i);
-
-    if (tile_i == 0) {
-      // Top row of the input
-      process_row(process_top_row, row_inptr, m_row);
-    } else if (tile_i == tile_M - 1) {
-      // Bottom row of the input
-      process_row(process_bottom_row, row_inptr, m_row);
-    } else {
-      // Any other row of the input
-      process_row(process_middle_row, row_inptr, m_row);
-    }
-  }
-}
-
-/*****************************************************************************/
-template <typename T>
-template <const int pad_top, const int pad_left,
-          const int pad_bottom, const int pad_right,
-          const int proc_channels>
-void winograd::Winograd2x2_3x3GemmInput<T>::process_tile_row(
-    const int tile_N,      // Number of tiles in the row
-    const T* const input,  // Base input pointer (appropriate to batch, channel and row)
-    const int input_row_stride,  // Stride between rows of the input
-    const int input_col_stride,  // Stride between columns of the input
-    T* const matrix,              // 1st output matrix (appropriate to batch, channel and row)
-    const int matrix_stride,      // Stride between matrices
-    const int matrix_row_stride   // Stride between rows of the output matrix
-) {
-  // Construct copies of the pointers
-  const T *inptr = input;
-  T *outptr = matrix;
-
-  // Storage for the tensors x, X.T x, and X.T x X.
-  T x[4][4][proc_channels], XTx[4][4][proc_channels], XTxX[4][4][proc_channels];
-
-  // For every tile in the row
-  for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-    // Determine the padding for the tile
-    const int tile_pad_left = (tile_j == 0) ? pad_left : 0;
-    const int tile_pad_right = (tile_j == tile_N - 1) ? pad_right : 0;
-
-    // Load tile values. If this is the first tile in the row then we must load
-    // all values, otherwise we can just load the final two columns of the input.
-    for (int i = 0; i < 4; i++) {
-      for (int j = ((tile_j == 0) ? 0 : 2); j < 4; j++) {
-        // Fill with padding if required
-        if (i < pad_top || 4 - pad_bottom <= i ||
-            j < tile_pad_left || 4 - tile_pad_right <= j) {
-          for (int c = 0; c < proc_channels; c++) {
-            x[i][j][c] = static_cast<T>(0);  // Padding
-          }
-        } else {
-          // Load values, note that the initial padding offsets the pointer we
-          // were provided.
-          for (int c = 0; c < proc_channels; c++) {
-            const int row_offset = (i - pad_top) * input_row_stride;
-            const int col_offset = (j - tile_pad_left) * input_col_stride;
-            x[i][j][c] = inptr[row_offset + col_offset + c];
-          }
-        }
-      }
-    }
-
-    // Compute the matrix X.T x.  Note, can elide operations depending on the
-    // padding. Furthermore, if this isn't the left-most tile we can skip half
-    // of the operations by copying results from the previous version of X.T x.
-    // This latter optimisation can be simplified by unrolling the outermost
-    // loop by two and by renaming the registers containing XTx.
-    if (tile_j == 0) {
-      for (int j = 0; j < 4; j++) {
-        for (int c = 0; c < proc_channels; c++) {
-          XTx[0][j][c] =  x[0][j][c] - x[2][j][c];
-          XTx[1][j][c] =  x[1][j][c] + x[2][j][c];
-          XTx[2][j][c] = -x[1][j][c] + x[2][j][c];
-          XTx[3][j][c] =  x[1][j][c] - x[3][j][c];
-        }
-      }
-    } else {
-      for (int j = 0; j < 2; j++) {
-        for (int c = 0; c < proc_channels; c++) {
-          XTx[0][j][c] = XTx[0][j + 2][c];
-          XTx[1][j][c] = XTx[1][j + 2][c];
-          XTx[2][j][c] = XTx[2][j + 2][c];
-          XTx[3][j][c] = XTx[3][j + 2][c];
-        }
-      }
-      for (int j = 2; j < 4; j++) {
-        for (int c = 0; c < proc_channels; c++) {
-          XTx[0][j][c] =  x[0][j][c] - x[2][j][c];
-          XTx[1][j][c] =  x[1][j][c] + x[2][j][c];
-          XTx[2][j][c] = -x[1][j][c] + x[2][j][c];
-          XTx[3][j][c] =  x[1][j][c] - x[3][j][c];
-        }
-      }
-    }
-
-    // Compute the matrix X.T x X. Note, can elide operations based on the
-    // padding.
-    for (int i = 0; i < 4; i++) {
-      for (int c = 0; c < proc_channels; c++) {
-        XTxX[i][0][c] =  XTx[i][0][c] - XTx[i][2][c];
-        XTxX[i][1][c] =  XTx[i][1][c] + XTx[i][2][c];
-        XTxX[i][2][c] = -XTx[i][1][c] + XTx[i][2][c];
-        XTxX[i][3][c] =  XTx[i][1][c] - XTx[i][3][c];
-      }
-    }
-
-    // Store the output matrix (X.T x X)
-    for (int i = 0; i < 4; i++) {
-      for (int j = 0; j < 4; j++) {
-        // Get a pointer to the relevant output matrix
-        T *mptr = outptr + (i*4 + j)*matrix_stride;
-
-        // Write out the channels
-        for (int c = 0; c < proc_channels; c++) {
-          mptr[c] = XTxX[i][j][c];
-        }
-      }
-    }
-
-    // Update the pointers
-    inptr += input_col_stride * ((tile_j == 0 && pad_left) ? 1 : 2);
-    outptr += matrix_row_stride;
-  }
-}
-
-/*****************************************************************************/
-template <typename T>
-void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::execute(
-    const T *inptr,
-    const Tensor4DShape& input_shape,
-    const PaddingType padding_type,
-    const int tile_M,
-    const int tile_N,
-    T *outptr_base,
-    const int matrix_stride,
-    const int matrix_batch_stride,
-    const int matrix_row_stride
-) {
-  const int n_channels = input_shape.n_channels;
-  const int input_col_stride = n_channels;
-  const int input_row_stride = input_shape.n_cols * input_col_stride;
-
-  // Determine the padding and hence select appropriate methods for each tile.
-  tilefunc fs[3][3];
-
-  if (padding_type == PADDING_VALID) {
-    constexpr int pad_top = 0;
-    constexpr int pad_left = 0;
-    const int pad_right = input_shape.n_cols % 2 == 0;
-
-    fs[0][0] = process_tile<pad_top, pad_left, 0, 0>;
-    fs[0][1] = process_tile<pad_top, 0, 0, 0>;
-    fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 0> : process_tile<pad_top, 0, 0, 1>;
-
-    fs[1][0] = process_tile<0, pad_left, 0, 0>;
-    fs[1][1] = process_tile<0, 0, 0, 0>;
-    fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 0> : process_tile<0, 0, 0, 1>;
-
-    if (input_shape.n_rows % 2 == 0) {
-      constexpr int pad_bottom = 0;
-      fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
-      fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
-      fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>;
-    } else {
-      constexpr int pad_bottom = 1;
-      fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
-      fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
-      fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 0> : process_tile<0, 0, pad_bottom, 1>;
-    }
-  } else {
-    constexpr int pad_top = 1;
-    constexpr int pad_left = 1;
-    const int pad_right = input_shape.n_cols % 2 == 0;
-
-    fs[0][0] = process_tile<pad_top, pad_left, 0, 0>;
-    fs[0][1] = process_tile<pad_top, 0, 0, 0>;
-    fs[0][2] = (pad_right) ? process_tile<pad_top, 0, 0, 1> : process_tile<pad_top, 0, 0, 2>;
-
-    fs[1][0] = process_tile<0, pad_left, 0, 0>;
-    fs[1][1] = process_tile<0, 0, 0, 0>;
-    fs[1][2] = (pad_right) ? process_tile<0, 0, 0, 1> : process_tile<0, 0, 0, 2>;
-
-    if (input_shape.n_rows % 2 == 0) {
-      constexpr int pad_bottom = 1;
-      fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
-      fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
-      fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>;
-    } else {
-      constexpr int pad_bottom = 2;
-      fs[2][0] = process_tile<0, pad_left, pad_bottom, 0>;
-      fs[2][1] = process_tile<0, 0, pad_bottom, 0>;
-      fs[2][2] = (pad_right) ? process_tile<0, 0, pad_bottom, 1> : process_tile<0, 0, pad_bottom, 2>;
-    }
-  }
-
-  // Process each tile in turn
-  for (int batch = 0; batch < input_shape.n_batches; batch++) {
-    const T* const input_base_batch = inptr + batch*input_shape.n_rows*input_shape.n_cols*n_channels;
-
-    for (int tile_i = 0; tile_i < tile_M; tile_i++) {
-      const int row_offset = (tile_i == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
-      const T* const input_base_row = input_base_batch + (2*tile_i - row_offset)*input_shape.n_cols*n_channels;
-
-      // Select the set of functions for the row
-      const int fs_i = (tile_i == 0) ? 0 : ((tile_i < tile_M - 1) ? 1 : 2);
-
-      for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-        // Select the function for the column
-        const int fs_j = (tile_j == 0) ? 0 : ((tile_j < tile_N - 1) ? 1 : 2);
-        const auto f = fs[fs_i][fs_j];
-
-        // Get pointers into the input and outputs
-        const int col_offset = (tile_j == 0) ? 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
-        const T* const input_base_col = input_base_row + (2*tile_j - col_offset)*n_channels;
-        T* const matrix_base = outptr_base + batch*matrix_batch_stride + (tile_i*tile_N + tile_j)*matrix_row_stride;
-        f(n_channels, input_base_col, input_row_stride, input_col_stride,
-          matrix_base, matrix_stride);
-      }
-    }
-  }
-}
-
-template <typename T>
-template <const int pad_top,
-          const int pad_left,
-          const int pad_bottom,
-          const int pad_right>
-void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::process_tile(
-    int n_channels,  // Number of channels in the tile
-    const T* const input_base,
-    const int input_row_stride,
-    const int input_col_stride,
-    T* const matrix_base,
-    const int matrix_stride
-) {
-  // Copy pointers
-  const T *inptr = input_base;
-  T *outptr = matrix_base;
-
-  // Process channels (modifies inptr, outptr and n_channels)
-  _process_tile<pad_top, pad_left, pad_bottom, pad_right, 4>(
-    n_channels, inptr, input_row_stride, input_col_stride,
-    outptr, matrix_stride
-  );
-  _process_tile<pad_top, pad_left, pad_bottom, pad_right, 2>(
-    n_channels, inptr, input_row_stride, input_col_stride,
-    outptr, matrix_stride
-  );
-  _process_tile<pad_top, pad_left, pad_bottom, pad_right, 1>(
-    n_channels, inptr, input_row_stride, input_col_stride,
-    outptr, matrix_stride
-  );
-}
-
-template <typename T>
-template <const int pad_top,
-          const int pad_left,
-          const int pad_bottom,
-          const int pad_right,
-          const int proc_channels>
-void winograd::Winograd2x2_3x3GemmInputChannelwise<T>::_process_tile(
-    int &n_channels,
-    const T* &inptr, const int input_row_stride, const int input_col_stride,
-    T* &outptr, const int matrix_stride
-) {
-  // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
-  // offsets to access the intermediate matrices.
-  T* outptrs[4] = {
-    outptr,
-    outptr + matrix_stride * 4,
-    outptr + matrix_stride * 8,
-    outptr + matrix_stride * 12
-  };
-
-  // The matrix X; zeroed to account for padding.
-  T x[4][4];
-  for (int i = 0; i < 4; i++) {
-    for (int j = 0; j < 4; j++) {
-      x[i][j] = 0;
-    }
-  }
-
-  // The matrices X.T x and U
-  T XTx[4][4], U[4][4];
-
-  // Now progress through each channel
-  for (; n_channels >= proc_channels; n_channels -= proc_channels) {
-    for (int n = 0; n < proc_channels; n++) {
-      // Load the matrix X
-      for (int cell_i = pad_top, i = 0; cell_i < 4 - pad_bottom; cell_i++, i++) {
-        for (int cell_j = pad_left, j = 0; cell_j < 4 - pad_right; cell_j++, j++) {
-          x[cell_i][cell_j] = inptr[i*input_row_stride + j*input_col_stride];
-        }
-      }
-      inptr++;
-
-      // Compute the matrix X.T
-      for (int j = 0; j < 4; j++) {
-        XTx[0][j] = x[0][j] - x[2][j];
-        XTx[1][j] = x[1][j] + x[2][j];
-        XTx[2][j] = x[2][j] - x[1][j];
-        XTx[3][j] = x[1][j] - x[3][j];
-      }
-
-      // Hence compute the matrix U
-      for (int i = 0; i < 4; i++) {
-        U[i][0] = XTx[i][0] - XTx[i][2];
-        U[i][1] = XTx[i][1] + XTx[i][2];
-        U[i][2] = XTx[i][2] - XTx[i][1];
-        U[i][3] = XTx[i][1] - XTx[i][3];
-      }
-
-      // Store the matrix U
-      for (int i = 0; i < 4; i++) {
-        for (int j = 0; j < 4; j++) {
-          outptrs[i][j * matrix_stride] = U[i][j];
-        }
-        outptrs[i]++;
-      }
-    }
-  }
-
-  // Update the output pointer for future calls
-  outptr = outptrs[0];
-}
diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp
deleted file mode 100644
index a99cbe3..0000000
--- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,1498 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include "../input_2x2_3x3.hpp"
-
-#ifdef __aarch64__
-namespace winograd {
-
-// Pad left by one column, pad right by one column, no upper or lower padding, 4 channels
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 0, 1, 4>(
-    const int tile_N,            // Number of tiles in the row
-    const float* const input,    // Base input pointer (appropriate to batch, channel and row)
-    const int input_row_stride,  // Stride between rows of the input
-    const int input_col_stride,  // Stride between columns of the input
-    float* const matrix,         // 1st output matrix (appropriate to batch, channel and row)
-    const int matrix_stride,     // Stride between matrices
-    const int matrix_row_stride  // Stride between rows of the output matrix
-) {
-  /* SIMD register allocation
-   * ========================
-   *
-   * In the following code we read 4x4 tiles of a matrix `x`, with which we
-   * compute another matrix `X.T x` where:
-   *
-   *         /  1  0  0  0 \
-   *     X = |  0  1 -1  1 |
-   *         | -1  1  1  0 |
-   *         \  0  0  0 -1 /
-   *
-   * Hence, `X.T` is a program which operates upon rows of the matrix `X`.
-   * We subsequently compute and store the matrix `U = (X.T x) X`.
-   *
-   * Importantly, each iteration of the loop below loads a new matrix `x'`
-   * where the final two columns of `x'` are the first two columns of the
-   * previous `x`. That is:
-   *
-   *   x11  x12  x13  x14
-   *   x21  x22  x23  x24
-   *   x31  x32  x33  x34
-   *   x41  x42  x43  x44
-   *
-   *            x'11 x'12 x'13 x'14
-   *            x'21 x'22 x'23 x'24
-   *            x'31 x'32 x'33 x'34
-   *            x'41 x'42 x'43 x'44
-   *
-   * Consequently, while the first iteration of the below loop must load 16
-   * values for `x`, the second need load only 8. *Furthermore*, since we noted
-   * above that the operation `X.T x` was a program which operated upon *rows*
-   * of the matrix `x` it follows that that the relation that `x'[i][1] =
-   * x[i][3]` and `x'[i][2] = x[i][4]` applies also the matrices `X.T x'` and
-   * `X.T x`. That is:
-   *
-   *   (X.T x)11  (X.T x)12  (X.T x)13  (X.T x)14
-   *   (X.T x)21  (X.T x)22  (X.T x)23  (X.T x)24
-   *   (X.T x)31  (X.T x)32  (X.T x)33  (X.T x)34
-   *   (X.T x)41  (X.T x)42  (X.T x)43  (X.T x)44
-   *
-   *                        (X.T x')11 (X.T x')12 (X.T x')13 (X.T x')14
-   *                        (X.T x')12 (X.T x')12 (X.T x')12 (X.T x')12
-   *                        (X.T x')13 (X.T x')13 (X.T x')13 (X.T x')13
-   *                        (X.T x')14 (X.T x')14 (X.T x')14 (X.T x')14
-   *
-   * Hence, as well as not needing to load new values for x'[i][1..2] it is
-   * also unnecessary to recompute values for (X.T x')[i][1..2].
-   *
-   * Following this we break the registers into blocks `A` and `B` used by the
-   * two stages of the unrolled loop. These registers named such that the
-   * latter columns of `A` become the earlier columns of `B` and vice-versa:
-   *
-   *  AXTx11 AXTx12 > AXTx13 AXTx14 |
-   *  AXTx21 AXTx22 > AXTx23 AXTx24 |
-   *  AXTx31 AXTx32 > AXTx33 AXTx34 |
-   *  AXTx41 AXTx42 > AXTx43 AXTx44 |
-   *
-   *  BXTx13 BXTx14 | BXTx11 BXTx12 >
-   *  BXTx23 BXTx24 | BXTx21 BXTx22 >
-   *  BXTx33 BXTx34 | BXTx31 BXTx32 >
-   *  BXTx43 BXTx44 | BXTx41 BXTx42 >
-   *
-   * These 32 named registers require only 16 architectural registers. 1
-   * additional architectural register is used as scratch space and 8
-   * architectural registers are used to load in the values x[1..4][3,4].
-   *
-   * Input and output addressing
-   * ===========================
-   * TODO Description
-   */
-  const float *inptr0 = input;
-  const float *inptr1 = input + input_row_stride;
-  const float *inptr2 = input + input_row_stride * 2;
-  const float *inptr3 = input + input_row_stride * 3;
-
-  float *outptr0 = matrix;
-  float *outptr4 = matrix + matrix_stride * 4;
-  float *outptr8 = matrix + matrix_stride * 8;
-  float *outptr12 = matrix + matrix_stride * 12;
-
-  int tile_j = tile_N;  // Tiles to process
-
-  asm volatile (
-      // Named SIMD registers according to the policy given above
-      // Registers into which to load the latter two columns of `x`
-      "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n"
-      "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
-      "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
-      "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n"
-
-      // Registers for storing X.T x (both A and B halves)
-      "AXTx11 .req  v8\n" "BXTx13 .req  v8\n"
-      "AXTx12 .req  v9\n" "BXTx14 .req  v9\n" "qAXTx12 .req  q9\n"
-      "AXTx21 .req v10\n" "BXTx23 .req v10\n"
-      "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
-      "AXTx31 .req v12\n" "BXTx33 .req v12\n"
-      "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
-      "AXTx41 .req v14\n" "BXTx43 .req v14\n"
-      "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
-      "AXTx13 .req v16\n" "BXTx11 .req v16\n"
-      "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
-      "AXTx23 .req v18\n" "BXTx21 .req v18\n"
-      "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
-      "AXTx33 .req v20\n" "BXTx31 .req v20\n"
-      "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
-      "AXTx43 .req v22\n" "BXTx41 .req v22\n"
-      "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
-
-      // Result register (TODO Does using more registers yield better
-      // performance)
-      "U .req v24\n qU .req q24\n"
-
-      // ----------------------------------------------------------------------
-      // Head of loop
-      //   Loads a complete 4x4 tile of x, computes X.T x, computes and stores
-      //   `U = X.T x X`. Prepares for the 'A' half of the loop.
-      //   NOTE: Since the first tile has the leftmost column padded we can
-      //   skip 4 loads and 4 calculations for the matrix X.T x X.
-
-      // Temporarily alias registers for computing the first (non-padded)
-      // column of x.
-      "x_12 .req v0\n qx_12 .req q0\n"
-      "x_22 .req v1\n qx_22 .req q1\n"
-      "x_32 .req v2\n qx_32 .req q2\n"
-      "x_42 .req v3\n qx_42 .req q3\n"
-
-      "ldr qx_12, [%x[inptr0]]\n"
-      "ldr qx_22, [%x[inptr1]]\n"
-      "ldr qx_32, [%x[inptr2]]\n"
-      "ldr qx_42, [%x[inptr3]]\n"
-
-      "fsub BXTx12.4s, x_12.4s, x_32.4s\n"
-      "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
-      "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
-      "fsub BXTx42.4s, x_22.4s, x_42.4s\n"
-
-      ".unreq x_12\n .unreq qx_12\n"
-      ".unreq x_22\n .unreq qx_22\n"
-      ".unreq x_32\n .unreq qx_32\n"
-      ".unreq x_42\n .unreq qx_42\n"
-
-      // Load and compute latter two columns of the first tile. Progress the
-      // input pointers (by three columns so that the each points are the
-      // second column of the next tile, that is, each points at the first
-      // column which must be read for the next tile.
-      "ldr qx_13, [%x[inptr0], %x[colstride1]]\n"
-      "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
-      "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
-      "ldr qx_43, [%x[inptr3], %x[colstride1]]\n"
-
-      "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
-      "ldr qx_14, [%x[inptr0], %x[colstride2]]\n"
-
-      "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-      "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
-
-      "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-      "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
-
-      "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-      "ldr qx_44, [%x[inptr3], %x[colstride2]]\n"
-
-      "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
-      "add %x[inptr0],  %x[inptr0], %x[colstride3]\n"
-
-      "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
-      "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
-
-      "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
-      "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
-
-      "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
-      "add %x[inptr3], %x[inptr3], %x[colstride3]\n"
-
-      // Compute and store U for the first tile
-      // First row
-      "fneg U.4s, BXTx13.4s\n"
-      "str qU, [%x[outptr0]]\n"
-      "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-      "str qU, [%x[outptr0], %x[mstride1]]\n"
-      "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-      "str qU, [%x[outptr0], %x[mstride2]]\n"
-      "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
-      "str qU, [%x[outptr0], %x[mstride3]]\n"
-      "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-      // Second row
-      "fneg U.4s, BXTx23.4s\n"
-      "str qU, [%x[outptr4]]\n"
-      "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-      "str qU, [%x[outptr4], %x[mstride1]]\n"
-      "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-      "str qU, [%x[outptr4], %x[mstride2]]\n"
-      "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
-      "str qU, [%x[outptr4], %x[mstride3]]\n"
-      "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-      // Third row
-      "fneg U.4s, BXTx33.4s\n"
-      "str qU, [%x[outptr8]]\n"
-      "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-      "str qU, [%x[outptr8], %x[mstride1]]\n"
-      "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-      "str qU, [%x[outptr8], %x[mstride2]]\n"
-      "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
-      "str qU, [%x[outptr8], %x[mstride3]]\n"
-      "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-      // Fourth row, simultaneously load the first column of inputs for the
-      // next tile.
-      "fneg U.4s, BXTx43.4s\n"
-      "str qU, [%x[outptr12]]\n"
-      "ldr qx_13, [%x[inptr0]]\n"
-
-      "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-      "str qU, [%x[outptr12], %x[mstride1]]\n"
-      "ldr qx_23, [%x[inptr1]]\n"
-
-      "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-      "str qU, [%x[outptr12], %x[mstride2]]\n"
-      "ldr qx_33, [%x[inptr2]]\n"
-
-      "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
-      "str qU, [%x[outptr12], %x[mstride3]]\n"
-      "ldr qx_43, [%x[inptr3]]\n"
-
-      "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
-      // Update the loop counter, subtract two to account for both the head and
-      // the tail.
-      "subs %x[tile_j], %x[tile_j], #2\n"
-      "beq 2f\n"  // Jump to "A" tail if out of tiles
-
-      // ----------------------------------------------------------------------
-      "1:"
-        // Start part A
-        // Load last column of this tile (the first column has already been
-        // loaded) and compute latter two columns of X.T x.
-        "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
-        "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
-        "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
-        "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
-        "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
-        "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
-        "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
-        "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
-        "fsub AXTx14.4s, x_14.4s, x_34.4s\n"
-        "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
-        "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
-        "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
-        "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
-        "fsub AXTx44.4s, x_24.4s, x_44.4s\n"
-        "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
-        // Compute and store U.
-        // First row
-        "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-        // Second row
-        "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-        // Third row
-        "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-        // Fourth row
-        "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "ldr qx_13, [%x[inptr0]]\n"
-
-        "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "ldr qx_23, [%x[inptr1]]\n"
-
-        "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "ldr qx_33, [%x[inptr2]]\n"
-
-        "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "ldr qx_43, [%x[inptr3]]\n"
-
-        "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
-        "subs %x[tile_j], %x[tile_j], #1\n"
-        "beq 3f\n"  // Jump to 'B' tail
-
-        // Start part B
-        // Load last column of this tile (the first column has already been
-        // loaded) and compute latter two columns of X.T x.
-        "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
-        "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
-        "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-        "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
-        "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-        "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
-        "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-        "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
-        "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
-        "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
-        "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
-        "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
-        "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
-        "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
-        "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
-        // Compute and store U.
-        // First row
-        "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-        // Second row
-        "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-        // Third row
-        "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-        // Fourth row
-        "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "ldr qx_13, [%x[inptr0]]\n"
-
-        "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "ldr qx_23, [%x[inptr1]]\n"
-
-        "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "ldr qx_33, [%x[inptr2]]\n"
-
-        "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "ldr qx_43, [%x[inptr3]]\n"
-
-        "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-        "subs %x[tile_j], %x[tile_j], #1\n"
-        "bne 1b\n"  // Continue loop, otherwise flow into 'A' tail
-
-      // ----------------------------------------------------------------------
-      "2:"
-        // 'A' tail
-        // Since the final column is padding and the last-but-one column has
-        // already been loaded just compute the 3rd column of `X.T x'.
-        "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
-        "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
-        "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
-        "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
-
-        // Compute and store U. Modified to account for the final column of X.T
-        // x containing padding. Note, it is also unnecessary to update the
-        // output pointers.
-        // First row
-        "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
-
-        // Second row
-        "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
-
-        // Third row
-        "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
-
-        // Fourth row
-        "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
-
-        "b 4f\n"  // Jump to end of function
-
-      // ----------------------------------------------------------------------
-      "3:"
-        // 'B' tail
-        // Since the final column is padding and the last-but-one column has
-        // already been loaded just compute the 3rd column of `X.T x'.
-        "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
-        "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-        "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-        "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-
-        // Compute and store U. Modified to account for the final column of X.T
-        // x containing padding. Note, it is also unnecessary to update the
-        // output pointers.
-        // First row
-        "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
-
-        // Second row
-        "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
-
-        // Third row
-        "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
-
-        // Fourth row
-        "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
-
-      // ----------------------------------------------------------------------
-      "4:"
-        // End of function
-
-      // Clear names
-      ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n"
-      ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
-      ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
-      ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n"
-      ".unreq AXTx11\n" ".unreq BXTx13\n"
-      ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
-      ".unreq AXTx21\n" ".unreq BXTx23\n"
-      ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
-      ".unreq AXTx31\n" ".unreq BXTx33\n"
-      ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
-      ".unreq AXTx41\n" ".unreq BXTx43\n"
-      ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
-      ".unreq AXTx13\n" ".unreq BXTx11\n"
-      ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
-      ".unreq AXTx23\n" ".unreq BXTx21\n"
-      ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
-      ".unreq AXTx33\n" ".unreq BXTx31\n"
-      ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
-      ".unreq AXTx43\n" ".unreq BXTx41\n"
-      ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
-      ".unreq U\n" ".unreq qU\n"
-    : [inptr0] "+r" (inptr0),
-      [inptr1] "+r" (inptr1),
-      [inptr2] "+r" (inptr2),
-      [inptr3] "+r" (inptr3),
-      [outptr0] "+r" (outptr0),
-      [outptr4] "+r" (outptr4),
-      [outptr8] "+r" (outptr8),
-      [outptr12] "+r" (outptr12),
-      [tile_j] "+r" (tile_j)  // Tile counter
-    : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
-      [colstride2] "r" (2 * input_col_stride * sizeof(float)),
-      [colstride3] "r" (3 * input_col_stride * sizeof(float)),
-      [mstride1] "r" (1 * matrix_stride * sizeof(float)),
-      [mstride2] "r" (2 * matrix_stride * sizeof(float)),
-      [mstride3] "r" (3 * matrix_stride * sizeof(float)),
-      [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
-    : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
-      "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
-      "v22", "v23", "v24"
-  );
-}
-
-// Pad top, left and right by 1.
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<1, 1, 0, 1, 4>(
-    const int tile_N,
-    const float* const input,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* const matrix,
-    const int matrix_stride,
-    const int matrix_row_stride
-) {
-  const float *inptr0 = input;
-  const float *inptr1 = input + input_row_stride;
-  const float *inptr2 = input + input_row_stride * 2;
-
-  float *outptr0 = matrix;
-  float *outptr4 = matrix + matrix_stride * 4;
-  float *outptr8 = matrix + matrix_stride * 8;
-  float *outptr12 = matrix + matrix_stride * 12;
-
-  int tile_j = tile_N;  // Tiles to process
-
-  asm volatile (
-      // Named SIMD registers according to the policy given above
-      // Registers into which to load the latter two columns of `x`
-      // NOTE: We need only load the latter three rows since we know that the
-      // first row is padded.
-      "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
-      "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
-      "x_43 .req v3\n qx_43 .req q3\n" "x_44 .req v7\n qx_44 .req q7\n"
-
-      // Registers for storing X.T x (both A and B halves)
-      "AXTx11 .req  v8\n" "BXTx13 .req  v8\n"
-      "AXTx12 .req  v9\n" "BXTx14 .req  v9\n" "qAXTx12 .req  q9\n"
-      "AXTx21 .req v10\n" "BXTx23 .req v10\n"
-      "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
-      "AXTx31 .req v12\n" "BXTx33 .req v12\n"
-      "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
-      "AXTx41 .req v14\n" "BXTx43 .req v14\n"
-      "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
-      "AXTx13 .req v16\n" "BXTx11 .req v16\n"
-      "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
-      "AXTx23 .req v18\n" "BXTx21 .req v18\n"
-      "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
-      "AXTx33 .req v20\n" "BXTx31 .req v20\n"
-      "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
-      "AXTx43 .req v22\n" "BXTx41 .req v22\n"
-      "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
-
-      // Result register (TODO Does using more registers yield better
-      // performance)
-      "U .req v24\n qU .req q24\n"
-
-      // ----------------------------------------------------------------------
-      // Head of loop
-      //   Loads a complete 4x4 tile of x, computes X.T x, computes and stores
-      //   `U = X.T x X`. Prepares for the 'A' half of the loop.
-      //   NOTE: Since the first tile has the leftmost column padded we can
-      //   skip 4 loads and 4 calculations for the matrix X.T x X.
-
-      // Temporarily alias registers for computing the first (non-padded)
-      // column of x.
-      "x_22 .req v1\n qx_22 .req q1\n"
-      "x_32 .req v2\n qx_32 .req q2\n"
-      "x_42 .req v3\n qx_42 .req q3\n"
-
-      "ldr qx_22, [%x[inptr1]]\n"
-      "ldr qx_32, [%x[inptr2]]\n"
-      "ldr qx_42, [%x[inptr3]]\n"
-
-      "fneg BXTx12.4s,          x_32.4s\n"
-      "fadd BXTx22.4s, x_22.4s, x_32.4s\n"
-      "fsub BXTx32.4s, x_32.4s, x_22.4s\n"
-      "fsub BXTx42.4s, x_22.4s, x_42.4s\n"
-
-      ".unreq x_22\n .unreq qx_22\n"
-      ".unreq x_32\n .unreq qx_32\n"
-      ".unreq x_42\n .unreq qx_42\n"
-
-      // Load and compute latter two columns of the first tile. Progress the
-      // input pointers (by three columns so that the each points are the
-      // second column of the next tile, that is, each points at the first
-      // column which must be read for the next tile.
-      "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
-      "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
-      "ldr qx_43, [%x[inptr3], %x[colstride1]]\n"
-
-      "fneg BXTx13.4s,          x_33.4s\n"
-
-      "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-      "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
-
-      "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-      "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
-
-      "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-      "ldr qx_44, [%x[inptr3], %x[colstride2]]\n"
-
-      "fneg BXTx14.4s,          x_34.4s\n"
-
-      "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
-      "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
-
-      "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
-      "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
-
-      "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
-      "add %x[inptr3], %x[inptr3], %x[colstride3]\n"
-
-      // Compute and store U for the first tile
-      // First row
-      "fneg U.4s, BXTx13.4s\n"
-      "str qU, [%x[outptr0]]\n"
-      "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-      "str qU, [%x[outptr0], %x[mstride1]]\n"
-      "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-      "str qU, [%x[outptr0], %x[mstride2]]\n"
-      "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
-      "str qU, [%x[outptr0], %x[mstride3]]\n"
-      "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-      // Second row
-      "fneg U.4s, BXTx23.4s\n"
-      "str qU, [%x[outptr4]]\n"
-      "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-      "str qU, [%x[outptr4], %x[mstride1]]\n"
-      "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-      "str qU, [%x[outptr4], %x[mstride2]]\n"
-      "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
-      "str qU, [%x[outptr4], %x[mstride3]]\n"
-      "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-      // Third row
-      "fneg U.4s, BXTx33.4s\n"
-      "str qU, [%x[outptr8]]\n"
-      "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-      "str qU, [%x[outptr8], %x[mstride1]]\n"
-      "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-      "str qU, [%x[outptr8], %x[mstride2]]\n"
-      "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
-      "str qU, [%x[outptr8], %x[mstride3]]\n"
-      "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-      // Fourth row, simultaneously load the first column of inputs for the
-      // next tile.
-      "fneg U.4s, BXTx43.4s\n"
-      "str qU, [%x[outptr12]]\n"
-
-      "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-      "str qU, [%x[outptr12], %x[mstride1]]\n"
-      "ldr qx_23, [%x[inptr1]]\n"
-
-      "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-      "str qU, [%x[outptr12], %x[mstride2]]\n"
-      "ldr qx_33, [%x[inptr2]]\n"
-
-      "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
-      "str qU, [%x[outptr12], %x[mstride3]]\n"
-      "ldr qx_43, [%x[inptr3]]\n"
-
-      "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
-      // Update the loop counter, subtract two to account for both the head and
-      // the tail.
-      "subs %x[tile_j], %x[tile_j], #2\n"
-      "beq 2f\n"  // Jump to "A" tail if out of tiles
-
-      // ----------------------------------------------------------------------
-      "1:"
-        // Start part A
-        // Load last column of this tile (the first column has already been
-        // loaded) and compute latter two columns of X.T x.
-        "fneg AXTx13.4s,          x_33.4s\n"
-        "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
-        "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
-        "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
-        "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
-        "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
-        "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
-        "fneg AXTx14.4s,          x_34.4s\n"
-        "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
-        "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
-        "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
-        "fsub AXTx44.4s, x_24.4s, x_44.4s\n"
-        "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
-        // Compute and store U.
-        // First row
-        "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-        // Second row
-        "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-        // Third row
-        "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-        // Fourth row
-        "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-
-        "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "ldr qx_23, [%x[inptr1]]\n"
-
-        "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "ldr qx_33, [%x[inptr2]]\n"
-
-        "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "ldr qx_43, [%x[inptr3]]\n"
-
-        "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
-        "subs %x[tile_j], %x[tile_j], #1\n"
-        "beq 3f\n"  // Jump to 'B' tail
-
-        // Start part B
-        // Load last column of this tile (the first column has already been
-        // loaded) and compute latter two columns of X.T x.
-        "fneg BXTx13.4s,          x_33.4s\n"
-        "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-        "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
-        "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-        "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
-        "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-        "ldr qx_44, [%x[inptr3], %x[colstride1]]\n"
-        "fneg BXTx14.4s,          x_34.4s\n"
-        "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
-        "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
-        "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
-        "fsub BXTx44.4s, x_24.4s, x_44.4s\n"
-        "add %x[inptr3], %x[inptr3], %x[colstride2]\n"
-
-        // Compute and store U.
-        // First row
-        "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-        // Second row
-        "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-        // Third row
-        "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-        // Fourth row
-        "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-
-        "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "ldr qx_23, [%x[inptr1]]\n"
-
-        "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "ldr qx_33, [%x[inptr2]]\n"
-
-        "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "ldr qx_43, [%x[inptr3]]\n"
-
-        "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-        "subs %x[tile_j], %x[tile_j], #1\n"
-        "bne 1b\n"  // Continue loop, otherwise flow into 'A' tail
-
-      // ----------------------------------------------------------------------
-      "2:"
-        // 'A' tail
-        // Since the final column is padding and the last-but-one column has
-        // already been loaded just compute the 3rd column of `X.T x'.
-        "fneg AXTx13.4s,          x_33.4s\n"
-        "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
-        "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
-        "fsub AXTx43.4s, x_23.4s, x_43.4s\n"
-
-        // Compute and store U. Modified to account for the final column of X.T
-        // x containing padding. Note, it is also unnecessary to update the
-        // output pointers.
-        // First row
-        "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
-
-        // Second row
-        "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
-
-        // Third row
-        "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
-
-        // Fourth row
-        "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
-
-        "b 4f\n"  // Jump to end of function
-
-      // ----------------------------------------------------------------------
-      "3:"
-        // 'B' tail
-        // Since the final column is padding and the last-but-one column has
-        // already been loaded just compute the 3rd column of `X.T x'.
-        "fneg BXTx13.4s,          x_33.4s\n"
-        "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-        "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-        "fsub BXTx43.4s, x_23.4s, x_43.4s\n"
-
-        // Compute and store U. Modified to account for the final column of X.T
-        // x containing padding. Note, it is also unnecessary to update the
-        // output pointers.
-        // First row
-        "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
-
-        // Second row
-        "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
-
-        // Third row
-        "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
-
-        // Fourth row
-        "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
-
-      // ----------------------------------------------------------------------
-      "4:"
-        // End of function
-
-      // Clear names
-      ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
-      ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
-      ".unreq x_43\n" ".unreq qx_43\n" ".unreq x_44\n" ".unreq qx_44\n"
-      ".unreq AXTx11\n" ".unreq BXTx13\n"
-      ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
-      ".unreq AXTx21\n" ".unreq BXTx23\n"
-      ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
-      ".unreq AXTx31\n" ".unreq BXTx33\n"
-      ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
-      ".unreq AXTx41\n" ".unreq BXTx43\n"
-      ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
-      ".unreq AXTx13\n" ".unreq BXTx11\n"
-      ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
-      ".unreq AXTx23\n" ".unreq BXTx21\n"
-      ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
-      ".unreq AXTx33\n" ".unreq BXTx31\n"
-      ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
-      ".unreq AXTx43\n" ".unreq BXTx41\n"
-      ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
-      ".unreq U\n" ".unreq qU\n"
-    : [inptr1] "+r" (inptr0),  // Offset to account for padded row
-      [inptr2] "+r" (inptr1),  // Offset to account for padded row
-      [inptr3] "+r" (inptr2),  // Offset to account for padded row
-      [outptr0] "+r" (outptr0),
-      [outptr4] "+r" (outptr4),
-      [outptr8] "+r" (outptr8),
-      [outptr12] "+r" (outptr12),
-      [tile_j] "+r" (tile_j)  // Tile counter
-    : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
-      [colstride2] "r" (2 * input_col_stride * sizeof(float)),
-      [colstride3] "r" (3 * input_col_stride * sizeof(float)),
-      [mstride1] "r" (1 * matrix_stride * sizeof(float)),
-      [mstride2] "r" (2 * matrix_stride * sizeof(float)),
-      [mstride3] "r" (3 * matrix_stride * sizeof(float)),
-      [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
-    : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
-      "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
-      "v22", "v23", "v24"
-  );
-}
-
-// Pad left, right and bottom by 1.
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInput<float>::process_tile_row<0, 1, 1, 1, 4>(
-    const int tile_N,
-    const float* const input,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* const matrix,
-    const int matrix_stride,
-    const int matrix_row_stride
-) {
-  const float *inptr0 = input;
-  const float *inptr1 = input + input_row_stride;
-  const float *inptr2 = input + input_row_stride * 2;
-
-  float *outptr0 = matrix;
-  float *outptr4 = matrix + matrix_stride * 4;
-  float *outptr8 = matrix + matrix_stride * 8;
-  float *outptr12 = matrix + matrix_stride * 12;
-
-  int tile_j = tile_N;  // Tiles to process
-
-  asm volatile (
-      // Named SIMD registers according to the policy given above
-      // Registers into which to load the latter two columns of `x`
-      // NOTE: Bottom row is not required since since it is padded.
-      "x_13 .req v0\n qx_13 .req q0\n" "x_14 .req v4\n qx_14 .req q4\n"
-      "x_23 .req v1\n qx_23 .req q1\n" "x_24 .req v5\n qx_24 .req q5\n"
-      "x_33 .req v2\n qx_33 .req q2\n" "x_34 .req v6\n qx_34 .req q6\n"
-
-      // Registers for storing X.T x (both A and B halves)
-      "AXTx11 .req  v8\n" "BXTx13 .req  v8\n"
-      "AXTx12 .req  v9\n" "BXTx14 .req  v9\n" "qAXTx12 .req  q9\n"
-      "AXTx21 .req v10\n" "BXTx23 .req v10\n"
-      "AXTx22 .req v11\n" "BXTx24 .req v11\n" "qAXTx22 .req q11\n"
-      "AXTx31 .req v12\n" "BXTx33 .req v12\n"
-      "AXTx32 .req v13\n" "BXTx34 .req v13\n" "qAXTx32 .req q13\n"
-      "AXTx41 .req v14\n" "BXTx43 .req v14\n"
-      "AXTx42 .req v15\n" "BXTx44 .req v15\n" "qAXTx42 .req q15\n"
-      "AXTx13 .req v16\n" "BXTx11 .req v16\n"
-      "AXTx14 .req v17\n" "BXTx12 .req v17\n" "qBXTx12 .req q17\n"
-      "AXTx23 .req v18\n" "BXTx21 .req v18\n"
-      "AXTx24 .req v19\n" "BXTx22 .req v19\n" "qBXTx22 .req q19\n"
-      "AXTx33 .req v20\n" "BXTx31 .req v20\n"
-      "AXTx34 .req v21\n" "BXTx32 .req v21\n" "qBXTx32 .req q21\n"
-      "AXTx43 .req v22\n" "BXTx41 .req v22\n"
-      "AXTx44 .req v23\n" "BXTx42 .req v23\n" "qBXTx42 .req q23\n"
-
-      // Result register (TODO Does using more registers yield better
-      // performance)
-      "U .req v24\n qU .req q24\n"
-
-      // ----------------------------------------------------------------------
-      // Head of loop
-      //   Loads a complete 4x4 tile of x, computes X.T x, computes and stores
-      //   `U = X.T x X`. Prepares for the 'A' half of the loop.
-      //   NOTE: Since the first tile has the leftmost column padded we can
-      //   skip 4 loads and 4 calculations for the matrix X.T x X.
-
-      // Temporarily alias registers for computing the first (non-padded)
-      // column of x.
-      "x_12 .req v0\n qx_12 .req q0\n"
-      "x_22 .req v1\n qx_22 .req q1\n"
-      "x_32 .req v2\n qx_32 .req q2\n"
-
-      "ldr qx_12, [%x[inptr0]]\n"
-      "ldr qx_22, [%x[inptr1]]\n"
-      "ldr qx_32, [%x[inptr2]]\n"
-
-      "fsub BXTx12.4s,  x_12.4s, x_32.4s\n"
-      "fadd BXTx22.4s,  x_22.4s, x_32.4s\n"
-      "fsub BXTx32.4s,  x_32.4s, x_22.4s\n"
-      "mov  BXTx42.16b, x_22.16b\n"  // Probably should do better
-
-      ".unreq x_12\n .unreq qx_12\n"
-      ".unreq x_22\n .unreq qx_22\n"
-      ".unreq x_32\n .unreq qx_32\n"
-
-      // Load and compute latter two columns of the first tile. Progress the
-      // input pointers (by three columns so that the each points are the
-      // second column of the next tile, that is, each points at the first
-      // column which must be read for the next tile.
-      "ldr qx_13, [%x[inptr0], %x[colstride1]]\n"
-      "ldr qx_23, [%x[inptr1], %x[colstride1]]\n"
-      "ldr qx_33, [%x[inptr2], %x[colstride1]]\n"
-
-      "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
-      "ldr qx_14, [%x[inptr0], %x[colstride2]]\n"
-
-      "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-      "ldr qx_24, [%x[inptr1], %x[colstride2]]\n"
-
-      "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-      "ldr qx_34, [%x[inptr2], %x[colstride2]]\n"
-
-      "mov  BXTx43.16b, x_23.16b\n"
-      "fsub BXTx14.4s,  x_14.4s, x_34.4s\n"
-      "add %x[inptr0],  %x[inptr0], %x[colstride3]\n"
-
-      "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
-      "add %x[inptr1], %x[inptr1], %x[colstride3]\n"
-
-      "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
-      "add %x[inptr2], %x[inptr2], %x[colstride3]\n"
-
-      "mov BXTx44.16b, x_24.16b\n"
-
-      // Compute and store U for the first tile
-      // First row
-      "fneg U.4s, BXTx13.4s\n"
-      "str qU, [%x[outptr0]]\n"
-      "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-      "str qU, [%x[outptr0], %x[mstride1]]\n"
-      "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-      "str qU, [%x[outptr0], %x[mstride2]]\n"
-      "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
-      "str qU, [%x[outptr0], %x[mstride3]]\n"
-      "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-      // Second row
-      "fneg U.4s, BXTx23.4s\n"
-      "str qU, [%x[outptr4]]\n"
-      "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-      "str qU, [%x[outptr4], %x[mstride1]]\n"
-      "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-      "str qU, [%x[outptr4], %x[mstride2]]\n"
-      "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
-      "str qU, [%x[outptr4], %x[mstride3]]\n"
-      "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-      // Third row
-      "fneg U.4s, BXTx33.4s\n"
-      "str qU, [%x[outptr8]]\n"
-      "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-      "str qU, [%x[outptr8], %x[mstride1]]\n"
-      "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-      "str qU, [%x[outptr8], %x[mstride2]]\n"
-      "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
-      "str qU, [%x[outptr8], %x[mstride3]]\n"
-      "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-      // Fourth row, simultaneously load the first column of inputs for the
-      // next tile.
-      "fneg U.4s, BXTx43.4s\n"
-      "str qU, [%x[outptr12]]\n"
-      "ldr qx_13, [%x[inptr0]]\n"
-
-      "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-      "str qU, [%x[outptr12], %x[mstride1]]\n"
-      "ldr qx_23, [%x[inptr1]]\n"
-
-      "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-      "str qU, [%x[outptr12], %x[mstride2]]\n"
-      "ldr qx_33, [%x[inptr2]]\n"
-
-      "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
-      "str qU, [%x[outptr12], %x[mstride3]]\n"
-
-      "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
-      // Update the loop counter, subtract two to account for both the head and
-      // the tail.
-      "subs %x[tile_j], %x[tile_j], #2\n"
-      "beq 2f\n"  // Jump to "A" tail if out of tiles
-
-      // ----------------------------------------------------------------------
-      "1:"
-        // Start part A
-        // Load last column of this tile (the first column has already been
-        // loaded) and compute latter two columns of X.T x.
-        "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
-        "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
-        "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
-        "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
-        "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
-        "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
-        "mov  AXTx43.16b, x_23.16b\n"
-
-        "fsub AXTx14.4s, x_14.4s, x_34.4s\n"
-        "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
-        "fadd AXTx24.4s, x_24.4s, x_34.4s\n"
-        "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
-        "fsub AXTx34.4s, x_34.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
-        "mov  AXTx44.16b, x_24.16b\n"
-
-        // Compute and store U.
-        // First row
-        "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, AXTx12.4s, AXTx14.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-        // Second row
-        "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fsub U.4s, AXTx22.4s, AXTx24.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-        // Third row
-        "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, AXTx32.4s, AXTx34.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-        // Fourth row
-        "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "ldr qx_13, [%x[inptr0]]\n"
-
-        "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "ldr qx_23, [%x[inptr1]]\n"
-
-        "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "ldr qx_33, [%x[inptr2]]\n"
-
-        "fsub U.4s, AXTx42.4s, AXTx44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-
-        "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-
-        "subs %x[tile_j], %x[tile_j], #1\n"
-        "beq 3f\n"  // Jump to 'B' tail
-
-        // Start part B
-        // Load last column of this tile (the first column has already been
-        // loaded) and compute latter two columns of X.T x.
-        "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
-        "ldr qx_14, [%x[inptr0], %x[colstride1]]\n"
-        "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-        "ldr qx_24, [%x[inptr1], %x[colstride1]]\n"
-        "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-        "ldr qx_34, [%x[inptr2], %x[colstride1]]\n"
-        "mov BXTx43.16b, x_23.16b\n"
-
-        "fsub BXTx14.4s, x_14.4s, x_34.4s\n"
-        "add %x[inptr0], %x[inptr0], %x[colstride2]\n"
-        "fadd BXTx24.4s, x_24.4s, x_34.4s\n"
-        "add %x[inptr1], %x[inptr1], %x[colstride2]\n"
-        "fsub BXTx34.4s, x_34.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], %x[colstride2]\n"
-        "mov BXTx44.16b, x_24.16b\n"
-
-        // Compute and store U.
-        // First row
-        "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, BXTx12.4s, BXTx14.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], %x[matrix_row_stride]\n"
-
-        // Second row
-        "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fsub U.4s, BXTx22.4s, BXTx24.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], %x[matrix_row_stride]\n"
-
-        // Third row
-        "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, BXTx32.4s, BXTx34.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], %x[matrix_row_stride]\n"
-
-        // Fourth row
-        "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "ldr qx_13, [%x[inptr0]]\n"
-
-        "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "ldr qx_23, [%x[inptr1]]\n"
-
-        "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "ldr qx_33, [%x[inptr2]]\n"
-
-        "fsub U.4s, BXTx42.4s, BXTx44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-
-        "add %x[outptr12], %x[outptr12], %x[matrix_row_stride]\n"
-        "subs %x[tile_j], %x[tile_j], #1\n"
-        "bne 1b\n"  // Continue loop, otherwise flow into 'A' tail
-
-      // ----------------------------------------------------------------------
-      "2:"
-        // 'A' tail
-        // Since the final column is padding and the last-but-one column has
-        // already been loaded just compute the 3rd column of `X.T x'.
-        "fsub AXTx13.4s, x_13.4s, x_33.4s\n"
-        "fadd AXTx23.4s, x_23.4s, x_33.4s\n"
-        "fsub AXTx33.4s, x_33.4s, x_23.4s\n"
-        "mov  AXTx43.16b, x_23.16b\n"
-
-        // Compute and store U. Modified to account for the final column of X.T
-        // x containing padding. Note, it is also unnecessary to update the
-        // output pointers.
-        // First row
-        "fsub U.4s, AXTx11.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, AXTx12.4s, AXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, AXTx13.4s, AXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "str qAXTx12, [%x[outptr0], %x[mstride3]]\n"
-
-        // Second row
-        "fsub U.4s, AXTx21.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, AXTx22.4s, AXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, AXTx23.4s, AXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "str qAXTx22, [%x[outptr4], %x[mstride3]]\n"
-
-        // Third row
-        "fsub U.4s, AXTx31.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, AXTx32.4s, AXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, AXTx33.4s, AXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "str qAXTx32, [%x[outptr8], %x[mstride3]]\n"
-
-        // Fourth row
-        "fsub U.4s, AXTx41.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fadd U.4s, AXTx42.4s, AXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, AXTx43.4s, AXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "str qAXTx42, [%x[outptr12], %x[mstride3]]\n"
-
-        "b 4f\n"  // Jump to end of function
-
-      // ----------------------------------------------------------------------
-      "3:"
-        // 'B' tail
-        // Since the final column is padding and the last-but-one column has
-        // already been loaded just compute the 3rd column of `X.T x'.
-        "fsub BXTx13.4s, x_13.4s, x_33.4s\n"
-        "fadd BXTx23.4s, x_23.4s, x_33.4s\n"
-        "fsub BXTx33.4s, x_33.4s, x_23.4s\n"
-        "mov  BXTx43.16b, x_23.16b\n"
-
-        // Compute and store U. Modified to account for the final column of X.T
-        // x containing padding. Note, it is also unnecessary to update the
-        // output pointers.
-        // First row
-        "fsub U.4s, BXTx11.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fadd U.4s, BXTx12.4s, BXTx13.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, BXTx13.4s, BXTx12.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "str qBXTx12, [%x[outptr0], %x[mstride3]]\n"
-
-        // Second row
-        "fsub U.4s, BXTx21.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, BXTx22.4s, BXTx23.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fsub U.4s, BXTx23.4s, BXTx22.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "str qBXTx22, [%x[outptr4], %x[mstride3]]\n"
-
-        // Third row
-        "fsub U.4s, BXTx31.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fadd U.4s, BXTx32.4s, BXTx33.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, BXTx33.4s, BXTx32.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "str qBXTx32, [%x[outptr8], %x[mstride3]]\n"
-
-        // Fourth row
-        "fsub U.4s, BXTx41.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fadd U.4s, BXTx42.4s, BXTx43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, BXTx43.4s, BXTx42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "str qBXTx42, [%x[outptr12], %x[mstride3]]\n"
-
-      // ----------------------------------------------------------------------
-      "4:"
-        // End of function
-
-      // Clear names
-      ".unreq x_13\n" ".unreq qx_13\n" ".unreq x_14\n" ".unreq qx_14\n"
-      ".unreq x_23\n" ".unreq qx_23\n" ".unreq x_24\n" ".unreq qx_24\n"
-      ".unreq x_33\n" ".unreq qx_33\n" ".unreq x_34\n" ".unreq qx_34\n"
-      ".unreq AXTx11\n" ".unreq BXTx13\n"
-      ".unreq AXTx12\n" ".unreq BXTx14\n" ".unreq qAXTx12\n"
-      ".unreq AXTx21\n" ".unreq BXTx23\n"
-      ".unreq AXTx22\n" ".unreq BXTx24\n" ".unreq qAXTx22\n"
-      ".unreq AXTx31\n" ".unreq BXTx33\n"
-      ".unreq AXTx32\n" ".unreq BXTx34\n" ".unreq qAXTx32\n"
-      ".unreq AXTx41\n" ".unreq BXTx43\n"
-      ".unreq AXTx42\n" ".unreq BXTx44\n" ".unreq qAXTx42\n"
-      ".unreq AXTx13\n" ".unreq BXTx11\n"
-      ".unreq AXTx14\n" ".unreq BXTx12\n" ".unreq qBXTx12\n"
-      ".unreq AXTx23\n" ".unreq BXTx21\n"
-      ".unreq AXTx24\n" ".unreq BXTx22\n" ".unreq qBXTx22\n"
-      ".unreq AXTx33\n" ".unreq BXTx31\n"
-      ".unreq AXTx34\n" ".unreq BXTx32\n" ".unreq qBXTx32\n"
-      ".unreq AXTx43\n" ".unreq BXTx41\n"
-      ".unreq AXTx44\n" ".unreq BXTx42\n" ".unreq qBXTx42\n"
-      ".unreq U\n" ".unreq qU\n"
-    : [inptr0] "+r" (inptr0),
-      [inptr1] "+r" (inptr1),
-      [inptr2] "+r" (inptr2),
-      [outptr0] "+r" (outptr0),
-      [outptr4] "+r" (outptr4),
-      [outptr8] "+r" (outptr8),
-      [outptr12] "+r" (outptr12),
-      [tile_j] "+r" (tile_j)  // Tile counter
-    : [colstride1] "r" (1 * input_col_stride * sizeof(float)),
-      [colstride2] "r" (2 * input_col_stride * sizeof(float)),
-      [colstride3] "r" (3 * input_col_stride * sizeof(float)),
-      [mstride1] "r" (1 * matrix_stride * sizeof(float)),
-      [mstride2] "r" (2 * matrix_stride * sizeof(float)),
-      [mstride3] "r" (3 * matrix_stride * sizeof(float)),
-      [matrix_row_stride] "r" (matrix_row_stride * sizeof(float))
-    : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
-      "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
-      "v22", "v23", "v24"
-  );
-}
-}
-#endif  // __aarch64__
diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp
deleted file mode 100644
index ad1ad55..0000000
--- a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3/a64_float_channelwise.hpp
+++ /dev/null
@@ -1,961 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include "../input_2x2_3x3.hpp"
-
-#ifdef __aarch64__
-
-namespace winograd {
-
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 0, 4>(
-    int &n_channels,  // Number of channels in the tile
-    const float* &inptr0,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* &outptr0,
-    const int matrix_stride
-) {
-  // We use 4 pointers to point to the starting position on each row and use
-  // three offsets to extract elements from each of the other 3 columns.
-  auto inptr1 = inptr0 + 1*input_row_stride;
-  auto inptr2 = inptr0 + 2*input_row_stride;
-  auto inptr3 = inptr0 + 3*input_row_stride;
-
-  // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
-  // offsets to access the intermediate matrices.
-  auto outptr1 = outptr0 + matrix_stride * 4;
-  auto outptr2 = outptr0 + matrix_stride * 8;
-  auto outptr3 = outptr0 + matrix_stride * 12;
-
-  for (; n_channels > 3; n_channels -= 4) {
-    asm volatile (
-        "X_11 .req  v0\n"  "qX_11 .req  q0\n"
-        "X_12 .req  v1\n"  "qX_12 .req  q1\n"
-        "X_13 .req  v2\n"  "qX_13 .req  q2\n"
-        "X_14 .req  v3\n"  "qX_14 .req  q3\n"
-        "X_21 .req  v4\n"  "qX_21 .req  q4\n"
-        "X_22 .req  v5\n"  "qX_22 .req  q5\n"
-        "X_23 .req  v6\n"  "qX_23 .req  q6\n"
-        "X_24 .req  v7\n"  "qX_24 .req  q7\n"
-        "X_31 .req  v8\n"  "qX_31 .req  q8\n"
-        "X_32 .req  v9\n"  "qX_32 .req  q9\n"
-        "X_33 .req v10\n"  "qX_33 .req q10\n"
-        "X_34 .req v11\n"  "qX_34 .req q11\n"
-        "X_41 .req v12\n"  "qX_41 .req q12\n"
-        "X_42 .req v13\n"  "qX_42 .req q13\n"
-        "X_43 .req v14\n"  "qX_43 .req q14\n"
-        "X_44 .req v15\n"  "qX_44 .req q15\n"
-        "xX_11 .req v16\n"
-        "xX_12 .req v17\n"
-        "xX_13 .req v18\n"
-        "xX_14 .req v19\n"
-        "xX_21 .req v20\n"
-        "xX_22 .req v21\n"
-        "xX_23 .req v22\n"
-        "xX_24 .req v23\n"
-        "xX_31 .req v24\n"
-        "xX_32 .req v25\n"
-        "xX_33 .req v26\n"
-        "xX_34 .req v27\n"
-        "xX_41 .req v28\n"
-        "xX_42 .req v29\n"
-        "xX_43 .req v30\n"
-        "xX_44 .req v31\n"
-        " U .req v0\n"
-        "qU .req q0\n"
-
-        // Load the tile, and compute compute the matrix xX
-        "ldr qX_11, [%x[inptr0]]\n"
-        "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
-        "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
-        "ldr qX_14, [%x[inptr0], %x[colstride3]]\n"
-        "add %x[inptr0], %x[inptr0], #0x10\n"
-
-        "ldr qX_21, [%x[inptr1]]\n"
-        "fsub xX_11.4s, x_11.4s, x_13.4s\n"
-        "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
-        "fadd xX_12.4s, x_12.4s, x_13.4s\n"
-        "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
-        "fsub xX_13.4s, x_13.4s, x_12.4s\n"
-        "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
-        "fsub xX_14.4s, x_12.4s, x_14.4s\n"
-        "add %x[inptr1], %x[inptr1], #0x10\n"
-
-        "ldr qX_31, [%x[inptr2]]\n"
-        "fsub xX_21.4s, x_21.4s, x_23.4s\n"
-        "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
-        "fadd xX_22.4s, x_22.4s, x_23.4s\n"
-        "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
-        "fsub xX_23.4s, x_23.4s, x_22.4s\n"
-        "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
-        "fsub xX_24.4s, x_22.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], #0x10\n"
-
-        "ldr qX_41, [%x[inptr3]]\n"
-        "fsub xX_31.4s, x_31.4s, x_33.4s\n"
-        "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
-        "fadd xX_32.4s, x_32.4s, x_33.4s\n"
-        "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
-        "fsub xX_33.4s, x_33.4s, x_32.4s\n"
-        "ldr qX_44, [%x[inptr3], %x[colstride3]]\n"
-        "fsub xX_34.4s, x_32.4s, x_34.4s\n"
-        "add %x[inptr3], %x[inptr3], #0x10\n"
-
-        // Complete computing xX while beginning to compute and store
-        // $U = X.T x X$
-
-        "fsub xX_41.4s, x_41.4s, x_43.4s\n"
-
-        "fsub U.4s, xX_11.4s, xX_31.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fsub U.4s, xX_12.4s, xX_32.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, xX_13.4s, xX_33.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, xX_14.4s, xX_34.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], #0x10\n"
-
-        "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
-        "fadd U.4s, xX_21.4s, xX_31.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, xX_22.4s, xX_32.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fadd U.4s, xX_23.4s, xX_33.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fadd U.4s, xX_24.4s, xX_34.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], #0x10\n"
-
-        "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
-        "fsub U.4s, xX_31.4s, xX_21.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fsub U.4s, xX_32.4s, xX_22.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, xX_33.4s, xX_23.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, xX_34.4s, xX_24.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], #0x10\n"
-
-        "fsub xX_44.4s, x_42.4s, x_44.4s\n"
-
-        "fsub U.4s, xX_21.4s, xX_41.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fsub U.4s, xX_22.4s, xX_42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, xX_23.4s, xX_43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "fsub U.4s, xX_24.4s, xX_44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "add %x[outptr12], %x[outptr12], #0x10\n"
-
-        ".unreq qU\n"
-        ".unreq U\n"
-        ".unreq X_11\n"  ".unreq qX_11\n"
-        ".unreq X_12\n"  ".unreq qX_12\n"
-        ".unreq X_13\n"  ".unreq qX_13\n"
-        ".unreq X_14\n"  ".unreq qX_14\n"
-        ".unreq X_21\n"  ".unreq qX_21\n"
-        ".unreq X_22\n"  ".unreq qX_22\n"
-        ".unreq X_23\n"  ".unreq qX_23\n"
-        ".unreq X_24\n"  ".unreq qX_24\n"
-        ".unreq X_31\n"  ".unreq qX_31\n"
-        ".unreq X_32\n"  ".unreq qX_32\n"
-        ".unreq X_33\n"  ".unreq qX_33\n"
-        ".unreq X_34\n"  ".unreq qX_34\n"
-        ".unreq X_41\n"  ".unreq qX_41\n"
-        ".unreq X_42\n"  ".unreq qX_42\n"
-        ".unreq X_43\n"  ".unreq qX_43\n"
-        ".unreq X_44\n"  ".unreq qX_44\n"
-        ".unreq xX_11\n"
-        ".unreq xX_12\n"
-        ".unreq xX_13\n"
-        ".unreq xX_14\n"
-        ".unreq xX_21\n"
-        ".unreq xX_22\n"
-        ".unreq xX_23\n"
-        ".unreq xX_24\n"
-        ".unreq xX_31\n"
-        ".unreq xX_32\n"
-        ".unreq xX_33\n"
-        ".unreq xX_34\n"
-        ".unreq xX_41\n"
-        ".unreq xX_42\n"
-        ".unreq xX_43\n"
-        ".unreq xX_44\n"
-
-        : [inptr0] "+r" (inptr0),
-          [inptr1] "+r" (inptr1),
-          [inptr2] "+r" (inptr2),
-          [inptr3] "+r" (inptr3),
-          [outptr0] "+r" (outptr0),
-          [outptr4] "+r" (outptr1),
-          [outptr8] "+r" (outptr2),
-          [outptr12] "+r" (outptr3)
-        : [colstride1] "r" (input_col_stride * sizeof(float)),
-          [colstride2] "r" (input_col_stride * sizeof(float) * 2),
-          [colstride3] "r" (input_col_stride * sizeof(float) * 3),
-          [mstride1] "r" (matrix_stride * sizeof(float)),
-          [mstride2] "r" (matrix_stride * sizeof(float) * 2),
-          [mstride3] "r" (matrix_stride * sizeof(float) * 3)
-        : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-          "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-          "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
-          "v30", "v31"
-    );
-  }
-}
-
-// Pad top by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<1, 0, 0, 0, 4>(
-    int &n_channels,  // Number of channels in the tile
-    const float* &inptr0,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* &outptr0,
-    const int matrix_stride
-) {
-  // We use 4 pointers to point to the starting position on each row and use
-  // three offsets to extract elements from each of the other 3 columns.
-  auto inptr1 = inptr0 + 0*input_row_stride;
-  auto inptr2 = inptr0 + 1*input_row_stride;
-
-  // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
-  // offsets to access the intermediate matrices.
-  auto outptr1 = outptr0 + matrix_stride * 4;
-  auto outptr2 = outptr0 + matrix_stride * 8;
-  auto outptr3 = outptr0 + matrix_stride * 12;
-
-  for (; n_channels > 3; n_channels -= 4) {
-    asm volatile (
-        "X_21 .req  v4\n"  "qX_21 .req  q4\n"
-        "X_22 .req  v5\n"  "qX_22 .req  q5\n"
-        "X_23 .req  v6\n"  "qX_23 .req  q6\n"
-        "X_24 .req  v7\n"  "qX_24 .req  q7\n"
-        "X_31 .req  v8\n"  "qX_31 .req  q8\n"
-        "X_32 .req  v9\n"  "qX_32 .req  q9\n"
-        "X_33 .req v10\n"  "qX_33 .req q10\n"
-        "X_34 .req v11\n"  "qX_34 .req q11\n"
-        "X_41 .req v12\n"  "qX_41 .req q12\n"
-        "X_42 .req v13\n"  "qX_42 .req q13\n"
-        "X_43 .req v14\n"  "qX_43 .req q14\n"
-        "X_44 .req v15\n"  "qX_44 .req q15\n"
-        "xX_21 .req v20\n"
-        "xX_22 .req v21\n"
-        "xX_23 .req v22\n"
-        "xX_24 .req v23\n"
-        "xX_31 .req v24\n"
-        "xX_32 .req v25\n"
-        "xX_33 .req v26\n"
-        "xX_34 .req v27\n"
-        "xX_41 .req v28\n"
-        "xX_42 .req v29\n"
-        "xX_43 .req v30\n"
-        "xX_44 .req v31\n"
-        " U .req v0\n"
-        "qU .req q0\n"
-
-        // Load the tile, and compute compute the matrix xX
-        "ldr qX_21, [%x[inptr1]]\n"
-        "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
-        "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
-        "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
-        "add %x[inptr1], %x[inptr1], #0x10\n"
-
-        "ldr qX_31, [%x[inptr2]]\n"
-        "fsub xX_21.4s, x_21.4s, x_23.4s\n"
-        "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
-        "fadd xX_22.4s, x_22.4s, x_23.4s\n"
-        "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
-        "fsub xX_23.4s, x_23.4s, x_22.4s\n"
-        "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
-        "fsub xX_24.4s, x_22.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], #0x10\n"
-
-        "ldr qX_41, [%x[inptr3]]\n"
-        "fsub xX_31.4s, x_31.4s, x_33.4s\n"
-        "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
-        "fadd xX_32.4s, x_32.4s, x_33.4s\n"
-        "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
-        "fsub xX_33.4s, x_33.4s, x_32.4s\n"
-        "ldr qX_44, [%x[inptr3], %x[colstride3]]\n"
-        "fsub xX_34.4s, x_32.4s, x_34.4s\n"
-        "add %x[inptr3], %x[inptr3], #0x10\n"
-
-        // Complete computing xX while beginning to compute and store
-        // $U = X.T x X$
-
-        "fsub xX_41.4s, x_41.4s, x_43.4s\n"
-
-        "fneg U.4s, xX_31.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fneg U.4s, xX_32.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fneg U.4s, xX_33.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fneg U.4s, xX_34.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], #0x10\n"
-
-        "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
-        "fadd U.4s, xX_21.4s, xX_31.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, xX_22.4s, xX_32.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fadd U.4s, xX_23.4s, xX_33.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fadd U.4s, xX_24.4s, xX_34.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], #0x10\n"
-
-        "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
-        "fsub U.4s, xX_31.4s, xX_21.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fsub U.4s, xX_32.4s, xX_22.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, xX_33.4s, xX_23.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, xX_34.4s, xX_24.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], #0x10\n"
-
-        "fsub xX_44.4s, x_42.4s, x_44.4s\n"
-
-        "fsub U.4s, xX_21.4s, xX_41.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fsub U.4s, xX_22.4s, xX_42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, xX_23.4s, xX_43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "fsub U.4s, xX_24.4s, xX_44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "add %x[outptr12], %x[outptr12], #0x10\n"
-
-        ".unreq qU\n"
-        ".unreq U\n"
-        ".unreq X_21\n"  ".unreq qX_21\n"
-        ".unreq X_22\n"  ".unreq qX_22\n"
-        ".unreq X_23\n"  ".unreq qX_23\n"
-        ".unreq X_24\n"  ".unreq qX_24\n"
-        ".unreq X_31\n"  ".unreq qX_31\n"
-        ".unreq X_32\n"  ".unreq qX_32\n"
-        ".unreq X_33\n"  ".unreq qX_33\n"
-        ".unreq X_34\n"  ".unreq qX_34\n"
-        ".unreq X_41\n"  ".unreq qX_41\n"
-        ".unreq X_42\n"  ".unreq qX_42\n"
-        ".unreq X_43\n"  ".unreq qX_43\n"
-        ".unreq X_44\n"  ".unreq qX_44\n"
-        ".unreq xX_21\n"
-        ".unreq xX_22\n"
-        ".unreq xX_23\n"
-        ".unreq xX_24\n"
-        ".unreq xX_31\n"
-        ".unreq xX_32\n"
-        ".unreq xX_33\n"
-        ".unreq xX_34\n"
-        ".unreq xX_41\n"
-        ".unreq xX_42\n"
-        ".unreq xX_43\n"
-        ".unreq xX_44\n"
-
-        : [inptr1] "+r" (inptr0),  // Offset for missing row
-          [inptr2] "+r" (inptr1),  // Offset for missing row
-          [inptr3] "+r" (inptr2),  // Offset for missing row
-          [outptr0] "+r" (outptr0),
-          [outptr4] "+r" (outptr1),
-          [outptr8] "+r" (outptr2),
-          [outptr12] "+r" (outptr3)
-        : [colstride1] "r" (input_col_stride * sizeof(float)),
-          [colstride2] "r" (input_col_stride * sizeof(float) * 2),
-          [colstride3] "r" (input_col_stride * sizeof(float) * 3),
-          [mstride1] "r" (matrix_stride * sizeof(float)),
-          [mstride2] "r" (matrix_stride * sizeof(float) * 2),
-          [mstride3] "r" (matrix_stride * sizeof(float) * 3)
-        : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-          "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-          "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
-          "v30", "v31"
-    );
-  }
-}
-
-// Pad left by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 1, 0, 0, 4>(
-    int &n_channels,  // Number of channels in the tile
-    const float* &inptr0,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* &outptr0,
-    const int matrix_stride
-) {
-  // We use 4 pointers to point to the starting position on each row and use
-  // three offsets to extract elements from each of the other 3 columns.
-  auto inptr1 = inptr0 + 1*input_row_stride;
-  auto inptr2 = inptr0 + 2*input_row_stride;
-  auto inptr3 = inptr0 + 3*input_row_stride;
-
-  // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
-  // offsets to access the intermediate matrices.
-  auto outptr1 = outptr0 + matrix_stride * 4;
-  auto outptr2 = outptr0 + matrix_stride * 8;
-  auto outptr3 = outptr0 + matrix_stride * 12;
-
-  for (; n_channels > 3; n_channels -= 4) {
-    asm volatile (
-        "X_12 .req  v1\n"  "qX_12 .req  q1\n"
-        "X_13 .req  v2\n"  "qX_13 .req  q2\n"
-        "X_14 .req  v3\n"  "qX_14 .req  q3\n"
-        "X_22 .req  v5\n"  "qX_22 .req  q5\n"
-        "X_23 .req  v6\n"  "qX_23 .req  q6\n"
-        "X_24 .req  v7\n"  "qX_24 .req  q7\n"
-        "X_32 .req  v9\n"  "qX_32 .req  q9\n"
-        "X_33 .req v10\n"  "qX_33 .req q10\n"
-        "X_34 .req v11\n"  "qX_34 .req q11\n"
-        "X_42 .req v13\n"  "qX_42 .req q13\n"
-        "X_43 .req v14\n"  "qX_43 .req q14\n"
-        "X_44 .req v15\n"  "qX_44 .req q15\n"
-        "xX_11 .req v16\n"
-        "xX_12 .req v17\n"
-        "xX_13 .req v18\n"
-        "xX_14 .req v19\n"
-        "xX_21 .req v20\n"
-        "xX_22 .req v21\n"
-        "xX_23 .req v22\n"
-        "xX_24 .req v23\n"
-        "xX_31 .req v24\n"
-        "xX_32 .req v25\n"
-        "xX_33 .req v26\n"
-        "xX_34 .req v27\n"
-        "xX_41 .req v28\n"
-        "xX_42 .req v29\n"
-        "xX_43 .req v30\n"
-        "xX_44 .req v31\n"
-        " U .req v0\n"
-        "qU .req q0\n"
-
-        // Load the tile, and compute compute the matrix xX
-        "ldr qX_12, [%x[inptr0]]\n"
-        "ldr qX_13, [%x[inptr0], %x[colstride1]]\n"
-        "ldr qX_14, [%x[inptr0], %x[colstride2]]\n"
-        "add %x[inptr0], %x[inptr0], #0x10\n"
-
-        "fneg xX_11.4s, x_13.4s\n"
-        "ldr qX_22, [%x[inptr1]]\n"
-        "fadd xX_12.4s, x_12.4s, x_13.4s\n"
-        "ldr qX_23, [%x[inptr1], %x[colstride1]]\n"
-        "fsub xX_13.4s, x_13.4s, x_12.4s\n"
-        "ldr qX_24, [%x[inptr1], %x[colstride2]]\n"
-        "fsub xX_14.4s, x_12.4s, x_14.4s\n"
-        "add %x[inptr1], %x[inptr1], #0x10\n"
-
-        "fneg xX_21.4s, x_23.4s\n"
-        "ldr qX_32, [%x[inptr2]]\n"
-        "fadd xX_22.4s, x_22.4s, x_23.4s\n"
-        "ldr qX_33, [%x[inptr2], %x[colstride1]]\n"
-        "fsub xX_23.4s, x_23.4s, x_22.4s\n"
-        "ldr qX_34, [%x[inptr2], %x[colstride2]]\n"
-        "fsub xX_24.4s, x_22.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], #0x10\n"
-
-        "fneg xX_31.4s, x_33.4s\n"
-        "ldr qX_42, [%x[inptr3]]\n"
-        "fadd xX_32.4s, x_32.4s, x_33.4s\n"
-        "ldr qX_43, [%x[inptr3], %x[colstride1]]\n"
-        "fsub xX_33.4s, x_33.4s, x_32.4s\n"
-        "ldr qX_44, [%x[inptr3], %x[colstride2]]\n"
-        "fsub xX_34.4s, x_32.4s, x_34.4s\n"
-        "add %x[inptr3], %x[inptr3], #0x10\n"
-
-        // Complete computing xX while beginning to compute and store
-        // $U = X.T x X$
-
-        "fneg xX_41.4s, x_43.4s\n"
-
-        "fsub U.4s, xX_11.4s, xX_31.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fsub U.4s, xX_12.4s, xX_32.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, xX_13.4s, xX_33.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, xX_14.4s, xX_34.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], #0x10\n"
-
-        "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
-        "fadd U.4s, xX_21.4s, xX_31.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, xX_22.4s, xX_32.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fadd U.4s, xX_23.4s, xX_33.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fadd U.4s, xX_24.4s, xX_34.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], #0x10\n"
-
-        "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
-        "fsub U.4s, xX_31.4s, xX_21.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fsub U.4s, xX_32.4s, xX_22.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, xX_33.4s, xX_23.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, xX_34.4s, xX_24.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], #0x10\n"
-
-        "fsub xX_44.4s, x_42.4s, x_44.4s\n"
-
-        "fsub U.4s, xX_21.4s, xX_41.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fsub U.4s, xX_22.4s, xX_42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, xX_23.4s, xX_43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "fsub U.4s, xX_24.4s, xX_44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "add %x[outptr12], %x[outptr12], #0x10\n"
-
-        ".unreq X_12\n"  ".unreq qX_12\n"
-        ".unreq X_13\n"  ".unreq qX_13\n"
-        ".unreq X_14\n"  ".unreq qX_14\n"
-        ".unreq X_22\n"  ".unreq qX_22\n"
-        ".unreq X_23\n"  ".unreq qX_23\n"
-        ".unreq X_24\n"  ".unreq qX_24\n"
-        ".unreq X_32\n"  ".unreq qX_32\n"
-        ".unreq X_33\n"  ".unreq qX_33\n"
-        ".unreq X_34\n"  ".unreq qX_34\n"
-        ".unreq X_42\n"  ".unreq qX_42\n"
-        ".unreq X_43\n"  ".unreq qX_43\n"
-        ".unreq X_44\n"  ".unreq qX_44\n"
-        ".unreq xX_11\n"
-        ".unreq xX_12\n"
-        ".unreq xX_13\n"
-        ".unreq xX_14\n"
-        ".unreq xX_21\n"
-        ".unreq xX_22\n"
-        ".unreq xX_23\n"
-        ".unreq xX_24\n"
-        ".unreq xX_31\n"
-        ".unreq xX_32\n"
-        ".unreq xX_33\n"
-        ".unreq xX_34\n"
-        ".unreq xX_41\n"
-        ".unreq xX_42\n"
-        ".unreq xX_43\n"
-        ".unreq xX_44\n"
-        ".unreq U\n"
-        ".unreq qU\n"
-
-        : [inptr0] "+r" (inptr0),
-          [inptr1] "+r" (inptr1),
-          [inptr2] "+r" (inptr2),
-          [inptr3] "+r" (inptr3),
-          [outptr0] "+r" (outptr0),
-          [outptr4] "+r" (outptr1),
-          [outptr8] "+r" (outptr2),
-          [outptr12] "+r" (outptr3)
-        : [colstride1] "r" (input_col_stride * sizeof(float)),
-          [colstride2] "r" (input_col_stride * sizeof(float) * 2),
-          [mstride1] "r" (matrix_stride * sizeof(float)),
-          [mstride2] "r" (matrix_stride * sizeof(float) * 2),
-          [mstride3] "r" (matrix_stride * sizeof(float) * 3)
-        : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-          "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-          "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
-          "v30", "v31"
-    );
-  }
-}
-
-// Pad bottom by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 1, 0, 4>(
-    int &n_channels,  // Number of channels in the tile
-    const float* &inptr0,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* &outptr0,
-    const int matrix_stride
-) {
-  // We use 4 pointers to point to the starting position on each row and use
-  // three offsets to extract elements from each of the other 3 columns.
-  auto inptr1 = inptr0 + 1*input_row_stride;
-  auto inptr2 = inptr0 + 2*input_row_stride;
-
-  // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
-  // offsets to access the intermediate matrices.
-  auto outptr1 = outptr0 + matrix_stride * 4;
-  auto outptr2 = outptr0 + matrix_stride * 8;
-  auto outptr3 = outptr0 + matrix_stride * 12;
-
-  for (; n_channels > 3; n_channels -= 4) {
-    asm volatile (
-        "X_11 .req  v0\n"  "qX_11 .req  q0\n"
-        "X_12 .req  v1\n"  "qX_12 .req  q1\n"
-        "X_13 .req  v2\n"  "qX_13 .req  q2\n"
-        "X_14 .req  v3\n"  "qX_14 .req  q3\n"
-        "X_21 .req  v4\n"  "qX_21 .req  q4\n"
-        "X_22 .req  v5\n"  "qX_22 .req  q5\n"
-        "X_23 .req  v6\n"  "qX_23 .req  q6\n"
-        "X_24 .req  v7\n"  "qX_24 .req  q7\n"
-        "X_31 .req  v8\n"  "qX_31 .req  q8\n"
-        "X_32 .req  v9\n"  "qX_32 .req  q9\n"
-        "X_33 .req v10\n"  "qX_33 .req q10\n"
-        "X_34 .req v11\n"  "qX_34 .req q11\n"
-        "xX_11 .req v16\n"
-        "xX_12 .req v17\n"
-        "xX_13 .req v18\n"
-        "xX_14 .req v19\n"
-        "xX_21 .req v20\n" "qxX_21 .req q20\n"
-        "xX_22 .req v21\n" "qxX_22 .req q21\n"
-        "xX_23 .req v22\n" "qxX_23 .req q22\n"
-        "xX_24 .req v23\n" "qxX_24 .req q23\n"
-        "xX_31 .req v24\n"
-        "xX_32 .req v25\n"
-        "xX_33 .req v26\n"
-        "xX_34 .req v27\n"
-        " U .req v0\n"
-        "qU .req q0\n"
-
-        // Load the tile, and compute compute the matrix xX
-        "ldr qX_11, [%x[inptr0]]\n"
-        "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
-        "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
-        "ldr qX_14, [%x[inptr0], %x[colstride3]]\n"
-        "add %x[inptr0], %x[inptr0], #0x10\n"
-
-        "ldr qX_21, [%x[inptr1]]\n"
-        "fsub xX_11.4s, x_11.4s, x_13.4s\n"
-        "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
-        "fadd xX_12.4s, x_12.4s, x_13.4s\n"
-        "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
-        "fsub xX_13.4s, x_13.4s, x_12.4s\n"
-        "ldr qX_24, [%x[inptr1], %x[colstride3]]\n"
-        "fsub xX_14.4s, x_12.4s, x_14.4s\n"
-        "add %x[inptr1], %x[inptr1], #0x10\n"
-
-        "ldr qX_31, [%x[inptr2]]\n"
-        "fsub xX_21.4s, x_21.4s, x_23.4s\n"
-        "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
-        "fadd xX_22.4s, x_22.4s, x_23.4s\n"
-        "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
-        "fsub xX_23.4s, x_23.4s, x_22.4s\n"
-        "ldr qX_34, [%x[inptr2], %x[colstride3]]\n"
-        "fsub xX_24.4s, x_22.4s, x_24.4s\n"
-        "add %x[inptr2], %x[inptr2], #0x10\n"
-
-        "fsub xX_31.4s, x_31.4s, x_33.4s\n"
-        "fadd xX_32.4s, x_32.4s, x_33.4s\n"
-        "fsub xX_33.4s, x_33.4s, x_32.4s\n"
-        "fsub xX_34.4s, x_32.4s, x_34.4s\n"
-
-        // Complete computing xX while beginning to compute and store
-        // $U = X.T x X$
-
-        "fsub U.4s, xX_11.4s, xX_31.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fsub U.4s, xX_12.4s, xX_32.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, xX_13.4s, xX_33.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, xX_14.4s, xX_34.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], #0x10\n"
-
-        "fadd U.4s, xX_21.4s, xX_31.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, xX_22.4s, xX_32.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fadd U.4s, xX_23.4s, xX_33.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fadd U.4s, xX_24.4s, xX_34.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], #0x10\n"
-
-        "fsub U.4s, xX_31.4s, xX_21.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fsub U.4s, xX_32.4s, xX_22.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, xX_33.4s, xX_23.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, xX_34.4s, xX_24.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], #0x10\n"
-
-        "str qxX_21, [%x[outptr12]]\n"
-        "str qxX_22, [%x[outptr12], %x[mstride1]]\n"
-        "str qxX_23, [%x[outptr12], %x[mstride2]]\n"
-        "str qxX_24, [%x[outptr12], %x[mstride3]]\n"
-        "add %x[outptr12], %x[outptr12], #0x10\n"
-
-        ".unreq qU\n"
-        ".unreq U\n"
-        ".unreq X_11\n"  ".unreq qX_11\n"
-        ".unreq X_12\n"  ".unreq qX_12\n"
-        ".unreq X_13\n"  ".unreq qX_13\n"
-        ".unreq X_14\n"  ".unreq qX_14\n"
-        ".unreq X_21\n"  ".unreq qX_21\n"
-        ".unreq X_22\n"  ".unreq qX_22\n"
-        ".unreq X_23\n"  ".unreq qX_23\n"
-        ".unreq X_24\n"  ".unreq qX_24\n"
-        ".unreq X_31\n"  ".unreq qX_31\n"
-        ".unreq X_32\n"  ".unreq qX_32\n"
-        ".unreq X_33\n"  ".unreq qX_33\n"
-        ".unreq X_34\n"  ".unreq qX_34\n"
-        ".unreq xX_11\n"
-        ".unreq xX_12\n"
-        ".unreq xX_13\n"
-        ".unreq xX_14\n"
-        ".unreq xX_21\n" ".unreq qxX_21\n"
-        ".unreq xX_22\n" ".unreq qxX_22\n"
-        ".unreq xX_23\n" ".unreq qxX_23\n"
-        ".unreq xX_24\n" ".unreq qxX_24\n"
-        ".unreq xX_31\n"
-        ".unreq xX_32\n"
-        ".unreq xX_33\n"
-        ".unreq xX_34\n"
-
-        : [inptr0] "+r" (inptr0),
-          [inptr1] "+r" (inptr1),
-          [inptr2] "+r" (inptr2),
-          [outptr0] "+r" (outptr0),
-          [outptr4] "+r" (outptr1),
-          [outptr8] "+r" (outptr2),
-          [outptr12] "+r" (outptr3)
-        : [colstride1] "r" (input_col_stride * sizeof(float)),
-          [colstride2] "r" (input_col_stride * sizeof(float) * 2),
-          [colstride3] "r" (input_col_stride * sizeof(float) * 3),
-          [mstride1] "r" (matrix_stride * sizeof(float)),
-          [mstride2] "r" (matrix_stride * sizeof(float) * 2),
-          [mstride3] "r" (matrix_stride * sizeof(float) * 3)
-        : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-          "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-          "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
-          "v30", "v31"
-    );
-  }
-}
-
-// Pad right by 1
-template <>
-template <>
-inline void Winograd2x2_3x3GemmInputChannelwise<float>::_process_tile<0, 0, 0, 1, 4>(
-    int &n_channels,  // Number of channels in the tile
-    const float* &inptr0,
-    const int input_row_stride,
-    const int input_col_stride,
-    float* &outptr0,
-    const int matrix_stride
-) {
-  // We use 4 pointers to point to the starting position on each row and use
-  // three offsets to extract elements from each of the other 3 columns.
-  auto inptr1 = inptr0 + 1*input_row_stride;
-  auto inptr2 = inptr0 + 2*input_row_stride;
-  auto inptr3 = inptr0 + 3*input_row_stride;
-
-  // We use 4 pointers to point at matrices 0, 4, 8 and 12 and use three
-  // offsets to access the intermediate matrices.
-  auto outptr1 = outptr0 + matrix_stride * 4;
-  auto outptr2 = outptr0 + matrix_stride * 8;
-  auto outptr3 = outptr0 + matrix_stride * 12;
-
-  for (; n_channels > 3; n_channels -= 4) {
-    asm volatile (
-        "X_11 .req  v0\n"  "qX_11 .req  q0\n"
-        "X_12 .req  v1\n"  "qX_12 .req  q1\n"
-        "X_13 .req  v2\n"  "qX_13 .req  q2\n"
-        "X_21 .req  v4\n"  "qX_21 .req  q4\n"
-        "X_22 .req  v5\n"  "qX_22 .req  q5\n"
-        "X_23 .req  v6\n"  "qX_23 .req  q6\n"
-        "X_31 .req  v8\n"  "qX_31 .req  q8\n"
-        "X_32 .req  v9\n"  "qX_32 .req  q9\n"
-        "X_33 .req v10\n"  "qX_33 .req q10\n"
-        "X_41 .req v12\n"  "qX_41 .req q12\n"
-        "X_42 .req v13\n"  "qX_42 .req q13\n"
-        "X_43 .req v14\n"  "qX_43 .req q14\n"
-        "xX_11 .req v16\n"
-        "xX_12 .req v17\n"
-        "xX_13 .req v18\n"
-        "xX_14 .req x_12\n"
-        "xX_21 .req v20\n"
-        "xX_22 .req v21\n"
-        "xX_23 .req v22\n"
-        "xX_24 .req x_22\n"
-        "xX_31 .req v24\n"
-        "xX_32 .req v25\n"
-        "xX_33 .req v26\n"
-        "xX_34 .req x_32\n"
-        "xX_41 .req v28\n"
-        "xX_42 .req v29\n"
-        "xX_43 .req v30\n"
-        "xX_44 .req x_42\n"
-        " U .req v0\n"
-        "qU .req q0\n"
-
-        // Load the tile, and compute compute the matrix xX
-        "ldr qX_11, [%x[inptr0]]\n"
-        "ldr qX_12, [%x[inptr0], %x[colstride1]]\n"
-        "ldr qX_13, [%x[inptr0], %x[colstride2]]\n"
-        "add %x[inptr0], %x[inptr0], #0x10\n"
-
-        "ldr qX_21, [%x[inptr1]]\n"
-        "fsub xX_11.4s, x_11.4s, x_13.4s\n"
-        "ldr qX_22, [%x[inptr1], %x[colstride1]]\n"
-        "fadd xX_12.4s, x_12.4s, x_13.4s\n"
-        "ldr qX_23, [%x[inptr1], %x[colstride2]]\n"
-        "fsub xX_13.4s, x_13.4s, x_12.4s\n"
-        "add %x[inptr1], %x[inptr1], #0x10\n"
-
-        "ldr qX_31, [%x[inptr2]]\n"
-        "fsub xX_21.4s, x_21.4s, x_23.4s\n"
-        "ldr qX_32, [%x[inptr2], %x[colstride1]]\n"
-        "fadd xX_22.4s, x_22.4s, x_23.4s\n"
-        "ldr qX_33, [%x[inptr2], %x[colstride2]]\n"
-        "fsub xX_23.4s, x_23.4s, x_22.4s\n"
-        "add %x[inptr2], %x[inptr2], #0x10\n"
-
-        "ldr qX_41, [%x[inptr3]]\n"
-        "fsub xX_31.4s, x_31.4s, x_33.4s\n"
-        "ldr qX_42, [%x[inptr3], %x[colstride1]]\n"
-        "fadd xX_32.4s, x_32.4s, x_33.4s\n"
-        "ldr qX_43, [%x[inptr3], %x[colstride2]]\n"
-        "fsub xX_33.4s, x_33.4s, x_32.4s\n"
-        "add %x[inptr3], %x[inptr3], #0x10\n"
-
-        // Complete computing xX while beginning to compute and store
-        // $U = X.T x X$
-
-        "fsub xX_41.4s, x_41.4s, x_43.4s\n"
-
-        "fsub U.4s, xX_11.4s, xX_31.4s\n"
-        "str qU, [%x[outptr0]]\n"
-        "fsub U.4s, xX_12.4s, xX_32.4s\n"
-        "str qU, [%x[outptr0], %x[mstride1]]\n"
-        "fsub U.4s, xX_13.4s, xX_33.4s\n"
-        "str qU, [%x[outptr0], %x[mstride2]]\n"
-        "fsub U.4s, xX_14.4s, xX_34.4s\n"
-        "str qU, [%x[outptr0], %x[mstride3]]\n"
-        "add %x[outptr0], %x[outptr0], #0x10\n"
-
-        "fadd xX_42.4s, x_42.4s, x_43.4s\n"
-
-        "fadd U.4s, xX_21.4s, xX_31.4s\n"
-        "str qU, [%x[outptr4]]\n"
-        "fadd U.4s, xX_22.4s, xX_32.4s\n"
-        "str qU, [%x[outptr4], %x[mstride1]]\n"
-        "fadd U.4s, xX_23.4s, xX_33.4s\n"
-        "str qU, [%x[outptr4], %x[mstride2]]\n"
-        "fadd U.4s, xX_24.4s, xX_34.4s\n"
-        "str qU, [%x[outptr4], %x[mstride3]]\n"
-        "add %x[outptr4], %x[outptr4], #0x10\n"
-
-        "fsub xX_43.4s, x_43.4s, x_42.4s\n"
-
-        "fsub U.4s, xX_31.4s, xX_21.4s\n"
-        "str qU, [%x[outptr8]]\n"
-        "fsub U.4s, xX_32.4s, xX_22.4s\n"
-        "str qU, [%x[outptr8], %x[mstride1]]\n"
-        "fsub U.4s, xX_33.4s, xX_23.4s\n"
-        "str qU, [%x[outptr8], %x[mstride2]]\n"
-        "fsub U.4s, xX_34.4s, xX_24.4s\n"
-        "str qU, [%x[outptr8], %x[mstride3]]\n"
-        "add %x[outptr8], %x[outptr8], #0x10\n"
-
-        "fsub U.4s, xX_21.4s, xX_41.4s\n"
-        "str qU, [%x[outptr12]]\n"
-        "fsub U.4s, xX_22.4s, xX_42.4s\n"
-        "str qU, [%x[outptr12], %x[mstride1]]\n"
-        "fsub U.4s, xX_23.4s, xX_43.4s\n"
-        "str qU, [%x[outptr12], %x[mstride2]]\n"
-        "fsub U.4s, xX_24.4s, xX_44.4s\n"
-        "str qU, [%x[outptr12], %x[mstride3]]\n"
-        "add %x[outptr12], %x[outptr12], #0x10\n"
-
-        ".unreq qU\n"
-        ".unreq U\n"
-        ".unreq X_11\n"  ".unreq qX_11\n"
-        ".unreq X_12\n"  ".unreq qX_12\n"
-        ".unreq X_13\n"  ".unreq qX_13\n"
-        ".unreq X_21\n"  ".unreq qX_21\n"
-        ".unreq X_22\n"  ".unreq qX_22\n"
-        ".unreq X_23\n"  ".unreq qX_23\n"
-        ".unreq X_31\n"  ".unreq qX_31\n"
-        ".unreq X_32\n"  ".unreq qX_32\n"
-        ".unreq X_33\n"  ".unreq qX_33\n"
-        ".unreq X_41\n"  ".unreq qX_41\n"
-        ".unreq X_42\n"  ".unreq qX_42\n"
-        ".unreq X_43\n"  ".unreq qX_43\n"
-        ".unreq xX_11\n"
-        ".unreq xX_12\n"
-        ".unreq xX_13\n"
-        ".unreq xX_14\n"
-        ".unreq xX_21\n"
-        ".unreq xX_22\n"
-        ".unreq xX_23\n"
-        ".unreq xX_24\n"
-        ".unreq xX_31\n"
-        ".unreq xX_32\n"
-        ".unreq xX_33\n"
-        ".unreq xX_34\n"
-        ".unreq xX_41\n"
-        ".unreq xX_42\n"
-        ".unreq xX_43\n"
-        ".unreq xX_44\n"
-
-        : [inptr0] "+r" (inptr0),
-          [inptr1] "+r" (inptr1),
-          [inptr2] "+r" (inptr2),
-          [inptr3] "+r" (inptr3),
-          [outptr0] "+r" (outptr0),
-          [outptr4] "+r" (outptr1),
-          [outptr8] "+r" (outptr2),
-          [outptr12] "+r" (outptr3)
-        : [colstride1] "r" (input_col_stride * sizeof(float)),
-          [colstride2] "r" (input_col_stride * sizeof(float) * 2),
-          [mstride1] "r" (matrix_stride * sizeof(float)),
-          [mstride2] "r" (matrix_stride * sizeof(float) * 2),
-          [mstride3] "r" (matrix_stride * sizeof(float) * 3)
-        : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
-          "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-          "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
-          "v30", "v31"
-    );
-  }
-}
-}
-#endif
diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp
new file mode 100644
index 0000000..381ae92
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_3x3_fp32.cpp
@@ -0,0 +1,409 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "transforms/input.hpp"
+#include "winograd_gemm.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+using Transform = WinogradGEMM<2, 2, 3, 3>::InputTransform<float>;
+
+/******************************************************************************
+ * Cost methods for the input transform.
+ * =====================================
+ */
+template <>
+template <>
+int Transform::ops_performed(const Tensor4DShape &input_shape)
+{
+  // NOTE: Cost in FLOPs rather than instructions or uops.
+  const int tile_M = iceildiv(input_shape.n_rows, inner_tile_rows);
+  const int tile_N = iceildiv(input_shape.n_cols, inner_tile_cols);
+  return 16 * 16 * tile_M * tile_N * input_shape.n_channels;
+}
+/*****************************************************************************/
+
+/*****************************************************************************
+* F(2x2, 3x3) implies the use of a 4x4 input tile. Such tiles can require a
+* variety of padding types. For example, tiles at the top and left of an image
+* can require one row or column of padding on their top and left sides if the
+* padding type is SAME (where X represents a padded value):
+*
+*      _______    _______
+*     |X X X X|  |X X X X|
+*     |X      |  |       |   . . .
+*     |X      |  |       |
+*     |X______|  |_______|
+*      _______
+*     |X      |             .
+*     |X      |   . . .       .
+*     |X      |                 .
+*     |X______|
+*
+* For tiles near the right or bottom of the image it is more complicated.  Such
+* tiles might require padding by 0 or 1 rows or columns if the padding type is
+* VALID or 1 or 2 rows or columns if the padding type is SAME:
+*
+*      _______    _______    _______    _______
+*     |X X X X|  |X X X X|  |X X X X|  |X X X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X______|  |_______|  |______X|  |____X_X|
+*      _______    _______    _______    _______
+*     |X      |  |       |  |      X|  |    X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X______|  |_______|  |______X|  |____X_X|
+*      _______    _______    _______    _______
+*     |X      |  |       |  |      X|  |    X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X_X_X_X|  |X_X_X_X|  |X_X_X_X|  |X_X_X_X|
+*      _______    _______    _______    _______
+*     |X      |  |       |  |      X|  |    X X|
+*     |X      |  |       |  |      X|  |    X X|
+*     |X X X X|  |X X X X|  |X X X X|  |X X X X|
+*     |X_X_X_X|  |X_X_X_X|  |X_X_X_X|  |X_X_X_X|
+*
+* Additional tiles are required for especially small input images.
+*
+* Build an array of the specialised methods that deal with each of the
+* different padding combinations which may be required. These padding
+* constraints are the space:
+*
+*     Padding top in {0, 1}
+*     Padding left in {0, 1}
+*     Padding bottom in {0, 1, 2}
+*     Padding right in {0, 1, 2}
+*/
+template <>
+template <>
+template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+void Transform::process_tile(
+  int n_channels,
+  const float* const input_base,
+  const int input_row_stride,
+  const int input_col_stride,
+  float* const matrix_base,
+  const int matrix_stride
+)
+{
+  constexpr int inner_tile_i = 4, inner_tile_j = 4;
+  constexpr int cells_i = inner_tile_i - pad_bottom;
+  constexpr int cells_j = inner_tile_i - pad_right;
+
+  float *outptr = matrix_base;
+
+  // Get pointers into the input tile
+  const float *x_ptrs[inner_tile_i][inner_tile_j];
+  for (int i = pad_top, xi = 0; i < cells_i; i++, xi++)
+  {
+    // Get a pointer into the row
+    const float* const row_ptr = input_base + xi*input_row_stride;
+
+    for (int j = pad_left, xj = 0; j < cells_j; j++, xj++)
+    {
+      x_ptrs[i][j] = row_ptr + xj*input_col_stride;
+    }
+  }
+
+  // Matrices used/computed in this kernel.
+  float x[inner_tile_i][inner_tile_j];
+  float XTx[inner_tile_i][inner_tile_j];
+  float U[inner_tile_i][inner_tile_j];
+
+  for (int i = 0; i < inner_tile_i; i++)
+  {
+    for (int j = 0; j < inner_tile_j; j++)
+    {
+      x[i][j] = XTx[i][j] = 0.0f;
+    }
+  }
+
+  // Perform the Winograd input transformation for each channel in the input
+  // tensor.
+  int channels_remaining = n_channels;
+#ifdef __aarch64__
+  for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Matrices used/computed in this kernel.
+    float32x4_t x[inner_tile_i][inner_tile_j];
+    float32x4_t XTx[inner_tile_i][inner_tile_j];
+    float32x4_t U[inner_tile_i][inner_tile_j];
+
+    for (int i = 0; i < inner_tile_i; i++)
+    {
+      for (int j = 0; j < inner_tile_j; j++)
+      {
+        x[i][j] = vdupq_n_f32(0.0f);
+        XTx[i][j] = vdupq_n_f32(0.0f);
+      }
+    }
+
+    // Load x
+    for (int i = pad_top; i < cells_i; i++)
+    {
+      for (int j = pad_left; j < cells_j; j++)
+      {
+        x[i][j] = vld1q_f32(x_ptrs[i][j]);
+        x_ptrs[i][j] += 4;
+      }
+    }
+
+    // Compute XT . x
+    for (int j = pad_left; j < cells_j; j++)
+    {
+      // XTx[0][j] = x[0][j] - x[2][j];
+      XTx[0][j] = vsubq_f32(x[0][j], x[2][j]);
+
+      // XTx[1][j] = x[1][j] + x[2][j];
+      XTx[1][j] = vaddq_f32(x[1][j], x[2][j]);
+
+      // XTx[2][j] = x[2][j] - x[1][j];
+      XTx[2][j] = vsubq_f32(x[2][j], x[1][j]);
+
+      // XTx[3][j] = x[1][j] - x[3][j];
+      XTx[3][j] = vsubq_f32(x[1][j], x[3][j]);
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < inner_tile_i; i++)
+    {
+      // U[i][0] = XTx[i][0] - XTx[i][2];
+      U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]);
+
+      // U[i][1] = XTx[i][1] + XTx[i][2];
+      U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]);
+
+      // U[i][2] = XTx[i][2] - XTx[i][1];
+      U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]);
+
+      // U[i][3] = XTx[i][1] - XTx[i][3];
+      U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]);
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < inner_tile_i; i++)
+    {
+      for (int j = 0; j < inner_tile_j; j++, m++)
+      {
+        vst1q_f32(outptr + m*matrix_stride, U[i][j]);
+      }
+    }
+    outptr += 4;
+  }
+#endif  // __aarch64__
+#ifdef __arm_any__
+  for (; channels_remaining >= 2; channels_remaining -= 2)
+  {
+    // Matrices used/computed in this kernel.
+    float32x2_t x[inner_tile_i][inner_tile_j];
+    float32x2_t XTx[inner_tile_i][inner_tile_j];
+    float32x2_t U[inner_tile_i][inner_tile_j];
+
+    for (int i = 0; i < inner_tile_i; i++)
+    {
+      for (int j = 0; j < inner_tile_j; j++)
+      {
+        x[i][j] = vdup_n_f32(0.0f);
+        XTx[i][j] = vdup_n_f32(0.0f);
+      }
+    }
+
+    // Load x
+    for (int i = pad_top; i < cells_i; i++)
+    {
+      for (int j = pad_left; j < cells_j; j++)
+      {
+        x[i][j] = vld1_f32(x_ptrs[i][j]);
+        x_ptrs[i][j] += 2;
+      }
+    }
+
+    // Compute XT . x
+    for (int j = pad_left; j < cells_j; j++)
+    {
+      // XTx[0][j] = x[0][j] - x[2][j];
+      XTx[0][j] = vsub_f32(x[0][j], x[2][j]);
+
+      // XTx[1][j] = x[1][j] + x[2][j];
+      XTx[1][j] = vadd_f32(x[1][j], x[2][j]);
+
+      // XTx[2][j] = x[2][j] - x[1][j];
+      XTx[2][j] = vsub_f32(x[2][j], x[1][j]);
+
+      // XTx[3][j] = x[1][j] - x[3][j];
+      XTx[3][j] = vsub_f32(x[1][j], x[3][j]);
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < inner_tile_i; i++)
+    {
+      // U[i][0] = XTx[i][0] - XTx[i][2];
+      U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]);
+
+      // U[i][1] = XTx[i][1] + XTx[i][2];
+      U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]);
+
+      // U[i][2] = XTx[i][2] - XTx[i][1];
+      U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]);
+
+      // U[i][3] = XTx[i][1] - XTx[i][3];
+      U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]);
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < inner_tile_i; i++)
+    {
+      for (int j = 0; j < inner_tile_j; j++, m++)
+      {
+        vst1_f32(outptr + m*matrix_stride, U[i][j]);
+      }
+    }
+    outptr += 2;
+  }
+#endif  // __arm_any__
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Load x
+    for (int i = pad_top; i < cells_i; i++)
+    {
+      for (int j = pad_left; j < cells_j; j++)
+      {
+        x[i][j] = *(x_ptrs[i][j]++);
+      }
+    }
+
+    // Compute XT . x
+    for (int j = pad_left; j < cells_j; j++)
+    {
+      XTx[0][j] = x[0][j] - x[2][j];
+      XTx[1][j] = x[1][j] + x[2][j];
+      XTx[2][j] = x[2][j] - x[1][j];
+      XTx[3][j] = x[1][j] - x[3][j];
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < inner_tile_i; i++)
+    {
+      U[i][0] = XTx[i][0] - XTx[i][2];
+      U[i][1] = XTx[i][1] + XTx[i][2];
+      U[i][2] = XTx[i][2] - XTx[i][1];
+      U[i][3] = XTx[i][1] - XTx[i][3];
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < inner_tile_i; i++)
+    {
+      for (int j = 0; j < inner_tile_j; j++, m++)
+      {
+        *(outptr + m*matrix_stride) = U[i][j];
+      }
+    }
+    outptr++;
+  }
+}
+
+template <>
+template <>
+const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] =
+{
+  {
+    {
+      {
+        Transform::template process_tile<0, 0, 0, 0>,  // No padding
+        Transform::template process_tile<0, 0, 0, 1>,  // Right
+        Transform::template process_tile<0, 0, 0, 2>,  // Right
+      },
+      {
+        Transform::template process_tile<0, 0, 1, 0>,  // Bottom
+        Transform::template process_tile<0, 0, 1, 1>,  // Bottom-right
+        Transform::template process_tile<0, 0, 1, 2>,  // Bottom-right
+      },
+      {
+        Transform::template process_tile<0, 0, 2, 0>,  // Bottom
+        Transform::template process_tile<0, 0, 2, 1>,  // Bottom-right
+        Transform::template process_tile<0, 0, 2, 2>,  // Bottom-right
+      }
+    },
+    {
+      {
+        Transform::template process_tile<0, 1, 0, 0>,  // Left
+        Transform::template process_tile<0, 1, 0, 1>,  // Left AND right
+        Transform::template process_tile<0, 1, 0, 2>,  // Left AND right
+      },
+      {
+        Transform::template process_tile<0, 1, 1, 0>,  // Left-bottom
+        Transform::template process_tile<0, 1, 1, 1>,  // Left, bottom AND right
+        Transform::template process_tile<0, 1, 1, 2>,  // Left, bottom AND right
+      },
+      {
+        Transform::template process_tile<0, 1, 2, 0>,  // Left-bottom
+        Transform::template process_tile<0, 1, 2, 1>,  // Left, bottom AND right
+        Transform::template process_tile<0, 1, 2, 2>,  // Left, bottom AND right
+      }
+    },
+  },
+  {
+    {
+      {
+        Transform::template process_tile<1, 0, 0, 0>,  // Top
+        Transform::template process_tile<1, 0, 0, 1>,  // Top-right
+        Transform::template process_tile<1, 0, 0, 2>,  // Top-right
+      },
+      {
+        Transform::template process_tile<1, 0, 1, 0>,  // Top AND bottom
+        Transform::template process_tile<1, 0, 1, 1>,  // Top, bottom AND right
+        Transform::template process_tile<1, 0, 1, 2>,  // Top, bottom AND right
+      },
+      {
+        Transform::template process_tile<1, 0, 2, 0>,  // Top AND bottom
+        Transform::template process_tile<1, 0, 2, 1>,  // Top, bottom AND right
+        Transform::template process_tile<1, 0, 2, 2>,  // Top, bottom AND right
+      }
+    },
+    {
+      {
+        Transform::template process_tile<1, 1, 0, 0>,  // Top-left
+        Transform::template process_tile<1, 1, 0, 1>,  // Top, left AND right
+        Transform::template process_tile<1, 1, 0, 2>,  // Top, left AND right
+      },
+      {
+        Transform::template process_tile<1, 1, 1, 0>,  // Top, left AND bottom
+        Transform::template process_tile<1, 1, 1, 1>,  // All padded
+        Transform::template process_tile<1, 1, 1, 2>,  // All padded
+      },
+      {
+        Transform::template process_tile<1, 1, 2, 0>,  // Top, left AND bottom
+        Transform::template process_tile<1, 1, 2, 1>,  // All padded
+        Transform::template process_tile<1, 1, 2, 2>,  // All padded
+      }
+    }
+  }
+};
+
+template struct WinogradGEMM<2, 2, 3, 3>::InputTransform<float>;
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp
new file mode 100644
index 0000000..477aaaf
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/input_4x4_3x3_fp32.cpp
@@ -0,0 +1,486 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "transforms/input.hpp"
+#include "winograd_gemm.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+using Transform = WinogradGEMM<4, 4, 3, 3>::InputTransform<float>;
+
+template <>
+template <>
+int Transform::ops_performed(const Tensor4DShape &input_shape)
+{
+  // NOTE: Cost in FLOPs rather than instructions or uops.
+  const int tile_M = iceildiv(input_shape.n_rows, inner_tile_rows);
+  const int tile_N = iceildiv(input_shape.n_cols, inner_tile_cols);
+  return 12 * 24 * tile_M * tile_N * input_shape.n_channels;
+}
+
+/* F(4x4, 3x3) implies the use of a 6x6 input tile. Such tiles can require a
+* variety of padding types. For example, tiles at the top and left of an
+* image can require one row or column of padding on their top and left sides
+* if the padding type is SAME (where X represents a padded value):
+*
+*      ___________    ___________
+*     |X X X X X X|  |X X X X X X|
+*     |X          |  |           |
+*     |X          |  |           |
+*     |X          |  |           |
+*     |X          |  |           |
+*     |X__________|  |___________|
+*      ___________
+*     |X          |
+*     |X          |
+*     |X          |
+*     |X          |
+*     |X          |
+*     |X__________|
+*
+* For tiles near the right or bottom of the image it is more complicated.
+* Such tiles might require padding by 0, 1, 2 or 3 rows or columns if the
+* padding type is VALID or 1, 2, 3 or 4 rows or columns if the padding
+* type is SAME.
+*
+* Build an array of the specialised methods that deal with each of the
+* different padding combinations which may be required. These padding
+* constraints are the space:
+*
+*     Padding top in {0, 1}
+*     Padding left in {0, 1}
+*     Padding bottom in {0, 1, 2, 3, 4}
+*     Padding right in {0, 1, 2, 3, 4}
+*/
+template <>
+template <>
+template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+void Transform::process_tile(
+  int n_channels,
+  const float* const input_base,
+  const int input_row_stride,
+  const int input_col_stride,
+  float* const matrix_base,
+  const int matrix_stride
+)
+{
+  constexpr int cells_i = 6 - pad_bottom;
+  constexpr int cells_j = 6 - pad_right;
+
+  float *outptr = matrix_base;
+
+  // Get pointers into the input tile
+  const float *x_ptrs[6][6];
+  for (int i = pad_top, xi = 0; i < cells_i; i++, xi++)
+  {
+    // Get a pointer into the row
+    const float* const row_ptr = input_base + xi*input_row_stride;
+
+    for (int j = pad_left, xj = 0; j < cells_j; j++, xj++)
+    {
+      x_ptrs[i][j] = row_ptr + xj*input_col_stride;
+    }
+  }
+
+  // Matrices used/computed in this kernel.
+  float x[6][6], XTx[6][6], U[6][6];
+  for (int i = 0; i < 6; i++)
+  {
+    for (int j = 0; j < 6; j++)
+    {
+      x[i][j] = XTx[i][j] = 0.0f;
+    }
+  }
+
+  // Perform the Winograd input transformation for each channel in the input
+  // tensor.
+  int channels_remaining = n_channels;
+#ifdef __aarch64__
+  for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Matrices used/computed in this kernel
+    float32x4_t x[6][6], XTx[6][6], U[6][6];
+    for (int i = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++)
+      {
+        x[i][j] = vdupq_n_f32(0.0f);
+        XTx[i][j] = vdupq_n_f32(0.0f);
+      }
+    }
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = pad_top; i < cells_i; i++)
+    {
+      for (int j = pad_left; j < cells_j; j++)
+      {
+        x[i][j] = vld1q_f32(x_ptrs[i][j]);
+        x_ptrs[i][j] += 4;
+      }
+    }
+
+    // Compute XT . x
+    for (int j = pad_left; j < cells_j; j++)
+    {
+      // XTx[0][j] =  4*x[0][j] + -5*x[2][j] +  1*x[4][j];
+      XTx[0][j] = vmlsq_n_f32(vmlaq_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f);
+
+      // XTx[1][j] = -4*x[1][j] + -4*x[2][j] +  1*x[3][j] +  1*x[4][j];
+      XTx[1][j] = vmlsq_n_f32(vaddq_f32(x[3][j], x[4][j]), vaddq_f32(x[1][j], x[2][j]), 4.0f);
+
+      // XTx[2][j] =  4*x[1][j] + -4*x[2][j] + -1*x[3][j] +  1*x[4][j];
+      XTx[2][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[3][j]), vsubq_f32(x[1][j], x[2][j]), 4.0f);
+
+      // XTx[3][j] = -2*x[1][j] + -1*x[2][j] +  2*x[3][j] +  1*x[4][j];
+      XTx[3][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[3][j], x[1][j]), 2.0f);
+
+      // XTx[4][j] =  2*x[1][j] + -1*x[2][j] + -2*x[3][j] +  1*x[4][j];
+      XTx[4][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[1][j], x[3][j]), 2.0f);
+
+      // XTx[5][j] =  4*x[1][j] + -5*x[3][j] +  1*x[5][j];
+      XTx[5][j] = vmlsq_n_f32(vmlaq_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f);
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < 6; i++)
+    {
+      // U[i][0] =  4*XTx[i][0] + -5*XTx[i][2] +  1*XTx[i][4];
+      U[i][0] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f);
+
+      // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] +  1*XTx[i][3] +  1*XTx[i][4];
+      U[i][1] = vmlsq_n_f32(vaddq_f32(XTx[i][3], XTx[i][4]), vaddq_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+      // U[i][2] =  4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] +  1*XTx[i][4];
+      U[i][2] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][3]), vsubq_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+      // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] +  2*XTx[i][3] +  1*XTx[i][4];
+      U[i][3] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][3], XTx[i][1]), 2.0f);
+
+      // U[i][4] =  2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] +  1*XTx[i][4];
+      U[i][4] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][1], XTx[i][3]), 2.0f);
+
+      // U[i][5] =  4*XTx[i][1] + -5*XTx[i][3] +  1*XTx[i][5];
+      U[i][5] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f);
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        vst1q_f32(outptr + m*matrix_stride, U[i][j]);
+      }
+    }
+    outptr += 4;
+  }
+#endif  // __aarch64__
+#ifdef __arm_any__
+  for (; channels_remaining >= 2; channels_remaining -= 2)
+  {
+    // Matrices used/computed in this kernel
+    float32x2_t x[6][6], XTx[6][6], U[6][6];
+    for (int i = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++)
+      {
+        x[i][j] = vdup_n_f32(0.0f);
+        XTx[i][j] = vdup_n_f32(0.0f);
+      }
+    }
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = pad_top; i < cells_i; i++)
+    {
+      for (int j = pad_left; j < cells_j; j++)
+      {
+        x[i][j] = vld1_f32(x_ptrs[i][j]);
+        x_ptrs[i][j] += 2;
+      }
+    }
+
+    // Compute XT . x
+    for (int j = pad_left; j < cells_j; j++)
+    {
+      // XTx[0][j] =  4*x[0][j] + -5*x[2][j] +  1*x[4][j];
+      XTx[0][j] = vmls_n_f32(vmla_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f);
+
+      // XTx[1][j] = -4*x[1][j] + -4*x[2][j] +  1*x[3][j] +  1*x[4][j];
+      XTx[1][j] = vmls_n_f32(vadd_f32(x[3][j], x[4][j]), vadd_f32(x[1][j], x[2][j]), 4.0f);
+
+      // XTx[2][j] =  4*x[1][j] + -4*x[2][j] + -1*x[3][j] +  1*x[4][j];
+      XTx[2][j] = vmla_n_f32(vsub_f32(x[4][j], x[3][j]), vsub_f32(x[1][j], x[2][j]), 4.0f);
+
+      // XTx[3][j] = -2*x[1][j] + -1*x[2][j] +  2*x[3][j] +  1*x[4][j];
+      XTx[3][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[3][j], x[1][j]), 2.0f);
+
+      // XTx[4][j] =  2*x[1][j] + -1*x[2][j] + -2*x[3][j] +  1*x[4][j];
+      XTx[4][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[1][j], x[3][j]), 2.0f);
+
+      // XTx[5][j] =  4*x[1][j] + -5*x[3][j] +  1*x[5][j];
+      XTx[5][j] = vmls_n_f32(vmla_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f);
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < 6; i++)
+    {
+      // U[i][0] =  4*XTx[i][0] + -5*XTx[i][2] +  1*XTx[i][4];
+      U[i][0] = vmls_n_f32(vmla_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f);
+
+      // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] +  1*XTx[i][3] +  1*XTx[i][4];
+      U[i][1] = vmls_n_f32(vadd_f32(XTx[i][3], XTx[i][4]), vadd_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+      // U[i][2] =  4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] +  1*XTx[i][4];
+      U[i][2] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][3]), vsub_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+      // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] +  2*XTx[i][3] +  1*XTx[i][4];
+      U[i][3] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][3], XTx[i][1]), 2.0f);
+
+      // U[i][4] =  2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] +  1*XTx[i][4];
+      U[i][4] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][1], XTx[i][3]), 2.0f);
+
+      // U[i][5] =  4*XTx[i][1] + -5*XTx[i][3] +  1*XTx[i][5];
+      U[i][5] = vmls_n_f32(vmla_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f);
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        vst1_f32(outptr + m*matrix_stride, U[i][j]);
+      }
+    }
+    outptr += 2;
+  }
+#endif  // __arm_any__
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Load x
+    for (int i = pad_top; i < cells_i; i++)
+    {
+      for (int j = pad_left; j < cells_j; j++)
+      {
+        x[i][j] = *(x_ptrs[i][j]++);
+      }
+    }
+
+    // Compute XT . x
+    for (int j = pad_left; j < cells_j; j++)
+    {
+      XTx[0][j] =  4*x[0][j] + -5*x[2][j] +  1*x[4][j];
+      XTx[1][j] = -4*x[1][j] + -4*x[2][j] +  1*x[3][j] +  1*x[4][j];
+      XTx[2][j] =  4*x[1][j] + -4*x[2][j] + -1*x[3][j] +  1*x[4][j];
+      XTx[3][j] = -2*x[1][j] + -1*x[2][j] +  2*x[3][j] +  1*x[4][j];
+      XTx[4][j] =  2*x[1][j] + -1*x[2][j] + -2*x[3][j] +  1*x[4][j];
+      XTx[5][j] =  4*x[1][j] + -5*x[3][j] +  1*x[5][j];
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < 6; i++)
+    {
+      U[i][0] =  4*XTx[i][0] + -5*XTx[i][2] +  1*XTx[i][4];
+      U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] +  1*XTx[i][3] +  1*XTx[i][4];
+      U[i][2] =  4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] +  1*XTx[i][4];
+      U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] +  2*XTx[i][3] +  1*XTx[i][4];
+      U[i][4] =  2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] +  1*XTx[i][4];
+      U[i][5] =  4*XTx[i][1] + -5*XTx[i][3] +  1*XTx[i][5];
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        *(outptr + m*matrix_stride) = U[i][j];
+      }
+    }
+    outptr++;
+  }
+}
+
+/* In the below, unusual or especially small tiles are routed via the slow
+ * path whereas common or large tiles are routed through a faster path.
+ */
+template <>
+template <>
+const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] =
+{
+  {
+    {
+      {
+        Transform::template process_tile<0, 0, 0, 0>,  // No padding
+        Transform::template process_tile<0, 0, 0, 1>,  // Right
+        Transform::template process_tile<0, 0, 0, 2>,  // "   "
+        Transform::template process_tile<0, 0, 0, 3>,  // "   "
+        Transform::template process_tile<0, 0, 0, 4>,  // "   "
+      },
+      {
+        Transform::template process_tile<0, 0, 1, 0>,  // Bottom
+        Transform::template process_tile<0, 0, 1, 1>,  // Bottom right
+        Transform::template process_tile<0, 0, 1, 2>,  // "          "
+        Transform::template process_tile<0, 0, 1, 3>,  // "          "
+        Transform::template process_tile<0, 0, 1, 4>,  // "          "
+      },
+      {
+        Transform::template process_tile<0, 0, 2, 0>,  // Bottom
+        Transform::template process_tile<0, 0, 2, 1>,  // Bottom right
+        Transform::template process_tile<0, 0, 2, 2>,  // "          "
+        Transform::template process_tile<0, 0, 2, 3>,  // "          "
+        Transform::template process_tile<0, 0, 2, 4>,  // "          "
+      },
+      {
+        Transform::template process_tile<0, 0, 3, 0>,  // Bottom
+        Transform::template process_tile<0, 0, 3, 1>,  // Bottom right
+        Transform::template process_tile<0, 0, 3, 2>,  // "          "
+        Transform::template process_tile<0, 0, 3, 3>,  // "          "
+        Transform::template process_tile<0, 0, 3, 4>,  // "          "
+      },
+      {
+        Transform::template process_tile<0, 0, 4, 0>,  // Bottom
+        Transform::template process_tile<0, 0, 4, 1>,  // Bottom right
+        Transform::template process_tile<0, 0, 4, 2>,  // "          "
+        Transform::template process_tile<0, 0, 4, 3>,  // "          "
+        Transform::template process_tile<0, 0, 4, 4>,  // "          "
+      }
+    },
+    {
+      {
+        Transform::template process_tile<0, 1, 0, 0>,  // Left
+        Transform::template process_tile<0, 1, 0, 1>,
+        Transform::template process_tile<0, 1, 0, 2>,
+        Transform::template process_tile<0, 1, 0, 3>,
+        Transform::template process_tile<0, 1, 0, 4>,
+      },
+      {
+        Transform::template process_tile<0, 1, 1, 0>,  // Bottom left
+        Transform::template process_tile<0, 1, 1, 1>,
+        Transform::template process_tile<0, 1, 1, 2>,
+        Transform::template process_tile<0, 1, 1, 3>,
+        Transform::template process_tile<0, 1, 1, 4>,
+      },
+      {
+        Transform::template process_tile<0, 1, 2, 0>,  // "          "
+        Transform::template process_tile<0, 1, 2, 1>,
+        Transform::template process_tile<0, 1, 2, 2>,
+        Transform::template process_tile<0, 1, 2, 3>,
+        Transform::template process_tile<0, 1, 2, 4>,
+      },
+      {
+        Transform::template process_tile<0, 1, 3, 0>,  // "          "
+        Transform::template process_tile<0, 1, 3, 1>,
+        Transform::template process_tile<0, 1, 3, 2>,
+        Transform::template process_tile<0, 1, 3, 3>,
+        Transform::template process_tile<0, 1, 3, 4>,
+      },
+      {
+        Transform::template process_tile<0, 1, 4, 0>,  // "          "
+        Transform::template process_tile<0, 1, 4, 1>,
+        Transform::template process_tile<0, 1, 4, 2>,
+        Transform::template process_tile<0, 1, 4, 3>,
+        Transform::template process_tile<0, 1, 4, 4>,
+      }
+    }
+  },
+  {
+    {
+      {
+        Transform::template process_tile<1, 0, 0, 0>,  // Top
+        Transform::template process_tile<1, 0, 0, 1>,  // Top right
+        Transform::template process_tile<1, 0, 0, 2>,  // "       "
+        Transform::template process_tile<1, 0, 0, 3>,  // "       "
+        Transform::template process_tile<1, 0, 0, 4>,  // "       "
+      },
+      {
+        Transform::template process_tile<1, 0, 1, 0>,
+        Transform::template process_tile<1, 0, 1, 1>,
+        Transform::template process_tile<1, 0, 1, 2>,
+        Transform::template process_tile<1, 0, 1, 3>,
+        Transform::template process_tile<1, 0, 1, 4>,
+      },
+      {
+        Transform::template process_tile<1, 0, 2, 0>,
+        Transform::template process_tile<1, 0, 2, 1>,
+        Transform::template process_tile<1, 0, 2, 2>,
+        Transform::template process_tile<1, 0, 2, 3>,
+        Transform::template process_tile<1, 0, 2, 4>,
+      },
+      {
+        Transform::template process_tile<1, 0, 3, 0>,
+        Transform::template process_tile<1, 0, 3, 1>,
+        Transform::template process_tile<1, 0, 3, 2>,
+        Transform::template process_tile<1, 0, 3, 3>,
+        Transform::template process_tile<1, 0, 3, 4>,
+      },
+      {
+        Transform::template process_tile<1, 0, 4, 0>,
+        Transform::template process_tile<1, 0, 4, 1>,
+        Transform::template process_tile<1, 0, 4, 2>,
+        Transform::template process_tile<1, 0, 4, 3>,
+        Transform::template process_tile<1, 0, 4, 4>,
+      },
+    },
+    {
+      {
+        Transform::template process_tile<1, 1, 0, 0>,  // Top left
+        Transform::template process_tile<1, 1, 0, 1>,
+        Transform::template process_tile<1, 1, 0, 2>,
+        Transform::template process_tile<1, 1, 0, 3>,
+        Transform::template process_tile<1, 1, 0, 4>,
+      },
+      {
+        Transform::template process_tile<1, 1, 1, 0>,
+        Transform::template process_tile<1, 1, 1, 1>,
+        Transform::template process_tile<1, 1, 1, 2>,
+        Transform::template process_tile<1, 1, 1, 3>,
+        Transform::template process_tile<1, 1, 1, 4>,
+      },
+      {
+        Transform::template process_tile<1, 1, 2, 0>,
+        Transform::template process_tile<1, 1, 2, 1>,
+        Transform::template process_tile<1, 1, 2, 2>,
+        Transform::template process_tile<1, 1, 2, 3>,
+        Transform::template process_tile<1, 1, 2, 4>,
+      },
+      {
+        Transform::template process_tile<1, 1, 3, 0>,
+        Transform::template process_tile<1, 1, 3, 1>,
+        Transform::template process_tile<1, 1, 3, 2>,
+        Transform::template process_tile<1, 1, 3, 3>,
+        Transform::template process_tile<1, 1, 3, 4>,
+      },
+      {
+        Transform::template process_tile<1, 1, 4, 0>,
+        Transform::template process_tile<1, 1, 4, 1>,
+        Transform::template process_tile<1, 1, 4, 2>,
+        Transform::template process_tile<1, 1, 4, 3>,
+        Transform::template process_tile<1, 1, 4, 4>,
+      }
+    }
+  }
+};
+
+template struct WinogradGEMM<4, 4, 3, 3>::InputTransform<float>;
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
deleted file mode 100644
index 033442a..0000000
--- a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-namespace winograd {
-  /* Transform a kernel into the Winograd domain.
-   *
-   * NOTE: It is assumed that the kernel is in the form [height x width x
-   * input_channels x output_channel].
-   */
-  template <typename T>
-  struct winograd2x2_3x3_gemm_kernel_transform_impl{
-    static void execute(
-      const KernelShape &shape,
-      const T* const kernel,
-      T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride
-    );
-
-    protected:
-    template <const int output_channel_tail>
-    static void transform_kernel(
-      const T* const kernel,
-      const int n_input_channels,
-      const int n_output_channels,
-      T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride
-    );
-  };
-}
-
-/*****************************************************************************/
-/* Transform a fp32 kernel into the Winograd domain.
- */
-#include "kernel_2x2_3x3/a64_float.hpp"  // AArch64 specialisations
-
-namespace winograd
-{
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::execute(
-  const KernelShape &shape,
-  const float* const kernel,
-  float* const matrix_base,
-  const int matrix_stride,
-  const int matrix_row_stride
-) {
-  // Delegate based on tail size
-  const int n_input_channels = shape.n_input_channels;
-  const int n_output_channels = shape.n_output_channels;
-
-  switch (n_output_channels % 4) {
-    case 0:
-      transform_kernel<0>(
-        kernel, n_input_channels, n_output_channels,
-        matrix_base, matrix_stride, matrix_row_stride
-      );
-      break;
-    case 1:
-      transform_kernel<1>(
-        kernel, n_input_channels, n_output_channels,
-        matrix_base, matrix_stride, matrix_row_stride
-      );
-      break;
-    case 2:
-      transform_kernel<2>(
-        kernel, n_input_channels, n_output_channels,
-        matrix_base, matrix_stride, matrix_row_stride
-      );
-      break;
-    case 3:
-      transform_kernel<3>(
-        kernel, n_input_channels, n_output_channels,
-        matrix_base, matrix_stride, matrix_row_stride
-      );
-      break;
-    default:
-        ARM_COMPUTE_ERROR("Cannot happen");
-        break;
-  }
-}
-
-template <>
-template<const int output_channel_tail>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel(
-    const float* const kernel,
-    const int n_input_channels,
-    const int n_output_channels,
-    float* const matrix_base,
-    const int mstride,
-    const int matrix_row_stride
-) {
-  // Use one input pointer for each row of the kernel, use two additional
-  // offsets to extract columns.
-  const int kernel_col_stride = n_input_channels * n_output_channels;
-  const int kernel_row_stride = 3 * kernel_col_stride;
-  const float *inptr0 = kernel;
-  const float *inptr1 = kernel + kernel_row_stride;
-  const float *inptr2 = kernel + kernel_row_stride*2;
-
-  // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
-  // offsets to extract further matrices.
-  float  *outptr0 = matrix_base;
-  float  *outptr4 = matrix_base + mstride * 4;
-  float  *outptr8 = matrix_base + mstride * 8;
-  float *outptr12 = matrix_base + mstride * 12;
-
-  // For every input channel
-  for (int in_c = 0; in_c < n_input_channels; in_c++) {
-    // For every output channel
-    for (int c = 0; c < n_output_channels; c++) {
-      // Read in the kernel
-      float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2];
-      float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2];
-      float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2];
-
-      // Progress input pointers
-      inptr0++;
-      inptr1++;
-      inptr2++;
-
-      // Compute the kernel W w, note we need only compute the middle two rows
-      // (2 and 3) because the first and last rows are merely copies of values
-      // from the matrix w.
-      float Ww11 = w11, Ww12 = w12, Ww13 = w13;
-      float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33);
-      float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33);
-      float Ww41 = w31, Ww42 = w32, Ww43 = w33;
-
-      // Hence compute W w W.T; again note we need compute only the middle two
-      // columns since the first and last columns are copies of the first and
-      // last columns of the previous matrix.
-      float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13;
-      float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23;
-      float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33;
-      float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43;
-
-      // Store the computed weights
-      outptr0[0 * mstride] = WwWT11;
-      outptr0[1 * mstride] = WwWT12;
-      outptr0[2 * mstride] = WwWT13;
-      outptr0[3 * mstride] = WwWT14;
-
-      outptr4[0 * mstride] = WwWT21;
-      outptr4[1 * mstride] = WwWT22;
-      outptr4[2 * mstride] = WwWT23;
-      outptr4[3 * mstride] = WwWT24;
-
-      outptr8[0 * mstride] = WwWT31;
-      outptr8[1 * mstride] = WwWT32;
-      outptr8[2 * mstride] = WwWT33;
-      outptr8[3 * mstride] = WwWT34;
-
-      outptr12[0 * mstride] = WwWT41;
-      outptr12[1 * mstride] = WwWT42;
-      outptr12[2 * mstride] = WwWT43;
-      outptr12[3 * mstride] = WwWT44;
-
-      // Progress output pointers
-      outptr0++;
-      outptr4++;
-      outptr8++;
-      outptr12++;
-    }
-
-    // Progression to complete stride
-    outptr0 += matrix_row_stride - n_output_channels;
-    outptr4 += matrix_row_stride - n_output_channels;
-    outptr8 += matrix_row_stride - n_output_channels;
-    outptr12 += matrix_row_stride - n_output_channels;
-  }
-}
-}
diff --git a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp
deleted file mode 100644
index 3dd62d1..0000000
--- a/src/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,822 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-#ifdef __aarch64__
-namespace winograd {
-template <>
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<0>(
-    const float* const kernel,
-    const int n_input_channels,
-    const int n_output_channels,
-    float* const matrix_base,
-    const int mstride,
-    const int matrix_row_stride
-) {
-  // Use one input pointer for each row of the kernel, use two additional
-  // offsets to extract columns.
-  const int kernel_col_stride = n_input_channels * n_output_channels;
-  const int kernel_row_stride = 3 * kernel_col_stride;
-  const float *inptr0 = kernel;
-  const float *inptr1 = kernel + kernel_row_stride;
-  const float *inptr2 = kernel + kernel_row_stride*2;
-
-  // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
-  // offsets to extract further matrices.
-  float  *outptr0 = matrix_base;
-  float  *outptr4 = matrix_base + mstride * 4;
-  float  *outptr8 = matrix_base + mstride * 8;
-  float *outptr12 = matrix_base + mstride * 12;
-
-  // For every input channel
-  for (int in_c = 0; in_c < n_input_channels; in_c++) {
-    int n_remaining_channels = n_output_channels;
-
-    asm volatile (
-        // Registers into which to read the kernel
-        "w_11 .req v0\n"  "qw_11 .req q0\n"
-        "w_12 .req v1\n"  "qw_12 .req q1\n"
-        "w_13 .req v2\n"  "qw_13 .req q2\n"
-        "w_21 .req v3\n"  "qw_21 .req q3\n"
-        "w_22 .req v4\n"  "qw_22 .req q4\n"
-        "w_23 .req v5\n"  "qw_23 .req q5\n"
-        "w_31 .req v6\n"  "qw_31 .req q6\n"
-        "w_32 .req v7\n"  "qw_32 .req q7\n"
-        "w_33 .req v8\n"  "qw_33 .req q8\n"
-
-        // Transformed matrix Ww
-        "Ww11 .req w_11\n"  "Ww12 .req w_12\n"  "Ww13 .req w_13\n"
-        "Ww21 .req  v9\n"   "Ww22 .req v10\n"   "Ww23 .req v11\n"
-        "Ww31 .req v12\n"   "Ww32 .req v13\n"   "Ww33 .req v14\n"
-        "Ww41 .req w_31\n"  "Ww42 .req w_32\n"  "Ww43 .req w_33\n"
-
-        // Output matrix U = WwWT
-        "U11 .req Ww11\n"   "U12 .req v15\n"  "U13 .req v16\n"  "U14 .req Ww13\n"
-        "U21 .req Ww21\n"   "U22 .req v17\n"  "U23 .req v18\n"  "U24 .req Ww23\n"
-        "U31 .req Ww31\n"   "U32 .req v19\n"  "U33 .req v20\n"  "U34 .req Ww33\n"
-        "U41 .req Ww41\n"   "U42 .req v21\n"  "U43 .req v22\n"  "U44 .req Ww43\n"
-
-        // Storage view of output matrices
-        "qU11 .req   q0\n"   "qU12 .req q15\n"  "qU13 .req q16\n"  "qU14 .req   q2\n"
-        "qU21 .req   q9\n"   "qU22 .req q17\n"  "qU23 .req q18\n"  "qU24 .req  q11\n"
-        "qU31 .req  q12\n"   "qU32 .req q19\n"  "qU33 .req q20\n"  "qU34 .req  q14\n"
-        "qU41 .req   q6\n"   "qU42 .req q21\n"  "qU43 .req q22\n"  "qU44 .req   q8\n"
-
-        "half .req v23\n"  // {0.5, ..., 0.5}
-        "dup half.4s, %w[one_half]\n"
-        "scratch .req v24\n"
-
-        "1:"
-          // Load tile of the kernel
-          "ldr qw_11, [%x[inptr0]]\n"
-          "str qU11, [%x[outptr0]]\n"
-          "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
-          "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
-          "str qU14, [%x[outptr0], %x[mstride3]]\n"
-          "add %x[inptr0], %x[inptr0], #0x10\n"
-
-          "ldr qw_21, [%x[inptr1]]\n"
-          "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
-          "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
-          "add %x[inptr1], %x[inptr1], #0x10\n"
-
-          "ldr qw_31, [%x[inptr2]]\n"
-          "str qU41, [%x[outptr12]]\n"
-          "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
-          "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
-          "str qU44, [%x[outptr12], %x[mstride3]]\n"
-          "add %x[inptr2], %x[inptr2], #0x10\n"
-
-          // Compute 2nd and 3rd rows of Ww
-          "fadd scratch.4s, w_11.4s, w_31.4s\n"
-          "fmul Ww21.4s, scratch.4s, half.4s\n"
-          "fmla Ww21.4s, w_21.4s, half.4s\n"
-          "str qU21, [%x[outptr4]]\n"
-          "fmul Ww31.4s, scratch.4s, half.4s\n"
-          "fmls Ww31.4s, w_21.4s, half.4s\n"
-          "str qU31, [%x[outptr8]]\n"
-
-          "fadd scratch.4s, w_12.4s, w_32.4s\n"
-          "fmul Ww22.4s, scratch.4s, half.4s\n"
-          "fmla Ww22.4s, w_22.4s, half.4s\n"
-          "fmul Ww32.4s, scratch.4s, half.4s\n"
-          "fmls Ww32.4s, w_22.4s, half.4s\n"
-
-          "fadd scratch.4s, w_13.4s, w_33.4s\n"
-          "fmul Ww23.4s, scratch.4s, half.4s\n"
-          "fmla Ww23.4s, w_23.4s, half.4s\n"
-          "str qU24, [%x[outptr4], %x[mstride3]]\n"
-          "fmul Ww33.4s, scratch.4s, half.4s\n"
-          "fmls Ww33.4s, w_23.4s, half.4s\n"
-          "str qU34, [%x[outptr8], %x[mstride3]]\n"
-
-          // Compute and store U, only need to compute the 2nd and 3rd columns
-          // of U and update output pointers
-          "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
-          "fmul U12.4s, scratch.4s, half.4s\n"
-          "fmla U12.4s, Ww12.4s, half.4s\n"
-          "str qU12, [%x[outptr0], %x[mstride1]]\n"
-          "fmul U13.4s, scratch.4s, half.4s\n"
-          "fmls U13.4s, Ww12.4s, half.4s\n"
-          "str qU13, [%x[outptr0], %x[mstride2]]\n"
-          "add  %x[outptr0],  %x[outptr0], #0x10\n"
-
-          "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
-          "fmul U22.4s, scratch.4s, half.4s\n"
-          "fmla U22.4s, Ww22.4s, half.4s\n"
-          "str qU22, [%x[outptr4], %x[mstride1]]\n"
-          "fmul U23.4s, scratch.4s, half.4s\n"
-          "fmls U23.4s, Ww22.4s, half.4s\n"
-          "str qU23, [%x[outptr4], %x[mstride2]]\n"
-          "add  %x[outptr4],  %x[outptr4], #0x10\n"
-
-          "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
-          "fmul U32.4s, scratch.4s, half.4s\n"
-          "fmla U32.4s, Ww32.4s, half.4s\n"
-          "str qU32, [%x[outptr8], %x[mstride1]]\n"
-          "fmul U33.4s, scratch.4s, half.4s\n"
-          "fmls U33.4s, Ww32.4s, half.4s\n"
-          "str qU33, [%x[outptr8], %x[mstride2]]\n"
-          "add  %x[outptr8],  %x[outptr8], #0x10\n"
-
-          "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
-          "fmul U42.4s, scratch.4s, half.4s\n"
-          "fmla U42.4s, Ww42.4s, half.4s\n"
-          "str qU42, [%x[outptr12], %x[mstride1]]\n"
-          "fmul U43.4s, scratch.4s, half.4s\n"
-          "fmls U43.4s, Ww42.4s, half.4s\n"
-          "str qU43, [%x[outptr12], %x[mstride2]]\n"
-          "add %x[outptr12], %x[outptr12], #0x10\n"
-
-          "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
-          "bne 1b\n"
-
-        // Clear aliases
-        ".unreq half\n"
-        ".unreq scratch\n"
-        ".unreq w_11\n"  ".unreq qw_11\n"
-        ".unreq w_12\n"  ".unreq qw_12\n"
-        ".unreq w_13\n"  ".unreq qw_13\n"
-        ".unreq w_21\n"  ".unreq qw_21\n"
-        ".unreq w_22\n"  ".unreq qw_22\n"
-        ".unreq w_23\n"  ".unreq qw_23\n"
-        ".unreq w_31\n"  ".unreq qw_31\n"
-        ".unreq w_32\n"  ".unreq qw_32\n"
-        ".unreq w_33\n"  ".unreq qw_33\n"
-        ".unreq Ww11\n"  ".unreq Ww12\n"  ".unreq Ww13\n"
-        ".unreq Ww21\n"  ".unreq Ww22\n"  ".unreq Ww23\n"
-        ".unreq Ww31\n"  ".unreq Ww32\n"  ".unreq Ww33\n"
-        ".unreq Ww41\n"  ".unreq Ww42\n"  ".unreq Ww43\n"
-        ".unreq U11\n"   ".unreq U12\n"   ".unreq U13\n"   ".unreq U14\n"
-        ".unreq U21\n"   ".unreq U22\n"   ".unreq U23\n"   ".unreq U24\n"
-        ".unreq U31\n"   ".unreq U32\n"   ".unreq U33\n"   ".unreq U34\n"
-        ".unreq U41\n"   ".unreq U42\n"   ".unreq U43\n"   ".unreq U44\n"
-        ".unreq qU11\n"  ".unreq qU12\n"  ".unreq qU13\n"  ".unreq qU14\n"
-        ".unreq qU21\n"  ".unreq qU22\n"  ".unreq qU23\n"  ".unreq qU24\n"
-        ".unreq qU31\n"  ".unreq qU32\n"  ".unreq qU33\n"  ".unreq qU34\n"
-        ".unreq qU41\n"  ".unreq qU42\n"  ".unreq qU43\n"  ".unreq qU44\n"
-
-      : [inptr0] "+r" (inptr0),
-        [inptr1] "+r" (inptr1),
-        [inptr2] "+r" (inptr2),
-        [outptr0] "+r" (outptr0),
-        [outptr4] "+r" (outptr4),
-        [outptr8] "+r" (outptr8),
-        [outptr12] "+r" (outptr12),
-        [n_remaining_channels] "+r" (n_remaining_channels)
-      : [mstride1] "r" (sizeof(float) * mstride),
-        [mstride2] "r" (sizeof(float) * mstride * 2),
-        [mstride3] "r" (sizeof(float) * mstride * 3),
-        [colstride1] "r" (sizeof(float) * kernel_col_stride),
-        [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
-        [one_half] "r" (0.5f)
-      : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-        "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-        "v20", "v21", "v22", "v23", "v24"
-    );
-
-    // Progression to complete stride
-    outptr0 += matrix_row_stride - n_output_channels;
-    outptr4 += matrix_row_stride - n_output_channels;
-    outptr8 += matrix_row_stride - n_output_channels;
-    outptr12 += matrix_row_stride - n_output_channels;
-  }
-}
-
-template <>
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<2>(
-    const float* const kernel,
-    const int n_input_channels,
-    const int n_output_channels,
-    float* const matrix_base,
-    const int mstride,
-    const int matrix_row_stride
-) {
-  // Use one input pointer for each row of the kernel, use two additional
-  // offsets to extract columns.
-  const int kernel_col_stride = n_input_channels * n_output_channels;
-  const int kernel_row_stride = 3 * kernel_col_stride;
-  const float *inptr0 = kernel;
-  const float *inptr1 = kernel + kernel_row_stride;
-  const float *inptr2 = kernel + kernel_row_stride*2;
-
-  // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
-  // offsets to extract further matrices.
-  float  *outptr0 = matrix_base;
-  float  *outptr4 = matrix_base + mstride * 4;
-  float  *outptr8 = matrix_base + mstride * 8;
-  float *outptr12 = matrix_base + mstride * 12;
-
-  // For every input channel
-  for (int in_c = 0; in_c < n_input_channels; in_c++) {
-    int n_remaining_channels = n_output_channels;
-
-    asm volatile (
-        // Registers into which to read the kernel
-        "w_11 .req v0\n"  "qw_11 .req q0\n"  "dw_11 .req d0\n"
-        "w_12 .req v1\n"  "qw_12 .req q1\n"  "dw_12 .req d1\n"
-        "w_13 .req v2\n"  "qw_13 .req q2\n"  "dw_13 .req d2\n"
-        "w_21 .req v3\n"  "qw_21 .req q3\n"  "dw_21 .req d3\n"
-        "w_22 .req v4\n"  "qw_22 .req q4\n"  "dw_22 .req d4\n"
-        "w_23 .req v5\n"  "qw_23 .req q5\n"  "dw_23 .req d5\n"
-        "w_31 .req v6\n"  "qw_31 .req q6\n"  "dw_31 .req d6\n"
-        "w_32 .req v7\n"  "qw_32 .req q7\n"  "dw_32 .req d7\n"
-        "w_33 .req v8\n"  "qw_33 .req q8\n"  "dw_33 .req d8\n"
-
-        // Transformed matrix Ww
-        "Ww11 .req w_11\n"  "Ww12 .req w_12\n"  "Ww13 .req w_13\n"
-        "Ww21 .req  v9\n"   "Ww22 .req v10\n"   "Ww23 .req v11\n"
-        "Ww31 .req v12\n"   "Ww32 .req v13\n"   "Ww33 .req v14\n"
-        "Ww41 .req w_31\n"  "Ww42 .req w_32\n"  "Ww43 .req w_33\n"
-
-        // Output matrix U = WwWT
-        "U11 .req Ww11\n"   "U12 .req v15\n"  "U13 .req v16\n"  "U14 .req Ww13\n"
-        "U21 .req Ww21\n"   "U22 .req v17\n"  "U23 .req v18\n"  "U24 .req Ww23\n"
-        "U31 .req Ww31\n"   "U32 .req v19\n"  "U33 .req v20\n"  "U34 .req Ww33\n"
-        "U41 .req Ww41\n"   "U42 .req v21\n"  "U43 .req v22\n"  "U44 .req Ww43\n"
-
-        // Storage view of output matrices
-        "qU11 .req   q0\n"   "qU12 .req q15\n"  "qU13 .req q16\n"  "qU14 .req   q2\n"
-        "qU21 .req   q9\n"   "qU22 .req q17\n"  "qU23 .req q18\n"  "qU24 .req  q11\n"
-        "qU31 .req  q12\n"   "qU32 .req q19\n"  "qU33 .req q20\n"  "qU34 .req  q14\n"
-        "qU41 .req   q6\n"   "qU42 .req q21\n"  "qU43 .req q22\n"  "qU44 .req   q8\n"
-
-        "dU11 .req   d0\n"   "dU12 .req d15\n"  "dU13 .req d16\n"  "dU14 .req   d2\n"
-        "dU21 .req   d9\n"   "dU22 .req d17\n"  "dU23 .req d18\n"  "dU24 .req  d11\n"
-        "dU31 .req  d12\n"   "dU32 .req d19\n"  "dU33 .req d20\n"  "dU34 .req  d14\n"
-        "dU41 .req   d6\n"   "dU42 .req d21\n"  "dU43 .req d22\n"  "dU44 .req   d8\n"
-
-        "half .req v23\n"  // {0.5, ..., 0.5}
-        "dup half.4s, %w[one_half]\n"
-        "scratch .req v24\n"
-        
-        // Subtract the tail from the number of remaining channels and jump to
-        // the tail if necessary.
-        "subs %x[n_remaining_channels], %x[n_remaining_channels], #2\n"
-        "beq 2f\n"
-
-        "1:"
-          // Load tile of the kernel
-          "ldr qw_11, [%x[inptr0]]\n"
-          "str qU11, [%x[outptr0]]\n"
-          "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
-          "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
-          "str qU14, [%x[outptr0], %x[mstride3]]\n"
-          "add %x[inptr0], %x[inptr0], #0x10\n"
-
-          "ldr qw_21, [%x[inptr1]]\n"
-          "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
-          "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
-          "add %x[inptr1], %x[inptr1], #0x10\n"
-
-          "ldr qw_31, [%x[inptr2]]\n"
-          "str qU41, [%x[outptr12]]\n"
-          "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
-          "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
-          "str qU44, [%x[outptr12], %x[mstride3]]\n"
-          "add %x[inptr2], %x[inptr2], #0x10\n"
-
-          // Compute 2nd and 3rd rows of Ww
-          "fadd scratch.4s, w_11.4s, w_31.4s\n"
-          "fmul Ww21.4s, scratch.4s, half.4s\n"
-          "fmla Ww21.4s, w_21.4s, half.4s\n"
-          "str qU21, [%x[outptr4]]\n"
-          "fmul Ww31.4s, scratch.4s, half.4s\n"
-          "fmls Ww31.4s, w_21.4s, half.4s\n"
-          "str qU31, [%x[outptr8]]\n"
-
-          "fadd scratch.4s, w_12.4s, w_32.4s\n"
-          "fmul Ww22.4s, scratch.4s, half.4s\n"
-          "fmla Ww22.4s, w_22.4s, half.4s\n"
-          "fmul Ww32.4s, scratch.4s, half.4s\n"
-          "fmls Ww32.4s, w_22.4s, half.4s\n"
-
-          "fadd scratch.4s, w_13.4s, w_33.4s\n"
-          "fmul Ww23.4s, scratch.4s, half.4s\n"
-          "fmla Ww23.4s, w_23.4s, half.4s\n"
-          "str qU24, [%x[outptr4], %x[mstride3]]\n"
-          "fmul Ww33.4s, scratch.4s, half.4s\n"
-          "fmls Ww33.4s, w_23.4s, half.4s\n"
-          "str qU34, [%x[outptr8], %x[mstride3]]\n"
-
-          // Compute and store U, only need to compute the 2nd and 3rd columns
-          // of U and update output pointers
-          "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
-          "fmul U12.4s, scratch.4s, half.4s\n"
-          "fmla U12.4s, Ww12.4s, half.4s\n"
-          "str qU12, [%x[outptr0], %x[mstride1]]\n"
-          "fmul U13.4s, scratch.4s, half.4s\n"
-          "fmls U13.4s, Ww12.4s, half.4s\n"
-          "str qU13, [%x[outptr0], %x[mstride2]]\n"
-          "add  %x[outptr0],  %x[outptr0], #0x10\n"
-
-          "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
-          "fmul U22.4s, scratch.4s, half.4s\n"
-          "fmla U22.4s, Ww22.4s, half.4s\n"
-          "str qU22, [%x[outptr4], %x[mstride1]]\n"
-          "fmul U23.4s, scratch.4s, half.4s\n"
-          "fmls U23.4s, Ww22.4s, half.4s\n"
-          "str qU23, [%x[outptr4], %x[mstride2]]\n"
-          "add  %x[outptr4],  %x[outptr4], #0x10\n"
-
-          "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
-          "fmul U32.4s, scratch.4s, half.4s\n"
-          "fmla U32.4s, Ww32.4s, half.4s\n"
-          "str qU32, [%x[outptr8], %x[mstride1]]\n"
-          "fmul U33.4s, scratch.4s, half.4s\n"
-          "fmls U33.4s, Ww32.4s, half.4s\n"
-          "str qU33, [%x[outptr8], %x[mstride2]]\n"
-          "add  %x[outptr8],  %x[outptr8], #0x10\n"
-
-          "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
-          "fmul U42.4s, scratch.4s, half.4s\n"
-          "fmla U42.4s, Ww42.4s, half.4s\n"
-          "str qU42, [%x[outptr12], %x[mstride1]]\n"
-          "fmul U43.4s, scratch.4s, half.4s\n"
-          "fmls U43.4s, Ww42.4s, half.4s\n"
-          "str qU43, [%x[outptr12], %x[mstride2]]\n"
-          "add %x[outptr12], %x[outptr12], #0x10\n"
-
-          "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
-          "bne 1b\n"
-
-        // Tail size 2
-        "2:"
-          // Load tile of the kernel
-          "ldr dw_11, [%x[inptr0]]\n"
-          "str dU11, [%x[outptr0]]\n"
-          "ldr dw_12, [%x[inptr0], %x[colstride1]]\n"
-          "ldr dw_13, [%x[inptr0], %x[colstride2]]\n"
-          "str dU14, [%x[outptr0], %x[mstride3]]\n"
-          "add %x[inptr0], %x[inptr0], #0x08\n"
-
-          "ldr dw_21, [%x[inptr1]]\n"
-          "ldr dw_22, [%x[inptr1], %x[colstride1]]\n"
-          "ldr dw_23, [%x[inptr1], %x[colstride2]]\n"
-          "add %x[inptr1], %x[inptr1], #0x08\n"
-
-          "ldr dw_31, [%x[inptr2]]\n"
-          "str dU41, [%x[outptr12]]\n"
-          "ldr dw_32, [%x[inptr2], %x[colstride1]]\n"
-          "ldr dw_33, [%x[inptr2], %x[colstride2]]\n"
-          "str dU44, [%x[outptr12], %x[mstride3]]\n"
-          "add %x[inptr2], %x[inptr2], #0x08\n"
-
-          // Compute 2nd and 3rd rows of Ww
-          "fadd scratch.2s, w_11.2s, w_31.2s\n"
-          "fmul Ww21.2s, scratch.2s, half.2s\n"
-          "fmla Ww21.2s, w_21.2s, half.2s\n"
-          "str dU21, [%x[outptr4]]\n"
-          "fmul Ww31.2s, scratch.2s, half.2s\n"
-          "fmls Ww31.2s, w_21.2s, half.2s\n"
-          "str dU31, [%x[outptr8]]\n"
-
-          "fadd scratch.2s, w_12.2s, w_32.2s\n"
-          "fmul Ww22.2s, scratch.2s, half.2s\n"
-          "fmla Ww22.2s, w_22.2s, half.2s\n"
-          "fmul Ww32.2s, scratch.2s, half.2s\n"
-          "fmls Ww32.2s, w_22.2s, half.2s\n"
-
-          "fadd scratch.2s, w_13.2s, w_33.2s\n"
-          "fmul Ww23.2s, scratch.2s, half.2s\n"
-          "fmla Ww23.2s, w_23.2s, half.2s\n"
-          "str dU24, [%x[outptr4], %x[mstride3]]\n"
-          "fmul Ww33.2s, scratch.2s, half.2s\n"
-          "fmls Ww33.2s, w_23.2s, half.2s\n"
-          "str dU34, [%x[outptr8], %x[mstride3]]\n"
-
-          // Compute and store U, only need to compute the 2nd and 3rd columns of
-          // U and update output pointers
-          "fadd scratch.2s, Ww11.2s, Ww13.2s\n"
-          "fmul U12.2s, scratch.2s, half.2s\n"
-          "fmla U12.2s, Ww12.2s, half.2s\n"
-          "str dU12, [%x[outptr0], %x[mstride1]]\n"
-          "fmul U13.2s, scratch.2s, half.2s\n"
-          "fmls U13.2s, Ww12.2s, half.2s\n"
-          "str dU13, [%x[outptr0], %x[mstride2]]\n"
-          "add  %x[outptr0],  %x[outptr0], #0x08\n"
-
-          "fadd scratch.2s, Ww21.2s, Ww23.2s\n"
-          "fmul U22.2s, scratch.2s, half.2s\n"
-          "fmla U22.2s, Ww22.2s, half.2s\n"
-          "str dU22, [%x[outptr4], %x[mstride1]]\n"
-          "fmul U23.2s, scratch.2s, half.2s\n"
-          "fmls U23.2s, Ww22.2s, half.2s\n"
-          "str dU23, [%x[outptr4], %x[mstride2]]\n"
-          "add  %x[outptr4],  %x[outptr4], #0x08\n"
-
-          "fadd scratch.2s, Ww31.2s, Ww33.2s\n"
-          "fmul U32.2s, scratch.2s, half.2s\n"
-          "fmla U32.2s, Ww32.2s, half.2s\n"
-          "str dU32, [%x[outptr8], %x[mstride1]]\n"
-          "fmul U33.2s, scratch.2s, half.2s\n"
-          "fmls U33.2s, Ww32.2s, half.2s\n"
-          "str dU33, [%x[outptr8], %x[mstride2]]\n"
-          "add  %x[outptr8],  %x[outptr8], #0x08\n"
-
-          "fadd scratch.2s, Ww41.2s, Ww43.2s\n"
-          "fmul U42.2s, scratch.2s, half.2s\n"
-          "fmla U42.2s, Ww42.2s, half.2s\n"
-          "str dU42, [%x[outptr12], %x[mstride1]]\n"
-          "fmul U43.2s, scratch.2s, half.2s\n"
-          "fmls U43.2s, Ww42.2s, half.2s\n"
-          "str dU43, [%x[outptr12], %x[mstride2]]\n"
-          "add %x[outptr12], %x[outptr12], #0x08\n"
-
-        // Clear aliases
-        ".unreq half\n"
-        ".unreq scratch\n"
-        ".unreq w_11\n"  ".unreq qw_11\n" ".unreq dw_11\n"
-        ".unreq w_12\n"  ".unreq qw_12\n" ".unreq dw_12\n"
-        ".unreq w_13\n"  ".unreq qw_13\n" ".unreq dw_13\n"
-        ".unreq w_21\n"  ".unreq qw_21\n" ".unreq dw_21\n"
-        ".unreq w_22\n"  ".unreq qw_22\n" ".unreq dw_22\n"
-        ".unreq w_23\n"  ".unreq qw_23\n" ".unreq dw_23\n"
-        ".unreq w_31\n"  ".unreq qw_31\n" ".unreq dw_31\n"
-        ".unreq w_32\n"  ".unreq qw_32\n" ".unreq dw_32\n"
-        ".unreq w_33\n"  ".unreq qw_33\n" ".unreq dw_33\n"
-        ".unreq Ww11\n"  ".unreq Ww12\n"  ".unreq Ww13\n"
-        ".unreq Ww21\n"  ".unreq Ww22\n"  ".unreq Ww23\n"
-        ".unreq Ww31\n"  ".unreq Ww32\n"  ".unreq Ww33\n"
-        ".unreq Ww41\n"  ".unreq Ww42\n"  ".unreq Ww43\n"
-        ".unreq U11\n"   ".unreq U12\n"   ".unreq U13\n"   ".unreq U14\n"
-        ".unreq U21\n"   ".unreq U22\n"   ".unreq U23\n"   ".unreq U24\n"
-        ".unreq U31\n"   ".unreq U32\n"   ".unreq U33\n"   ".unreq U34\n"
-        ".unreq U41\n"   ".unreq U42\n"   ".unreq U43\n"   ".unreq U44\n"
-        ".unreq qU11\n"  ".unreq qU12\n"  ".unreq qU13\n"  ".unreq qU14\n"
-        ".unreq qU21\n"  ".unreq qU22\n"  ".unreq qU23\n"  ".unreq qU24\n"
-        ".unreq qU31\n"  ".unreq qU32\n"  ".unreq qU33\n"  ".unreq qU34\n"
-        ".unreq qU41\n"  ".unreq qU42\n"  ".unreq qU43\n"  ".unreq qU44\n"
-        ".unreq dU11\n"  ".unreq dU12\n"  ".unreq dU13\n"  ".unreq dU14\n"
-        ".unreq dU21\n"  ".unreq dU22\n"  ".unreq dU23\n"  ".unreq dU24\n"
-        ".unreq dU31\n"  ".unreq dU32\n"  ".unreq dU33\n"  ".unreq dU34\n"
-        ".unreq dU41\n"  ".unreq dU42\n"  ".unreq dU43\n"  ".unreq dU44\n"
-
-      : [inptr0] "+r" (inptr0),
-        [inptr1] "+r" (inptr1),
-        [inptr2] "+r" (inptr2),
-        [outptr0] "+r" (outptr0),
-        [outptr4] "+r" (outptr4),
-        [outptr8] "+r" (outptr8),
-        [outptr12] "+r" (outptr12),
-        [n_remaining_channels] "+r" (n_remaining_channels)
-      : [mstride1] "r" (sizeof(float) * mstride),
-        [mstride2] "r" (sizeof(float) * mstride * 2),
-        [mstride3] "r" (sizeof(float) * mstride * 3),
-        [colstride1] "r" (sizeof(float) * kernel_col_stride),
-        [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
-        [one_half] "r" (0.5f)
-      : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-        "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-        "v20", "v21", "v22", "v23", "v24"
-    );
-
-    // Progression to complete stride
-    outptr0 += matrix_row_stride - n_output_channels;
-    outptr4 += matrix_row_stride - n_output_channels;
-    outptr8 += matrix_row_stride - n_output_channels;
-    outptr12 += matrix_row_stride - n_output_channels;
-  }
-}
-
-template <>
-template <>
-inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel<1>(
-    const float* const kernel,
-    const int n_input_channels,
-    const int n_output_channels,
-    float* const matrix_base,
-    const int mstride,
-    const int matrix_row_stride
-) {
-  // Use one input pointer for each row of the kernel, use two additional
-  // offsets to extract columns.
-  const int kernel_col_stride = n_input_channels * n_output_channels;
-  const int kernel_row_stride = 3 * kernel_col_stride;
-  const float *inptr0 = kernel;
-  const float *inptr1 = kernel + kernel_row_stride;
-  const float *inptr2 = kernel + kernel_row_stride*2;
-
-  // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
-  // offsets to extract further matrices.
-  float  *outptr0 = matrix_base;
-  float  *outptr4 = matrix_base + mstride * 4;
-  float  *outptr8 = matrix_base + mstride * 8;
-  float *outptr12 = matrix_base + mstride * 12;
-
-  // For every input channel
-  for (int in_c = 0; in_c < n_input_channels; in_c++) {
-    int n_remaining_channels = n_output_channels;
-
-    asm volatile (
-        // Registers into which to read the kernel
-        "w_11 .req v0\n"  "qw_11 .req q0\n"  "sw_11 .req s0\n"
-        "w_12 .req v1\n"  "qw_12 .req q1\n"  "sw_12 .req s1\n"
-        "w_13 .req v2\n"  "qw_13 .req q2\n"  "sw_13 .req s2\n"
-        "w_21 .req v3\n"  "qw_21 .req q3\n"  "sw_21 .req s3\n"
-        "w_22 .req v4\n"  "qw_22 .req q4\n"  "sw_22 .req s4\n"
-        "w_23 .req v5\n"  "qw_23 .req q5\n"  "sw_23 .req s5\n"
-        "w_31 .req v6\n"  "qw_31 .req q6\n"  "sw_31 .req s6\n"
-        "w_32 .req v7\n"  "qw_32 .req q7\n"  "sw_32 .req s7\n"
-        "w_33 .req v8\n"  "qw_33 .req q8\n"  "sw_33 .req s8\n"
-
-        // Transformed matrix Ww
-        "Ww11 .req w_11\n"  "Ww12 .req w_12\n"  "Ww13 .req w_13\n"
-        "Ww21 .req  v9\n"   "Ww22 .req v10\n"   "Ww23 .req v11\n"
-        "Ww31 .req v12\n"   "Ww32 .req v13\n"   "Ww33 .req v14\n"
-        "Ww41 .req w_31\n"  "Ww42 .req w_32\n"  "Ww43 .req w_33\n"
-
-        // Output matrix U = WwWT
-        "U11 .req Ww11\n"   "U12 .req v15\n"  "U13 .req v16\n"  "U14 .req Ww13\n"
-        "U21 .req Ww21\n"   "U22 .req v17\n"  "U23 .req v18\n"  "U24 .req Ww23\n"
-        "U31 .req Ww31\n"   "U32 .req v19\n"  "U33 .req v20\n"  "U34 .req Ww33\n"
-        "U41 .req Ww41\n"   "U42 .req v21\n"  "U43 .req v22\n"  "U44 .req Ww43\n"
-
-        // Storage view of output matrices
-        "qU11 .req   q0\n"   "qU12 .req q15\n"  "qU13 .req q16\n"  "qU14 .req   q2\n"
-        "qU21 .req   q9\n"   "qU22 .req q17\n"  "qU23 .req q18\n"  "qU24 .req  q11\n"
-        "qU31 .req  q12\n"   "qU32 .req q19\n"  "qU33 .req q20\n"  "qU34 .req  q14\n"
-        "qU41 .req   q6\n"   "qU42 .req q21\n"  "qU43 .req q22\n"  "qU44 .req   q8\n"
-
-        "sU11 .req   s0\n"   "sU12 .req s15\n"  "sU13 .req s16\n"  "sU14 .req   s2\n"
-        "sU21 .req   s9\n"   "sU22 .req s17\n"  "sU23 .req s18\n"  "sU24 .req  s11\n"
-        "sU31 .req  s12\n"   "sU32 .req s19\n"  "sU33 .req s20\n"  "sU34 .req  s14\n"
-        "sU41 .req   s6\n"   "sU42 .req s21\n"  "sU43 .req s22\n"  "sU44 .req   s8\n"
-
-        "half .req v23\n"  // {0.5, ..., 0.5}
-        "dup half.4s, %w[one_half]\n"
-        "scratch .req v24\n"
-        
-        // Subtract the tail from the number of remaining channels and jump to
-        // the tail if necessary.
-        "subs %x[n_remaining_channels], %x[n_remaining_channels], #1\n"
-        "beq 2f\n"
-
-        "1:"
-          // Load tile of the kernel
-          "ldr qw_11, [%x[inptr0]]\n"
-          "str qU11, [%x[outptr0]]\n"
-          "ldr qw_12, [%x[inptr0], %x[colstride1]]\n"
-          "ldr qw_13, [%x[inptr0], %x[colstride2]]\n"
-          "str qU14, [%x[outptr0], %x[mstride3]]\n"
-          "add %x[inptr0], %x[inptr0], #0x10\n"
-
-          "ldr qw_21, [%x[inptr1]]\n"
-          "ldr qw_22, [%x[inptr1], %x[colstride1]]\n"
-          "ldr qw_23, [%x[inptr1], %x[colstride2]]\n"
-          "add %x[inptr1], %x[inptr1], #0x10\n"
-
-          "ldr qw_31, [%x[inptr2]]\n"
-          "str qU41, [%x[outptr12]]\n"
-          "ldr qw_32, [%x[inptr2], %x[colstride1]]\n"
-          "ldr qw_33, [%x[inptr2], %x[colstride2]]\n"
-          "str qU44, [%x[outptr12], %x[mstride3]]\n"
-          "add %x[inptr2], %x[inptr2], #0x10\n"
-
-          // Compute 2nd and 3rd rows of Ww
-          "fadd scratch.4s, w_11.4s, w_31.4s\n"
-          "fmul Ww21.4s, scratch.4s, half.4s\n"
-          "fmla Ww21.4s, w_21.4s, half.4s\n"
-          "str qU21, [%x[outptr4]]\n"
-          "fmul Ww31.4s, scratch.4s, half.4s\n"
-          "fmls Ww31.4s, w_21.4s, half.4s\n"
-          "str qU31, [%x[outptr8]]\n"
-
-          "fadd scratch.4s, w_12.4s, w_32.4s\n"
-          "fmul Ww22.4s, scratch.4s, half.4s\n"
-          "fmla Ww22.4s, w_22.4s, half.4s\n"
-          "fmul Ww32.4s, scratch.4s, half.4s\n"
-          "fmls Ww32.4s, w_22.4s, half.4s\n"
-
-          "fadd scratch.4s, w_13.4s, w_33.4s\n"
-          "fmul Ww23.4s, scratch.4s, half.4s\n"
-          "fmla Ww23.4s, w_23.4s, half.4s\n"
-          "str qU24, [%x[outptr4], %x[mstride3]]\n"
-          "fmul Ww33.4s, scratch.4s, half.4s\n"
-          "fmls Ww33.4s, w_23.4s, half.4s\n"
-          "str qU34, [%x[outptr8], %x[mstride3]]\n"
-
-          // Compute and store U, only need to compute the 2nd and 3rd columns
-          // of U and update output pointers
-          "fadd scratch.4s, Ww11.4s, Ww13.4s\n"
-          "fmul U12.4s, scratch.4s, half.4s\n"
-          "fmla U12.4s, Ww12.4s, half.4s\n"
-          "str qU12, [%x[outptr0], %x[mstride1]]\n"
-          "fmul U13.4s, scratch.4s, half.4s\n"
-          "fmls U13.4s, Ww12.4s, half.4s\n"
-          "str qU13, [%x[outptr0], %x[mstride2]]\n"
-          "add  %x[outptr0],  %x[outptr0], #0x10\n"
-
-          "fadd scratch.4s, Ww21.4s, Ww23.4s\n"
-          "fmul U22.4s, scratch.4s, half.4s\n"
-          "fmla U22.4s, Ww22.4s, half.4s\n"
-          "str qU22, [%x[outptr4], %x[mstride1]]\n"
-          "fmul U23.4s, scratch.4s, half.4s\n"
-          "fmls U23.4s, Ww22.4s, half.4s\n"
-          "str qU23, [%x[outptr4], %x[mstride2]]\n"
-          "add  %x[outptr4],  %x[outptr4], #0x10\n"
-
-          "fadd scratch.4s, Ww31.4s, Ww33.4s\n"
-          "fmul U32.4s, scratch.4s, half.4s\n"
-          "fmla U32.4s, Ww32.4s, half.4s\n"
-          "str qU32, [%x[outptr8], %x[mstride1]]\n"
-          "fmul U33.4s, scratch.4s, half.4s\n"
-          "fmls U33.4s, Ww32.4s, half.4s\n"
-          "str qU33, [%x[outptr8], %x[mstride2]]\n"
-          "add  %x[outptr8],  %x[outptr8], #0x10\n"
-
-          "fadd scratch.4s, Ww41.4s, Ww43.4s\n"
-          "fmul U42.4s, scratch.4s, half.4s\n"
-          "fmla U42.4s, Ww42.4s, half.4s\n"
-          "str qU42, [%x[outptr12], %x[mstride1]]\n"
-          "fmul U43.4s, scratch.4s, half.4s\n"
-          "fmls U43.4s, Ww42.4s, half.4s\n"
-          "str qU43, [%x[outptr12], %x[mstride2]]\n"
-          "add %x[outptr12], %x[outptr12], #0x10\n"
-
-          "subs %x[n_remaining_channels], %x[n_remaining_channels], #4\n"
-          "bne 1b\n"
-
-        // Tail size 1
-        "2:"
-          // Load tile of the kernel
-          "ldr sw_11, [%x[inptr0]]\n"
-          "str sU11, [%x[outptr0]]\n"
-          "ldr sw_12, [%x[inptr0], %x[colstride1]]\n"
-          "ldr sw_13, [%x[inptr0], %x[colstride2]]\n"
-          "str sU14, [%x[outptr0], %x[mstride3]]\n"
-          "add %x[inptr0], %x[inptr0], #0x04\n"
-
-          "ldr sw_21, [%x[inptr1]]\n"
-          "ldr sw_22, [%x[inptr1], %x[colstride1]]\n"
-          "ldr sw_23, [%x[inptr1], %x[colstride2]]\n"
-          "add %x[inptr1], %x[inptr1], #0x04\n"
-
-          "ldr sw_31, [%x[inptr2]]\n"
-          "str sU41, [%x[outptr12]]\n"
-          "ldr sw_32, [%x[inptr2], %x[colstride1]]\n"
-          "ldr sw_33, [%x[inptr2], %x[colstride2]]\n"
-          "str sU44, [%x[outptr12], %x[mstride3]]\n"
-          "add %x[inptr2], %x[inptr2], #0x04\n"
-
-          // Compute 2nd and 3rd rows of Ww
-          "fadd scratch.2s, w_11.2s, w_31.2s\n"
-          "fmul Ww21.2s, scratch.2s, half.2s\n"
-          "fmla Ww21.2s, w_21.2s, half.2s\n"
-          "str sU21, [%x[outptr4]]\n"
-          "fmul Ww31.2s, scratch.2s, half.2s\n"
-          "fmls Ww31.2s, w_21.2s, half.2s\n"
-          "str sU31, [%x[outptr8]]\n"
-
-          "fadd scratch.2s, w_12.2s, w_32.2s\n"
-          "fmul Ww22.2s, scratch.2s, half.2s\n"
-          "fmla Ww22.2s, w_22.2s, half.2s\n"
-          "fmul Ww32.2s, scratch.2s, half.2s\n"
-          "fmls Ww32.2s, w_22.2s, half.2s\n"
-
-          "fadd scratch.2s, w_13.2s, w_33.2s\n"
-          "fmul Ww23.2s, scratch.2s, half.2s\n"
-          "fmla Ww23.2s, w_23.2s, half.2s\n"
-          "str sU24, [%x[outptr4], %x[mstride3]]\n"
-          "fmul Ww33.2s, scratch.2s, half.2s\n"
-          "fmls Ww33.2s, w_23.2s, half.2s\n"
-          "str sU34, [%x[outptr8], %x[mstride3]]\n"
-
-          // Compute and store U, only need to compute the 2nd and 3rd columns of
-          // U and update output pointers
-          "fadd scratch.2s, Ww11.2s, Ww13.2s\n"
-          "fmul U12.2s, scratch.2s, half.2s\n"
-          "fmla U12.2s, Ww12.2s, half.2s\n"
-          "str sU12, [%x[outptr0], %x[mstride1]]\n"
-          "fmul U13.2s, scratch.2s, half.2s\n"
-          "fmls U13.2s, Ww12.2s, half.2s\n"
-          "str sU13, [%x[outptr0], %x[mstride2]]\n"
-          "add  %x[outptr0],  %x[outptr0], #0x04\n"
-
-          "fadd scratch.2s, Ww21.2s, Ww23.2s\n"
-          "fmul U22.2s, scratch.2s, half.2s\n"
-          "fmla U22.2s, Ww22.2s, half.2s\n"
-          "str sU22, [%x[outptr4], %x[mstride1]]\n"
-          "fmul U23.2s, scratch.2s, half.2s\n"
-          "fmls U23.2s, Ww22.2s, half.2s\n"
-          "str sU23, [%x[outptr4], %x[mstride2]]\n"
-          "add  %x[outptr4],  %x[outptr4], #0x04\n"
-
-          "fadd scratch.2s, Ww31.2s, Ww33.2s\n"
-          "fmul U32.2s, scratch.2s, half.2s\n"
-          "fmla U32.2s, Ww32.2s, half.2s\n"
-          "str sU32, [%x[outptr8], %x[mstride1]]\n"
-          "fmul U33.2s, scratch.2s, half.2s\n"
-          "fmls U33.2s, Ww32.2s, half.2s\n"
-          "str sU33, [%x[outptr8], %x[mstride2]]\n"
-          "add  %x[outptr8],  %x[outptr8], #0x04\n"
-
-          "fadd scratch.2s, Ww41.2s, Ww43.2s\n"
-          "fmul U42.2s, scratch.2s, half.2s\n"
-          "fmla U42.2s, Ww42.2s, half.2s\n"
-          "str sU42, [%x[outptr12], %x[mstride1]]\n"
-          "fmul U43.2s, scratch.2s, half.2s\n"
-          "fmls U43.2s, Ww42.2s, half.2s\n"
-          "str sU43, [%x[outptr12], %x[mstride2]]\n"
-          "add %x[outptr12], %x[outptr12], #0x04\n"
-
-        // Clear aliases
-        ".unreq half\n"
-        ".unreq scratch\n"
-        ".unreq w_11\n"  ".unreq qw_11\n" ".unreq sw_11\n"
-        ".unreq w_12\n"  ".unreq qw_12\n" ".unreq sw_12\n"
-        ".unreq w_13\n"  ".unreq qw_13\n" ".unreq sw_13\n"
-        ".unreq w_21\n"  ".unreq qw_21\n" ".unreq sw_21\n"
-        ".unreq w_22\n"  ".unreq qw_22\n" ".unreq sw_22\n"
-        ".unreq w_23\n"  ".unreq qw_23\n" ".unreq sw_23\n"
-        ".unreq w_31\n"  ".unreq qw_31\n" ".unreq sw_31\n"
-        ".unreq w_32\n"  ".unreq qw_32\n" ".unreq sw_32\n"
-        ".unreq w_33\n"  ".unreq qw_33\n" ".unreq sw_33\n"
-        ".unreq Ww11\n"  ".unreq Ww12\n"  ".unreq Ww13\n"
-        ".unreq Ww21\n"  ".unreq Ww22\n"  ".unreq Ww23\n"
-        ".unreq Ww31\n"  ".unreq Ww32\n"  ".unreq Ww33\n"
-        ".unreq Ww41\n"  ".unreq Ww42\n"  ".unreq Ww43\n"
-        ".unreq U11\n"   ".unreq U12\n"   ".unreq U13\n"   ".unreq U14\n"
-        ".unreq U21\n"   ".unreq U22\n"   ".unreq U23\n"   ".unreq U24\n"
-        ".unreq U31\n"   ".unreq U32\n"   ".unreq U33\n"   ".unreq U34\n"
-        ".unreq U41\n"   ".unreq U42\n"   ".unreq U43\n"   ".unreq U44\n"
-        ".unreq qU11\n"  ".unreq qU12\n"  ".unreq qU13\n"  ".unreq qU14\n"
-        ".unreq qU21\n"  ".unreq qU22\n"  ".unreq qU23\n"  ".unreq qU24\n"
-        ".unreq qU31\n"  ".unreq qU32\n"  ".unreq qU33\n"  ".unreq qU34\n"
-        ".unreq qU41\n"  ".unreq qU42\n"  ".unreq qU43\n"  ".unreq qU44\n"
-        ".unreq sU11\n"  ".unreq sU12\n"  ".unreq sU13\n"  ".unreq sU14\n"
-        ".unreq sU21\n"  ".unreq sU22\n"  ".unreq sU23\n"  ".unreq sU24\n"
-        ".unreq sU31\n"  ".unreq sU32\n"  ".unreq sU33\n"  ".unreq sU34\n"
-        ".unreq sU41\n"  ".unreq sU42\n"  ".unreq sU43\n"  ".unreq sU44\n"
-
-      : [inptr0] "+r" (inptr0),
-        [inptr1] "+r" (inptr1),
-        [inptr2] "+r" (inptr2),
-        [outptr0] "+r" (outptr0),
-        [outptr4] "+r" (outptr4),
-        [outptr8] "+r" (outptr8),
-        [outptr12] "+r" (outptr12),
-        [n_remaining_channels] "+r" (n_remaining_channels)
-      : [mstride1] "r" (sizeof(float) * mstride),
-        [mstride2] "r" (sizeof(float) * mstride * 2),
-        [mstride3] "r" (sizeof(float) * mstride * 3),
-        [colstride1] "r" (sizeof(float) * kernel_col_stride),
-        [colstride2] "r" (sizeof(float) * kernel_col_stride * 2),
-        [one_half] "r" (0.5f)
-      : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
-        "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
-        "v20", "v21", "v22", "v23", "v24"
-    );
-
-    // Progression to complete stride
-    outptr0 += matrix_row_stride - n_output_channels;
-    outptr4 += matrix_row_stride - n_output_channels;
-    outptr8 += matrix_row_stride - n_output_channels;
-    outptr12 += matrix_row_stride - n_output_channels;
-  }
-}
-}
-#endif  // __aarch64__
diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp
deleted file mode 100644
index 0992c0b..0000000
--- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3.hpp
+++ /dev/null
@@ -1,356 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-namespace winograd {
-  /* Transform from the Winograd domain back to the spatial domain.
-   */
-  template <typename T>
-  struct Winograd2x2_3x3GemmOutput {
-    static void execute(
-      const Tensor4DShape &output_shape,
-      T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      T* const output
-    );
-
-    protected:
-    /* Specialised implementation method. */
-    template <bool tail_M, bool tail_N, int channel_tail>
-    static void _execute(
-      const Tensor4DShape &output_shape,
-      T *output,
-      const T *input,
-      const int matrix_stride,
-      const int matrix_row_stride
-    );
-  };
-
-  /* Two-stage implementation of the transformation from the Winograd domain.
-   *
-   * First computes Z.F and then computes (Z.F).Z^T.
-   */
-  template <typename T>
-  struct Winograd2x2_3x3GemmOutput_TwoStage {
-    static void execute(
-      const Tensor4DShape &output_shape,
-      T* const matrix_base,
-      const int matrix_stride,
-      const int matrix_row_stride,
-      T* const output
-    );
-
-    protected:
-    template <int channel_tail>
-    static void compute_zf(
-      const int n_rows, const int n_channels,
-      T* const zf, const T* const input[16]
-    );
-
-    template <bool tail_M, bool tail_N, int channel_tail>
-    static void compute_zfzT(
-      const Tensor4DShape &output_shape,
-      T* const output, const T* const zf
-    );
-  };
-}
-
-#include "output_2x2_3x3/a64_float.hpp"
-// #include "output_2x2_3x3/a64_float_two_stage.hpp"
-
-/*****************************************************************************/
-/*
-template <typename T>
-void winograd::Winograd2x2_3x3GemmOutput<T>::execute(
-    const Tensor4DShape &output_shape,
-    const int tile_M,
-    const int tile_N,
-    T* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    T* const output
-) {
-  T* const antipadding = reinterpret_cast<T *>(malloc(sizeof(T) * output_shape.n_channels));
-
-  // Get input pointers
-  const T* inptrs[16];
-  for (int i = 0; i < 16; i++) {
-    inptrs[i] = matrices[i];
-  }
-
-  for (int batch = 0; batch < output_shape.n_batches; batch++) {
-    for (int tile_i = 0; tile_i < tile_M; tile_i++) {
-      for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-        // Get pointers for each of the 4 output cells required for this computation
-        T* outptrs[4];
-        for (int cell_i = 0, c = 0; cell_i < 2; cell_i++) {
-          for (int cell_j = 0; cell_j < 2; cell_j++, c++) {
-            const int i = tile_i*2 + cell_i;
-            const int j = tile_j*2 + cell_j;
-
-            if (i < output_shape.n_rows && j < output_shape.n_cols) {
-              outptrs[c] = output + (
-                  (batch*output_shape.n_rows + i) * output_shape.n_cols +
-                j) * output_shape.n_channels;
-            } else {
-              outptrs[c] = antipadding;
-            }
-          }  // cell_j
-        }  // cell_i
-
-        for (int n = 0; n < output_shape.n_channels; n++) {
-          // Read 16 values and progress pointers
-          T v[16];
-          for (int i = 0; i < 16; i++) {
-            v[i] = *(inptrs[i]++);
-          }
-
-          // Compute output for 4 pixels
-          *(outptrs[0]++) = v[ 0] + v[ 1] + v[ 2] +
-                            v[ 4] + v[ 5] + v[ 6] +
-                            v[ 8] + v[ 9] + v[10];
-          *(outptrs[1]++) = v[ 1] - v[ 2] - v[ 3] +
-                            v[ 5] - v[ 6] - v[ 7] +
-                            v[ 9] - v[10] - v[11];
-          *(outptrs[2]++) = v[ 4] + v[ 5] + v[ 6] -
-                            v[ 8] - v[ 9] - v[10] -
-                            v[12] - v[13] - v[14];
-          *(outptrs[3]++) = v[ 5] - v[ 6] - v[ 7] -
-                            v[ 9] + v[10] + v[11] -
-                            v[13] + v[14] + v[15];
-        }  // output_channel
-      }  // tile_j
-    }  // tile_i
-  }  // batch
-
-  free(antipadding);
-}
-*/
-
-/*****************************************************************************/
-/*
-template <typename T>
-void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::execute(
-    const Tensor4DShape &output_shape,
-    T* const matrices[16], T* const output
-) {
-  // Allocate memory for the intermediate matrices
-  const int tile_M = iceildiv(output_shape.n_rows, 2);
-  const int tile_N = iceildiv(output_shape.n_cols, 2);
-  const int n_rows = output_shape.n_batches * tile_M * tile_N;
-  const int n_channels = output_shape.n_channels;
-  T* matrices_zf = reinterpret_cast<T*>(
-    calloc(8 * n_rows * n_channels, sizeof(T))
-  );
-  
-  // Perform the first stage transform, computing ZF.
-  // Specializations should dispatch to different methods based on tail size.
-  compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
-  
-  // Perform the second stage transform, finishing Z F Z^T - variable dispatch
-  // based on size of the output. Specialisations can also dispatch based on
-  // the tail-size of the channel.
-  if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
-    compute_zfzT<true, true, 0>(output_shape, output, matrices_zf);
-  } else if (output_shape.n_rows % 2) {
-    compute_zfzT<true, false, 0>(output_shape, output, matrices_zf);
-  } else if (output_shape.n_cols % 2) {
-    compute_zfzT<false, true, 0>(output_shape, output, matrices_zf);
-  } else {
-    compute_zfzT<false, false, 0>(output_shape, output, matrices_zf);
-  }
-
-  free(reinterpret_cast<void*>(matrices_zf));
-}
-
-template <typename T>
-template <int channel_tail>
-void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zf(
-    const int n_rows, const int n_channels,
-    T* output, const T* const input[16]
-) {
-  // Extract 8 output pointers
-  T* outptr[8];
-  for (int i = 0; i < 8; i++) {
-    outptr[i] = output + i*n_rows*n_channels;
-  }
-
-  // Copy the 16 input pointers
-  const T* inptr[16];
-  for (int i = 0; i < 16; i++) {
-    inptr[i] = input[i];
-  }
-
-  // For every row of the matrices
-  for (int i = 0; i < n_rows; i++) {
-    // For every channel
-    for (int j = 0; j < n_channels; j++) {
-      // Extract values from the input matrices
-      T val[16];
-      for (int n = 0; n < 16; n++) {
-        val[n] = *(inptr[n]++);
-      }
-
-      // Compute output values
-      *(outptr[0]++) = val[0] + val[1] + val[2];
-      *(outptr[1]++) = val[1] - val[2] - val[3];
-      *(outptr[2]++) = val[4] + val[5] + val[6];
-      *(outptr[3]++) = val[5] - val[6] - val[7];
-      *(outptr[4]++) = val[8] + val[9] + val[10];
-      *(outptr[5]++) = val[9] - val[10] - val[11];
-      *(outptr[6]++) = val[12] + val[13] + val[14];
-      *(outptr[7]++) = val[13] - val[14] - val[15];
-    }
-  }
-}
-
-template <typename T>
-template <bool tail_M, bool tail_N, int channel_tail>
-void winograd::Winograd2x2_3x3GemmOutput_TwoStage<T>::compute_zfzT(
-    const Tensor4DShape &output_shape,
-    T* const output, const T* const input
-) {
-  // Sizing information
-  const int tile_M = output_shape.n_rows / 2;
-  const int tile_N = output_shape.n_cols / 2;
-
-  const int n_rows = (output_shape.n_batches *
-                      (tile_M + (tail_M ? 1 : 0)) *
-                      (tile_N + (tail_N ? 1 : 0)));
-  const int n_channels = output_shape.n_channels;
-
-  // Extract 8 input pointers
-  const T* inptr[8];
-  for (int i = 0; i < 8; i++) {
-    inptr[i] = input + i*n_rows*n_channels;
-  }
-
-  // Extract 4 output pointers
-  T* outptr00 = output;
-  T* outptr01 = outptr00 + n_channels;
-  T* outptr10 = outptr00 + output_shape.n_cols * n_channels;
-  T* outptr11 = outptr10 + n_channels;
-
-  // Progress over the output tiles, generating output values.
-  for (int batch = 0; batch < output_shape.n_batches; batch++) {
-    for (int tile_i = 0; tile_i < tile_M; tile_i++) {
-      for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          T v[8];
-          for (int i = 0; i < 8; i++) {
-            v[i] = *(inptr[i]++);
-          }
-
-          // Compute the output values and progress the output pointers.
-          *(outptr00++) = v[0] + v[2] + v[4];
-          *(outptr01++) = v[1] + v[3] + v[5];
-          *(outptr10++) = v[2] - v[4] - v[6];
-          *(outptr11++) = v[3] - v[5] - v[7];
-        }
-
-        // Progress the output pointers to the next column
-        outptr00 += n_channels;
-        outptr01 += n_channels;
-        outptr10 += n_channels;
-        outptr11 += n_channels;
-      }
-
-      if (tail_N) {
-        // Only evaluate the left-most columns of the output
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          T v[8];
-          for (int i = 0; i < 4; i++) {
-            v[i * 2] = *inptr[i * 2];
-          }
-          for (int i = 0; i < 8; i++) {
-            inptr[i]++;
-          }
-
-          // Compute the output values and progress the output pointers.
-          *(outptr00++) = v[0] + v[2] + v[4];
-          *(outptr10++) = v[2] - v[4] - v[6];
-        }
-
-        // Progress the output pointers to the next column
-        outptr01 += n_channels;  // Account for being skipped above
-        outptr11 += n_channels;  // Account for being skipped above
-      }
-
-      // Progress the output pointers to the next row
-      outptr00 += output_shape.n_cols * n_channels;
-      outptr01 += output_shape.n_cols * n_channels;
-      outptr10 += output_shape.n_cols * n_channels;
-      outptr11 += output_shape.n_cols * n_channels;
-    }
-
-    if (tail_M) {
-      // Only work on the upper row of the output
-      for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          T v[8];
-          for (int i = 0; i < 8; i++) {
-            v[i] = *(inptr[i]++);
-          }
-
-          // Compute the output values and progress the output pointers.
-          *(outptr00++) = v[0] + v[2] + v[4];
-          *(outptr01++) = v[1] + v[3] + v[5];
-        }
-
-        // Progress the output pointers to the next column
-        outptr00 += n_channels;
-        outptr01 += n_channels;
-        outptr10 += 2 * n_channels;  // Account for being skipped above
-        outptr11 += 2 * n_channels;  // Account for being skipped above
-      }
-
-      if (tail_N) {
-        // Only evaluate the upper-left cell of the output
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          T v[8];
-          for (int i = 0; i < 3; i++) {
-            v[i * 2] = *inptr[i * 2];
-          }
-          for (int i = 0; i < 8; i++) {
-            inptr[i]++;
-          }
-
-          // Compute the output values and progress the output pointers.
-          *(outptr00++) = v[0] + v[2] + v[4];
-        }
-
-        // Progress the output pointers to the next column
-        outptr01 += n_channels;  // Account for being skipped above
-        outptr10 += n_channels;  // Account for being skipped above
-        outptr11 += n_channels;  // Account for being skipped above
-      }
-    }
-  }
-}
-*/
diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
deleted file mode 100644
index bf6ba90..0000000
--- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float.hpp
+++ /dev/null
@@ -1,650 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-/* Float implementation for AArch64.
- */
-#ifdef __aarch64__
-namespace winograd {
-
-
-template <>
-template <>
-inline void Winograd2x2_3x3GemmOutput<float>::_execute<false, false, 0>(
-    const Tensor4DShape &output_shape,
-    float *output,
-    const float *input,
-    const int mstride,
-    const int matrix_row_stride
-) {
-  const int tile_M = output_shape.n_rows / 2;
-  const int tile_N = output_shape.n_cols / 2;
-  int batch = output_shape.n_batches;
-  float *outptr = output;
-
-  const float *inptr0 = input;
-  const float *inptr4 = input + 4 * mstride;
-  const float *inptr8 = input + 8 * mstride;
-  const float *inptr12 = input + 12 * mstride;
-
-  const size_t col_stride = sizeof(float) * output_shape.n_channels;
-  const size_t row_stride = col_stride * tile_N * 2;
-
-  asm volatile (
-      // Aliases for elements of the input matrix `F`
-      // V-register      Q-register
-      "F11 .req  v0\n" "qF11 .req  q0\n"
-      "F12 .req  v1\n" "qF12 .req  q1\n"
-      "F13 .req  v2\n" "qF13 .req  q2\n"
-      "F14 .req  v3\n" "qF14 .req  q3\n"
-      "F21 .req  v4\n" "qF21 .req  q4\n"
-      "F22 .req  v5\n" "qF22 .req  q5\n"
-      "F23 .req  v6\n" "qF23 .req  q6\n"
-      "F24 .req  v7\n" "qF24 .req  q7\n"
-      "F31 .req  v8\n" "qF31 .req  q8\n"
-      "F32 .req  v9\n" "qF32 .req  q9\n"
-      "F33 .req v10\n" "qF33 .req q10\n"
-      "F34 .req v11\n" "qF34 .req q11\n"
-      "F41 .req v12\n" "qF41 .req q12\n"
-      "F42 .req v13\n" "qF42 .req q13\n"
-      "F43 .req v14\n" "qF43 .req q14\n"
-      "F44 .req v15\n" "qF44 .req q15\n"
-
-      // Aliases for elements of the intermediate matrix `FZ`
-      "FZ11 .req v16\n"
-      "FZ12 .req v17\n"
-      "FZ21 .req v18\n"
-      "FZ22 .req v19\n"
-      "FZ31 .req v20\n"
-      "FZ32 .req v21\n"
-      "FZ41 .req v22\n"
-      "FZ42 .req v23\n"
-
-      // Aliases for elements of the output matrix `f` (called `g` due to case
-      // insensitivity of aliases).
-      " g11 .req v24\n"
-      "qg11 .req q24\n"
-      " g12 .req v25\n"
-      "qg12 .req q25\n"
-      " g21 .req v26\n"
-      "qg21 .req q26\n"
-      " g22 .req v27\n"
-      "qg22 .req q27\n"
-
-      // Prepare the various strides
-      "col_stride .req %x[col_stride]\n"
-      "row_stride .req %x[row_stride]\n"
-      "row_plus_col_stride .req %x[row_plus_col_stride]\n"
-
-      "mstride1 .req %x[mstride1]\n"
-      "mstride2 .req %x[mstride2]\n"
-      "mstride3 .req %x[mstride3]\n"
-
-      "tile_i  .req x19\n"  // Tile row counter
-      "tile_j  .req x20\n"  // Tile column counter
-      "channel .req x21\n"  // Channel counter
-
-      "1:"  // Loop over batches
-        "mov tile_i, %x[tile_M]\n"  // Reset tile row counter
-
-        "2:"  // Loop over rows of tiles
-          "mov tile_j, %x[tile_N]\n"  // Reset tile column counter
-
-          "3:"  // Loop over columns of tiles
-            // Perform initial loads of the matrix `F`
-            "ldr qF11, [%x[inptr0]]\n"
-            "ldr qF12, [%x[inptr0], mstride1]\n"
-            "ldr qF13, [%x[inptr0], mstride2]\n"
-            "ldr qF14, [%x[inptr0], mstride3]\n"
-            "add %x[inptr0], %x[inptr0], #0x10\n"
-            "ldr qF21, [%x[inptr4]]\n"
-            "ldr qF22, [%x[inptr4], mstride1]\n"
-            "subs channel, %x[n_channels], #4\n"  // Reset channel counter
-
-            "ldr qF23, [%x[inptr4], mstride2]\n"
-            "ldr qF24, [%x[inptr4], mstride3]\n"
-            "add %x[inptr4], %x[inptr4], #0x10\n"
-            "beq 5f\n"  // Jump straight to tail if necessary
-
-            "4:"  // Loop over channels
-              "ldr qF31, [%x[inptr8]]\n"
-              "fadd FZ11.4s,  F11.4s, F12.4s\n"
-
-              "ldr qF32, [%x[inptr8], mstride1]\n"
-              "fsub FZ12.4s,  F12.4s, F13.4s\n"
-
-              "ldr qF33, [%x[inptr8], mstride2]\n"
-              "fadd FZ11.4s, FZ11.4s, F13.4s\n"
-
-              "ldr qF34, [%x[inptr8], mstride3]\n"
-              "fsub FZ12.4s, FZ12.4s, F14.4s\n"
-
-              "ldr qF41, [%x[inptr12]]\n"
-              "fadd FZ21.4s,  F21.4s, F22.4s\n"
-
-              "ldr qF42, [%x[inptr12], mstride1]\n"
-              "fsub FZ22.4s,  F22.4s, F23.4s\n"
-
-              "ldr qF43, [%x[inptr12], mstride2]\n"
-              "fadd FZ21.4s, FZ21.4s, F23.4s\n"
-
-              "ldr qF44, [%x[inptr12], mstride3]\n"
-              "fsub FZ22.4s, FZ22.4s, F24.4s\n"
-
-              "fadd FZ31.4s,  F31.4s, F32.4s\n"
-              "add %x[inptr8], %x[inptr8], #0x10\n"
-
-              "fsub FZ32.4s,  F32.4s, F33.4s\n"
-              "add %x[inptr12], %x[inptr12], #0x10\n"
-
-              "fadd FZ31.4s, FZ31.4s, F33.4s\n"
-
-              "fsub FZ32.4s, FZ32.4s, F34.4s\n"
-
-              "fadd g11.4s, FZ11.4s, FZ21.4s\n"
-
-              "fadd g12.4s, FZ12.4s, FZ22.4s\n"
-
-              "fadd g11.4s,  g11.4s, FZ31.4s\n"
-
-              "fadd g12.4s,  g12.4s, FZ32.4s\n"
-
-              "ldr qF11, [%x[inptr0]]\n"
-              "fadd FZ41.4s,  F41.4s, F42.4s\n"
-
-              "ldr qF12, [%x[inptr0], mstride1]\n"
-              "fsub g21.4s, FZ21.4s, FZ31.4s\n"
-
-              "ldr qF13, [%x[inptr0], mstride2]\n"
-              "fsub FZ42.4s,  F42.4s, F43.4s\n"
-
-              "ldr qF14, [%x[inptr0], mstride3]\n"
-              "str qg11, [%x[outptr]]\n"
-
-              "ldr qF21, [%x[inptr4]]\n"
-              "fadd FZ41.4s, FZ41.4s, F43.4s\n"
-
-              "ldr qF22, [%x[inptr4], mstride1]\n"
-              "str qg12, [%x[outptr], col_stride]\n"
-
-              "ldr qF23, [%x[inptr4], mstride2]\n"
-              "fsub FZ42.4s, FZ42.4s, F44.4s\n"
-
-              "ldr qF24, [%x[inptr4], mstride3]\n"
-              "fsub g22.4s, FZ22.4s, FZ32.4s\n"
-
-              "fsub g21.4s,  g21.4s, FZ41.4s\n"
-              "add %x[inptr0], %x[inptr0], #0x10\n"
-
-              "fsub g22.4s,  g22.4s, FZ42.4s\n"
-              "add %x[inptr4], %x[inptr4], #0x10\n"
-
-              "subs channel, channel, #4\n"
-
-              "str qg21, [%x[outptr], row_stride]\n"
-
-              "str qg22, [%x[outptr], row_plus_col_stride]\n"
-
-              "add %x[outptr], %x[outptr], #0x10\n"
-
-              "bne 4b\n"
-
-            "5:"  // Channel tail
-              "ldr qF31, [%x[inptr8]]\n"
-              "fadd FZ11.4s,  F11.4s, F12.4s\n"
-
-              "ldr qF32, [%x[inptr8], mstride1]\n"
-              "fsub FZ12.4s,  F12.4s, F13.4s\n"
-
-              "ldr qF33, [%x[inptr8], mstride2]\n"
-              "fadd FZ11.4s, FZ11.4s, F13.4s\n"
-
-              "ldr qF34, [%x[inptr8], mstride3]\n"
-              "fsub FZ12.4s, FZ12.4s, F14.4s\n"
-
-              "ldr qF41, [%x[inptr12]]\n"
-              "fadd FZ21.4s,  F21.4s, F22.4s\n"
-
-              "ldr qF42, [%x[inptr12], mstride1]\n"
-              "fsub FZ22.4s,  F22.4s, F23.4s\n"
-
-              "ldr qF43, [%x[inptr12], mstride2]\n"
-              "fadd FZ21.4s, FZ21.4s, F23.4s\n"
-
-              "ldr qF44, [%x[inptr12], mstride3]\n"
-              "fsub FZ22.4s, FZ22.4s, F24.4s\n"
-
-              "fadd FZ31.4s,  F31.4s, F32.4s\n"
-              "add %x[inptr8], %x[inptr8], #0x10\n"
-
-              "fsub FZ32.4s,  F32.4s, F33.4s\n"
-              "add %x[inptr12], %x[inptr12], #0x10\n"
-
-              "fadd FZ31.4s, FZ31.4s, F33.4s\n"
-
-              "fsub FZ32.4s, FZ32.4s, F34.4s\n"
-
-              "fadd g11.4s, FZ11.4s, FZ21.4s\n"
-
-              "fadd g12.4s, FZ12.4s, FZ22.4s\n"
-
-              "fadd g11.4s,  g11.4s, FZ31.4s\n"
-
-              "fadd g12.4s,  g12.4s, FZ32.4s\n"
-
-              "fadd FZ41.4s,  F41.4s, F42.4s\n"
-
-              "fsub g21.4s, FZ21.4s, FZ31.4s\n"
-
-              "fsub FZ42.4s,  F42.4s, F43.4s\n"
-
-              "str qg11, [%x[outptr]]\n"
-
-              "fadd FZ41.4s, FZ41.4s, F43.4s\n"
-
-              "str qg12, [%x[outptr], col_stride]\n"
-
-              "fsub FZ42.4s, FZ42.4s, F44.4s\n"
-
-              "fsub g22.4s, FZ22.4s, FZ32.4s\n"
-
-              "fsub g21.4s,  g21.4s, FZ41.4s\n"
-
-              "fsub g22.4s,  g22.4s, FZ42.4s\n"
-
-              "subs channel, channel, #4\n"
-
-              "str qg21, [%x[outptr], row_stride]\n"
-
-              // Progress input pointers to the next row of the matrix
-              "add  %x[inptr0],  %x[inptr0], %x[mrowpad]\n"
-              "add  %x[inptr4],  %x[inptr4], %x[mrowpad]\n"
-              "add  %x[inptr8],  %x[inptr8], %x[mrowpad]\n"
-              "add %x[inptr12], %x[inptr12], %x[mrowpad]\n"
-
-              "str qg22, [%x[outptr], row_plus_col_stride]\n"
-
-              "add %x[outptr], %x[outptr], #0x10\n"
-
-
-            "add %x[outptr], %x[outptr], col_stride\n"
-            "subs tile_j, tile_j, #1\n"
-            "bne 3b\n"
-
-          "add %x[outptr], %x[outptr], row_stride\n"
-          "subs tile_i, tile_i, #1\n"
-          "bne 2b\n"
-
-        "subs %w[batch], %w[batch], #1\n"
-        "bne 1b\n"
-
-      ".unreq  F11\n" ".unreq qF11\n"
-      ".unreq  F12\n" ".unreq qF12\n"
-      ".unreq  F13\n" ".unreq qF13\n"
-      ".unreq  F14\n" ".unreq qF14\n"
-      ".unreq  F21\n" ".unreq qF21\n"
-      ".unreq  F22\n" ".unreq qF22\n"
-      ".unreq  F23\n" ".unreq qF23\n"
-      ".unreq  F24\n" ".unreq qF24\n"
-      ".unreq  F31\n" ".unreq qF31\n"
-      ".unreq  F32\n" ".unreq qF32\n"
-      ".unreq  F33\n" ".unreq qF33\n"
-      ".unreq  F34\n" ".unreq qF34\n"
-      ".unreq  F41\n" ".unreq qF41\n"
-      ".unreq  F42\n" ".unreq qF42\n"
-      ".unreq  F43\n" ".unreq qF43\n"
-      ".unreq  F44\n" ".unreq qF44\n"
-
-      ".unreq FZ11\n" ".unreq FZ12\n"
-      ".unreq FZ21\n" ".unreq FZ22\n"
-      ".unreq FZ31\n" ".unreq FZ32\n"
-      ".unreq FZ41\n" ".unreq FZ42\n"
-
-      ".unreq  g11\n" ".unreq qg11\n"
-      ".unreq  g12\n" ".unreq qg12\n"
-      ".unreq  g21\n" ".unreq qg21\n"
-      ".unreq  g22\n" ".unreq qg22\n"
-
-      ".unreq col_stride\n"
-      ".unreq row_stride\n"
-      ".unreq row_plus_col_stride\n"
-
-      ".unreq mstride1\n"
-      ".unreq mstride2\n"
-      ".unreq mstride3\n"
-
-      ".unreq tile_i \n"
-      ".unreq tile_j \n"
-      ".unreq channel\n"
-
-    : [batch] "+r" (batch),
-      [outptr] "+r" (outptr),
-      [inptr0] "+r" (inptr0),
-      [inptr4] "+r" (inptr4),
-      [inptr8] "+r" (inptr8),
-      [inptr12] "+r" (inptr12)
-    : [tile_M] "r" (tile_M),
-      [tile_N] "r" (tile_N),
-      [n_channels] "r" (output_shape.n_channels),
-      [col_stride] "r" (col_stride),
-      [row_stride] "r" (row_stride),
-      [row_plus_col_stride] "r" (row_stride + col_stride),
-      [mstride1] "r" (mstride * sizeof(float)),
-      [mstride2] "r" (2 * mstride * sizeof(float)),
-      [mstride3] "r" (3 * mstride * sizeof(float)),
-      [mrowpad] "r" ((matrix_row_stride - output_shape.n_channels) * sizeof(float))
-    : "x19", "x20", "x21",
-      "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
-      "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
-      "v22", "v23", "v24", "v25", "v26", "v27",
-      "cc", "memory"
-  );
-}
-
-template <>
-template <bool tail_M, bool tail_N, const int channel_tail>
-inline void Winograd2x2_3x3GemmOutput<float>::_execute(
-    const Tensor4DShape &output_shape,
-    float *output,
-    const float *input,
-    const int mstride,
-    const int matrix_row_stride
-) {
-  // Compute basic information about the shape of the matrices
-  const int tile_M = output_shape.n_rows / 2;
-  const int tile_N = output_shape.n_cols / 2;
-  const int n_channels = output_shape.n_channels;
-
-  // Extract 16 input pointers
-  const float* inptr[16];
-  for (int i = 0; i < 16; i++) {
-    inptr[i] = input + i*mstride;
-  }
-
-  // Extract 4 output pointers
-  float *outptr00 = output;
-  float *outptr01 = outptr00 + n_channels;
-  float *outptr10 = outptr00 + output_shape.n_cols * n_channels;
-  float *outptr11 = outptr10 + n_channels;
-
-  // Progress over the output tiles, generating output values.
-  for (int batch = 0; batch < output_shape.n_batches; batch++) {
-    for (int tile_i = 0; tile_i < tile_M; tile_i++) {
-      for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          float F[4][4];
-          for (int i = 0; i < 4; i++) {
-            for (int j = 0; j < 4; j++) {
-              F[i][j] = *(inptr[i*4 + j]++);
-            }
-          }
-
-          // Compute the matrix F.Z
-          float ZF[4][2];
-          ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
-          ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
-          ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
-          ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
-          ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
-          ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
-          ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
-          ZF[3][1] = F[3][1] - F[3][2] - F[3][3];
-
-          // Hence compute the output matrix Z^T . (F.Z)
-          *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
-          *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
-          *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
-          *(outptr11++) = ZF[1][1] - ZF[2][1] - ZF[3][1];
-        }
-
-        // Progress the input pointers to the next row
-        for (int i = 0; i < 16; i++) {
-          inptr[i] += matrix_row_stride - n_channels;
-        }
-
-        // Progress the output pointers to the next column
-        outptr00 += n_channels;
-        outptr01 += n_channels;
-        outptr10 += n_channels;
-        outptr11 += n_channels;
-      }
-
-      if (tail_N) {
-        // Only evaluate the left-most columns of the output
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          float F[4][3];
-          for (int i = 0; i < 4; i++) {
-            for (int j = 0; j < 3; j++) {
-              F[i][j] = *(inptr[i*4 + j]++);
-            }
-          }
-          for (int i = 0; i < 4; i++) {
-            inptr[i*4 + 3]++;
-          }
-
-          // Compute the matrix F.Z
-          float ZF[4][1];
-          ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
-          ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
-          ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
-          ZF[3][0] = F[3][0] + F[3][1] + F[3][2];
-
-          // Hence compute the output matrix Z^T . (F.Z)
-          *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
-          *(outptr10++) = ZF[1][0] - ZF[2][0] - ZF[3][0];
-        }
-
-        // Progress the input pointers to the next row
-        for (int i = 0; i < 16; i++) {
-          inptr[i] += matrix_row_stride - n_channels;
-        }
-
-        // Progress the output pointers to the next column
-        outptr01 += n_channels;  // Account for being skipped above
-        outptr11 += n_channels;  // Account for being skipped above
-      }
-
-      // Progress the output pointers to the next row
-      outptr00 += output_shape.n_cols * n_channels;
-      outptr01 += output_shape.n_cols * n_channels;
-      outptr10 += output_shape.n_cols * n_channels;
-      outptr11 += output_shape.n_cols * n_channels;
-    }
-
-    if (tail_M) {
-      // Only work on the upper row of the output
-      for (int tile_j = 0; tile_j < tile_N; tile_j++) {
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          float F[3][4];
-          for (int i = 0; i < 3; i++) {
-            for (int j = 0; j < 4; j++) {
-              F[i][j] = *(inptr[i*4 + j]++);
-            }
-          }
-          for (int j = 0; j < 4; j++) {
-            inptr[12 + j]++;
-          }
-
-          // Compute the matrix F.Z
-          float ZF[3][2];
-          ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
-          ZF[0][1] = F[0][1] - F[0][2] - F[0][3];
-          ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
-          ZF[1][1] = F[1][1] - F[1][2] - F[1][3];
-          ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
-          ZF[2][1] = F[2][1] - F[2][2] - F[2][3];
-
-          // Hence compute the output matrix Z^T . (F.Z)
-          *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
-          *(outptr01++) = ZF[0][1] + ZF[1][1] + ZF[2][1];
-        }
-
-        // Progress the input pointers to the next row
-        for (int i = 0; i < 16; i++) {
-          inptr[i] += matrix_row_stride - n_channels;
-        }
-
-        // Progress the output pointers to the next column
-        outptr00 += n_channels;
-        outptr01 += n_channels;
-        outptr10 += 2 * n_channels;  // Account for being skipped above
-        outptr11 += 2 * n_channels;  // Account for being skipped above
-      }
-
-      if (tail_N) {
-        // Only evaluate the upper-left cell of the output
-        for (int channel = 0; channel < n_channels; channel++) {
-          // Read values from the input pointers
-          float F[3][3];
-          for (int i = 0; i < 3; i++) {
-            for (int j = 0; j < 3; j++) {
-              F[i][j] = *(inptr[i*4 + j]);
-            }
-          }
-          for (int i = 0; i < 16; i++) {
-            inptr[i]++;
-          }
-
-          // Compute the matrix F.Z
-          float ZF[3][1];
-          ZF[0][0] = F[0][0] + F[0][1] + F[0][2];
-          ZF[1][0] = F[1][0] + F[1][1] + F[1][2];
-          ZF[2][0] = F[2][0] + F[2][1] + F[2][2];
-
-          // Hence compute the output matrix Z^T . (F.Z)
-          *(outptr00++) = ZF[0][0] + ZF[1][0] + ZF[2][0];
-        }
-
-        // Progress the input pointers to the next row
-        for (int i = 0; i < 16; i++) {
-          inptr[i] += matrix_row_stride - n_channels;
-        }
-
-        // Progress the output pointers to the next column
-        outptr01 += n_channels;  // Account for being skipped above
-        outptr10 += n_channels;  // Account for being skipped above
-        outptr11 += n_channels;  // Account for being skipped above
-      }
-    }
-  }
-}
-
-/*****************************************************************************/
-template <>
-inline void Winograd2x2_3x3GemmOutput<float>::execute(
-    const Tensor4DShape &output_shape,
-    float* const matrix_base,
-    const int matrix_stride,
-    const int matrix_row_stride,
-    float* const output
-) {
-  // Dispatch to an appropriate implementation based on the shape of the output
-  // tensor.
-  if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
-    constexpr bool tail_M = true, tail_N = true;
-    switch (output_shape.n_channels % 4) {
-      case 0:
-        _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 1:
-        _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 2:
-        _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 3:
-        _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      default:
-        assert(0);
-        break;
-    }
-  } else if (output_shape.n_rows % 2) {
-    constexpr bool tail_M = true, tail_N = false;
-    switch (output_shape.n_channels % 4) {
-      case 0:
-        _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 1:
-        _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 2:
-        _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 3:
-        _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      default:
-        assert(0);
-        break;
-    }
-  } else if (output_shape.n_cols % 2) {
-    constexpr bool tail_M = false, tail_N = true;
-    switch (output_shape.n_channels % 4) {
-      case 0:
-        _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 1:
-        _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 2:
-        _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 3:
-        _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      default:
-        assert(0);
-        break;
-
-    }
-  } else {
-    constexpr bool tail_M = false, tail_N = false;
-    switch (output_shape.n_channels % 4) {
-      case 0:
-        _execute<tail_M, tail_N, 0>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 1:
-        _execute<tail_M, tail_N, 1>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 2:
-        _execute<tail_M, tail_N, 2>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      case 3:
-        _execute<tail_M, tail_N, 3>(output_shape, output, matrix_base, matrix_stride, matrix_row_stride);
-        break;
-      default:
-        assert(0);
-        break;
-
-    }
-  }
-}
-/*****************************************************************************/
-
-}  // namespace winograd
-#endif  // __aarch64__
diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
deleted file mode 100644
index f551b12..0000000
--- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3/a64_float_two_stage.hpp
+++ /dev/null
@@ -1,655 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-
-#ifdef __aarch64__
-
-/*****************************************************************************/
-// Compute ZF specializations
-
-template <>
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zf<0>(
-    const int n_rows, const int n_channels,
-    float* output, const float* const input[16]
-) {
-  // Make copies of some variables
-  int row = n_rows;
-  float* outptr = output;
-  const float* inptr = input[0];
-
-  // Perform the transformation
-  asm volatile (
-    // "inptr0 .req %x[inptr]\n"
-    "inptr1 .req x0\n"
-    "inptr2 .req x1\n"
-    "inptr3 .req x2\n"
-    "inptr4 .req x3\n"
-    "inptr5 .req x4\n"
-    "inptr6 .req x5\n"
-    "inptr7 .req x6\n"
-    "inptr8 .req x7\n"
-    "inptr9 .req x8\n"
-    "inptr10 .req x9\n"
-    "inptr11 .req x10\n"
-    "inptr12 .req x11\n"
-    "inptr13 .req x12\n"
-    "inptr14 .req x13\n"
-    "inptr15 .req x14\n"
-
-    // "outptr0 .req %x[outptr]\n"
-    "outptr1 .req x15\n"
-    "outptr2 .req x16\n"
-    "outptr3 .req x17\n"
-    "outptr4 .req x18\n"
-    "outptr5 .req x19\n"
-    "outptr6 .req x20\n"
-    "outptr7 .req x21\n"
-
-    // Compute additional pointers into the input and output matrices.
-    "mstride .req x22\n"  // Matrix stride
-    "mul mstride, %x[row], %x[n_channels]\n"
-    "lsl mstride, mstride, #2\n"  // * sizeof(float)
-
-    "add inptr1, %x[inptr], mstride\n"
-    "add inptr2, %x[inptr], mstride, LSL #1\n"
-    "add inptr3, inptr2, mstride\n"
-    "add inptr4, inptr3, mstride\n"
-    "add inptr5, inptr4, mstride\n"
-    "add inptr6, inptr5, mstride\n"
-    "add inptr7, inptr6, mstride\n"
-    "add inptr8, inptr7, mstride\n"
-    "add inptr9, inptr8, mstride\n"
-    "add inptr10, inptr9, mstride\n"
-    "add inptr11, inptr10, mstride\n"
-    "add inptr12, inptr11, mstride\n"
-    "add inptr13, inptr12, mstride\n"
-    "add inptr14, inptr13, mstride\n"
-    "add inptr15, inptr14, mstride\n"
-
-    "add outptr1, %[outptr], mstride\n"
-    "add outptr2, outptr1, mstride\n"
-    "add outptr3, outptr2, mstride\n"
-    "add outptr4, outptr3, mstride\n"
-    "add outptr5, outptr4, mstride\n"
-    "add outptr6, outptr5, mstride\n"
-    "add outptr7, outptr6, mstride\n"
-
-    ".unreq mstride\n"
-
-    "column .req x22\n"  // Column loop counter
-
-    "1:"  // Loop over rows
-      "ldr q0, [%x[inptr]], #0x10\n"
-      "ldr q1, [inptr1], #0x10\n"
-      "ldr q2, [inptr2], #0x10\n"
-      "ldr q3, [inptr3], #0x10\n"
-      "ldr q4, [inptr4], #0x10\n"
-      "ldr q5, [inptr5], #0x10\n"
-      "ldr q6, [inptr6], #0x10\n"
-      "ldr q7, [inptr7], #0x10\n"
-      "subs column, %x[n_channels], #0x4\n"
-      "beq 3f\n"
-
-      "2:"  // Loop over columns
-        "ldr q8, [inptr8], #0x10\n"
-        "prfm pldl1keep, [%x[inptr], #196]\n"
-        "fadd v16.4s, v0.4s, v1.4s\n"
-
-        "ldr q9, [inptr9], #0x10\n"
-        "prfm pldl1keep, [inptr1, #196]\n"
-        "fsub v17.4s, v1.4s, v2.4s\n"
-
-        "ldr q10, [inptr10], #0x10\n"
-        "prfm pldl1keep, [inptr2, #196]\n"
-        "fadd v16.4s, v16.4s, v2.4s\n"
-
-        "ldr q11, [inptr11], #0x10\n"
-        "prfm pldl1keep, [inptr3, #196]\n"
-        "fsub v17.4s, v17.4s, v3.4s\n"
-
-        "ldr q12, [inptr12], #0x10\n"
-        "prfm pldl1keep, [inptr4, #196]\n"
-        "str q16, [%x[outptr]], #0x10\n"
-
-        "ldr q13, [inptr13], #0x10\n"
-        "prfm pldl1keep, [inptr5, #196]\n"
-        "str q17, [outptr1], #0x10\n"
-
-        "ldr q14, [inptr14], #0x10\n"
-        "prfm pldl1keep, [inptr6, #196]\n"
-        "fadd v16.4s, v4.4s, v5.4s\n"
-
-        "ldr q15, [inptr15], #0x10\n"
-        "prfm pldl1keep, [inptr7, #196]\n"
-        "fsub v17.4s, v5.4s, v6.4s\n"
-
-        "ldr q0, [%x[inptr]], #0x10\n"
-        "prfm pldl1keep, [inptr8, #196]\n"
-        "fadd v16.4s, v16.4s, v6.4s\n"
-
-        "ldr q1, [inptr1], #0x10\n"
-        "prfm pldl1keep, [inptr9, #196]\n"
-        "fsub v17.4s, v17.4s, v7.4s\n"
-
-        "ldr q2, [inptr2], #0x10\n"
-        "prfm pldl1keep, [inptr10, #196]\n"
-        "str q16, [outptr2], #0x10\n"
-
-        "ldr q3, [inptr3], #0x10\n"
-        "prfm pldl1keep, [inptr11, #196]\n"
-        "str q17, [outptr3], #0x10\n"
-
-        "ldr q4, [inptr4], #0x10\n"
-        "prfm pldl1keep, [inptr12, #196]\n"
-        "fadd v16.4s, v8.4s, v9.4s\n"
-
-        "ldr q5, [inptr5], #0x10\n"
-        "prfm pldl1keep, [inptr13, #196]\n"
-        "fsub v17.4s, v9.4s, v10.4s\n"
-
-        "ldr q6, [inptr6], #0x10\n"
-        "prfm pldl1keep, [inptr14, #196]\n"
-        "fadd v16.4s, v16.4s, v10.4s\n"
-
-        "ldr q7, [inptr7], #0x10\n"
-        "prfm pldl1keep, [inptr15, #196]\n"
-        "fsub v17.4s, v17.4s, v11.4s\n"
-
-        "str q16, [outptr4], #0x10\n"
-        "fadd v16.4s, v12.4s, v13.4s\n"
-        "fsub v18.4s, v13.4s, v14.4s\n"
-
-        "str q17, [outptr5], #0x10\n"
-        "fadd v16.4s, v16.4s, v14.4s\n"
-        "fsub v18.4s, v18.4s, v15.4s\n"
-
-        "str q16, [outptr6], #0x10\n"
-        "subs column, column, #0x4\n"
-
-        "str q18, [outptr7], #0x10\n"
-        "bne 2b\n"
-
-      "3:"  // Tail
-        "ldr q8, [inptr8], #0x10\n"
-        "prfm pldl1keep, [%x[inptr], #196]\n"
-        "fadd v16.4s, v0.4s, v1.4s\n"
-
-        "ldr q9, [inptr9], #0x10\n"
-        "prfm pldl1keep, [inptr1, #196]\n"
-        "fsub v17.4s, v1.4s, v2.4s\n"
-
-        "ldr q10, [inptr10], #0x10\n"
-        "prfm pldl1keep, [inptr2, #196]\n"
-        "fadd v16.4s, v16.4s, v2.4s\n"
-
-        "ldr q11, [inptr11], #0x10\n"
-        "prfm pldl1keep, [inptr3, #196]\n"
-        "fsub v17.4s, v17.4s, v3.4s\n"
-
-        "ldr q12, [inptr12], #0x10\n"
-        "prfm pldl1keep, [inptr4, #196]\n"
-        "str q16, [%x[outptr]], #0x10\n"
-
-        "ldr q13, [inptr13], #0x10\n"
-        "prfm pldl1keep, [inptr5, #196]\n"
-        "str q17, [outptr1], #0x10\n"
-
-        "ldr q14, [inptr14], #0x10\n"
-        "prfm pldl1keep, [inptr6, #196]\n"
-        "fadd v16.4s, v4.4s, v5.4s\n"
-
-        "ldr q15, [inptr15], #0x10\n"
-        "prfm pldl1keep, [inptr7, #196]\n"
-        "fsub v17.4s, v5.4s, v6.4s\n"
-
-        "prfm pldl1keep, [inptr8, #196]\n"
-        "prfm pldl1keep, [inptr9, #196]\n"
-        "fadd v16.4s, v16.4s, v6.4s\n"
-
-        "prfm pldl1keep, [inptr10, #196]\n"
-        "prfm pldl1keep, [inptr11, #196]\n"
-        "fsub v17.4s, v17.4s, v7.4s\n"
-
-        "prfm pldl1keep, [inptr12, #196]\n"
-        "prfm pldl1keep, [inptr13, #196]\n"
-        "str q16, [outptr2], #0x10\n"
-
-        "prfm pldl1keep, [inptr14, #196]\n"
-        "prfm pldl1keep, [inptr15, #196]\n"
-        "str q17, [outptr3], #0x10\n"
-
-        "fadd v16.4s, v8.4s, v9.4s\n"
-        "fsub v17.4s, v9.4s, v10.4s\n"
-
-        "fadd v16.4s, v16.4s, v10.4s\n"
-        "fsub v17.4s, v17.4s, v11.4s\n"
-
-        "str q16, [outptr4], #0x10\n"
-        "fadd v16.4s, v12.4s, v13.4s\n"
-        "fsub v18.4s, v13.4s, v14.4s\n"
-
-        "str q17, [outptr5], #0x10\n"
-        "fadd v16.4s, v16.4s, v14.4s\n"
-        "fsub v18.4s, v18.4s, v15.4s\n"
-
-        "str q16, [outptr6], #0x10\n"
-        "str q18, [outptr7], #0x10\n"
-
-      "subs %x[row], %x[row], #0x1\n"
-      "bne 1b\n"
-
-    ".unreq inptr1\n"
-    ".unreq inptr2\n"
-    ".unreq inptr3\n"
-    ".unreq inptr4\n"
-    ".unreq inptr5\n"
-    ".unreq inptr6\n"
-    ".unreq inptr7\n"
-    ".unreq inptr8\n"
-    ".unreq inptr9\n"
-    ".unreq inptr10\n"
-    ".unreq inptr11\n"
-    ".unreq inptr12\n"
-    ".unreq inptr13\n"
-    ".unreq inptr14\n"
-    ".unreq inptr15\n"
-    ".unreq outptr1\n"
-    ".unreq outptr2\n"
-    ".unreq outptr3\n"
-    ".unreq outptr4\n"
-    ".unreq outptr5\n"
-    ".unreq outptr6\n"
-    ".unreq outptr7\n"
-
-    : [row] "+r" (row),
-      [inptr] "+r" (inptr),
-      [outptr] "+r" (outptr)
-    : [n_channels] "r" (n_channels),
-      [sizeof_float] "i" (sizeof(float))
-    : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
-      "q12", "q13", "q14", "q15", "q16", "q17", "x0", "x1", "x2", "x3", "x4",
-      "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15",
-      "x16", "x17", "x18", "x19", "x20", "x21", "x22", "cc", "memory"
-  );
-}
-
-/*****************************************************************************/
-// Compute ZFZ^T specializations
-
-template <>
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::compute_zfzT<false, false, 0>(
-    const Tensor4DShape &output_shape,
-    float* const output, const float* const input
-) {
-  const int tile_M = output_shape.n_rows / 2;
-  const int tile_N = output_shape.n_cols / 2;
-  int batch = output_shape.n_batches;
-  float *outptr = output;
-  const float *inptr = input;
-
-  asm volatile (
-    // Compute input pointers
-    "inptr1 .req x0\n"
-    "inptr2 .req x1\n"
-    "inptr3 .req x2\n"
-    "inptr4 .req x3\n"
-    "inptr5 .req x4\n"
-    "inptr6 .req x5\n"
-    "inptr7 .req x6\n"
-    "inptr8 .req x7\n"
-
-    "mstride .req x8\n"
-    "mul mstride, %x[tile_M], %x[tile_N]\n"
-    "mul mstride, mstride, %x[n_channels]\n"
-    "lsl mstride, mstride, #2\n"  // * sizeof(float)
-
-    "add inptr1, %[inptr], mstride\n"
-    "add inptr2, inptr1, mstride\n"
-    "add inptr3, inptr2, mstride\n"
-    "add inptr4, inptr3, mstride\n"
-    "add inptr5, inptr4, mstride\n"
-    "add inptr6, inptr5, mstride\n"
-    "add inptr7, inptr6, mstride\n"
-    "add inptr8, inptr7, mstride\n"
-
-    ".unreq mstride\n"
-
-    // Compute initial output pointers
-    "outptr01 .req  x8\n"
-    "outptr10 .req  x9\n"
-    "outptr11 .req x10\n"
-
-    "add outptr01, %x[outptr], %x[n_channels], LSL #2\n"
-    "add outptr10, %x[outptr], %x[row_stride], LSL #2\n"
-    "add outptr11,   outptr10, %x[n_channels], LSL #2\n"
-
-    "tile_i  .req x11\n"
-    "tile_j  .req x12\n"
-    "channel .req x13\n"
-
-    "1:"  // Loop over batches
-      "mov tile_i, %x[tile_M]\n"
-
-      "2:"  // Loop over rows of output tiles
-        "mov tile_j, %x[tile_N]\n"
-
-        "3:"  // Loop over columns of output tiles
-          "ldr q0, [%x[inptr]], #0x10\n"
-          "ldr q2, [inptr2], #0x10\n"
-          "subs channel, %x[n_channels], #0x4\n"
-
-          "ldr q1, [inptr1], #0x10\n"
-          "ldr q3, [inptr3], #0x10\n"
-          "beq 6f\n"
-
-          "4:"
-            "ldr q4, [inptr4], #0x10\n"
-            "ldr q5, [inptr5], #0x10\n"
-            "fadd v16.4s, v0.4s, v2.4s\n"
-
-            "ldr q6, [inptr6], #0x10\n"
-            "ldr q7, [inptr7], #0x10\n"
-            "fadd v17.4s, v1.4s, v3.4s\n"
-
-            "ldr q8, [%x[inptr]], #0x10\n"
-            "ldr q10, [inptr2], #0x10\n"
-            "fadd v16.4s, v16.4s, v4.4s\n"
-
-            "ldr q9, [inptr1], #0x10\n"
-            "ldr q11, [inptr3], #0x10\n"
-            "fadd v17.4s, v17.4s, v5.4s\n"
-
-            "str q16, [%x[outptr]], #0x10\n"
-            "prfm pldl1strm, [%x[inptr], #196]\n"
-            "fsub v18.4s, v2.4s, v4.4s\n"
-
-            "str q17, [outptr01], #0x10\n"
-            "prfm pldl1strm, [inptr2, #196]\n"
-            "fsub v19.4s, v3.4s, v5.4s\n"
-
-            "prfm pldl1strm, [inptr1, #196]\n"
-            "prfm pldl1strm, [inptr3, #196]\n"
-            "fsub v18.4s, v18.4s, v6.4s\n"
-
-            "prfm pldl1strm, [inptr4, #196]\n"
-            "prfm pldl1strm, [inptr5, #196]\n"
-            "fsub v19.4s, v19.4s, v7.4s\n"
-
-            "str q18, [outptr10], #0x10\n"
-            "prfm pldl1strm, [inptr6, #196]\n"
-            "prfm pldl1strm, [inptr7, #196]\n"
-
-            "subs channel, channel, #0x4\n"
-
-            "str q19, [outptr11], #0x10\n"
-            "beq 6f\n"  // Branch to tail
-
-            "ldr q12, [inptr4], #0x10\n"
-            "ldr q13, [inptr5], #0x10\n"
-            "fadd v16.4s, v8.4s, v10.4s\n"
-
-            "ldr q14, [inptr6], #0x10\n"
-            "ldr q15, [inptr7], #0x10\n"
-            "fadd v17.4s, v9.4s, v11.4s\n"
-
-            "ldr q0, [%x[inptr]], #0x10\n"
-            "ldr q2, [inptr2], #0x10\n"
-            "fadd v16.4s, v16.4s, v12.4s\n"
-
-            "ldr q1, [inptr1], #0x10\n"
-            "ldr q3, [inptr3], #0x10\n"
-            "fadd v17.4s, v17.4s, v13.4s\n"
-
-            "str q16, [%x[outptr]], #0x10\n"
-            "prfm pldl1strm, [%x[inptr], #196]\n"
-            "fsub v18.4s, v10.4s, v12.4s\n"
-
-            "str q17, [outptr01], #0x10\n"
-            "prfm pldl1strm, [inptr2, #196]\n"
-            "fsub v19.4s, v11.4s, v13.4s\n"
-
-            "prfm pldl1strm, [inptr1, #196]\n"
-            "prfm pldl1strm, [inptr3, #196]\n"
-            "fsub v18.4s, v18.4s, v14.4s\n"
-
-            "prfm pldl1strm, [inptr4, #196]\n"
-            "prfm pldl1strm, [inptr5, #196]\n"
-            "fsub v19.4s, v19.4s, v15.4s\n"
-
-            "str q18, [outptr10], #0x10\n"
-            "prfm pldl1strm, [inptr6, #196]\n"
-            "prfm pldl1strm, [inptr7, #196]\n"
-
-            "subs channel, channel, #0x4\n"
-
-            "str q19, [outptr11], #0x10\n"
-            "bne 4b\n"  // Continue loop
-
-          "5:"  // Tail
-            "ldr q12, [inptr4], #0x10\n"
-            "ldr q13, [inptr5], #0x10\n"
-            "fadd v16.4s, v8.4s, v10.4s\n"
-
-            "ldr q14, [inptr6], #0x10\n"
-            "ldr q15, [inptr7], #0x10\n"
-            "fadd v17.4s, v9.4s, v11.4s\n"
-
-            "fadd v16.4s, v16.4s, v12.4s\n"
-
-            "fadd v17.4s, v17.4s, v13.4s\n"
-
-            "str q16, [%x[outptr]], #0x10\n"
-            "fsub v18.4s, v10.4s, v12.4s\n"
-            "fsub v19.4s, v11.4s, v13.4s\n"
-
-            "str q17, [outptr01], #0x10\n"
-            "fsub v18.4s, v18.4s, v14.4s\n"
-            "fsub v19.4s, v19.4s, v15.4s\n"
-
-            "str q18, [outptr10], #0x10\n"
-            "str q19, [outptr11], #0x10\n"
-            "b 7f\n"
-
-          "6:"  // Tail
-            "ldr q4, [inptr4], #0x10\n"
-            "ldr q5, [inptr5], #0x10\n"
-            "fadd v16.4s, v0.4s, v2.4s\n"
-
-            "ldr q6, [inptr6], #0x10\n"
-            "ldr q7, [inptr7], #0x10\n"
-            "fadd v17.4s, v1.4s, v3.4s\n"
-
-            "fadd v16.4s, v16.4s, v4.4s\n"
-
-            "fadd v17.4s, v17.4s, v5.4s\n"
-
-            "str q16, [%x[outptr]], #0x10\n"
-            "fsub v18.4s, v2.4s, v4.4s\n"
-            "fsub v19.4s, v3.4s, v5.4s\n"
-
-            "str q17, [outptr01], #0x10\n"
-            "fsub v18.4s, v18.4s, v6.4s\n"
-            "fsub v19.4s, v19.4s, v7.4s\n"
-
-            "str q18, [outptr10], #0x10\n"
-            "str q19, [outptr11], #0x10\n"
-
-          "7:"
-            "add %x[outptr], %x[outptr], %x[n_channels], LSL #2\n"
-            "add outptr01, outptr01, %x[n_channels], LSL #2\n"
-            "add outptr10, outptr10, %x[n_channels], LSL #2\n"
-            "add outptr11, outptr11, %x[n_channels], LSL #2\n"
-
-            "subs tile_j, tile_j, #1\n"
-            "bne 3b\n"
-
-        // Progress the output pointers to the new row
-        "add %x[outptr], %x[outptr], %x[row_stride], LSL #2\n"
-        "add   outptr01,   outptr01, %x[row_stride], LSL #2\n"
-        "add   outptr10,   outptr10, %x[row_stride], LSL #2\n"
-        "add   outptr11,   outptr11, %x[row_stride], LSL #2\n"
-
-        "subs tile_i, tile_i, #1\n"
-        "bne 2b\n"
-
-      "subs %[batch], %[batch], #1\n"
-      "bne 1b\n"
-      "5:"
-
-    ".unreq inptr1\n"
-    ".unreq inptr2\n"
-    ".unreq inptr3\n"
-    ".unreq inptr4\n"
-    ".unreq inptr5\n"
-    ".unreq inptr6\n"
-    ".unreq inptr7\n"
-    ".unreq inptr8\n"
-    ".unreq outptr01\n"
-    ".unreq outptr10\n"
-    ".unreq outptr11\n"
-    : [batch] "+r" (batch),
-      [outptr] "+r" (outptr),
-      [inptr] "+r" (inptr)
-    : [tile_M] "r" (tile_M),
-      [tile_N] "r" (tile_N),
-      [n_channels] "r" (output_shape.n_channels),
-      [row_stride] "r" (output_shape.n_cols * output_shape.n_channels)
-    : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11",
-      "x12", "x13", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
-      "cc", "memory"
-  );
-}
-/*****************************************************************************/
-
-/*****************************************************************************/
-template <>
-inline void winograd::Winograd2x2_3x3GemmOutput_TwoStage<float>::execute(
-    const Tensor4DShape &output_shape,
-    float* const matrices[16], float* const output
-) {
-  // profiler prof;
-
-  // Allocate memory for the intermediate matrices
-  const int tile_M = iceildiv(output_shape.n_rows, 2);
-  const int tile_N = iceildiv(output_shape.n_cols, 2);
-  const int n_rows = output_shape.n_batches * tile_M * tile_N;
-  const int n_channels = output_shape.n_channels;
-  float* matrices_zf = reinterpret_cast<float*>(
-    calloc(8 * n_rows * n_channels, sizeof(float))
-  );
-  
-  // Perform the first stage transform, computing ZF.
-  const auto f_compute_zf = [&] () {
-    switch (n_channels % 4) {
-      case 0:
-        compute_zf<0>(n_rows, n_channels, matrices_zf, matrices);
-        break;
-      case 1:
-        compute_zf<1>(n_rows, n_channels, matrices_zf, matrices);
-        break;
-      case 2:
-        compute_zf<2>(n_rows, n_channels, matrices_zf, matrices);
-        break;
-      case 3:
-        compute_zf<3>(n_rows, n_channels, matrices_zf, matrices);
-    };
-  };
-  // prof("Compute ZF", f_compute_zf, 16 * n_rows * n_channels * sizeof(float), 0, 8 * n_rows * n_channels * sizeof(float));
-  f_compute_zf();
-  
-  // Perform the second stage transform, finishing Z F Z^T - variable dispatch
-  // based on size of the output and the channel tail.
-  const auto f_compute_zfzT = [&] () {
-    if (output_shape.n_rows % 2 && output_shape.n_cols % 2) {
-      constexpr bool tail_M = true, tail_N = true;
-      switch (n_channels % 4) {
-        case 0:
-          compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
-          break;
-        case 1:
-          compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
-          break;
-        case 2:
-          compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
-          break;
-        case 3:
-          compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
-      }
-    } else if (output_shape.n_rows % 2) {
-      constexpr bool tail_M = true, tail_N = false;
-      switch (n_channels % 4) {
-        case 0:
-          compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
-          break;
-        case 1:
-          compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
-          break;
-        case 2:
-          compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
-          break;
-        case 3:
-          compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
-      }
-    } else if (output_shape.n_cols % 2) {
-      constexpr bool tail_M = false, tail_N = true;
-      switch (n_channels % 4) {
-        case 0:
-          compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
-          break;
-        case 1:
-          compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
-          break;
-        case 2:
-          compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
-          break;
-        case 3:
-          compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
-      }
-    } else {
-      constexpr bool tail_M = false, tail_N = false;
-      switch (n_channels % 4) {
-        case 0:
-          compute_zfzT<tail_M, tail_N, 0>(output_shape, output, matrices_zf);
-          break;
-        case 1:
-          compute_zfzT<tail_M, tail_N, 1>(output_shape, output, matrices_zf);
-          break;
-        case 2:
-          compute_zfzT<tail_M, tail_N, 2>(output_shape, output, matrices_zf);
-          break;
-        case 3:
-          compute_zfzT<tail_M, tail_N, 3>(output_shape, output, matrices_zf);
-      }
-    }
-  };
-  // prof("Compute ZFZT", f_compute_zfzT, 8 * n_rows * n_channels * sizeof(float), 0, 4 * n_rows * n_channels * sizeof(float));
-  f_compute_zfzT();
-
-  free(reinterpret_cast<void*>(matrices_zf));
-}
-/*****************************************************************************/
-
-#endif  // __aarch64__
diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp
new file mode 100644
index 0000000..e7907d1
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp
@@ -0,0 +1,238 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "transforms/output.hpp"
+#include "winograd_gemm.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+using Transform = WinogradGEMM<2, 2, 3, 3>::OutputTransform<float>;
+
+template <>
+template <>
+int Transform::ops_performed(const Tensor4DShape &shape)
+{
+  // NOTE: Cost in FLOPs rather than instructions or uops.
+  const int tile_M = iceildiv(shape.n_rows, 2);
+  const int tile_N = iceildiv(shape.n_cols, 2);
+  return 24 * tile_M * tile_N * shape.n_channels;
+}
+
+/* F(2x2, 3x3) constructs 2x2 output tiles from a 3x3 convolution. Since we use
+ * enough tiles to cover the output space each output tile may contain 0 or 1
+ * padded values to the right and bottom columns or rows of the tile, e.g.:
+ *
+ *      ___     ___
+ *     |   |   |  X|
+ *     |___|   |__X|
+ *
+ *      ___     ___
+ *     |   |   |  X|
+ *     |X_X|   |X_X|
+ *
+ *
+ * We provide a specialised output transform for each of these instances.
+ * Consequently we below construct an array of the various padding options, the
+ * array contains pointers to the specific implementations.
+ */
+template <>
+template <>
+template <int pad_bottom, int pad_right>
+void Transform::process_tile(
+  const int n_channels,
+  const float* const matrix_base,
+  const int matrix_stride,
+  float* const output,
+  const int output_row_stride,
+  const int output_col_stride
+)
+{
+  constexpr int cells_i = 2 - pad_bottom;
+  constexpr int cells_j = 2 - pad_right;
+
+  // Construct a map to the output cells
+  float *outptrs[cells_i][cells_j];
+  for (int i = 0; i < cells_i; i++)
+  {
+    for (int j = 0; j < cells_j; j++)
+    {
+      outptrs[i][j] = output + i*output_row_stride + j*output_col_stride;
+    }
+  }
+  const float *inptr = matrix_base;
+
+  // For each channel of the output
+  int channels_remaining = n_channels;
+#ifdef __aarch64__
+  for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Matrices used and computed during this transform
+    float32x4_t F[4][4], FZ[4][2], f[2][2];
+
+    // Read a 4x4 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 4; i++)
+    {
+      for (int j = 0; j < 4; j++, m++)
+      {
+        F[i][j] = vld1q_f32(inptr + m*matrix_stride);
+      }
+    }
+    inptr += 4;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 4; i++)
+    {
+      // FZ[i][0] =  F[i][0] + F[i][1] + F[i][2];
+      FZ[i][0] = vaddq_f32(vaddq_f32(F[i][0], F[i][1]), F[i][2]);
+
+      // FZ[i][1] =  F[i][1] - F[i][2] - F[i][3];
+      FZ[i][1] = vsubq_f32(vsubq_f32(F[i][1], F[i][2]), F[i][3]);
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 2; j++)
+    {
+      // f[0][j] =  FZ[0][j] + FZ[1][j] + FZ[2][j];
+      f[0][j] = vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), FZ[2][j]);
+
+      // f[1][j] =  FZ[1][j] - FZ[2][j] - FZ[3][j];
+      f[1][j] = vsubq_f32(vsubq_f32(FZ[1][j], FZ[2][j]), FZ[3][j]);
+    }
+
+    // Write out the output tile
+    for (int i = 0; i < cells_i; i++)
+    {
+      for (int j = 0; j < cells_j; j++)
+      {
+        vst1q_f32(outptrs[i][j], f[i][j]);
+        outptrs[i][j] += 4;
+      }
+    }
+  }
+#endif  // __aarch64__
+#ifdef __arm_any__
+  for (; channels_remaining >= 2; channels_remaining -= 2)
+  {
+    // Matrices used and computed during this transform
+    float32x2_t F[4][4], FZ[4][2], f[2][2];
+
+    // Read a 4x4 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 4; i++)
+    {
+      for (int j = 0; j < 4; j++, m++)
+      {
+        F[i][j] = vld1_f32(inptr + m*matrix_stride);
+      }
+    }
+    inptr += 2;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 4; i++)
+    {
+      // FZ[i][0] =  F[i][0] + F[i][1] + F[i][2];
+      FZ[i][0] = vadd_f32(vadd_f32(F[i][0], F[i][1]), F[i][2]);
+
+      // FZ[i][1] =  F[i][1] - F[i][2] - F[i][3];
+      FZ[i][1] = vsub_f32(vsub_f32(F[i][1], F[i][2]), F[i][3]);
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 2; j++)
+    {
+      // f[0][j] =  FZ[0][j] + FZ[1][j] + FZ[2][j];
+      f[0][j] = vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), FZ[2][j]);
+
+      // f[1][j] =  FZ[1][j] - FZ[2][j] - FZ[3][j];
+      f[1][j] = vsub_f32(vsub_f32(FZ[1][j], FZ[2][j]), FZ[3][j]);
+    }
+
+    // Write out the output tile
+    for (int i = 0; i < cells_i; i++)
+    {
+      for (int j = 0; j < cells_j; j++)
+      {
+        vst1_f32(outptrs[i][j], f[i][j]);
+        outptrs[i][j] += 2;
+      }
+    }
+  }
+#endif  // __arm_any__
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Matrices used and computed during this transform
+    float F[4][4], FZ[4][2], f[2][2];
+
+    // Read a 4x4 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 4; i++)
+    {
+      for (int j = 0; j < 4; j++, m++)
+      {
+        F[i][j] = *(inptr + m*matrix_stride);
+      }
+    }
+    inptr++;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 4; i++)
+    {
+      FZ[i][0] =  F[i][0] + F[i][1] + F[i][2];
+      FZ[i][1] =  F[i][1] - F[i][2] - F[i][3];
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 2; j++)
+    {
+      f[0][j] =  FZ[0][j] + FZ[1][j] + FZ[2][j];
+      f[1][j] =  FZ[1][j] - FZ[2][j] - FZ[3][j];
+    }
+
+    // Write out the output tile
+    for (int i = 0; i < cells_i; i++)
+    {
+      for (int j = 0; j < cells_j; j++)
+      {
+        *(outptrs[i][j]++) = f[i][j];
+      }
+    }
+  }
+}
+
+template <>
+template <>
+const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] =
+{
+  {
+    Transform::template process_tile<0, 0>,  // No padding
+    Transform::template process_tile<0, 1>,  // Right padding
+  },
+  {
+    Transform::template process_tile<1, 0>,  // Bottom padding
+    Transform::template process_tile<1, 1>,  // Bottom and right padding
+  }
+};
+
+template struct WinogradGEMM<2, 2, 3, 3>::OutputTransform<float>;
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
new file mode 100644
index 0000000..483e5c1
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
@@ -0,0 +1,299 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "transforms/output.hpp"
+#include "winograd_gemm.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+using Transform = WinogradGEMM<4, 4, 3, 3>::OutputTransform<float>;
+
+template <>
+template <>
+int Transform::ops_performed(const Tensor4DShape &shape)
+{
+  // NOTE: Cost in FLOPs rather than instructions or uops.
+  const int tile_M = iceildiv(shape.n_rows, 4);
+  const int tile_N = iceildiv(shape.n_cols, 4);
+  return 170 * tile_M * tile_N * shape.n_channels;
+}
+
+// Instantiate cost methods
+template int Transform::ops_performed(const Tensor4DShape&);
+
+/* F(4x4, 3x3) constructs 4x4 output tiles from a 3x3 convolution. Since we use
+ * enough tiles to cover the output space each output tile may contain up to 3
+ * padded values to the right and bottom columns or rows of the tile, e.g.:
+*
+*      ________    ________   ________   ________
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |_______|   |______X|  |____X_X|  |__X_X_X|
+*
+*      ________    ________   ________   ________
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |X_X_X_X|   |X_X_X_X|  |X_X_X_X|  |X_X_X_X|
+*
+*      ________    ________   ________   ________
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |X X X X|   |X X X X|  |X X X X|  |X X X X|
+*     |X_X_X_X|   |X_X_X_X|  |X_X_X_X|  |X_X_X_X|
+*
+*      ________    ________   ________   ________
+*     |       |   |      X|  |    X X|  |  X X X|
+*     |X X X X|   |X X X X|  |X X X X|  |X X X X|
+*     |X X X X|   |X X X X|  |X X X X|  |X X X X|
+*     |X_X_X_X|   |X_X_X_X|  |X_X_X_X|  |X_X_X_X|
+*
+*
+* We provide a specialised output transform for each of these instances.
+*/
+template <>
+template <>
+template <int pad_bottom, int pad_right>
+void Transform::process_tile(
+  const int n_channels,
+  const float* const matrix_base,
+  const int matrix_stride,
+  float* const output,
+  const int output_row_stride,
+  const int output_col_stride
+)
+{
+  constexpr int cells_i = 4 - pad_bottom;
+  constexpr int cells_j = 4 - pad_right;
+
+  // Construct a map to the output cells
+  float *outptrs[cells_i][cells_j];
+  for (int i = 0; i < cells_i; i++)
+  {
+    for (int j = 0; j < cells_j; j++)
+    {
+      outptrs[i][j] = output + i*output_row_stride + j*output_col_stride;
+    }
+  }
+  const float *inptr = matrix_base;
+
+  // For each channel of the output
+  int channels_remaining = n_channels;
+#ifdef __aarch64__
+  for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Matrices used and computed during this transform
+    float32x4_t F[6][6], FZ[6][4], f[4][4];
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        F[i][j] = vld1q_f32(inptr + m*matrix_stride);
+      }
+    }
+    inptr += 4;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 6; i++)
+    {
+      // FZ[i][0] =  1*F[i][0] +  1*F[i][1] +  1*F[i][2] +  1*F[i][3] +  1*F[i][4];
+      FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]);
+
+      // FZ[i][1] =  1*F[i][1] + -1*F[i][2] +  2*F[i][3] + -2*F[i][4];
+      FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f);
+
+      // FZ[i][2] =  1*F[i][1] +  1*F[i][2] +  4*F[i][3] +  4*F[i][4];
+      FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f);
+
+      // FZ[i][3] =  1*F[i][1] + -1*F[i][2] +  8*F[i][3] + -8*F[i][4] +  1*F[i][5];
+      FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 4; j++)
+    {
+      // f[0][j] =  1*FZ[0][j] +  1*FZ[1][j] +  1*FZ[2][j] +  1*FZ[3][j] +  1*FZ[4][j];
+      f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+      // f[1][j] =  1*FZ[1][j] + -1*FZ[2][j] +  2*FZ[3][j] + -2*FZ[4][j];
+      f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f);
+
+      // f[2][j] =  1*FZ[1][j] +  1*FZ[2][j] +  4*FZ[3][j] +  4*FZ[4][j];
+      f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f);
+
+      // f[3][j] =  1*FZ[1][j] + -1*FZ[2][j] +  8*FZ[3][j] + -8*FZ[4][j] +  1*FZ[5][j];
+      f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
+    }
+
+    // Write out the output tile
+    for (int i = 0; i < cells_i; i++)
+    {
+      for (int j = 0; j < cells_j; j++)
+      {
+        vst1q_f32(outptrs[i][j], f[i][j]);
+        outptrs[i][j] += 4;
+      }
+    }
+  }
+#endif  // __aarch64__
+#ifdef __arm_any__
+  for (; channels_remaining >= 2; channels_remaining -= 2)
+  {
+    // Matrices used and computed during this transform
+    float32x2_t F[6][6], FZ[6][4], f[4][4];
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        F[i][j] = vld1_f32(inptr + m*matrix_stride);
+      }
+    }
+    inptr += 2;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 6; i++)
+    {
+      // FZ[i][0] =  1*F[i][0] +  1*F[i][1] +  1*F[i][2] +  1*F[i][3] +  1*F[i][4];
+      FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]);
+
+      // FZ[i][1] =  1*F[i][1] + -1*F[i][2] +  2*F[i][3] + -2*F[i][4];
+      FZ[i][1] = vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f);
+
+      // FZ[i][2] =  1*F[i][1] +  1*F[i][2] +  4*F[i][3] +  4*F[i][4];
+      FZ[i][2] = vmla_n_f32(vadd_f32(F[i][1], F[i][2]), vadd_f32(F[i][3], F[i][4]), 4.0f);
+
+      // FZ[i][3] =  1*F[i][1] + -1*F[i][2] +  8*F[i][3] + -8*F[i][4] +  1*F[i][5];
+      FZ[i][3] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 4; j++)
+    {
+      // f[0][j] =  1*FZ[0][j] +  1*FZ[1][j] +  1*FZ[2][j] +  1*FZ[3][j] +  1*FZ[4][j];
+      f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+      // f[1][j] =  1*FZ[1][j] + -1*FZ[2][j] +  2*FZ[3][j] + -2*FZ[4][j];
+      f[1][j] = vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f);
+
+      // f[2][j] =  1*FZ[1][j] +  1*FZ[2][j] +  4*FZ[3][j] +  4*FZ[4][j];
+      f[2][j] = vmla_n_f32(vadd_f32(FZ[1][j], FZ[2][j]), vadd_f32(FZ[3][j], FZ[4][j]), 4.0f);
+
+      // f[3][j] =  1*FZ[1][j] + -1*FZ[2][j] +  8*FZ[3][j] + -8*FZ[4][j] +  1*FZ[5][j];
+      f[3][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
+    }
+
+    // Write out the output tile
+    for (int i = 0; i < cells_i; i++)
+    {
+      for (int j = 0; j < cells_j; j++)
+      {
+        vst1_f32(outptrs[i][j], f[i][j]);
+        outptrs[i][j] += 2;
+      }
+    }
+  }
+#endif
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Matrices used and computed during this transform
+    float F[6][6], FZ[6][4], f[4][4];
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        F[i][j] = *(inptr + m*matrix_stride);
+      }
+    }
+    inptr++;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 6; i++)
+    {
+      FZ[i][0] =  1*F[i][0] +  1*F[i][1] +  1*F[i][2] +  1*F[i][3] +  1*F[i][4];
+      FZ[i][1] =  1*F[i][1] + -1*F[i][2] +  2*F[i][3] + -2*F[i][4];
+      FZ[i][2] =  1*F[i][1] +  1*F[i][2] +  4*F[i][3] +  4*F[i][4];
+      FZ[i][3] =  1*F[i][1] + -1*F[i][2] +  8*F[i][3] + -8*F[i][4] +  1*F[i][5];
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 4; j++)
+    {
+      f[0][j] =  1*FZ[0][j] +  1*FZ[1][j] +  1*FZ[2][j] +  1*FZ[3][j] +  1*FZ[4][j];
+      f[1][j] =  1*FZ[1][j] + -1*FZ[2][j] +  2*FZ[3][j] + -2*FZ[4][j];
+      f[2][j] =  1*FZ[1][j] +  1*FZ[2][j] +  4*FZ[3][j] +  4*FZ[4][j];
+      f[3][j] =  1*FZ[1][j] + -1*FZ[2][j] +  8*FZ[3][j] + -8*FZ[4][j] +  1*FZ[5][j];
+    }
+
+    // Write out the output tile
+    for (int i = 0; i < cells_i; i++)
+    {
+      for (int j = 0; j < cells_j; j++)
+      {
+        *(outptrs[i][j]++) = f[i][j];
+      }
+    }
+  }
+}
+
+template <>
+template <>
+const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] =
+{
+  {
+    Transform::template process_tile<0, 0>,
+    Transform::template process_tile<0, 1>,
+    Transform::template process_tile<0, 2>,
+    Transform::template process_tile<0, 3>,
+  },
+  {
+    Transform::template process_tile<1, 0>,
+    Transform::template process_tile<1, 1>,
+    Transform::template process_tile<1, 2>,
+    Transform::template process_tile<1, 3>,
+  },
+  {
+    Transform::template process_tile<2, 0>,
+    Transform::template process_tile<2, 1>,
+    Transform::template process_tile<2, 2>,
+    Transform::template process_tile<2, 3>,
+  },
+  {
+    Transform::template process_tile<3, 0>,
+    Transform::template process_tile<3, 1>,
+    Transform::template process_tile<3, 2>,
+    Transform::template process_tile<3, 3>,
+  }
+};
+
+template struct WinogradGEMM<4, 4, 3, 3>::OutputTransform<float>;
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp
new file mode 100644
index 0000000..c0b2824
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/weights_2x2_3x3_fp32.cpp
@@ -0,0 +1,228 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm.hpp"
+#include "winograd_gemm.hpp"
+#include "transforms/kernel.hpp"
+
+namespace winograd
+{
+  template <>
+  template <>
+  void WinogradGEMM<2, 2, 3, 3>::WeightsTransform<float>::execute(
+    const int n_output_channels,
+    const int n_input_channels,
+    const float* const input,
+    float* const output,
+    const int matrix_stride,
+    const int matrix_row_stride
+  )
+  {
+    constexpr int inner_tile_i = 4;
+    constexpr int inner_tile_j = 4;
+
+    // Get pointers to each cell of the weight tensor
+    const auto weight_col_stride = n_input_channels * n_output_channels;
+    const auto weight_row_stride = 3 * weight_col_stride;
+    const float *inptrs[3][3];
+    for (int i = 0; i < 3; i++)
+    {
+      for (int j = 0; j < 3; j++)
+      {
+        inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride;
+      }
+    }
+
+    // For each input channel
+    for (int ic = 0; ic < n_input_channels; ic++)
+    {
+      float *outptr = output + ic * matrix_row_stride;
+
+      // For each output channel
+      int channels_remaining = n_output_channels;
+#ifdef __aarch64__
+      for (; channels_remaining >= 4; channels_remaining -= 4)
+      {
+        // Matrices used and computed in this kernel
+        float32x4_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j];
+
+        // Read weights
+        for (int i = 0; i < 3; i++)
+        {
+          for (int j = 0; j < 3; j++)
+          {
+            w[i][j] = vld1q_f32(inptrs[i][j]);
+            inptrs[i][j] += 4;
+          }
+        }
+
+        // Compute the matrix W w
+        for (int j = 0; j < 3; j++)
+        {
+          Ww[0][j] = w[0][j];
+
+          // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]);
+          Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
+
+          // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]);
+          Ww[2][j] = vmulq_n_f32(vaddq_f32(vsubq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
+
+          Ww[3][j] = w[2][j];
+        }
+
+        // Compute V = W w WT
+        for (int i = 0; i < inner_tile_i; i++)
+        {
+          V[i][0] = Ww[i][0];
+
+          // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]);
+          V[i][1] = vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
+
+          // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]);
+          V[i][2] = vmulq_n_f32(vaddq_f32(vsubq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
+
+          V[i][3] = Ww[i][2];
+        }
+
+        // Store the transformed weights
+        for (int i = 0, m = 0; i < inner_tile_i; i++)
+        {
+          for (int j = 0; j < inner_tile_j; j++, m++)
+          {
+            vst1q_f32(outptr + m*matrix_stride, V[i][j]);
+          }
+        }
+        outptr += 4;
+      }
+#endif  // __aarch64__
+#ifdef __arm_any__
+      for (; channels_remaining >= 2; channels_remaining -= 2)
+      {
+        // Matrices used and computed in this kernel
+        float32x2_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j];
+
+        // Read weights
+        for (int i = 0; i < 3; i++)
+        {
+          for (int j = 0; j < 3; j++)
+          {
+            w[i][j] = vld1_f32(inptrs[i][j]);
+            inptrs[i][j] += 2;
+          }
+        }
+
+        // Compute the matrix W w
+        for (int j = 0; j < 3; j++)
+        {
+          Ww[0][j] = w[0][j];
+
+          // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]);
+          Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
+
+          // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]);
+          Ww[2][j] = vmul_n_f32(vadd_f32(vsub_f32(w[0][j], w[1][j]), w[2][j]), 0.5f);
+
+          Ww[3][j] = w[2][j];
+        }
+
+        // Compute V = W w WT
+        for (int i = 0; i < inner_tile_i; i++)
+        {
+          V[i][0] = Ww[i][0];
+
+          // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]);
+          V[i][1] = vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
+
+          // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]);
+          V[i][2] = vmul_n_f32(vadd_f32(vsub_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f);
+
+          V[i][3] = Ww[i][2];
+        }
+
+        // Store the transformed weights
+        for (int i = 0, m = 0; i < inner_tile_i; i++)
+        {
+          for (int j = 0; j < inner_tile_j; j++, m++)
+          {
+            vst1_f32(outptr + m*matrix_stride, V[i][j]);
+          }
+        }
+        outptr += 2;
+      }
+#endif  // __arm_any__
+      for (; channels_remaining; channels_remaining--)
+      {
+        // Matrices used and computed in this kernel
+        float w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j];
+
+        // Read weights
+        for (int i = 0; i < 3; i++)
+        {
+          for (int j = 0; j < 3; j++)
+          {
+            w[i][j] = *(inptrs[i][j]++);
+          }
+        }
+
+        // Compute the matrix W w
+        for (int j = 0; j < 3; j++)
+        {
+          Ww[0][j] = w[0][j];
+          Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]);
+          Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]);
+          Ww[3][j] = w[2][j];
+        }
+
+        // Compute V = W w WT
+        for (int i = 0; i < inner_tile_i; i++)
+        {
+          V[i][0] = Ww[i][0];
+          V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]);
+          V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]);
+          V[i][3] = Ww[i][2];
+        }
+
+        // Store the transformed weights
+        for (int i = 0, m = 0; i < inner_tile_i; i++)
+        {
+          for (int j = 0; j < inner_tile_j; j++, m++)
+          {
+            *(outptr + m*matrix_stride) = V[i][j];
+          }
+        }
+        outptr++;
+      }
+    }
+  }
+
+  template <>
+  template <>
+  int WinogradGEMM<2, 2, 3, 3>::WeightsTransform<float>::ops_performed(const KernelShape &shape)
+  {
+    const int channel_prod = shape.n_input_channels * shape.n_output_channels;
+    return 2 * 18 * channel_prod;
+  }
+
+  template struct WinogradGEMM<2, 2, 3, 3>::WeightsTransform<float>;
+}  // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp
new file mode 100644
index 0000000..de659c3
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/weights_4x4_3x3_fp32.cpp
@@ -0,0 +1,266 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm.hpp"
+#include "winograd_gemm.hpp"
+#include "transforms/kernel.hpp"
+
+namespace winograd
+{
+  /* Float implementation for kernel transform F(4x4, 3x3) */
+  template <>
+  template <>
+  void WinogradGEMM<4, 4, 3, 3>::WeightsTransform<float>::execute(
+    const int n_output_channels,
+    const int n_input_channels,
+    const float* const input,  // NOTE: Data in HWIO order
+    float* const output,
+    const int matrix_stride,
+    const int matrix_row_stride
+  )
+  {
+    // Get pointers to each cell of the weight tensor
+    const auto weight_col_stride = n_input_channels * n_output_channels;
+    const auto weight_row_stride = 3 * weight_col_stride;
+    const float *inptrs[3][3];
+    for (int i = 0; i < 3; i++)
+    {
+      for (int j = 0; j < 3; j++)
+      {
+        inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride;
+      }
+    }
+
+    // For each input channel
+    for (int ic = 0; ic < n_input_channels; ic++)
+    {
+      float *outptr = output + ic * matrix_row_stride;
+
+      // For each output channel
+      int channels_remaining = n_output_channels;
+#ifdef __aarch64__
+      for (; channels_remaining >= 4; channels_remaining -= 4)
+      {
+        // Matrices used and computed in this kernel
+        float32x4_t w[3][3], Ww[6][3], V[6][6];
+
+        // Read weights
+        for (int i = 0; i < 3; i++)
+        {
+          for (int j = 0; j < 3; j++)
+          {
+            w[i][j] = vld1q_f32(inptrs[i][j]);
+            inptrs[i][j] += 4;
+          }
+        }
+
+        // Compute the matrix W w
+        for (int j = 0; j < 3; j++)
+        {
+          // Ww[0][j] =  6*w[0][j];
+          Ww[0][j] = vmulq_n_f32(w[0][j], 6.0);
+
+          // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
+          Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0);
+
+          // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
+          Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0);
+
+          // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
+          Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
+
+          // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
+          Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
+
+          // Ww[5][j] = 24*w[2][j];
+          Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f);
+        }
+
+        // Compute V = W w WT
+        for (int i = 0; i < 6; i++)
+        {
+          const float recip576 = 1.0f / 576.0f;
+
+          // V[i][0] =  6*Ww[i][0];
+          V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576);
+
+          // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
+          V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
+
+          // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
+          V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
+
+          // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
+          V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
+
+          // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
+          V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
+
+          // V[i][5] = 24*Ww[i][2];
+          V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576);
+        }
+
+        // Store the transformed weights
+        for (int i = 0, m = 0; i < 6; i++)
+        {
+          for (int j = 0; j < 6; j++, m++)
+          {
+            vst1q_f32(outptr + m*matrix_stride, V[i][j]);
+          }
+        }
+        outptr += 4;
+      }
+#endif  // __aarch64__
+#ifdef __arm_any__
+      for (; channels_remaining >= 2; channels_remaining -= 2)
+      {
+        // Matrices used and computed in this kernel
+        float32x2_t w[3][3], Ww[6][3], V[6][6];
+
+        // Read weights
+        for (int i = 0; i < 3; i++)
+        {
+          for (int j = 0; j < 3; j++)
+          {
+            w[i][j] = vld1_f32(inptrs[i][j]);
+            inptrs[i][j] += 2;
+          }
+        }
+
+        // Compute the matrix W w
+        for (int j = 0; j < 3; j++)
+        {
+          // Ww[0][j] =  6*w[0][j];
+          Ww[0][j] = vmul_n_f32(w[0][j], 6.0);
+
+          // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
+          Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0);
+
+          // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
+          Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0);
+
+          // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
+          Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
+
+          // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
+          Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f);
+
+          // Ww[5][j] = 24*w[2][j];
+          Ww[5][j] = vmul_n_f32(w[2][j], 24.0f);
+        }
+
+        // Compute V = W w WT
+        for (int i = 0; i < 6; i++)
+        {
+          const float recip576 = 1.0f / 576.0f;
+
+          // V[i][0] =  6*Ww[i][0];
+          V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576);
+
+          // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
+          V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
+
+          // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
+          V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
+
+          // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
+          V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
+
+          // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
+          V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576);
+
+          // V[i][5] = 24*Ww[i][2];
+          V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576);
+        }
+
+        // Store the transformed weights
+        for (int i = 0, m = 0; i < 6; i++)
+        {
+          for (int j = 0; j < 6; j++, m++)
+          {
+            vst1_f32(outptr + m*matrix_stride, V[i][j]);
+          }
+        }
+        outptr += 2;
+      }
+#endif  // __arm_any__
+      for (; channels_remaining; channels_remaining--)
+      {
+        // Matrices used and computed in this kernel
+        float w[3][3], Ww[6][3], V[6][6];
+
+        // Read weights
+        for (int i = 0; i < 3; i++)
+        {
+          for (int j = 0; j < 3; j++)
+          {
+            w[i][j] = *(inptrs[i][j]++);
+          }
+        }
+
+        // Compute the matrix W w
+        for (int j = 0; j < 3; j++)
+        {
+          Ww[0][j] =  6*w[0][j];
+          Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
+          Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
+          Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
+          Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
+          Ww[5][j] = 24*w[2][j];
+        }
+
+        // Compute V = W w WT
+        for (int i = 0; i < 6; i++)
+        {
+          V[i][0] = ( 6*Ww[i][0]) / 576.0;
+          V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
+          V[i][2] = (-4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
+          V[i][3] = ( 1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
+          V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
+          V[i][5] = (24*Ww[i][2]) / 576.0;
+        }
+
+        // Store the transformed weights
+        for (int i = 0, m = 0; i < 6; i++)
+        {
+          for (int j = 0; j < 6; j++, m++)
+          {
+            *(outptr + m*matrix_stride) = V[i][j];
+          }
+        }
+        outptr++;
+      }
+    }
+  }
+
+  template <>
+  template <>
+  int WinogradGEMM<4, 4, 3, 3>::WeightsTransform<float>::ops_performed(const KernelShape &shape)
+  {
+    const int channel_prod = shape.n_input_channels * shape.n_output_channels;
+    return 9 * 16 * channel_prod;
+  }
+
+  template struct WinogradGEMM<4, 4, 3, 3>::WeightsTransform<float>;
+}
diff --git a/src/core/NEON/kernels/winograd/utils.hpp b/src/core/NEON/kernels/winograd/utils.cpp
similarity index 76%
rename from src/core/NEON/kernels/winograd/utils.hpp
rename to src/core/NEON/kernels/winograd/utils.cpp
index 14e709f..24d0386 100644
--- a/src/core/NEON/kernels/winograd/utils.hpp
+++ b/src/core/NEON/kernels/winograd/utils.cpp
@@ -1,4 +1,3 @@
-
 /*
  * Copyright (c) 2017 ARM Limited.
  *
@@ -22,31 +21,27 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#pragma once
+
+#include <cstdio>
 #include <ctime>
 
-inline double TimeInUs(void) {
+double TimeInUs(void)
+{
 #ifdef CYCLE_PROFILING
   timespec t;
-  clock_gettime(CLOCK_THREAD_CPUTIME_ID, &t);
+  clock_gettime(CLOCK_REALTIME, &t);
   return 1e6*t.tv_sec + 1e-3*t.tv_nsec;
 #else
   return 0;
 #endif
 }
 
-inline int iceildiv(const int a, const int b) {
-  return (a + b - 1) / b;
-}
-
-template <typename T>
-inline T roundup(const T a, const T b) {
-  return a + b - (a % b);
-}
-
-inline void PrintMatrix(const float* const m, const int M, const int N, const int row_stride) {
-  for (int i = 0; i < M; i++) {
-    for (int j = 0; j < N; j++) {
+void PrintMatrix(const float* const m, const int M, const int N, const int row_stride)
+{
+  for (int i = 0; i < M; i++)
+  {
+    for (int j = 0; j < N; j++)
+    {
       printf("%.3f ", m[i*row_stride + j]);
     }
     printf("\n");
diff --git a/src/core/NEON/kernels/winograd/winograd_gemm.cpp b/src/core/NEON/kernels/winograd/winograd_gemm.cpp
new file mode 100644
index 0000000..b44a453
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/winograd_gemm.cpp
@@ -0,0 +1,560 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "winograd_gemm.hpp"
+#include "batched_blocked_gemm.hpp"
+using namespace winograd;
+
+/** Get the output shape of a convolution. */
+template <int kr, int kc, int itr, int itc>
+template <typename TOut, typename TIn>
+Tensor4DShape WinogradGEMM<kr, kc, itr, itc>::Convolution<TOut, TIn>::get_output_shape(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &in_shape,
+  const PaddingType padding
+)
+{
+  // TODO Accept different kernel sizes
+  return Tensor4DShape {
+    in_shape.n_batches,
+    (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2,
+    (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2,
+    kernel_shape.n_output_channels,
+    in_shape.ordering
+  };
+}
+
+/* Get the memory required to transform the kernel.
+ */
+template <int kernel_rows, int kernel_cols,
+          int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_kernel_transform_working_size(const KernelShape &shape)
+{
+  if (shape.ordering == HWIO)
+  {
+    // Kernel is already in the correct order, so no additional memory is
+    // required.
+    return 0;
+  }
+  else
+  {
+    // Need to re-order the kernel into HWIO form, require enough space to
+    // represent the tensor.
+    return sizeof(TIn) * shape.size();
+  }
+}
+
+/** Get the memory required to store the kernel transformed into the
+ * Winograd domain.
+ */
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_kernel_storage_size(const KernelShape &shape)
+{
+  return N_GEMMS * get_kernel_matrix_size(shape);
+}
+
+
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_input_storage_size(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding
+)
+{
+  return N_GEMMS * get_input_matrix_size(kernel_shape, input_shape, padding);
+}
+
+
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_output_storage_size(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding
+)
+{
+  return N_GEMMS * get_output_matrix_size(kernel_shape, input_shape, padding);
+}
+
+
+/** Get the memory required to apply a Winograd operator to some input.
+ */
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_working_space_size(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding_type
+)
+{
+  const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type);
+
+  // Get the memory required to store the matrices
+  const size_t matrix_sizes = N_GEMMS * (
+    get_input_matrix_size(kernel_shape, input_shape, padding_type) +
+    get_output_matrix_size(kernel_shape, input_shape, padding_type)
+  );
+
+  // Add additional space to re-order the input and output if the input tensor
+  // is not in NHWC format.
+  if (input_shape.ordering == NHWC)
+  {
+    return matrix_sizes;  // No extra spacing required
+  }
+  else  // NCHW, must reorder the input and output tensors
+  {
+    // We only need to re-order the input or output at any one time, so request
+    // enough memory to do the largest of these.
+    const size_t extra_memory = std::max(
+      sizeof(TIn) * input_shape.size(),
+      sizeof(TOut) * output_shape.size()
+    );
+    return matrix_sizes + extra_memory;
+  }
+}
+
+
+/* Get the memory required by a single "input" matrix.
+ */
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_input_matrix_size(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding_type
+)
+{
+  return get_input_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TIn);
+}
+
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+int WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_input_matrix_stride(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding_type
+)
+{
+  // Compute shape for the GEMM
+  const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type);
+  const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows);
+  const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols);
+  const int M = roundup(input_shape.n_batches * tile_rows * tile_cols, M_BLOCK);
+  const int K = kernel_shape.n_input_channels;
+
+  return M * K;
+}
+
+
+/* Get the memory required by a single "output" matrix.
+ */
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_output_matrix_size(
+    const KernelShape &kernel_shape,
+    const Tensor4DShape &input_shape,
+    const PaddingType padding_type
+)
+{
+  return get_output_matrix_stride(kernel_shape, input_shape, padding_type) * sizeof(TOut);
+}
+
+
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+int WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_output_matrix_stride(
+    const KernelShape &kernel_shape,
+    const Tensor4DShape &input_shape,
+    const PaddingType padding_type
+)
+{
+  // Compute shape for the GEMM
+  const auto output_shape = get_output_shape(kernel_shape, input_shape, padding_type);
+  const int tile_rows = iceildiv(output_shape.n_rows, output_tile_rows);
+  const int tile_cols = iceildiv(output_shape.n_cols, output_tile_cols);
+  const int M = roundup(tile_rows * tile_cols, M_BLOCK);
+  const int N = roundup(kernel_shape.n_output_channels, N_BLOCK);
+
+  return input_shape.n_batches * M * N;
+}
+
+
+/* Get the memory required by a single "kernel" matrix.
+ */
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+size_t WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_kernel_matrix_size(const KernelShape &shape)
+{
+  return sizeof(TIn) * get_kernel_matrix_stride(shape);
+}
+
+template <int kernel_rows, int kernel_cols, int output_tile_rows, int output_tile_cols>
+template <typename TOut, typename TIn>
+int WinogradGEMM<kernel_rows, kernel_cols, output_tile_rows, output_tile_cols>::Convolution<TOut, TIn>::get_kernel_matrix_stride(const KernelShape &shape)
+{
+  const int K = shape.n_input_channels;
+  const int N = roundup(shape.n_output_channels, N_BLOCK);
+  return K * N;
+}
+
+
+/** Create a new Winograd operator. */
+template <int output_tile_rows, int output_tile_cols,
+          int kernel_rows, int kernel_cols>
+template <typename TOut, typename TIn>
+WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::Convolution<TOut, TIn>::Convolution(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding,
+  void *kernel_storage
+) : kernel_shape(kernel_shape),  // Store the kernel shape
+    kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)),
+    manage_kernel_storage(kernel_storage == NULL),
+    _kernel_storage(manage_kernel_storage ?
+                      ALLOCATE(get_kernel_storage_size(kernel_shape)) :
+                      kernel_storage),
+    input_shape(input_shape),
+    padding(padding),
+    output_shape(get_output_shape(kernel_shape, input_shape, padding)),
+    tile_rows(iceildiv(output_shape.n_rows, output_tile_rows)),
+    tile_cols(iceildiv(output_shape.n_cols, output_tile_cols)),
+    M(input_shape.n_batches * tile_rows * tile_cols),
+    K(kernel_shape.n_input_channels),
+    N(kernel_shape.n_output_channels),
+    prof()
+{
+  // Create pointers to the kernel matrices
+  const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape);
+  int8_t* const ks_bytes = reinterpret_cast<int8_t *>(_kernel_storage);
+  for (int i = 0; i < N_GEMMS; i++) {
+    kernel_matrices[i] = reinterpret_cast<TIn *>(
+      ks_bytes + i*kernel_matrix_size_bytes);
+  }
+}
+
+
+/** Create a new Winograd operator and initialise the weights. */
+template <int output_tile_rows, int output_tile_cols,
+          int kernel_rows, int kernel_cols>
+template <typename TOut, typename TIn>
+WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::Convolution<TOut, TIn>::Convolution(
+  const KernelShape &kernel_shape,
+  const Tensor4DShape &input_shape,
+  const PaddingType padding,
+  const TIn* const kernel,
+  void *kernel_storage,
+  void *transform_working_space
+) : Convolution(kernel_shape, input_shape, padding, kernel_storage)
+{
+  transform_weights(kernel, transform_working_space);
+}
+
+
+/** Clean up a convolution engine. */
+template <int output_tile_rows, int output_tile_cols, int kernel_rows, int kernel_cols>
+template <typename TOut, typename TIn>
+WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::
+Convolution<TOut, TIn>::~Convolution()
+{
+  // If we were responsible for managing kernel storage ensure that it is
+  // freed.
+  if (manage_kernel_storage)
+  {
+    free(_kernel_storage);
+  }
+}
+
+
+/** Transform weights into the Winograd domain and store them for later use/reuse. */
+template <int output_tile_rows, int output_tile_cols, int kernel_rows, int kernel_cols>
+template <typename TOut, typename TIn>
+template <typename WeightsTransformT>
+void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::
+Convolution<TOut, TIn>::transform_weights(
+  const TIn* const kernel,
+  void *transform_working_space
+)
+{
+  // Allocate working space if it is required
+  bool allocated_working_space = false;
+  if (transform_working_space == NULL &&  // If no memory has been provided
+      get_kernel_transform_working_size(kernel_shape) != 0)  // And we need the space
+  {
+    allocated_working_space = true;
+    transform_working_space = ALLOCATE(
+      get_kernel_transform_working_size(kernel_shape)
+    );
+  }
+
+  // The transformation methods only work on weights laid out in HWIO form, if
+  // the weights are not in this form then we need to re-order them.
+  const TIn *kernel_hwio = kernel;
+  if (kernel_shape.ordering != HWIO)
+  {
+    kernel_hwio = reinterpret_cast<TIn *>(transform_working_space);
+
+    // Re-order the weights from OIHW to HWIO
+    this->prof(
+      "Weight reorder",
+      [&kernel, &kernel_hwio, this] () {
+        reorder::ofm_ifm_h_w_to_h_w_ifm_ofm(
+          kernel, const_cast<TIn *>(kernel_hwio),
+          kernel_shape.n_output_channels,
+          kernel_shape.n_input_channels,
+          kernel_shape.n_rows,
+          kernel_shape.n_cols
+        );
+      },
+      kernel_shape.size() * sizeof(TIn),
+      0,
+      kernel_shape.size() * sizeof(TIn)
+    );
+  }
+
+  const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape);
+  WeightsTransformT weights_transform(
+    kernel_hwio, kernel_matrices[0],
+    kernel_matrix_size_bytes / sizeof(TIn),
+    kernel_matrix_row_stride,
+    kernel_shape.n_output_channels,
+    kernel_shape.n_input_channels
+  );
+
+  // Transform the weights into the Winograd domain
+  auto kernel_prep = [&] ()
+  {
+    weights_transform.run(0, weights_transform.get_window());
+  };
+
+  prof(
+    "Kernel Prep", kernel_prep,
+    WeightsTransformT::bytes_read(kernel_shape),
+    WeightsTransformT::ops_performed(kernel_shape),
+    WeightsTransformT::bytes_written(kernel_shape)
+  );
+
+  // Free memory if we allocated it
+  if (allocated_working_space)
+  {
+    free(transform_working_space);
+  }
+}
+
+
+/** Perform a convolution. */
+template <int output_tile_rows, int output_tile_cols,
+          int kernel_rows, int kernel_cols>
+template <typename TOut, typename TIn>
+void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::
+Convolution<TOut, TIn>::execute(
+  TOut* const output,
+  const TIn* const input,
+  void *working_space,
+  const int n_threads
+)
+{
+  const auto padding_type = padding;
+  const auto input_shape = this->input_shape;
+
+  // Allocate working space if none has been provided
+  const bool manage_working_space = (working_space == NULL);
+  if (manage_working_space)
+  {
+    const size_t ws_size = get_working_space_size(
+      kernel_shape, input_shape, padding_type
+    );
+    working_space = ALLOCATE(ws_size * sizeof(int8_t));
+    memset(working_space, 0x00, ws_size);
+  }
+  int8_t* const ws_bytes = reinterpret_cast<int8_t *>(working_space);
+
+  // Split the working space into that required for 16 input matrices and
+  // output matrices.
+  TIn *input_matrices[N_GEMMS];
+  TOut *output_matrices[N_GEMMS];
+  const int in_matrix_stride_bytes = get_input_matrix_size(kernel_shape, input_shape, padding_type);
+  const int out_matrix_stride_bytes = get_output_matrix_size(kernel_shape, input_shape, padding_type);
+
+  for (int i = 0; i < N_GEMMS; i++)
+  {
+    input_matrices[i] = reinterpret_cast<TIn *>(
+        ws_bytes + i*in_matrix_stride_bytes);
+    output_matrices[i] = reinterpret_cast<TIn *>(
+        ws_bytes + N_GEMMS*in_matrix_stride_bytes + i*out_matrix_stride_bytes);
+  }
+
+  // If we need to re-order the input and output tensors then the final chunk
+  // of the working space can be used for this purpose.
+  // TODO  - Overlay the input reorder on top of the output matrices
+  //       - Overlay the output reorder on top of the input matrices
+  // Reorder the input input form if it was not provided in this ordering.
+  const TIn* input_nhwc = input;
+  if (input_shape.ordering == NCHW)
+  {
+    input_nhwc = reinterpret_cast<TIn *>(
+      ws_bytes + N_GEMMS*(in_matrix_stride_bytes + out_matrix_stride_bytes)
+    );
+
+    this->prof(
+      "NCHW -> NHWC",
+      [input, input_shape, input_nhwc] () {
+        reorder::nchw_to_nhwc(
+          input, const_cast<TIn *>(input_nhwc),
+          input_shape.n_batches,
+          input_shape.n_channels,
+          input_shape.n_rows,
+          input_shape.n_cols
+        );
+      },
+      input_shape.size(), 0, input_shape.size()
+    );
+  }
+
+  // Compute shape for the GEMM
+  const auto output_shape = this->output_shape;
+  int M = this->M;
+  int K = this->K;
+  int N = this->N;
+
+  const int in_matrix_row_stride = K;
+  const int out_matrix_row_stride = kernel_matrix_row_stride;
+
+  InputTransform<TIn> input_transform(
+    input_nhwc,
+    input_shape.n_batches,
+    input_shape.n_rows,
+    input_shape.n_cols,
+    input_shape.n_channels,
+    padding_type,
+    input_matrices[0],
+    in_matrix_stride_bytes / sizeof(TIn),
+    in_matrix_row_stride
+  );
+
+  // Transform the input into the Winograd domain
+  auto input_prep = [&] () {
+    input_transform.run(0, input_transform.get_window());
+  };
+  prof(
+    "Input Prep", input_prep,
+    InputTransform<TIn>::bytes_read(input_shape),
+    InputTransform<TIn>::ops_performed(input_shape),
+    InputTransform<TIn>::bytes_written(input_shape)
+  );
+
+  // Perform the GEMMs
+  const int kernel_matrix_stride_bytes = get_kernel_matrix_size(kernel_shape);
+  BatchedBlockedGemm<M_BLOCK, N_BLOCK, TOut, TIn> gemms(
+    N_GEMMS, M, K, N,
+    in_matrix_stride_bytes / sizeof(TIn),
+    in_matrix_row_stride,
+    kernel_matrix_stride_bytes / sizeof(TIn),
+    kernel_matrix_row_stride,
+    out_matrix_stride_bytes / sizeof(TOut),
+    out_matrix_row_stride,
+    input_matrices[0],
+    kernel_matrices[0],
+    output_matrices[0]
+  );
+  gemms.run(0, gemms.get_window());
+
+  // If the output tensor needs to be in NCHW form then store the NHWC output
+  // tensor in temporary storage and then reorder. If the output tensor needs
+  // to be in NHWC then just write straight to the output tensor.
+  TOut *output_nhwc = output;
+  if (input_shape.ordering == NCHW)
+  {
+    output_nhwc = reinterpret_cast<TOut *>(
+      ws_bytes + N_GEMMS*(in_matrix_stride_bytes + out_matrix_stride_bytes)
+    );
+  }
+
+  // Transform the output tensor from the Winograd domain to the spatial
+  // domain.
+  OutputTransform<TOut> output_transform(
+    output_matrices[0],
+    out_matrix_stride_bytes / sizeof(TOut),
+    out_matrix_row_stride,
+    output_nhwc,
+    output_shape.n_batches,
+    output_shape.n_rows,
+    output_shape.n_cols,
+    output_shape.n_channels
+  );
+  auto output_prep = [&] () {
+    output_transform.run(0, output_transform.get_window());
+  };
+  prof(
+    "Output Comp", output_prep,
+    OutputTransform<TOut>::bytes_read(output_shape),
+    OutputTransform<TOut>::ops_performed(output_shape),
+    OutputTransform<TOut>::bytes_written(output_shape)
+  );
+
+  // Reorder the output tensor if it is required to be in NCHW form.
+  if (input_shape.ordering == NCHW)
+  {
+    prof(
+      "NHWC -> NCHW",
+      [output_nhwc, output_shape, output] () {
+        reorder::nhwc_to_nchw(
+          output_nhwc, output,
+          output_shape.n_batches,
+          output_shape.n_rows,
+          output_shape.n_cols,
+          output_shape.n_channels
+        );
+      },
+      output_shape.size(), 0, output_shape.size()
+    );
+  }
+
+  // Free working space if we were responsible for allocating it
+  if (manage_working_space)
+  {
+    free(working_space);
+  }
+}
+
+
+/** Perform a convolution. */
+template <int output_tile_rows, int output_tile_cols,
+          int kernel_rows, int kernel_cols>
+template <typename TOut, typename TIn>
+void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::
+Convolution<TOut, TIn>::execute(
+  TOut* const output,
+  const TIn* const input,
+  const int n_threads
+)
+{
+  execute(output, input, NULL, n_threads);
+}
+
+
+// Instantiate required implementations
+template class WinogradGEMM<2, 2, 3, 3>::Convolution<float, float>;
+template class WinogradGEMM<4, 4, 3, 3>::Convolution<float, float>;
diff --git a/src/core/NEON/kernels/winograd/winograd_gemm.hpp b/src/core/NEON/kernels/winograd/winograd_gemm.hpp
deleted file mode 100644
index 59afa2f..0000000
--- a/src/core/NEON/kernels/winograd/winograd_gemm.hpp
+++ /dev/null
@@ -1,345 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#pragma once
-#include <cstdint>
-#include <cstdlib>
-#include <cassert>
-
-#include "gemm.hpp"
-#include "profiler.hpp"
-#include "utils.hpp"
-#include "shims.hpp"
-
-#include "transforms.hpp"
-
-namespace winograd {
-  /***************************************************************************/
-  /* Implementation of the Winograd F(2x2, 3x3, 4x4) algorithm using GEMM
-   * internally.
-   */
-  template <typename TOut, typename TIn>
-  class Winograd2x2_3x3GEMM {
-    public:
-      /* Instantiate a new Winograd operator.
-       */
-      Winograd2x2_3x3GEMM(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage);
-      virtual ~Winograd2x2_3x3GEMM();
-
-      /** Transform the weights into the Winograd domain.
-       */
-      template <typename KernelTransform=winograd2x2_3x3_gemm_kernel_transform_impl<TIn>>
-      void transform_weights(const TIn* const kernel, void *transform_working_space);
-
-      /* Initializes matrices pointers, to be called once before execute()
-       */
-      template <typename InputTransform=Winograd2x2_3x3GemmInputChannelwise<TIn>>
-      void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const TIn* const input, void* working_space);
-
-      /* Apply the Winograd operator to some input.
-       */
-      template <typename OutputTransform=Winograd2x2_3x3GemmOutput<TOut>>
-      void reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output);
-
-
-      /* Apply the Winograd operator to some input.
-       */
-      void execute(size_t first, size_t last);
-
-      /* Get the memory required to transform the kernel.
-       */
-      static inline size_t get_kernel_transform_working_size(const KernelShape &shape);
-
-      /* Get the output shape of a convolution.
-       */
-      static Tensor4DShape get_output_shape(const Tensor4DShape &input_shape, const KernelShape &k_shape,
-                                     const PaddingType padding_type);
-
-      /* Get the memory required to instantiate a new Winograd operator.
-       */
-      static size_t get_kernel_storage_size(const KernelShape &shape);
-
-      /* Get the memory required to apply a Winograd operator to some input.
-       */
-      static size_t get_working_space_size(const Tensor4DShape &input_shape,const KernelShape &k_shape,
-                                    const PaddingType padding);
-
-
-      Winograd2x2_3x3GEMM(const Winograd2x2_3x3GEMM &) = delete;
-      /** Prevent instances of this class from being copied (As this class contains pointers) */
-      Winograd2x2_3x3GEMM &operator=(const Winograd2x2_3x3GEMM &) = delete;
-      /** Allow instances of this class to be moved */
-      Winograd2x2_3x3GEMM(Winograd2x2_3x3GEMM &&) = default;
-      /** Allow instances of this class to be moved */
-      Winograd2x2_3x3GEMM &operator=(Winograd2x2_3x3GEMM &&) = default;
-
-    protected:
-      /* Get the memory required by a single "input" matrix.
-       */
-      static size_t get_input_matrix_size(const Tensor4DShape &input_shape,const KernelShape &k_shape,
-                                   const PaddingType padding);
-
-      /* Get the memory required by a single "output" matrix.
-       */
-      static size_t get_output_matrix_size(const Tensor4DShape &input_shape, const KernelShape &k_shape,
-                                    const PaddingType padding);
-
-      /* Get the memory required by a single "kernel" matrix.
-       */
-      static size_t get_kernel_matrix_size(const KernelShape &shape);
-
-      const KernelShape kernel_shape;  // Shape of applied kernel
-      const Tensor4DShape in_shape;
-      const PaddingType padding;
-
-      const int kernel_matrix_row_stride;  // Stride within kernel matrix
-
-      const bool manage_kernel_storage;  // Free kernel storage when done
-      void* const _kernel_storage;  // Base pointer for kernel matrices
-
-      profiler prof;  // Profiler
-
-      TIn *kernel_matrices[16];  // Prepared form of kernel
-      TIn *input_matrices[16];
-      TOut *output_matrices[16];
-
-
-      static const int M_BLOCK = 4;
-      static const int N_BLOCK = 16;
-  };
-} // namespace winograd
-
-template <typename TOut, typename TIn>
-size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_kernel_transform_working_size(
-    const KernelShape &shape
-)
-{
-    // Need to re-order the kernel into HWIO form, require enough space to
-    // represent the tensor.
-    return sizeof(TIn) * shape.size();
-}
-
-
-template <typename TOut, typename TIn>
-template <typename KernelTransform>
-void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::transform_weights(
-  const TIn* const kernel,
-  void *transform_working_space
-)
-{
-    const int kernel_matrix_size_bytes = get_kernel_matrix_size(kernel_shape);
-    int8_t* const ks_bytes = reinterpret_cast<int8_t *>(_kernel_storage);
-    for (int i = 0; i < 16; i++) {
-        kernel_matrices[i] = reinterpret_cast<TIn *>(
-        ks_bytes + i*kernel_matrix_size_bytes);
-    }
-
-    const TIn *kernel_hwio = kernel;
-    if( transform_working_space)
-    {
-            kernel_hwio = reinterpret_cast<TIn *>(transform_working_space);
-            ofm_ifm_h_w_to_h_w_ifm_ofm(
-                  kernel, const_cast<TIn *>(kernel_hwio),
-                  kernel_shape.n_output_channels,
-                  kernel_shape.n_input_channels,
-                  kernel_shape.n_rows,
-                  kernel_shape.n_cols
-                );
-    }
-    KernelTransform::execute(
-      kernel_shape, kernel_hwio, kernel_matrices[0],
-      kernel_matrix_size_bytes / sizeof(TIn),
-      kernel_matrix_row_stride
-    );
-}
-
-template <typename TOut, typename TIn>
-winograd::Winograd2x2_3x3GEMM<TOut, TIn>::Winograd2x2_3x3GEMM( const KernelShape &kernel_shape, const Tensor4DShape input_shape,
-        const PaddingType padding_type, void *kernel_storage)
-    : kernel_shape(kernel_shape), in_shape(input_shape), padding(padding_type),kernel_matrix_row_stride(roundup(kernel_shape.n_output_channels, N_BLOCK)), manage_kernel_storage(false),
-        _kernel_storage(kernel_storage), prof() {
-     memset(kernel_matrices, 0x00, sizeof(TIn)*16);
-     memset(input_matrices, 0x00, sizeof(TIn)*16);
-     memset(output_matrices, 0x00, sizeof(TOut)*16);
-}
-
-/*****************************************************************************/
-template <typename TOut, typename TIn>
-winograd::Winograd2x2_3x3GEMM<TOut, TIn>::~Winograd2x2_3x3GEMM() {}
-
-/*****************************************************************************/
-template <typename TOut, typename TIn>
-template <typename InputTransform>
-void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::reshape_input(
-    const Tensor4DShape& input_shape,
-    const PaddingType padding_type,
-    const TIn* const input,
-    void *working_space
-) {
-  assert(working_space);
-  int8_t* const ws_bytes = reinterpret_cast<int8_t *>(working_space);
-  // Split the working space into that required for 16 input matrices and
-  // output matrices.
-  const int in_matrix_stride_bytes = get_input_matrix_size(input_shape, kernel_shape, padding_type);
-  const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type);
-
-  for (int i = 0; i < 16; i++) {
-    input_matrices[i] = reinterpret_cast<TIn *>(
-        ws_bytes + i*in_matrix_stride_bytes);
-    output_matrices[i] = reinterpret_cast<TIn *>(
-        ws_bytes + 16*in_matrix_stride_bytes + i*out_matrix_stride_bytes);
-  }
-
-  // Compute shape for the GEMM
-  const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type);
-  const int tile_rows = iceildiv(output_shape.n_rows, 2);
-  const int tile_cols = iceildiv(output_shape.n_cols, 2);
-  const int K = kernel_shape.n_input_channels;
-
-  const int in_matrix_row_stride = K;
-  const int in_matrix_batch_stride = tile_rows*tile_cols*in_matrix_row_stride;
-
-  // Transform the input tensor into an appropriate form
-  auto input_prep = [&] () {
-    InputTransform::execute(
-      input, input_shape, padding_type, tile_rows, tile_cols,
-      input_matrices[0], in_matrix_stride_bytes / sizeof(TIn),
-      in_matrix_batch_stride, in_matrix_row_stride
-    );
-  };
-  prof(
-    "Input Prep", input_prep,
-    InputTransform::bytes_read(input_shape, output_shape),
-    InputTransform::flops_performed(input_shape, output_shape),
-    InputTransform::bytes_written(input_shape, output_shape)
-  );
-
-}
-
-/*****************************************************************************/
-template <typename TOut, typename TIn>
-template <typename OutputTransform>
-void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::reshape_output(const Tensor4DShape& input_shape, const PaddingType padding_type, TOut* const output) {
-  assert(output_matrices[0]);
-  const int out_matrix_stride_bytes = get_output_matrix_size(input_shape, kernel_shape, padding_type);
-  const auto output_shape = get_output_shape(input_shape,kernel_shape, padding_type);
-  const int out_matrix_row_stride = kernel_matrix_row_stride;
-
-  // Transform the output tensor into an appropriate form
-    OutputTransform::execute(
-      output_shape,
-      output_matrices[0],
-      out_matrix_stride_bytes / sizeof(TOut),
-      out_matrix_row_stride,
-      output
-    );
-}
-
-
-/*****************************************************************************/
-template <typename TOut, typename TIn>
-void winograd::Winograd2x2_3x3GEMM<TOut, TIn>::execute( size_t first, size_t last ) {
-  assert(input_matrices[0] && kernel_matrices[0] && output_matrices[0]);
-  assert(first < 16 && last < 16 && first < last);
-  // Compute shape for the GEMM
-  const auto output_shape = get_output_shape(in_shape,kernel_shape, padding);
-  const int tile_rows = iceildiv(output_shape.n_rows, 2);
-  const int tile_cols = iceildiv(output_shape.n_cols, 2);
-  const int M = in_shape.n_batches * tile_rows * tile_cols;
-  const int K = kernel_shape.n_input_channels;
-  const int N = kernel_shape.n_output_channels;
-
-  const int in_matrix_row_stride = K;
-  const int out_matrix_row_stride = kernel_matrix_row_stride;
-  // Perform the GEMMs
-  for (size_t i = first; i <= last; i++) {
-      BlockedGemm<M_BLOCK, N_BLOCK>(
-        input_matrices[i], kernel_matrices[i], output_matrices[i], M, K, N,
-        in_matrix_row_stride, kernel_matrix_row_stride, out_matrix_row_stride
-      );
-//    prof("GEMM", perform_gemm, 0, 2*M*K*N, 0);  // TODO Memory
-  }
-
-}
-
-/*****************************************************************************/
-template <typename TOut, typename TIn>
-Tensor4DShape winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_shape(
-    const Tensor4DShape &in_shape, const KernelShape &k_shape, const PaddingType padding)  {
-  return Tensor4DShape {
-    in_shape.n_batches,
-    (padding == PADDING_SAME) ? in_shape.n_rows : in_shape.n_rows - 2,
-    (padding == PADDING_SAME) ? in_shape.n_cols : in_shape.n_cols - 2,
-    k_shape.n_output_channels
-  };
-}
-
-template <typename TOut, typename TIn>
-size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_kernel_storage_size(
-    const KernelShape &shape) {
-  return 16 * get_kernel_matrix_size(shape);
-}
-
-template <typename TOut, typename TIn>
-size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_kernel_matrix_size(
-    const KernelShape &shape) {
-  const int K = shape.n_input_channels;
-  const int N = roundup(shape.n_output_channels, N_BLOCK);
-  return sizeof(TIn) * K * N;
-}
-
-template <typename TOut, typename TIn>
-size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_working_space_size(
-    const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type
-)  {
-  return 16 * get_input_matrix_size(input_shape, k_shape, padding_type) +
-         16 * get_output_matrix_size(input_shape, k_shape, padding_type);
-}
-
-template <typename TOut, typename TIn>
-size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_input_matrix_size(
-    const Tensor4DShape& input_shape, const KernelShape &k_shape, const PaddingType padding_type
-)  {
-  // Compute shape for the GEMM
-  const auto output_shape = get_output_shape(input_shape, k_shape, padding_type);
-  const int tile_rows = iceildiv(output_shape.n_rows, 2);
-  const int tile_cols = iceildiv(output_shape.n_cols, 2);
-  const int M = roundup(tile_rows * tile_cols, M_BLOCK);
-  const int K = k_shape.n_input_channels;
-
-  return input_shape.n_batches * M * K * sizeof(TIn);
-}
-
-template <typename TOut, typename TIn>
-size_t winograd::Winograd2x2_3x3GEMM<TOut, TIn>::get_output_matrix_size(
-    const Tensor4DShape& input_shape, const KernelShape &k_shape,const PaddingType padding_type
-)  {
-  // Compute shape for the GEMM
-  const auto output_shape = get_output_shape(input_shape, k_shape, padding_type);
-  const int tile_rows = iceildiv(output_shape.n_rows, 2);
-  const int tile_cols = iceildiv(output_shape.n_cols, 2);
-  const int M = roundup(tile_rows * tile_cols, M_BLOCK);
-  const int N = roundup(k_shape.n_output_channels, N_BLOCK);
-
-  return input_shape.n_batches * M * N * sizeof(TOut);
-}
diff --git a/src/core/NEON/kernels/winograd/winograd_layer.cpp b/src/core/NEON/kernels/winograd/winograd_layer.cpp
new file mode 100644
index 0000000..689ecba
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/winograd_layer.cpp
@@ -0,0 +1,204 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "convolution.hpp"
+#include "winograd_layer.hpp"
+#include "tensor.hpp"
+
+
+/** Determine how much memory (in units of TIn) to allocate for the transformed
+ * weights.
+ */
+template <
+  int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+  typename TIn, typename TOut
+>
+unsigned int WinogradConvolutionLayer<
+  OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut
+>::get_weight_storage_size(
+  const int n_output_channels,  /** Number of output feature maps. */
+  const int n_input_channels    /** Number of input feature maps. */
+)
+{
+  const KernelShape shape(
+    n_output_channels, KernelRows, KernelCols, n_input_channels
+  );
+  return static_cast<unsigned int>(
+    // WinogradConv returns the size in bytes, we divide by `sizeof(TIn)` to
+    // express that in units of TIn.
+    WinogradConv::get_kernel_storage_size(shape) / sizeof(TIn)
+  );
+}
+
+
+/** Determine how much memory (in units of TIn) to allocate for the transformed
+ * input.
+ */
+template <
+  int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+  typename TIn, typename TOut
+>
+unsigned int WinogradConvolutionLayer<
+  OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut
+>::get_input_storage_size(
+  const int n_batches,     /** Number of batches in the input tensor. */
+  const int n_channels,    /** Number of feature maps in the input tensor. */
+  const int n_rows,        /** Number of rows in each feature map. */
+  const int n_cols,        /** Number of columns in each feature map. */
+  const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+)
+{
+  // Construct shapes for the input and kernel tensors.
+  const Tensor4DShape input_shape(n_batches, n_rows, n_cols, n_channels);
+  const KernelShape kern_shape(1, KernelRows, KernelCols, n_channels);
+  const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
+
+  // Return the size, converted into units of TIn
+  return static_cast<unsigned int>(
+    WinogradConv::get_input_storage_size(kern_shape, input_shape, padding) /
+    sizeof(TIn)
+  );
+}
+
+
+/** Determine how much memory (in units of TOut) to allocate for the (Winograd
+ * domain) output.
+ */
+template <
+  int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+  typename TIn, typename TOut
+>
+unsigned int WinogradConvolutionLayer<
+  OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut
+>::get_output_storage_size(
+  const int n_batches,          /** Number of batches in the output tensor. */
+  const int n_rows,             /** Number of rows in each feature map of the input tensor. */
+  const int n_cols,             /** Number of columns in each feature map of the input tensor. */
+  const int n_output_channels,  /** Number of feature maps in the output tensor. */
+  const bool same_padding       /** Use "SAME" padding, otherwise use "VALID". */
+)
+{
+  // Construct shapes for the input and kernel tensors.
+  const Tensor4DShape input_shape(n_batches, n_rows, n_cols, 1);
+  const KernelShape kern_shape(n_output_channels, KernelRows, KernelCols, 1);
+  const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
+
+  // Return the size, converted into units of TOut
+  return static_cast<unsigned int>(
+    WinogradConv::get_output_storage_size(kern_shape, input_shape, padding) /
+    sizeof(TOut)
+  );
+}
+
+
+/** Get the shape (rows, cols) of a feature map of the output tensor. */
+template <
+  int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+  typename TIn, typename TOut
+>
+std::pair<int, int> WinogradConvolutionLayer<
+  OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut
+>::get_output_feature_map_shape(
+  const int n_input_rows,  /** Number of rows in the input feature map. */
+  const int n_input_cols,  /** Number of columns in the input feature map. */
+  const bool same_padding  /** Use "SAME" padding, otherwise use "VALID". */
+)
+{
+  // Construct shapes for the input and kernel tensors.
+  const Tensor4DShape input_shape(1, n_input_rows, n_input_cols, 1);
+  const KernelShape kern_shape(1, KernelRows, KernelCols, 1);
+  const PaddingType padding = (same_padding) ? PADDING_SAME : PADDING_VALID;
+
+  // Compute the new shape
+  const auto output_shape = WinogradConv::get_output_shape(
+    kern_shape, input_shape, padding
+  );
+
+  return std::make_pair(output_shape.n_rows, output_shape.n_cols);
+}
+
+
+/** Create a new Winograd convolution layer.
+ */
+template <
+  int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols,
+  typename TIn, typename TOut
+>
+WinogradConvolutionLayer<OutputTileRows, OutputTileCols, KernelRows, KernelCols, TIn, TOut>::
+WinogradConvolutionLayer(
+  const int n_batches,          /** Number of batches in the input and output tensors. */
+  const int n_input_channels,   /** Number of feature maps in a batch of the input tensor. */
+  const int n_input_rows,       /** Number of rows in a feature map of the input tensor. */
+  const int n_input_cols,       /** Number of columns in a feature map of the input tensor. */
+  const int n_output_channels,  /** Number of feature maps in the output tensor. */
+  const bool same_padding,      /** Use "SAME" padding, otherwise use "VALID". */
+  const TIn* const weights,     /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */
+  TIn* const winograd_weights,  /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
+  const TIn* const input,       /** Pointer to NHWC ordered input tensor, in the spatial domain. */
+  TIn* const winograd_input,    /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+  TOut* const output,           /** Pointer to NHWC ordered output tensor, in the spatial domain. */
+  TOut* const winograd_output   /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
+) : _kernel_shape(n_output_channels, KernelRows, KernelCols, n_input_channels),
+    _input_shape(n_batches, n_input_rows, n_input_cols, n_input_channels),
+    _padding(same_padding ? PADDING_SAME : PADDING_VALID),
+    _output_shape(WinogradConv::get_output_shape(_kernel_shape, _input_shape, _padding)),
+    _n_output_rows(_output_shape.n_rows),
+    _n_output_cols(_output_shape.n_cols),
+    _kernel_matrix_stride(WinogradConv::get_kernel_matrix_stride(_kernel_shape)),
+    _kernel_matrix_row_stride(roundup(n_output_channels, WinogradConv::N_BLOCK)),
+    _input_matrix_stride(WinogradConv::get_input_matrix_stride(_kernel_shape, _input_shape, _padding)),
+    _input_matrix_row_stride(n_input_channels),
+    _output_matrix_stride(WinogradConv::get_output_matrix_stride(_kernel_shape, _input_shape, _padding)),
+    _output_matrix_row_stride(_kernel_matrix_row_stride),
+    _tile_rows(iceildiv(_n_output_rows, OutputTileRows)),
+    _tile_cols(iceildiv(_n_output_cols, OutputTileCols)),
+    _m(n_batches * _tile_rows * _tile_cols),
+    _k(n_input_channels),
+    _n(n_output_channels),
+    weights_transform(
+      weights, winograd_weights,
+      _kernel_matrix_stride, _kernel_matrix_row_stride,
+      n_output_channels, n_input_channels
+    ),
+    input_transform(
+      input, n_batches, n_input_rows, n_input_cols, n_input_channels, _padding,
+      winograd_input, _input_matrix_stride, _input_matrix_row_stride
+    ),
+    gemms(
+      WinogradBase::N_GEMMS, _m, _k, _n,
+      _input_matrix_stride, _input_matrix_row_stride,
+      _kernel_matrix_stride, _kernel_matrix_row_stride,
+      _output_matrix_stride, _output_matrix_row_stride,
+      winograd_input, winograd_weights, winograd_output
+    ),
+    output_transform(
+      winograd_output, _output_matrix_stride, _output_matrix_row_stride,
+      output, n_batches, _n_output_rows, _n_output_cols, n_output_channels
+    )
+{
+}
+
+// Instantiate valid implementations.
+template class WinogradConvolutionLayer<2, 2, 3, 3, float, float>;
+template class WinogradConvolutionLayer<4, 4, 3, 3, float, float>;
diff --git a/src/runtime/NEON/functions/NEWinogradLayer.cpp b/src/runtime/NEON/functions/NEWinogradLayer.cpp
index 21f298c..da46f87 100644
--- a/src/runtime/NEON/functions/NEWinogradLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -43,8 +43,8 @@
 namespace arm_compute
 {
 NEWinogradLayer::NEWinogradLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _winograd_kernel(), _permute_input(), _permute_weights(), _permute_output(), _workspace(), _kernel_storage(), _input_nhwc(), _output_nhwc(),
-      _weights_hwio(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv()
+    : _memory_group(std::move(memory_manager)), _winograd_kernel(), _permute_input(), _permute_weights(), _permute_output(), _input_workspace(), _output_workspace(), _kernel_storage(), _input_nhwc(),
+      _output_nhwc(), _weights_hwio(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv()
 {
 } /* arm_compute */
 
@@ -72,36 +72,37 @@
     ARM_COMPUTE_ERROR_ON_MSG(stride_y != 1 || stride_x != 1, "Winograd layer only supports unit strides.");
 
     // Get convolved dimensions
-    auto      padding        = PADDING_VALID;
-    const int in_channels    = input->info()->dimension(2);
-    const int out_channels   = output->info()->dimension(2);
-    const int weights_width  = weights->info()->dimension(0);
-    const int weights_height = weights->info()->dimension(1);
+    const int in_channels  = input->info()->dimension(2);
+    const int out_channels = output->info()->dimension(2);
 
-    const KernelShape   kernel_shape({ out_channels, weights_height, weights_width, in_channels });
     const Tensor4DShape in_shape(internal_get_input_shape(input));
 
     // Get the memory required to instantiate a new Winograd operator.
-    constexpr size_t kstore_alignment          = 64;
-    const size_t     kernel_storage_per_thread = NEWinogradLayerKernel::get_kernel_storage_size(kernel_shape);
-    _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_per_thread + kstore_alignment - 1) }, 1, DataType::U8));
+    constexpr size_t storage_alignment   = 64;
+    const size_t     kernel_storage_size = NEWinogradLayerKernel::get_weight_storage_size(out_channels, in_channels) * sizeof(float);
+    _kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_size + storage_alignment - 1) }, 1, DataType::U8));
     _memory_group.manage(&_kernel_storage);
-
-    // Get workbench size and allocate memory
-
-    constexpr size_t wspace_alignment = 64;
-    const size_t     ws_size          = NEWinogradLayerKernel::get_working_space_size(in_shape, kernel_shape, padding);
-    _workspace.allocator()->init(TensorInfo(TensorShape{ (ws_size + wspace_alignment - 1) }, 1, DataType::U8));
-    _memory_group.manage(&_workspace);
     _memory_group.manage(&_input_nhwc);
     _kernel_storage.allocator()->allocate();
-    _workspace.allocator()->allocate();
+    // Input storage
+    const size_t input_storage_size = NEWinogradLayerKernel::get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, false) * sizeof(float);
+    _input_workspace.allocator()->init(TensorInfo(TensorShape{ (input_storage_size + storage_alignment - 1) }, 1, DataType::U8));
+    _memory_group.manage(&_input_workspace);
+    _input_workspace.allocator()->allocate();
 
-    // Create Winograd operator object
-    _conv = support::cpp14::make_unique<Winograd3x3F32>(kernel_shape, in_shape, padding, _kernel_storage.buffer());
+    // Output storage
+    const size_t output_storage_size = NEWinogradLayerKernel::get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels, false) * sizeof(float);
+    _output_workspace.allocator()->init(TensorInfo(TensorShape{ (output_storage_size + storage_alignment - 1) }, 1, DataType::U8));
+    _memory_group.manage(&_output_workspace);
+    _output_workspace.allocator()->allocate();
 
-    // Configure the kernel, padding not needed so it's safe to call configure after allocare
-    _winograd_kernel.configure(_conv.get());
+    // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output()
+    TensorInfo info(TensorShape(_output->info()->dimension(2), _output->info()->dimension(0),
+                                _output->info()->dimension(1), _output->info()->dimension(3)),
+                    1, _output->info()->data_type());
+    _output_nhwc.allocator()->init(info);
+
+    _output_nhwc.allocator()->allocate();
 
     // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map]
     switch(weights->info()->num_dimensions())
@@ -122,60 +123,56 @@
             break;
         }
     }
+
+    _weights_hwio.allocator()->allocate();
+
     // configure the kernel to transform the input tensor from NCHW -> NHWC
     _permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
 
-    // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output()
-    TensorInfo info(TensorShape(_output->info()->dimension(2), _output->info()->dimension(0),
-                                _output->info()->dimension(1), _output->info()->dimension(3)),
-                    1, _output->info()->data_type());
-    _output_nhwc.allocator()->init(info);
-
-    _output_nhwc.allocator()->allocate();
-    _weights_hwio.allocator()->allocate();
     _input_nhwc.allocator()->allocate();
+
+    // Create Winograd operator object
+    _conv = support::cpp14::make_unique<Winograd3x3F32>(
+                in_shape.n_batches,
+                in_shape.n_channels,
+                in_shape.n_rows,
+                in_shape.n_cols,
+                out_channels,
+                false,
+                reinterpret_cast<const float *>(_weights_hwio.buffer()),
+                reinterpret_cast<float *>(_kernel_storage.buffer()),
+                reinterpret_cast<float *>(_input_nhwc.buffer()),
+                reinterpret_cast<float *>(_input_workspace.buffer()),
+                reinterpret_cast<float *>(_output_nhwc.buffer()),
+                reinterpret_cast<float *>(_output_workspace.buffer()));
+
+    // Configure the kernel, padding not needed so it's safe to call configure after allocare
+    _winograd_kernel.configure(_conv.get());
+
+    // Reorder the convoluted output to ACL's ordering NCHW
+    _permute_output.configure(&_output_nhwc, _output, PermutationVector(1U, 2U, 0U));
+
 }
 
 void NEWinogradLayer::run()
 {
-#if defined(__aarch64__)
     _memory_group.acquire();
     if(!_reshaped_kernel)
     {
         _reshaped_kernel = true;
         _permute_weights.run();
-        _conv->transform_weights(reinterpret_cast<const float *>(_weights_hwio.buffer()), nullptr);
+        _conv->transform_weights();
     }
-    const Tensor4DShape in_shape(internal_get_input_shape(_input));
-    auto                padding = PADDING_VALID;
-
     //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC
     _permute_input.run();
-
-    //Setup matrices ptrs and transfor the input tensor to the appropriate form before running GEMM.
-    _conv->reshape_input(in_shape, padding, reinterpret_cast<float *>(_input_nhwc.buffer()), _workspace.buffer());
-
+    // Transform input tensor to the winograd domain
+    _conv->transform_input();
     //Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs
     NEScheduler::get().schedule(&_winograd_kernel, Window::DimX);
-
-    //Transform the output to the appropriate form
-    _conv->reshape_output(in_shape, padding, reinterpret_cast<float *>(_output_nhwc.buffer()));
-
+    // Transform output tensor to the spatial domain
+    _conv->transform_output();
     // Reorder the convoluted output to ACL's ordering NCHW
-    _permute_output.configure(&_output_nhwc, _output, PermutationVector(1U, 2U, 0U));
     _permute_output.run();
-
     _memory_group.release();
-#else  /* __aarch64__ */
-    ARM_COMPUTE_UNUSED(_winograd_kernel);
-    ARM_COMPUTE_UNUSED(_workspace);
-    ARM_COMPUTE_UNUSED(_kernel_storage);
-    ARM_COMPUTE_UNUSED(_input);
-    ARM_COMPUTE_UNUSED(_weights);
-    ARM_COMPUTE_UNUSED(_output);
-    ARM_COMPUTE_UNUSED(_reshaped_kernel);
-    ARM_COMPUTE_UNUSED(_conv);
-    ARM_COMPUTE_ERROR("Winograd only supported for aarch64, recompile with arch=arm64-v8a.");
-#endif /* __aarch64__ */
 }
 } // namespace arm_compute