COMPMID-3108: Add Winograd 3x3,4x4 FP16 support for NEON

Change-Id: I20680dc74a3d709297539e2132417308a7aecc9d
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3159
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
index 68064ee..3100bf7 100644
--- a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
@@ -41,23 +41,34 @@
 
 namespace
 {
-inline bool is_kernel_size_supported(Size2D size)
+inline bool is_kernel_size_supported(DataType data_type, Size2D size)
 {
-    const std::array<Size2D, 8> supported_input_sizes = { { Size2D(1, 3), Size2D(3, 1), Size2D(5, 5), Size2D(3, 3), Size2D(1, 5), Size2D(5, 1), Size2D(7, 1), Size2D(1, 7) } };
-    return std::end(supported_input_sizes) != std::find(std::begin(supported_input_sizes), std::end(supported_input_sizes), size);
+    const std::array<Size2D, 8> f32_support = { { Size2D(1, 3), Size2D(3, 1), Size2D(5, 5), Size2D(3, 3), Size2D(1, 5), Size2D(5, 1), Size2D(7, 1), Size2D(1, 7) } };
+    const std::array<Size2D, 8> f16_support = { { Size2D(3, 3) } };
+
+    switch(data_type)
+    {
+        case DataType::F16:
+            return std::end(f16_support) != std::find(std::begin(f16_support), std::end(f16_support), size);
+        case DataType::F32:
+            return std::end(f32_support) != std::find(std::begin(f32_support), std::end(f32_support), size);
+        default:
+            return false;
+    }
 }
 
 Status validate_arguments_winograd_weight_trans(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
 
     const size_t idx_width    = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
     const size_t idx_height   = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
     const auto   input_width  = input->dimension(idx_width);
     const auto   input_height = input->dimension(idx_height);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(Size2D(input_width, input_height)), "Only 1x3, 3x1, 1x5, 5x1, 7x1, 1x7, 3x3 and 5x5 kernels are supported");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(input_width, input_height)),
+                                    "Only 1x3, 3x1, 1x5, 5x1, 7x1, 1x7, 3x3 and 5x5 kernels are supported");
     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
     const Size2D &output_tile = winograd_info.output_tile_size;
     const std::array<Size2D, 8> supported_tile_sizes = { { Size2D(2U, 2U), Size2D(4U, 4U), Size2D(1U, 6U), Size2D(6U, 1U), Size2D(4, 1), Size2D(1, 4), Size2D(2, 1), Size2D(1, 2) } };
@@ -89,9 +100,9 @@
     const PadStrideInfo &conv_info   = winograd_info.convolution_info;
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(Size2D(kernel_dims.width, kernel_dims.height)),
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(kernel_dims.width, kernel_dims.height)),
                                     "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported");
 
     // Validate configured output
@@ -128,9 +139,9 @@
 
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != num_tiles.area());
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(Size2D(kernel_dims.width, kernel_dims.height)),
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(kernel_dims.width, kernel_dims.height)),
                                     "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported");
 
     const std::array<unsigned int, 3> supported_gemm_sizes = { { 8U, 16U, 36U } };
@@ -162,22 +173,19 @@
 }
 } // namespace
 
-template <typename T>
-Status INEWinogradLayerTransformWeightsKernel<T>::validate(const ITensorInfo *input, const ITensorInfo *weights)
+Status INEWinogradLayerTransformWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *weights)
 {
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
     const DataLayout   data_layout = input->data_layout();
     const unsigned int width_idx   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
     const unsigned int height_idx  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(Size2D(weights->dimension(width_idx), weights->dimension(height_idx))),
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(weights->dimension(width_idx), weights->dimension(height_idx))),
                                     "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported");
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
     return Status{};
 }
 
-template class INEWinogradLayerTransformWeightsKernel<float>;
-
 template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
 unsigned int NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_weight_storage_size(int num_output_channels, int num_input_channels) const
 {
@@ -262,6 +270,11 @@
 template class NEWinogradLayerTransformWeightsKernel<float, 4, 1, 5, 1>;
 template class NEWinogradLayerTransformWeightsKernel<float, 1, 2, 1, 7>;
 template class NEWinogradLayerTransformWeightsKernel<float, 2, 1, 7, 1>;
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class NEWinogradLayerTransformWeightsKernel<__fp16, 4, 4, 3, 3>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
 // Input transform
 
 template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -396,6 +409,10 @@
 template class NEWinogradLayerTransformInputKernel<float, 1, 2, 1, 7>;
 template class NEWinogradLayerTransformInputKernel<float, 2, 1, 7, 1>;
 
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class NEWinogradLayerTransformInputKernel<__fp16, 4, 4, 3, 3>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
 // Output transform
 
 template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
@@ -524,4 +541,7 @@
 template class NEWinogradLayerTransformOutputKernel<float, 1, 2, 1, 7>;
 template class NEWinogradLayerTransformOutputKernel<float, 2, 1, 7, 1>;
 
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class NEWinogradLayerTransformOutputKernel<__fp16, 4, 4, 3, 3>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 } // namespace arm_compute
diff --git a/src/core/NEON/kernels/convolution/winograd/padding.cpp b/src/core/NEON/kernels/convolution/winograd/padding.cpp
index 46fe57c..04aa472 100644
--- a/src/core/NEON/kernels/convolution/winograd/padding.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/padding.cpp
@@ -85,6 +85,15 @@
   unsigned int, unsigned int, unsigned int, unsigned int, float
 );
 
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template void copy_and_pad_tile(
+    unsigned int, unsigned int, unsigned int,
+    const __fp16 *, unsigned int, unsigned int,
+    __fp16 *, unsigned int, unsigned int,
+    unsigned int, unsigned int, unsigned int, unsigned int, __fp16
+);
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
 template <unsigned int TileRows, unsigned int TileCols>
 void CopyCropped<TileRows, TileCols>::execute(
   const size_t size,
@@ -163,4 +172,21 @@
   unsigned int crop_right
 );
 
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template void crop_and_copy_tile(
+    unsigned int tile_rows,
+    unsigned int tile_cols,
+    unsigned int n_channels,
+    const __fp16 *inptr,
+    unsigned int in_row_stride,
+    unsigned int in_col_stride,
+    __fp16 *outptr,
+    unsigned int out_row_stride,
+    unsigned int out_col_stride,
+    unsigned int crop_top,
+    unsigned int crop_left,
+    unsigned int crop_bottom,
+    unsigned int crop_right
+);
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 }  // namespace padding
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd.cpp b/src/core/NEON/kernels/convolution/winograd/winograd.cpp
index a4eb9fc..867bb3c 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd.cpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd.cpp
@@ -176,3 +176,7 @@
 
 template class WinogradGEMM<1, 2, 1, 7, WinogradRoots::Integers>::Convolution<float, float, float, float>;
 template class WinogradGEMM<2, 1, 7, 1, WinogradRoots::Integers>::Convolution<float, float, float, float>;
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class WinogradGEMM<4, 4, 3, 3, WinogradRoots::Integers>::Convolution<__fp16, __fp16, __fp16, __fp16>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp
new file mode 100644
index 0000000..1ea68b5
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp
@@ -0,0 +1,257 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#include "input.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+template <>
+void InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>::transform_tile(
+    const int n_channels,
+    const __fp16* const input_base,
+    const int input_row_stride,
+    const int input_col_stride,
+    __fp16* outptr,
+    const int matrix_stride
+)
+{
+    constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
+
+    // Get pointers into the input tile
+    const __fp16 *x_ptrs[inner_tile_rows][inner_tile_cols];
+    for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
+    {
+        // Get a pointer into the row
+        const __fp16* const row_ptr = input_base + xi*input_row_stride;
+
+        for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
+        {
+            x_ptrs[i][j] = row_ptr + xj*input_col_stride;
+        }
+    }
+
+    // Matrices used/computed in this kernel.
+    __fp16 x[inner_tile_rows][inner_tile_cols];
+    __fp16 XTx[inner_tile_rows][inner_tile_cols];
+    __fp16 U[inner_tile_rows][inner_tile_cols];
+
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+        for (int j = 0; j < inner_tile_cols; j++)
+        {
+            x[i][j] = XTx[i][j] = 0.0f;
+        }
+    }
+
+    // Perform the Winograd input transformation for each channel in the input
+    // tensor.
+    int channels_remaining = n_channels;
+#ifdef __aarch64__
+    for (; channels_remaining >= 8; channels_remaining -= 8)
+  {
+    // Matrices used/computed in this kernel.
+    float16x8_t x[inner_tile_rows][inner_tile_cols];
+    float16x8_t XTx[inner_tile_rows][inner_tile_cols];
+    float16x8_t U[inner_tile_rows][inner_tile_cols];
+
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        x[i][j] = vdupq_n_f16(0.0f);
+        XTx[i][j] = vdupq_n_f16(0.0f);
+      }
+    }
+
+    // Load x
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        x[i][j] = vld1q_f16(x_ptrs[i][j]);
+        x_ptrs[i][j] += 8;
+      }
+    }
+
+    // Compute XT . x
+    for (int j = 0; j < inner_tile_cols; j++)
+    {
+      // XTx[0][j] = x[0][j] - x[2][j];
+      XTx[0][j] = vsubq_f16(x[0][j], x[2][j]);
+
+      // XTx[1][j] = x[1][j] + x[2][j];
+      XTx[1][j] = vaddq_f16(x[1][j], x[2][j]);
+
+      // XTx[2][j] = x[2][j] - x[1][j];
+      XTx[2][j] = vsubq_f16(x[2][j], x[1][j]);
+
+      // XTx[3][j] = x[1][j] - x[3][j];
+      XTx[3][j] = vsubq_f16(x[1][j], x[3][j]);
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      // U[i][0] = XTx[i][0] - XTx[i][2];
+      U[i][0] = vsubq_f16(XTx[i][0], XTx[i][2]);
+
+      // U[i][1] = XTx[i][1] + XTx[i][2];
+      U[i][1] = vaddq_f16(XTx[i][1], XTx[i][2]);
+
+      // U[i][2] = XTx[i][2] - XTx[i][1];
+      U[i][2] = vsubq_f16(XTx[i][2], XTx[i][1]);
+
+      // U[i][3] = XTx[i][1] - XTx[i][3];
+      U[i][3] = vsubq_f16(XTx[i][1], XTx[i][3]);
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < inner_tile_rows; i++)
+    {
+      for (int j = 0; j < inner_tile_cols; j++, m++)
+      {
+        vst1q_f16(outptr + m*matrix_stride, U[i][j]);
+      }
+    }
+    outptr += 8;
+  }
+#endif  // __aarch64__
+#ifdef __arm_any__
+    for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Matrices used/computed in this kernel.
+    float16x4_t x[inner_tile_rows][inner_tile_cols];
+    float16x4_t XTx[inner_tile_rows][inner_tile_cols];
+    float16x4_t U[inner_tile_rows][inner_tile_cols];
+
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        x[i][j] = vdup_n_f16(0.0f);
+        XTx[i][j] = vdup_n_f16(0.0f);
+      }
+    }
+
+    // Load x
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        x[i][j] = vld1_f16(x_ptrs[i][j]);
+        x_ptrs[i][j] += 4;
+      }
+    }
+
+    // Compute XT . x
+    for (int j = 0; j < inner_tile_cols; j++)
+    {
+      // XTx[0][j] = x[0][j] - x[2][j];
+      XTx[0][j] = vsub_f16(x[0][j], x[2][j]);
+
+      // XTx[1][j] = x[1][j] + x[2][j];
+      XTx[1][j] = vadd_f16(x[1][j], x[2][j]);
+
+      // XTx[2][j] = x[2][j] - x[1][j];
+      XTx[2][j] = vsub_f16(x[2][j], x[1][j]);
+
+      // XTx[3][j] = x[1][j] - x[3][j];
+      XTx[3][j] = vsub_f16(x[1][j], x[3][j]);
+    }
+
+    // Compute U = XT . x . X
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      // U[i][0] = XTx[i][0] - XTx[i][2];
+      U[i][0] = vsub_f16(XTx[i][0], XTx[i][2]);
+
+      // U[i][1] = XTx[i][1] + XTx[i][2];
+      U[i][1] = vadd_f16(XTx[i][1], XTx[i][2]);
+
+      // U[i][2] = XTx[i][2] - XTx[i][1];
+      U[i][2] = vsub_f16(XTx[i][2], XTx[i][1]);
+
+      // U[i][3] = XTx[i][1] - XTx[i][3];
+      U[i][3] = vsub_f16(XTx[i][1], XTx[i][3]);
+    }
+
+    // Store the transformed matrix
+    for (int i = 0, m = 0; i < inner_tile_rows; i++)
+    {
+      for (int j = 0; j < inner_tile_cols; j++, m++)
+      {
+        vst1_f16(outptr + m*matrix_stride, U[i][j]);
+      }
+    }
+    outptr += 4;
+  }
+#endif  // __arm_any__
+    for (; channels_remaining; channels_remaining--)
+    {
+        // Load x
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++)
+            {
+                x[i][j] = *(x_ptrs[i][j]++);
+            }
+        }
+
+        // Compute XT . x
+        for (int j = 0; j < inner_tile_cols; j++)
+        {
+            XTx[0][j] = x[0][j] - x[2][j];
+            XTx[1][j] = x[1][j] + x[2][j];
+            XTx[2][j] = x[2][j] - x[1][j];
+            XTx[3][j] = x[1][j] - x[3][j];
+        }
+
+        // Compute U = XT . x . X
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            U[i][0] = XTx[i][0] - XTx[i][2];
+            U[i][1] = XTx[i][1] + XTx[i][2];
+            U[i][2] = XTx[i][2] - XTx[i][1];
+            U[i][3] = XTx[i][1] - XTx[i][3];
+        }
+
+        // Store the transformed matrix
+        for (int i = 0, m = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++, m++)
+            {
+                *(outptr + m*matrix_stride) = U[i][j];
+            }
+        }
+        outptr++;
+    }
+}
+
+template class InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>;
+
+}  // namespace
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp
new file mode 100644
index 0000000..3eaf977
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#include "arm.hpp"
+#include "input.hpp"
+
+namespace winograd
+{
+template <>
+void InputTransform<6, 6, __fp16, __fp16, WinogradRoots::Integers>::transform_tile(
+    const int n_channels,
+    const __fp16* const input_base,
+    const int input_row_stride,
+    const int input_col_stride,
+    __fp16* outptr,
+    const int matrix_stride
+)
+{
+    constexpr int inner_tile_rows = 6;
+    constexpr int inner_tile_cols = 6;
+
+    // Get pointers into the input tile
+    const __fp16 *x_ptrs[inner_tile_rows][inner_tile_cols];
+    for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
+    {
+        // Get a pointer into the row
+        const __fp16* const row_ptr = input_base + xi*input_row_stride;
+
+        for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
+        {
+            x_ptrs[i][j] = row_ptr + xj*input_col_stride;
+        }
+    }
+
+    // Matrices used/computed in this kernel.
+    __fp16 x[inner_tile_rows][inner_tile_cols];
+    __fp16 XTx[inner_tile_rows][inner_tile_cols];
+    __fp16 U[inner_tile_rows][inner_tile_cols];
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+        for (int j = 0; j < inner_tile_cols; j++)
+        {
+            x[i][j] = XTx[i][j] = 0.0f;
+        }
+    }
+
+    // Perform the Winograd input transformation for each channel in the input
+    // tensor.
+    int channels_remaining = n_channels;
+    for (; channels_remaining >= 8; channels_remaining -= 8)
+    {
+        // Matrices used/computed in this kernel
+        float16x8_t x[inner_tile_rows][inner_tile_cols];
+        float16x8_t XTx[inner_tile_rows][inner_tile_cols];
+        float16x8_t U[inner_tile_rows][inner_tile_cols];
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++)
+            {
+                x[i][j] = vdupq_n_f16(0.0f);
+                XTx[i][j] = vdupq_n_f16(0.0f);
+            }
+        }
+
+        // Read a 6x6 tile in the Winograd domain
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++)
+            {
+                x[i][j] = vld1q_f16(x_ptrs[i][j]);
+                x_ptrs[i][j] += 8;
+            }
+        }
+
+        // Compute XT . x
+        for (int j = 0; j < inner_tile_cols; j++)
+        {
+            // XTx[0][j] =  4*x[0][j] + -5*x[2][j] +  1*x[4][j];
+            XTx[0][j] = vsubq_f16(vaddq_f16(x[4][j], vmulq_f16(x[0][j], vdupq_n_f16(4.0f))), vmulq_f16(x[2][j], vdupq_n_f16(5.0f)));
+
+            // XTx[1][j] = -4*x[1][j] + -4*x[2][j] +  1*x[3][j] +  1*x[4][j];
+            XTx[1][j] = vsubq_f16(vaddq_f16(x[3][j], x[4][j]), vmulq_f16(vaddq_f16(x[1][j], x[2][j]),  vdupq_n_f16(4.0f)));
+
+            // XTx[2][j] =  4*x[1][j] + -4*x[2][j] + -1*x[3][j] +  1*x[4][j];
+            XTx[2][j] = vaddq_f16(vsubq_f16(x[4][j], x[3][j]), vmulq_f16(vsubq_f16(x[1][j], x[2][j]), vdupq_n_f16(4.0f)));
+
+            // XTx[3][j] = -2*x[1][j] + -1*x[2][j] +  2*x[3][j] +  1*x[4][j];
+            XTx[3][j] = vaddq_f16(vsubq_f16(x[4][j], x[2][j]), vmulq_f16(vsubq_f16(x[3][j], x[1][j]), vdupq_n_f16(2.0f)));
+
+            // XTx[4][j] =  2*x[1][j] + -1*x[2][j] + -2*x[3][j] +  1*x[4][j];
+            XTx[4][j] = vaddq_f16(vsubq_f16(x[4][j], x[2][j]), vmulq_f16(vsubq_f16(x[1][j], x[3][j]), vdupq_n_f16(2.0f)));
+
+            // XTx[5][j] =  4*x[1][j] + -5*x[3][j] +  1*x[5][j];
+            XTx[5][j] = vsubq_f16(vaddq_f16(x[5][j], vmulq_f16(x[1][j], vdupq_n_f16(4.0f))), vmulq_f16(x[3][j], vdupq_n_f16(5.0f)));
+        }
+
+        // Compute U = XT . x . X
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            // U[i][0] =  4*XTx[i][0] + -5*XTx[i][2] +  1*XTx[i][4];
+            U[i][0] = vsubq_f16(vaddq_f16(XTx[i][4], vmulq_f16(XTx[i][0], vdupq_n_f16(4.0f))), vmulq_f16(XTx[i][2], vdupq_n_f16(5.0f)));
+
+            // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] +  1*XTx[i][3] +  1*XTx[i][4];
+            U[i][1] = vsubq_f16(vaddq_f16(XTx[i][3], XTx[i][4]), vmulq_f16(vaddq_f16(XTx[i][1], XTx[i][2]), vdupq_n_f16(4.0f)));
+
+            // U[i][2] =  4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] +  1*XTx[i][4];
+            U[i][2] = vaddq_f16(vsubq_f16(XTx[i][4], XTx[i][3]), vmulq_f16(vsubq_f16(XTx[i][1], XTx[i][2]), vdupq_n_f16(4.0f)));
+
+            // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] +  2*XTx[i][3] +  1*XTx[i][4];
+            U[i][3] = vaddq_f16(vsubq_f16(XTx[i][4], XTx[i][2]), vmulq_f16(vsubq_f16(XTx[i][3], XTx[i][1]), vdupq_n_f16(2.0f)));
+
+            // U[i][4] =  2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] +  1*XTx[i][4];
+            U[i][4] = vaddq_f16(vsubq_f16(XTx[i][4], XTx[i][2]), vmulq_f16(vsubq_f16(XTx[i][1], XTx[i][3]), vdupq_n_f16(2.0f)));
+
+            // U[i][5] =  4*XTx[i][1] + -5*XTx[i][3] +  1*XTx[i][5];
+            U[i][5] = vsubq_f16(vaddq_f16(XTx[i][5], vmulq_f16(XTx[i][1], vdupq_n_f16(4.0f))), vmulq_f16(XTx[i][3], vdupq_n_f16(5.0f)));
+        }
+
+        // Store the transformed matrix
+        for (int i = 0, m = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++, m++)
+            {
+                vst1q_f16(outptr + m*matrix_stride, U[i][j]);
+            }
+        }
+        outptr += 8;
+    }
+    for (; channels_remaining >= 4; channels_remaining -= 4)
+    {
+        // Matrices used/computed in this kernel
+        float16x4_t x[inner_tile_rows][inner_tile_cols];
+        float16x4_t XTx[inner_tile_rows][inner_tile_cols];
+        float16x4_t U[inner_tile_rows][inner_tile_cols];
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++)
+            {
+                x[i][j] = vdup_n_f16(0.0f);
+                XTx[i][j] = vdup_n_f16(0.0f);
+            }
+        }
+
+        // Read a 6x6 tile in the Winograd domain
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++)
+            {
+                x[i][j] = vld1_f16(x_ptrs[i][j]);
+                x_ptrs[i][j] += 4;
+            }
+        }
+
+        // Compute XT . x
+        for (int j = 0; j < inner_tile_cols; j++)
+        {
+            // XTx[0][j] =  4*x[0][j] + -5*x[2][j] +  1*x[4][j];
+            XTx[0][j] = vsub_f16(vadd_f16(x[4][j], vmul_f16(x[0][j], vdup_n_f16(4.0f))), vmul_f16(x[2][j], vdup_n_f16(5.0f)));
+
+            // XTx[1][j] = -4*x[1][j] + -4*x[2][j] +  1*x[3][j] +  1*x[4][j];
+            XTx[1][j] = vsub_f16(vadd_f16(x[3][j], x[4][j]), vmul_f16(vadd_f16(x[1][j], x[2][j]),  vdup_n_f16(4.0f)));
+
+            // XTx[2][j] =  4*x[1][j] + -4*x[2][j] + -1*x[3][j] +  1*x[4][j];
+            XTx[2][j] = vadd_f16(vsub_f16(x[4][j], x[3][j]), vmul_f16(vsub_f16(x[1][j], x[2][j]), vdup_n_f16(4.0f)));
+
+            // XTx[3][j] = -2*x[1][j] + -1*x[2][j] +  2*x[3][j] +  1*x[4][j];
+            XTx[3][j] = vadd_f16(vsub_f16(x[4][j], x[2][j]), vmul_f16(vsub_f16(x[3][j], x[1][j]), vdup_n_f16(2.0f)));
+
+            // XTx[4][j] =  2*x[1][j] + -1*x[2][j] + -2*x[3][j] +  1*x[4][j];
+            XTx[4][j] = vadd_f16(vsub_f16(x[4][j], x[2][j]), vmul_f16(vsub_f16(x[1][j], x[3][j]), vdup_n_f16(2.0f)));
+
+            // XTx[5][j] =  4*x[1][j] + -5*x[3][j] +  1*x[5][j];
+            XTx[5][j] = vsub_f16(vadd_f16(x[5][j], vmul_f16(x[1][j], vdup_n_f16(4.0f))), vmul_f16(x[3][j], vdup_n_f16(5.0f)));
+        }
+
+        // Compute U = XT . x . X
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            // U[i][0] =  4*XTx[i][0] + -5*XTx[i][2] +  1*XTx[i][4];
+            U[i][0] = vsub_f16(vadd_f16(XTx[i][4], vmul_f16(XTx[i][0], vdup_n_f16(4.0f))), vmul_f16(XTx[i][2], vdup_n_f16(5.0f)));
+
+            // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] +  1*XTx[i][3] +  1*XTx[i][4];
+            U[i][1] = vsub_f16(vadd_f16(XTx[i][3], XTx[i][4]), vmul_f16(vadd_f16(XTx[i][1], XTx[i][2]), vdup_n_f16(4.0f)));
+
+            // U[i][2] =  4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] +  1*XTx[i][4];
+            U[i][2] = vadd_f16(vsub_f16(XTx[i][4], XTx[i][3]), vmul_f16(vsub_f16(XTx[i][1], XTx[i][2]), vdup_n_f16(4.0f)));
+
+            // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] +  2*XTx[i][3] +  1*XTx[i][4];
+            U[i][3] = vadd_f16(vsub_f16(XTx[i][4], XTx[i][2]), vmul_f16(vsub_f16(XTx[i][3], XTx[i][1]), vdup_n_f16(2.0f)));
+
+            // U[i][4] =  2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] +  1*XTx[i][4];
+            U[i][4] = vadd_f16(vsub_f16(XTx[i][4], XTx[i][2]), vmul_f16(vsub_f16(XTx[i][1], XTx[i][3]), vdup_n_f16(2.0f)));
+
+            // U[i][5] =  4*XTx[i][1] + -5*XTx[i][3] +  1*XTx[i][5];
+            U[i][5] = vsub_f16(vadd_f16(XTx[i][5], vmul_f16(XTx[i][1], vdup_n_f16(4.0f))), vmul_f16(XTx[i][3], vdup_n_f16(5.0f)));
+        }
+
+        // Store the transformed matrix
+        for (int i = 0, m = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++, m++)
+            {
+                vst1_f16(outptr + m*matrix_stride, U[i][j]);
+            }
+        }
+        outptr += 4;
+    }
+    for (; channels_remaining; channels_remaining--)
+    {
+        // Load x
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++)
+            {
+                x[i][j] = *(x_ptrs[i][j]++);
+            }
+        }
+
+        // Compute XT . x
+        for (int j = 0; j < inner_tile_cols; j++)
+        {
+            XTx[0][j] =  4*x[0][j] + -5*x[2][j] +  1*x[4][j];
+            XTx[1][j] = -4*x[1][j] + -4*x[2][j] +  1*x[3][j] +  1*x[4][j];
+            XTx[2][j] =  4*x[1][j] + -4*x[2][j] + -1*x[3][j] +  1*x[4][j];
+            XTx[3][j] = -2*x[1][j] + -1*x[2][j] +  2*x[3][j] +  1*x[4][j];
+            XTx[4][j] =  2*x[1][j] + -1*x[2][j] + -2*x[3][j] +  1*x[4][j];
+            XTx[5][j] =  4*x[1][j] + -5*x[3][j] +  1*x[5][j];
+        }
+
+        // Compute U = XT . x . X
+        for (int i = 0; i < inner_tile_rows; i++)
+        {
+            U[i][0] =  4*XTx[i][0] + -5*XTx[i][2] +  1*XTx[i][4];
+            U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] +  1*XTx[i][3] +  1*XTx[i][4];
+            U[i][2] =  4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] +  1*XTx[i][4];
+            U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] +  2*XTx[i][3] +  1*XTx[i][4];
+            U[i][4] =  2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] +  1*XTx[i][4];
+            U[i][5] =  4*XTx[i][1] + -5*XTx[i][3] +  1*XTx[i][5];
+        }
+
+        // Store the transformed matrix
+        for (int i = 0, m = 0; i < inner_tile_rows; i++)
+        {
+            for (int j = 0; j < inner_tile_cols; j++, m++)
+            {
+                *(outptr + m*matrix_stride) = U[i][j];
+            }
+        }
+        outptr++;
+    }
+}
+
+template class InputTransform<6, 6, __fp16, __fp16, WinogradRoots::Integers>;
+
+}  // namespace winograd
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
\ No newline at end of file
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
index fe47ccb..ed88098 100644
--- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp
@@ -48,15 +48,9 @@
       _n_channels(n_channels),
       _output_min((activation.type == arm_gemm::Activation::Type::ReLU ||
                    activation.type == arm_gemm::Activation::Type::BoundedReLU)
-                      ? static_cast<TOut>(0.0f)
-                      : (std::numeric_limits<TOut>::has_infinity)
-                            ? -std::numeric_limits<TOut>::infinity()
-                            : std::numeric_limits<TOut>::lowest()),
+                      ? static_cast<TOut>(0.0f) : TypeBounds<TOut>::lower()),
       _output_max((activation.type == arm_gemm::Activation::Type::BoundedReLU)
-                      ? static_cast<TOut>(activation.param1)
-                      : (std::numeric_limits<TOut>::has_infinity)
-                            ? std::numeric_limits<TOut>::infinity()
-                            : std::numeric_limits<TOut>::max()),
+                      ? static_cast<TOut>(activation.param1) : TypeBounds<TOut>::upper()),
       _matrix_base(nullptr), _biases(nullptr), _matrix_stride(0),
       _matrix_row_stride(0), _matrix_batch_stride(0), _outptr(nullptr),
       _tiles_M(iceildiv(n_rows, output_tile_rows)),
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp
new file mode 100644
index 0000000..37b890d
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp
@@ -0,0 +1,255 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#include "arm.hpp"
+#include "output.hpp"
+
+namespace winograd
+{
+
+template <>
+void winograd::OutputTransform<3, 3, 6, 6, __fp16, __fp16, winograd::WinogradRoots::Integers>::transform_tile(
+    const int n_channels,
+    const __fp16* inptr,
+    const int matrix_stride,
+    const __fp16* bptr,
+    __fp16* const output,
+    const int output_row_stride,
+    const int output_col_stride,
+    const __fp16 output_min,
+    const __fp16 output_max
+)
+{
+    // Construct a map to the output cells
+    __fp16 *outptrs[output_tile_rows][output_tile_cols];
+    for (int i = 0; i < output_tile_rows; i++)
+    {
+        for (int j = 0; j < output_tile_cols; j++)
+        {
+            outptrs[i][j] = output + i*output_row_stride + j*output_col_stride;
+        }
+    }
+
+    // For each channel of the output
+    int channels_remaining = n_channels;
+
+#ifdef __aarch64__
+    for (; channels_remaining >= 8; channels_remaining -= 8)
+  {
+    // Matrices used and computed during this transform
+    float16x8_t F[6][6], FZ[6][4], f[4][4], b;
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        F[i][j] = vld1q_f16(inptr + m*matrix_stride);
+      }
+    }
+    inptr += 8;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 6; i++)
+    {
+      // FZ[i][0] =  1*F[i][0] +  1*F[i][1] +  1*F[i][2] +  1*F[i][3] +  1*F[i][4];
+      FZ[i][0] = vaddq_f16(vaddq_f16(vaddq_f16(F[i][0], F[i][1]), vaddq_f16(F[i][2], F[i][3])), F[i][4]);
+
+      // FZ[i][1] =  1*F[i][1] + -1*F[i][2] +  2*F[i][3] + -2*F[i][4];
+      FZ[i][1] = vaddq_f16(vsubq_f16(F[i][1], F[i][2]), vmulq_f16(vsubq_f16(F[i][3], F[i][4]), vdupq_n_f16(2.0f)));
+
+      // FZ[i][2] =  1*F[i][1] +  1*F[i][2] +  4*F[i][3] +  4*F[i][4];
+      FZ[i][2] = vaddq_f16(vaddq_f16(F[i][1], F[i][2]), vmulq_f16(vaddq_f16(F[i][3], F[i][4]), vdupq_n_f16(4.0f)));
+
+      // FZ[i][3] =  1*F[i][1] + -1*F[i][2] +  8*F[i][3] + -8*F[i][4] +  1*F[i][5];
+      FZ[i][3] = vaddq_f16(vaddq_f16(vsubq_f16(F[i][1], F[i][2]), vmulq_f16(vsubq_f16(F[i][3], F[i][4]), vdupq_n_f16(8.0f))), F[i][5]);
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 4; j++)
+    {
+      // f[0][j] =  1*FZ[0][j] +  1*FZ[1][j] +  1*FZ[2][j] +  1*FZ[3][j] +  1*FZ[4][j];
+      f[0][j] = vaddq_f16(vaddq_f16(vaddq_f16(FZ[0][j], FZ[1][j]), vaddq_f16(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+      // f[1][j] =  1*FZ[1][j] + -1*FZ[2][j] +  2*FZ[3][j] + -2*FZ[4][j];
+      f[1][j] = vaddq_f16(vsubq_f16(FZ[1][j], FZ[2][j]), vmulq_f16(vsubq_f16(FZ[3][j], FZ[4][j]), vdupq_n_f16(2.0f)));
+
+      // f[2][j] =  1*FZ[1][j] +  1*FZ[2][j] +  4*FZ[3][j] +  4*FZ[4][j];
+      f[2][j] = vaddq_f16(vaddq_f16(FZ[1][j], FZ[2][j]), vmulq_f16(vaddq_f16(FZ[3][j], FZ[4][j]), vdupq_n_f16(4.0f)));
+
+      // f[3][j] =  1*FZ[1][j] + -1*FZ[2][j] +  8*FZ[3][j] + -8*FZ[4][j] +  1*FZ[5][j];
+      f[3][j] = vaddq_f16(vaddq_f16(vsubq_f16(FZ[1][j], FZ[2][j]), vmulq_f16(vsubq_f16(FZ[3][j], FZ[4][j]), vdupq_n_f16(8.0f))), FZ[5][j]);
+    }
+
+    // Write out the output tile
+    if (bptr != nullptr)
+    {
+      b = vld1q_f16(bptr);
+      bptr += 8;
+    }
+    else
+    {
+      b = vdupq_n_f16(0.0f);
+    }
+    for (int i = 0; i < output_tile_rows; i++)
+    {
+      for (int j = 0; j < output_tile_cols; j++)
+      {
+        const auto y =
+            vmaxq_f16(vminq_f16(vaddq_f16(f[i][j], b), vdupq_n_f16(output_max)),
+                     vdupq_n_f16(output_min));
+        vst1q_f16(outptrs[i][j], y);
+        outptrs[i][j] += 8;
+      }
+    }
+  }
+#endif  // __aarch64__
+#ifdef __arm_any__
+    for (; channels_remaining >= 4; channels_remaining -= 4)
+  {
+    // Matrices used and computed during this transform
+    float16x4_t F[6][6], FZ[6][4], f[4][4], b;
+
+    // Read a 6x6 tile in the Winograd domain
+    for (int i = 0, m = 0; i < 6; i++)
+    {
+      for (int j = 0; j < 6; j++, m++)
+      {
+        F[i][j] = vld1_f16(inptr + m*matrix_stride);
+      }
+    }
+    inptr += 4;
+
+    // Compute the matrix F Z
+    for (int i = 0; i < 6; i++)
+    {
+      // FZ[i][0] =  1*F[i][0] +  1*F[i][1] +  1*F[i][2] +  1*F[i][3] +  1*F[i][4];
+      FZ[i][0] = vadd_f16(vadd_f16(vadd_f16(F[i][0], F[i][1]), vadd_f16(F[i][2], F[i][3])), F[i][4]);
+
+      // FZ[i][1] =  1*F[i][1] + -1*F[i][2] +  2*F[i][3] + -2*F[i][4];
+      FZ[i][1] = vadd_f16(vsub_f16(F[i][1], F[i][2]), vmul_f16(vsub_f16(F[i][3], F[i][4]), vdup_n_f16(2.0f)));
+
+      // FZ[i][2] =  1*F[i][1] +  1*F[i][2] +  4*F[i][3] +  4*F[i][4];
+      FZ[i][2] = vadd_f16(vadd_f16(F[i][1], F[i][2]), vmul_f16(vadd_f16(F[i][3], F[i][4]), vdup_n_f16(4.0f)));
+
+      // FZ[i][3] =  1*F[i][1] + -1*F[i][2] +  8*F[i][3] + -8*F[i][4] +  1*F[i][5];
+      FZ[i][3] = vadd_f16(vadd_f16(vsub_f16(F[i][1], F[i][2]), vmul_f16(vsub_f16(F[i][3], F[i][4]), vdup_n_f16(8.0f))), F[i][5]);
+    }
+
+    // Compute the output tile f = ZT F Z
+    for (int j = 0; j < 4; j++)
+    {
+      // f[0][j] =  1*FZ[0][j] +  1*FZ[1][j] +  1*FZ[2][j] +  1*FZ[3][j] +  1*FZ[4][j];
+      f[0][j] = vadd_f16(vadd_f16(vadd_f16(FZ[0][j], FZ[1][j]), vadd_f16(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+      // f[1][j] =  1*FZ[1][j] + -1*FZ[2][j] +  2*FZ[3][j] + -2*FZ[4][j];
+      f[1][j] = vadd_f16(vsub_f16(FZ[1][j], FZ[2][j]), vmul_f16(vsub_f16(FZ[3][j], FZ[4][j]), vdup_n_f16(2.0f)));
+
+      // f[2][j] =  1*FZ[1][j] +  1*FZ[2][j] +  4*FZ[3][j] +  4*FZ[4][j];
+      f[2][j] = vadd_f16(vadd_f16(FZ[1][j], FZ[2][j]), vmul_f16(vadd_f16(FZ[3][j], FZ[4][j]), vdup_n_f16(4.0f)));
+
+      // f[3][j] =  1*FZ[1][j] + -1*FZ[2][j] +  8*FZ[3][j] + -8*FZ[4][j] +  1*FZ[5][j];
+      f[3][j] = vadd_f16(vadd_f16(vsub_f16(FZ[1][j], FZ[2][j]), vmul_f16(vsub_f16(FZ[3][j], FZ[4][j]), vdup_n_f16(8.0f))), FZ[5][j]);
+    }
+
+    // Write out the output tile
+    if (bptr != nullptr)
+    {
+      b = vld1_f16(bptr);
+      bptr += 4;
+    }
+    else
+    {
+      b = vdup_n_f16(0.0f);
+    }
+    for (int i = 0; i < output_tile_rows; i++)
+    {
+      for (int j = 0; j < output_tile_cols; j++)
+      {
+        const auto y =
+            vmax_f16(vmin_f16(vadd_f16(f[i][j], b), vdup_n_f16(output_max)),
+                     vdup_n_f16(output_min));
+        vst1_f16(outptrs[i][j], y);
+        outptrs[i][j] += 4;
+      }
+    }
+  }
+#endif  // __arm_any__
+    for (; channels_remaining; channels_remaining--)
+    {
+        // Matrices used and computed during this transform
+        __fp16 F[6][6], FZ[6][4], f[4][4], b;
+
+        // Read a 6x6 tile in the Winograd domain
+        for (int i = 0, m = 0; i < 6; i++)
+        {
+            for (int j = 0; j < 6; j++, m++)
+            {
+                F[i][j] = *(inptr + m*matrix_stride);
+            }
+        }
+        inptr++;
+
+        // Compute the matrix F Z
+        for (int i = 0; i < 6; i++)
+        {
+            FZ[i][0] =  1*F[i][0] +  1*F[i][1] +  1*F[i][2] +  1*F[i][3] +  1*F[i][4];
+            FZ[i][1] =  1*F[i][1] + -1*F[i][2] +  2*F[i][3] + -2*F[i][4];
+            FZ[i][2] =  1*F[i][1] +  1*F[i][2] +  4*F[i][3] +  4*F[i][4];
+            FZ[i][3] =  1*F[i][1] + -1*F[i][2] +  8*F[i][3] + -8*F[i][4] +  1*F[i][5];
+        }
+
+        // Compute the output tile f = ZT F Z
+        for (int j = 0; j < 4; j++)
+        {
+            f[0][j] =  1*FZ[0][j] +  1*FZ[1][j] +  1*FZ[2][j] +  1*FZ[3][j] +  1*FZ[4][j];
+            f[1][j] =  1*FZ[1][j] + -1*FZ[2][j] +  2*FZ[3][j] + -2*FZ[4][j];
+            f[2][j] =  1*FZ[1][j] +  1*FZ[2][j] +  4*FZ[3][j] +  4*FZ[4][j];
+            f[3][j] =  1*FZ[1][j] + -1*FZ[2][j] +  8*FZ[3][j] + -8*FZ[4][j] +  1*FZ[5][j];
+        }
+
+        // Write out the output tile
+        if (bptr != nullptr)
+        {
+            b = *(bptr++);
+        }
+        else
+        {
+            b = 0.0f;
+        }
+        for (int i = 0; i < output_tile_rows; i++)
+        {
+            for (int j = 0; j < output_tile_cols; j++)
+            {
+                const auto y = std::max(std::min<__fp16>(f[i][j] + b, output_max), output_min);
+                *(outptrs[i][j]++) = y;
+            }
+        }
+    }
+}
+
+template class OutputTransform<3, 3, 6, 6, __fp16, __fp16, winograd::WinogradRoots::Integers>;
+
+}  // namespace winograd
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp
new file mode 100644
index 0000000..3c4f8b4
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp
@@ -0,0 +1,259 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#include "arm.hpp"
+#include "kernel.hpp"
+
+namespace winograd
+{
+
+template <>
+void WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>::execute(
+    const int n_output_channels,
+    const int n_input_channels,
+    const __fp16* const input,  // NOTE: Data in HWIO order
+    __fp16* const output,
+    const int matrix_stride,
+    const int matrix_row_stride
+)
+{
+    // Get pointers to each cell of the weight tensor
+    const auto weight_col_stride = n_input_channels * n_output_channels;
+    const auto weight_row_stride = 3 * weight_col_stride;
+    const __fp16 *inptrs[3][3];
+    for (int i = 0; i < 3; i++)
+    {
+        for (int j = 0; j < 3; j++)
+        {
+            inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride;
+        }
+    }
+
+    // For each input channel
+    for (int ic = 0; ic < n_input_channels; ic++)
+    {
+        __fp16 *outptr = output + ic * matrix_row_stride;
+
+        // For each output channel
+        int channels_remaining = n_output_channels;
+#ifdef __aarch64__
+    for (; channels_remaining >= 8; channels_remaining -= 8)
+    {
+      // Matrices used and computed in this kernel
+      float16x8_t w[3][3], Ww[6][3], V[6][6];
+
+      // Read weights
+      for (int i = 0; i < 3; i++)
+      {
+        for (int j = 0; j < 3; j++)
+        {
+          w[i][j] = vld1q_f16(inptrs[i][j]);
+          inptrs[i][j] += 8;
+        }
+      }
+
+      // Compute the matrix W w
+      for (int j = 0; j < 3; j++)
+      {
+        // Ww[0][j] =  6*w[0][j];
+        Ww[0][j] = vmulq_n_f16(w[0][j], 6.0);
+
+        // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
+        Ww[1][j] = vmulq_n_f16(vaddq_f16(vaddq_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
+
+        // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
+        Ww[2][j] = vmulq_n_f16(vsubq_f16(vsubq_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
+
+        // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
+        Ww[3][j] = vaddq_f16(vaddq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
+
+        // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
+        Ww[4][j] = vaddq_f16(vsubq_f16(w[0][j], vmulq_f16(w[1][j], vdupq_n_f16(2.0f))), vmulq_f16(w[2][j], vdupq_n_f16(4.0f)));
+
+        // Ww[5][j] = 24*w[2][j];
+        Ww[5][j] = vmulq_n_f16(w[2][j], 24.0f);
+      }
+
+      // Compute V = W w WT
+      for (int i = 0; i < 6; i++)
+      {
+        const float recip576 = 1.0f / 576.0f;
+
+        // V[i][0] =  6*Ww[i][0];
+        V[i][0] = vmulq_n_f16(vmulq_n_f16(Ww[i][0], 6.0), recip576);
+
+        // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
+        V[i][1] = vmulq_n_f16(vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
+
+        // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
+        V[i][2] = vmulq_n_f16(vmulq_n_f16(vsubq_f16(vsubq_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
+
+        // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
+        V[i][3] = vmulq_n_f16(vaddq_f16(vaddq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
+
+        // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
+        V[i][4] = vmulq_n_f16(vaddq_f16(vsubq_f16(Ww[i][0], vmulq_f16(Ww[i][1], vdupq_n_f16(2.0f))), vmulq_f16(Ww[i][2], vdupq_n_f16(4.0f))), recip576);
+
+        // V[i][5] = 24*Ww[i][2];
+        V[i][5] = vmulq_n_f16(vmulq_n_f16(Ww[i][2], 24.0f), recip576);
+      }
+
+      // Store the transformed weights
+      for (int i = 0, m = 0; i < 6; i++)
+      {
+        for (int j = 0; j < 6; j++, m++)
+        {
+          vst1q_f16(outptr + m*matrix_stride, V[i][j]);
+        }
+      }
+      outptr += 8;
+    }
+#endif  // __aarch64__
+#ifdef __arm_any__
+        for (; channels_remaining >= 4; channels_remaining -= 4)
+    {
+      // Matrices used and computed in this kernel
+      float16x4_t w[3][3], Ww[6][3], V[6][6];
+
+      // Read weights
+      for (int i = 0; i < 3; i++)
+      {
+        for (int j = 0; j < 3; j++)
+        {
+          w[i][j] = vld1_f16(inptrs[i][j]);
+          inptrs[i][j] += 4;
+        }
+      }
+
+      // Compute the matrix W w
+      for (int j = 0; j < 3; j++)
+      {
+        // Ww[0][j] =  6*w[0][j];
+        Ww[0][j] = vmul_n_f16(w[0][j], 6.0);
+
+        // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
+        Ww[1][j] = vmul_n_f16(vadd_f16(vadd_f16(w[0][j], w[1][j]), w[2][j]), -4.0);
+
+        // Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
+        Ww[2][j] = vmul_n_f16(vsub_f16(vsub_f16(w[1][j], w[0][j]), w[2][j]), 4.0);
+
+        // Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
+        Ww[3][j] = vadd_f16(vadd_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
+
+        // Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
+        Ww[4][j] = vadd_f16(vsub_f16(w[0][j], vmul_f16(w[1][j], vdup_n_f16(2.0f))), vmul_f16(w[2][j], vdup_n_f16(4.0f)));
+
+        // Ww[5][j] = 24*w[2][j];
+        Ww[5][j] = vmul_n_f16(w[2][j], 24.0f);
+      }
+
+      // Compute V = W w WT
+      for (int i = 0; i < 6; i++)
+      {
+        const float recip576 = 1.0f / 576.0f;
+
+        // V[i][0] =  6*Ww[i][0];
+        V[i][0] = vmul_n_f16(vmul_n_f16(Ww[i][0], 6.0), recip576);
+
+        // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2];
+        V[i][1] = vmul_n_f16(vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576);
+
+        // V[i][2] = -4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2];
+        V[i][2] = vmul_n_f16(vmul_n_f16(vsub_f16(vsub_f16(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576);
+
+        // V[i][3] =  1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2];
+        V[i][3] = vmul_n_f16(vadd_f16(vadd_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
+
+        // V[i][4] =  1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2];
+        V[i][4] = vmul_n_f16(vadd_f16(vsub_f16(Ww[i][0], vmul_f16(Ww[i][1], vdup_n_f16(2.0f))), vmul_f16(Ww[i][2], vdup_n_f16(4.0f))), recip576);
+
+        // V[i][5] = 24*Ww[i][2];
+        V[i][5] = vmul_n_f16(vmul_n_f16(Ww[i][2], 24.0f), recip576);
+      }
+
+      // Store the transformed weights
+      for (int i = 0, m = 0; i < 6; i++)
+      {
+        for (int j = 0; j < 6; j++, m++)
+        {
+          vst1_f16(outptr + m*matrix_stride, V[i][j]);
+        }
+      }
+      outptr += 4;
+    }
+#endif  // __arm_any__
+        for (; channels_remaining; channels_remaining--)
+        {
+            // Matrices used and computed in this kernel
+            __fp16 w[3][3], Ww[6][3], V[6][6];
+
+            // Read weights
+            for (int i = 0; i < 3; i++)
+            {
+                for (int j = 0; j < 3; j++)
+                {
+                    w[i][j] = *(inptrs[i][j]++);
+                }
+            }
+
+            // Compute the matrix W w
+            for (int j = 0; j < 3; j++)
+            {
+                Ww[0][j] =  6*w[0][j];
+                Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j];
+                Ww[2][j] = -4*w[0][j] +  4*w[1][j] + -4*w[2][j];
+                Ww[3][j] =  1*w[0][j] +  2*w[1][j] +  4*w[2][j];
+                Ww[4][j] =  1*w[0][j] + -2*w[1][j] +  4*w[2][j];
+                Ww[5][j] = 24*w[2][j];
+            }
+
+            // Compute V = W w WT
+            for (int i = 0; i < 6; i++)
+            {
+                V[i][0] = ( 6*Ww[i][0]) / 576.0;
+                V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
+                V[i][2] = (-4*Ww[i][0] +  4*Ww[i][1] + -4*Ww[i][2]) / 576.0;
+                V[i][3] = ( 1*Ww[i][0] +  2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
+                V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] +  4*Ww[i][2]) / 576.0;
+                V[i][5] = (24*Ww[i][2]) / 576.0;
+            }
+
+            // Store the transformed weights
+            for (int i = 0, m = 0; i < 6; i++)
+            {
+                for (int j = 0; j < 6; j++, m++)
+                {
+                    *(outptr + m*matrix_stride) = V[i][j];
+                }
+            }
+            outptr++;
+        }
+    }
+}
+
+template class WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>;
+
+}  // namespace
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
index 81190fb..d567a18 100644
--- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
@@ -23,11 +23,11 @@
  */
 #include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h"
 
+#include "arm_compute/core/CPP/Validate.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
-#include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
@@ -43,18 +43,32 @@
 inline Status validate_kernel_3x3(const Size2D input_dims, const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
-    if(input_dims.width > 4 && input_dims.height > 4)
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+
+    if(input->data_type() == DataType::F32)
     {
-        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 4, 4, 3, 3>::validate(input, input0, winograd_info)));
-        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 4, 4, 3, 3>::validate(weights, input1, winograd_info)));
-        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 4, 4, 3, 3>::validate(batched_mm_output, biases, output, winograd_info)));
+        if(input_dims.width > 4 && input_dims.height > 4)
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 4, 4, 3, 3>::validate(input, input0, winograd_info)));
+            ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 4, 4, 3, 3>::validate(weights, input1, winograd_info)));
+            ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 4, 4, 3, 3>::validate(batched_mm_output, biases, output, winograd_info)));
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 2, 2, 3, 3>::validate(input, input0, winograd_info)));
+            ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 2, 2, 3, 3>::validate(weights, input1, winograd_info)));
+            ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 2, 2, 3, 3>::validate(batched_mm_output, biases, output, winograd_info)));
+        }
     }
-    else
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+    else if(input->data_type() == DataType::F32)
     {
-        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 2, 2, 3, 3>::validate(input, input0, winograd_info)));
-        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 2, 2, 3, 3>::validate(weights, input1, winograd_info)));
-        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 2, 2, 3, 3>::validate(batched_mm_output, biases, output, winograd_info)));
+        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<__fp16, 4, 4, 3, 3>::validate(input, input0, winograd_info)));
+        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<__fp16, 4, 4, 3, 3>::validate(weights, input1, winograd_info)));
+        ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<__fp16, 4, 4, 3, 3>::validate(batched_mm_output, biases, output, winograd_info)));
     }
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
     if(act_info.enabled())
     {
@@ -79,6 +93,7 @@
 inline Status validate_kernel_3x1(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 1, 6, 1, 3>::validate(input, input0, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 1, 6, 1, 3>::validate(weights, input1, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 1, 6, 1, 3>::validate(batched_mm_output, biases, output, winograd_info)));
@@ -92,6 +107,7 @@
 inline Status validate_kernel_1x3(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 6, 1, 3, 1>::validate(input, input0, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 6, 1, 3, 1>::validate(weights, input1, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 6, 1, 3, 1>::validate(batched_mm_output, biases, output, winograd_info)));
@@ -106,6 +122,7 @@
 inline Status validate_kernel_5x1(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 1, 4, 1, 5>::validate(input, input0, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 1, 4, 1, 5>::validate(weights, input1, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 1, 4, 1, 5>::validate(batched_mm_output, biases, output, winograd_info)));
@@ -118,6 +135,7 @@
 inline Status validate_kernel_1x5(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 4, 1, 5, 1>::validate(input, input0, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 4, 1, 5, 1>::validate(weights, input1, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 4, 1, 5, 1>::validate(batched_mm_output, biases, output, winograd_info)));
@@ -131,6 +149,7 @@
 inline Status validate_kernel_7x1(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 1, 2, 1, 7>::validate(input, input0, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 1, 2, 1, 7>::validate(weights, input1, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 1, 2, 1, 7>::validate(batched_mm_output, biases, output, winograd_info)));
@@ -144,6 +163,7 @@
 inline Status validate_kernel_1x7(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output,
                                   const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<float, 2, 1, 7, 1>::validate(input, input0, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<float, 2, 1, 7, 1>::validate(weights, input1, winograd_info)));
     ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<float, 2, 1, 7, 1>::validate(batched_mm_output, biases, output, winograd_info)));
@@ -169,21 +189,27 @@
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_UNUSED(output);
+    ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd layer only supports unit strides.");
     if(biases != nullptr)
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
         ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
     }
-    return INEWinogradLayerTransformWeightsKernel<float>::validate(input, weights);
+    return INEWinogradLayerTransformWeightsKernel::validate(input, weights);
 }
 
-Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims)
+Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims, DataType data_type)
 {
     Size2D output_tile = Size2D{};
     if(kernel_dims == Size2D(3U, 3U))
     {
         output_tile = (input_dims.width <= 4 || input_dims.height <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U);
+        if(data_type == DataType::F16)
+        {
+            output_tile = Size2D(4U, 4U);
+        }
     }
     else if(kernel_dims == Size2D(5U, 5U))
     {
@@ -216,12 +242,17 @@
     return output_tile;
 }
 
-bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_size)
+bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_size, DataType data_type)
 {
     // Check if we want to configure a Winograd configuration which requires fast math
     using WinogradConfiguration = std::pair<std::pair<int, int>, std::pair<int, int>>;
 
-    const std::vector<WinogradConfiguration> fast_math_winograd =
+    const std::vector<WinogradConfiguration> fast_math_winograd_f16 =
+    {
+        WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(3, 3))
+    };
+
+    const std::vector<WinogradConfiguration> fast_math_winograd_f32 =
     {
         WinogradConfiguration(std::pair<int, int>(2, 2), std::pair<int, int>(5, 5)),
         WinogradConfiguration(std::pair<int, int>(4, 4), std::pair<int, int>(5, 5))
@@ -230,7 +261,15 @@
     auto p = std::make_pair(std::pair<int, int>(output_tile.width, output_tile.height),
                             std::pair<int, int>(kernel_size.width, kernel_size.height));
 
-    return std::find(fast_math_winograd.begin(), fast_math_winograd.end(), p) != fast_math_winograd.end();
+    switch(data_type)
+    {
+        case DataType::F16:
+            return std::find(fast_math_winograd_f16.begin(), fast_math_winograd_f16.end(), p) != fast_math_winograd_f16.end();
+        case DataType::F32:
+            return std::find(fast_math_winograd_f32.begin(), fast_math_winograd_f32.end(), p) != fast_math_winograd_f32.end();
+        default:
+            return false;
+    }
 }
 
 inline bool fuse_function_supported(const ActivationLayerInfo &act_info)
@@ -256,7 +295,6 @@
         }
     }
 }
-
 } //namespace
 
 NEWinogradConvolutionLayer::NEWinogradConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
@@ -278,14 +316,16 @@
     const unsigned int height_idx  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
     const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
 
-    const Size2D input_dims  = Size2D(input->info()->dimension(width_idx), input->info()->dimension(height_idx));
-    const Size2D kernel_size = Size2D(weights->info()->dimension(width_idx), weights->info()->dimension(height_idx));
-    const Size2D output_tile = winograd_output_tile(input_dims, kernel_size);
+    const Size2D   input_dims  = Size2D(input->info()->dimension(width_idx), input->info()->dimension(height_idx));
+    const Size2D   kernel_size = Size2D(weights->info()->dimension(width_idx), weights->info()->dimension(height_idx));
+    const DataType data_type   = input->info()->data_type();
+    const Size2D   output_tile = winograd_output_tile(input_dims, kernel_size, data_type);
 
     // Check if the Winograd configuration requires fast math
     if(!enable_fast_math)
     {
-        ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true");
+        ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size, data_type),
+                                 "This Winograd configuration requires enable_fast_math=true");
     }
 
     _weights     = weights;
@@ -293,18 +333,93 @@
     _output      = output;
     _is_prepared = false;
 
-    std::unique_ptr<INEWinogradLayerTransformInputKernel<float>>   transform_input_kernel;
-    std::unique_ptr<INEWinogradLayerTransformWeightsKernel<float>> transform_weights_kernel;
-    std::unique_ptr<INEWinogradLayerTransformOutputKernel<float>>  transform_output_kernel;
-
     int n_gemms = 0;
     int N_BLOCK = 0; // Size of block used by GEMM.
 
-    if(kernel_size == Size2D(3, 3))
+    std::unique_ptr<INEWinogradLayerTransformInputKernel>   transform_input_kernel;
+    std::unique_ptr<INEWinogradLayerTransformWeightsKernel> transform_weights_kernel;
+    std::unique_ptr<INEWinogradLayerTransformOutputKernel>  transform_output_kernel;
+
+    if(data_type == DataType::F32)
     {
-        if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
+        if(kernel_size == Size2D(3, 3))
         {
-            using config             = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
+            if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
+            {
+                using config             = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
+                transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+                transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+                transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+                n_gemms                  = config::WinogradBase::N_GEMMS;
+                N_BLOCK                  = config::WinogradConv::N_BLOCK;
+            }
+            else
+            {
+                using config             = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
+                transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+                transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+                transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+                n_gemms                  = config::WinogradBase::N_GEMMS;
+                N_BLOCK                  = config::WinogradConv::N_BLOCK;
+            }
+        }
+        else if(kernel_size == Size2D(5, 5))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
+            transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+            transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+            transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+            n_gemms                  = config::WinogradBase::N_GEMMS;
+            N_BLOCK                  = config::WinogradConv::N_BLOCK;
+        }
+        else if(kernel_size == Size2D(1, 3))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 6, 1, 3, 1>;
+            transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+            transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+            transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+            n_gemms                  = config::WinogradBase::N_GEMMS;
+            N_BLOCK                  = config::WinogradConv::N_BLOCK;
+        }
+        else if(kernel_size == Size2D(3, 1))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 1, 6, 1, 3>;
+            transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+            transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+            transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+            n_gemms                  = config::WinogradBase::N_GEMMS;
+            N_BLOCK                  = config::WinogradConv::N_BLOCK;
+        }
+        else if(kernel_size == Size2D(1, 5))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 4, 1, 5, 1>;
+            transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+            transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+            transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+            n_gemms                  = config::WinogradBase::N_GEMMS;
+            N_BLOCK                  = config::WinogradConv::N_BLOCK;
+        }
+        else if(kernel_size == Size2D(5, 1))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 1, 4, 1, 5>;
+            transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+            transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+            transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+            n_gemms                  = config::WinogradBase::N_GEMMS;
+            N_BLOCK                  = config::WinogradConv::N_BLOCK;
+        }
+        else if(kernel_size == Size2D(1, 7))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 2, 1, 7, 1>;
+            transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
+            transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+            transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
+            n_gemms                  = config::WinogradBase::N_GEMMS;
+            N_BLOCK                  = config::WinogradConv::N_BLOCK;
+        }
+        else if(kernel_size == Size2D(7, 1))
+        {
+            using config             = NEWinogradLayerConfiguration<float, float, 1, 2, 1, 7>;
             transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
             transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
             transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
@@ -313,81 +428,27 @@
         }
         else
         {
-            using config             = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
+            ARM_COMPUTE_ERROR("Not supported.");
+        }
+    }
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+    else if(data_type == DataType::F16)
+    {
+        if(kernel_size == Size2D(3, 3))
+        {
+            using config             = NEWinogradLayerConfiguration<__fp16, __fp16, 4, 4, 3, 3>;
             transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
             transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
             transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
             n_gemms                  = config::WinogradBase::N_GEMMS;
             N_BLOCK                  = config::WinogradConv::N_BLOCK;
         }
+        else
+        {
+            ARM_COMPUTE_ERROR("Not supported.");
+        }
     }
-    else if(kernel_size == Size2D(5, 5))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else if(kernel_size == Size2D(1, 3))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 6, 1, 3, 1>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else if(kernel_size == Size2D(3, 1))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 1, 6, 1, 3>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else if(kernel_size == Size2D(1, 5))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 4, 1, 5, 1>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else if(kernel_size == Size2D(5, 1))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 1, 4, 1, 5>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else if(kernel_size == Size2D(1, 7))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 2, 1, 7, 1>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else if(kernel_size == Size2D(7, 1))
-    {
-        using config             = NEWinogradLayerConfiguration<float, float, 1, 2, 1, 7>;
-        transform_input_kernel   = support::cpp14::make_unique<config::TransformInputKernel>();
-        transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
-        transform_output_kernel  = support::cpp14::make_unique<config::TransformOutputKernel>();
-        n_gemms                  = config::WinogradBase::N_GEMMS;
-        N_BLOCK                  = config::WinogradConv::N_BLOCK;
-    }
-    else
-    {
-        ARM_COMPUTE_ERROR("Not supported.");
-    }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
     const PaddingType use_padding_type = (conv_info.pad_top() != 0u || conv_info.pad_left() != 0) ? PADDING_SAME : PADDING_VALID;
     const bool        use_same_padding = use_padding_type == PADDING_SAME;
@@ -397,7 +458,6 @@
     const int out_channels = output->info()->dimension(channel_idx);
 
     const Tensor4DShape in_shape(internal_get_input_shape(input));
-    const DataType      data_type      = input->info()->data_type();
     const size_t        data_type_size = input->info()->element_size();
     // Get the memory required to instantiate a new Winograd operator.
     constexpr size_t storage_alignment = 64;
@@ -592,14 +652,16 @@
     const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
 
     // Input shape, kernel size and output tile
-    const Size2D input_dims  = Size2D(input->dimension(idx_width), input->dimension(idx_height));
-    const Size2D kernel_size = Size2D(weights->dimension(idx_width), weights->dimension(idx_height));
-    const Size2D output_tile = winograd_output_tile(input_dims, kernel_size);
+    const Size2D   input_dims  = Size2D(input->dimension(idx_width), input->dimension(idx_height));
+    const Size2D   kernel_size = Size2D(weights->dimension(idx_width), weights->dimension(idx_height));
+    const DataType data_type   = input->data_type();
+    const Size2D   output_tile = winograd_output_tile(input_dims, kernel_size, data_type);
 
     // Check if the Winograd configuration requires fast math
     if(!enable_fast_math)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true");
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size, data_type),
+                                        "This Winograd configuration requires enable_fast_math=true");
     }
 
     const WinogradInfo winograd_info = WinogradInfo(output_tile,
@@ -706,5 +768,4 @@
         _is_prepared = true;
     }
 }
-
 } // namespace arm_compute