COMPMID-1532: Add DepthwiseConvolution3x3 FP16 on NEON

Change-Id: I780970f317b979b3230e2b471ac01df7fda9ee14
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148168
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h b/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h
index 3ffafd8..64f10b4 100644
--- a/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h
@@ -55,7 +55,7 @@
      *
      * @note Supported data layouts: NCHW and NHWC
      *
-     * @param[in]  input            Source tensor. DataType supported: QASYMM8, F32.
+     * @param[in]  input            Source tensor. DataType supported: QASYMM8/F16/F32.
      * @param[in]  weights          Weights tensor. This is a 3D tensor with dimensions [3, 3, IFM] for NCHW or [IFM, 3, 3] if NHWC data layout. Data type supported: Same as @p input.
      * @param[out] output           Destination tensor. Data type supported: Same as @p input.
      * @param[in]  conv_info        Padding and stride information to use for the convolution.
@@ -81,7 +81,7 @@
      *
      * @note Supported data layouts: NCHW and NHWC
      *
-     * @param[in] input            Source tensor. DataType supported: QASYMM8, F32.
+     * @param[in] input            Source tensor. DataType supported: QASYMM8/F16/F32.
      * @param[in] weights          Weights tensor. This is a 3D tensor with dimensions [3, 3, IFM] for NCHW or [IFM, 3, 3] if NHWC data layout. Data type supported: Same as @p input.
      * @param[in] output           Destination tensor. Data type supported: Same as @p input.
      * @param[in] conv_info        Padding and stride information to use for the convolution.
diff --git a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
index b245505..e6dc43a 100644
--- a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
+++ b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
@@ -374,8 +374,9 @@
  *
  * @return The loaded matrix.
  */
-inline float16x8x3_t load_matrix_row(const float16_t *ptr)
+inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
 {
+    ARM_COMPUTE_UNUSED(weights_offset);
     /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
        r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
     const float16x8x3_t r =
@@ -400,11 +401,16 @@
  *
  */
 template <unsigned int stridex>
-float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2);
+float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
+                           const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                           int input_offset = 0);
 
 template <>
-inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2)
+inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
+                                     const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                                     int input_offset)
 {
+    ARM_COMPUTE_UNUSED(input_offset);
     const float16x8x3_t vtop =
     {
         {
@@ -456,8 +462,11 @@
 }
 
 template <>
-inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2)
+inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
+                                     const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                                     int input_offset)
 {
+    ARM_COMPUTE_UNUSED(input_offset);
     float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
     out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
     out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
@@ -470,8 +479,11 @@
 }
 
 template <>
-inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2)
+inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
+                                     const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                                     int input_offset)
 {
+    ARM_COMPUTE_UNUSED(input_offset);
     float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
     out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
     out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
diff --git a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h
index 2f000fe..b7398f6 100644
--- a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h
@@ -55,7 +55,7 @@
     NEDepthwiseConvolutionLayer3x3();
     /** Initialize the function's source, destination, kernels and border_size.
      *
-     * @param[in, out] input            Source tensor. Data type supported: QASYMM8/F32. (Written to only for border filling).
+     * @param[in, out] input            Source tensor. Data type supported: QASYMM8/F16/F32. (Written to only for border filling).
      * @param[in]      weights          Weights tensor. These are 3D tensors with shape [3, 3, IFM]. Data type supported: Same as @p input.
      * @param[in]      biases           (Optional) Biases tensor. A 1D tensor with shape [IFM]. Must be nullptr if not needed.
      *                                  Data type supported: Same as @p input.
@@ -67,7 +67,7 @@
 
     /** Static function to check if given info will lead to a valid configuration of @ref NEDepthwiseConvolutionLayer3x3
      *
-     * @param[in] input            Source tensor. Data type supported: QASYMM8/F32. (Written to only for border filling).
+     * @param[in] input            Source tensor. Data type supported: QASYMM8/F16/F32. (Written to only for border filling).
      * @param[in] weights          Weights tensor. These are 3D tensors with shape [3, 3, IFM]. Data type supported: Same as @p input.
      * @param[in] biases           (Optional) Biases tensor. A 1D tensor with shape [IFM]. Must be nullptr if not needed.
      *                             Data type supported: Same as @p input.
diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp
index 1aee241..35ab224 100644
--- a/examples/graph_mobilenet.cpp
+++ b/examples/graph_mobilenet.cpp
@@ -67,9 +67,6 @@
             return false;
         }
 
-        // Checks
-        ARM_COMPUTE_EXIT_ON_MSG(common_params.data_type == DataType::F16 && common_params.target == Target::NEON, "F16 NEON not supported for this graph");
-
         // Print parameter values
         std::cout << common_params << std::endl;
 
diff --git a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp
index 88758b5..7029b06 100644
--- a/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.cpp
@@ -146,7 +146,7 @@
 
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, bool is_optimized)
 {
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
 
     const DataLayout   data_layout = input->data_layout();
@@ -165,8 +165,14 @@
         const TensorShape output_shape = compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
 
-        //ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(input->data_type()) && (output->data_type() != DataType::S32));
-        ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_float(input->data_type()) && (output->data_type() != DataType::F32));
+        if(is_data_type_quantized_asymmetric(input->data_type()))
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::S32);
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+        }
     }
 
     return Status{};
@@ -229,6 +235,11 @@
             case DataType::QASYMM8:
                 num_elems_read_per_iteration = 16;
                 break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+            case DataType::F16:
+                num_elems_read_per_iteration = 24;
+                break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
             case DataType::F32:
                 num_elems_read_per_iteration = 12;
                 break;
@@ -313,7 +324,7 @@
     }
 
     // Check supported data type
-    bool supported_datatype = (dt == DataType::F32);
+    bool supported_datatype = is_data_type_float(dt);
 
     // Check for supported strides
     const auto &strides           = conv_info.stride();
@@ -334,7 +345,7 @@
 
 void NEDepthwiseConvolutionLayer3x3Kernel::generate_convolver()
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_input, 1, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(_input, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(_input, _weights);
     ARM_COMPUTE_ERROR_ON(_weights->info()->dimension(1) != 3 || _weights->info()->dimension(2) != 3);
 
@@ -371,6 +382,11 @@
 
     switch(_input->info()->data_type())
     {
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F16:
+            convolve_3x3<float16_t, float16_t>(window, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info, _depth_multiplier);
+            break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F32:
             convolve_3x3<float, float>(window, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info, _depth_multiplier);
             break;
@@ -398,6 +414,7 @@
                                                                                                                 ITensor       *out,
                                                                                                                 bool           setup_strides)
 {
+    const DataType    dt                  = in->info()->data_type();
     const TensorShape shape               = in->info()->tensor_shape();
     const int         in_rows             = shape.z();
     const int         in_cols             = shape.y();
@@ -414,34 +431,60 @@
     const int         output_batch_stride = (setup_strides) ? out->info()->strides_in_bytes()[3] / out->info()->element_size() : 0;
 
     const auto stride_x = conv_info.stride().first;
-    switch(stride_x)
+    switch(dt)
     {
-        case 1:
-            return arm_compute::support::cpp14::make_unique<DepthwiseConvolution<4, 4, 3, 3, 1, 1, float, float>>(
-                       n_batches,
-                       in_rows,
-                       in_cols,
-                       n_channels,
-                       padding_same,
-                       reinterpret_cast<const float *>(w->ptr_to_element(Coordinates())),
-                       reinterpret_cast<float *>(in->ptr_to_element(Coordinates())),
-                       reinterpret_cast<float *>(out->ptr_to_element(Coordinates())),
-                       weight_col_stride, weight_row_stride,
-                       input_col_stride, input_row_stride, input_batch_stride,
-                       output_col_stride, output_row_stride, output_batch_stride);
-        case 2:
-            return arm_compute::support::cpp14::make_unique<DepthwiseConvolution<3, 3, 3, 3, 2, 2, float, float>>(
-                       n_batches,
-                       in_rows,
-                       in_cols,
-                       n_channels,
-                       padding_same,
-                       reinterpret_cast<const float *>(w->ptr_to_element(Coordinates())),
-                       reinterpret_cast<float *>(in->ptr_to_element(Coordinates())),
-                       reinterpret_cast<float *>(out->ptr_to_element(Coordinates())),
-                       weight_col_stride, weight_row_stride,
-                       input_col_stride, input_row_stride, input_batch_stride,
-                       output_col_stride, output_row_stride, output_batch_stride);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F16:
+        {
+            switch(stride_x)
+            {
+                case 1:
+                    return arm_compute::support::cpp14::make_unique<DepthwiseConvolution<4, 4, 3, 3, 1, 1, float16_t, float16_t>>(
+                               n_batches, in_rows, in_cols, n_channels, padding_same,
+                               reinterpret_cast<const float16_t *>(w->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float16_t *>(in->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float16_t *>(out->ptr_to_element(Coordinates())), weight_col_stride,
+                               weight_row_stride, input_col_stride, input_row_stride, input_batch_stride,
+                               output_col_stride, output_row_stride, output_batch_stride);
+                case 2:
+                    return arm_compute::support::cpp14::make_unique<DepthwiseConvolution<4, 4, 3, 3, 2, 2, float16_t, float16_t>>(
+                               n_batches, in_rows, in_cols, n_channels, padding_same,
+                               reinterpret_cast<const float16_t *>(w->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float16_t *>(in->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float16_t *>(out->ptr_to_element(Coordinates())), weight_col_stride,
+                               weight_row_stride, input_col_stride, input_row_stride, input_batch_stride,
+                               output_col_stride, output_row_stride, output_batch_stride);
+                default:
+                    return nullptr;
+            }
+            break;
+        }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F32:
+        {
+            switch(stride_x)
+            {
+                case 1:
+                    return arm_compute::support::cpp14::make_unique<DepthwiseConvolution<4, 4, 3, 3, 1, 1, float, float>>(
+                               n_batches, in_rows, in_cols, n_channels, padding_same,
+                               reinterpret_cast<const float *>(w->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float *>(in->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float *>(out->ptr_to_element(Coordinates())), weight_col_stride,
+                               weight_row_stride, input_col_stride, input_row_stride, input_batch_stride,
+                               output_col_stride, output_row_stride, output_batch_stride);
+                case 2:
+                    return arm_compute::support::cpp14::make_unique<DepthwiseConvolution<3, 3, 3, 3, 2, 2, float, float>>(
+                               n_batches, in_rows, in_cols, n_channels, padding_same,
+                               reinterpret_cast<const float *>(w->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float *>(in->ptr_to_element(Coordinates())),
+                               reinterpret_cast<float *>(out->ptr_to_element(Coordinates())), weight_col_stride,
+                               weight_row_stride, input_col_stride, input_row_stride, input_batch_stride,
+                               output_col_stride, output_row_stride, output_batch_stride);
+                default:
+                    return nullptr;
+            }
+            break;
+        }
         default:
             return nullptr;
     }
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index eefbd98..864c63f 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -451,6 +451,13 @@
     {
         switch(input->info()->data_type())
         {
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+            case DataType::F16:
+            {
+                _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
+                break;
+            }
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
             case DataType::F32:
             {
                 _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_1x1_fp32_fp32.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_1x1_fp32_fp32.cpp
index 9b3a60d..c5a0565 100644
--- a/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_1x1_fp32_fp32.cpp
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_1x1_fp32_fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp"
+#include "impl_fp32_fp32.hpp"
 
 namespace depthwise
 {
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_2x2_fp32_fp32.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_2x2_fp32_fp32.cpp
index dba2330..9ce43f9 100644
--- a/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_2x2_fp32_fp32.cpp
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_2x2_3x3_2x2_fp32_fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp"
+#include "impl_fp32_fp32.hpp"
 
 namespace depthwise
 {
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_1x1_fp32_fp32.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_1x1_fp32_fp32.cpp
index b946e5d..0c96beb 100644
--- a/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_1x1_fp32_fp32.cpp
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_1x1_fp32_fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp"
+#include "impl_fp32_fp32.hpp"
 
 namespace depthwise
 {
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_2x2_fp32_fp32.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_2x2_fp32_fp32.cpp
index 2510941..941c8e9 100644
--- a/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_2x2_fp32_fp32.cpp
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_3x3_3x3_2x2_fp32_fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp"
+#include "impl_fp32_fp32.hpp"
 
 namespace depthwise
 {
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp16_fp16.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp16_fp16.cpp
new file mode 100644
index 0000000..33b55df
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp16_fp16.cpp
@@ -0,0 +1,130 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "impl_fp16_fp16.hpp"
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+namespace depthwise
+{
+using Conv = DepthwiseConvolution<4, 4, 3, 3, 1, 1, float16_t, float16_t>;
+using ConvImpl = DepthwiseConvolutionImpl<4, 4, 3, 3, 1, 1, float16_t, float16_t>;
+
+template <>
+const Conv::TileFn Conv::tilefn_unpadded = ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>;
+
+template <>
+const Conv::TileFn Conv::tilefn_top[n_in_pad_top_fns] = {
+        ConvImpl::template process_tile<true, 1, 0, 0, 0, 0, 0>,
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_left[n_in_pad_left_fns] = {
+        ConvImpl::template process_tile<true, 0, 1, 0, 0, 0, 0>,
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_bottom[n_in_pad_bottom_fns][n_out_pad_bottom_fns] = {
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 3, 0>,
+        },
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_right[n_in_pad_right_fns][n_out_pad_right_fns] = {
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 3>,
+        },
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_generic = ConvImpl::template process_tile<false>;
+
+template class DepthwiseConvolution<4, 4, 3, 3, 1, 1, float16_t, float16_t>;
+}  // namespace depthwise
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp32_fp32.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp32_fp32.cpp
index 44b93a1..1cbd6d5 100644
--- a/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp32_fp32.cpp
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_1x1_fp32_fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp"
+#include "impl_fp32_fp32.hpp"
 
 namespace depthwise
 {
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp16_fp16.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp16_fp16.cpp
new file mode 100644
index 0000000..09722d0
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp16_fp16.cpp
@@ -0,0 +1,168 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "impl_fp16_fp16.hpp"
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+namespace depthwise
+{
+using Conv = DepthwiseConvolution<4, 4, 3, 3, 2, 2, float16_t, float16_t>;
+using ConvImpl = DepthwiseConvolutionImpl<4, 4, 3, 3, 2, 2, float16_t, float16_t>;
+
+template <>
+const Conv::TileFn Conv::tilefn_unpadded = ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>;
+
+template <>
+const Conv::TileFn Conv::tilefn_top[n_in_pad_top_fns] = {
+        ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>,
+        ConvImpl::template process_tile<true, 1, 0, 0, 0, 0, 0>,
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_left[n_in_pad_left_fns] = {
+        ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>,
+        ConvImpl::template process_tile<true, 0, 1, 0, 0, 0, 0>,
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_bottom[n_in_pad_bottom_fns][n_out_pad_bottom_fns] = {
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 1, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 2, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 3, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 4, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 5, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 6, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 6, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 6, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 6, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 7, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 7, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 7, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 7, 0, 3, 0>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 8, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 8, 0, 1, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 8, 0, 2, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 8, 0, 3, 0>,
+        },
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_right[n_in_pad_right_fns][n_out_pad_right_fns] = {
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 0, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 1, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 2, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 3, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 4, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 5, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 6, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 6, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 6, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 6, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 7, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 7, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 7, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 7, 0, 3>,
+        },
+        {
+                ConvImpl::template process_tile<true, 0, 0, 0, 8, 0, 0>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 8, 0, 1>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 8, 0, 2>,
+                ConvImpl::template process_tile<true, 0, 0, 0, 8, 0, 3>,
+        },
+};
+
+template <>
+const Conv::TileFn Conv::tilefn_generic = ConvImpl::template process_tile<false>;
+
+template class DepthwiseConvolution<4, 4, 3, 3, 2, 2, float16_t, float16_t>;
+}  // namespace depthwise
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp32_fp32.cpp b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp32_fp32.cpp
index 8eb53a6..05315ee 100644
--- a/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp32_fp32.cpp
+++ b/src/core/NEON/kernels/convolution/depthwise/depthwise_4x4_3x3_2x2_fp32_fp32.cpp
@@ -21,7 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp"
+#include "impl_fp32_fp32.hpp"
 
 namespace depthwise
 {
diff --git a/src/core/NEON/kernels/convolution/depthwise/impl_fp16_fp16.hpp b/src/core/NEON/kernels/convolution/depthwise/impl_fp16_fp16.hpp
new file mode 100644
index 0000000..ed4cfb8
--- /dev/null
+++ b/src/core/NEON/kernels/convolution/depthwise/impl_fp16_fp16.hpp
@@ -0,0 +1,290 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/*
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ *
+ *          NOTE: Header to be included by implementation files only.
+ *
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#include "arm_compute/core/NEON/kernels/convolution/common/arm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp"
+
+#pragma once
+
+namespace depthwise
+{
+// Partial specialisation for FP16 to FP16
+template <int OutputTileRows, int OutputTileCols,
+          int KernelRows, int KernelCols,
+          int StrideRows, int StrideCols>
+struct DepthwiseConvolutionImpl<OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols, float16_t, float16_t>
+{
+  typedef DepthwiseConvolution<
+    OutputTileRows, OutputTileCols,
+    KernelRows, KernelCols,
+    StrideRows, StrideCols,
+    float16_t, float16_t
+  > DWC;
+
+  template <
+    bool Specialize=false,  // Specialize (or not) the method
+    int InPadTop=0,         // If specialized, top padding
+    int InPadLeft=0,        // If specialized, left padding
+    int InPadBottom=0,      // If specialized, bottom padding
+    int InPadRight=0,       // If specialized, right padding
+    int OutPadBottom=0,     // If specialized, bottom output padding
+    int OutPadRight=0       // If specialized, bottom right padding
+  >
+  static void process_tile(
+    const int n_channels,
+    const float16_t* const weights,
+    const int weight_row_stride,
+    const int weight_col_stride,
+    const float16_t* const inptr,
+    const int in_row_stride,
+    const int in_col_stride,
+    float16_t* const outptr,
+    const int out_row_stride,
+    const int out_col_stride,
+    const int in_pad_top=0,
+    const int in_pad_left=0,
+    const int in_pad_bottom=0,
+    const int in_pad_right=0,
+    const int out_pad_bottom=0,
+    const int out_pad_right=0
+  );
+};
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC>
+template <
+  bool Specialize,
+  int InPadTop, int InPadLeft, int InPadBottom, int InPadRight,
+  int OutPadBottom, int OutPadRight
+>
+void DepthwiseConvolutionImpl<OTR, OTC, KR, KC, SR, SC, float16_t, float16_t>::process_tile(
+  const int n_channels,
+  const float16_t *__restrict__ const weights,
+  const int weight_row_stride,
+  const int weight_col_stride,
+  const float16_t *__restrict__ const inptr,
+  const int in_row_stride,
+  const int in_col_stride,
+  float16_t *__restrict__ const outptr,
+  const int out_row_stride,
+  const int out_col_stride,
+  const int _in_pad_top,
+  const int _in_pad_left,
+  const int _in_pad_bottom,
+  const int _in_pad_right,
+  const int _out_pad_bottom,
+  const int _out_pad_right
+)
+{
+  constexpr auto inner_tile_rows = DWC::inner_tile_rows;
+  constexpr auto inner_tile_cols = DWC::inner_tile_cols;
+  constexpr auto kernel_rows = DWC::kernel_rows;
+  constexpr auto kernel_cols = DWC::kernel_cols;
+  constexpr auto output_tile_rows = DWC::output_tile_rows;
+  constexpr auto output_tile_cols = DWC::output_tile_cols;
+  constexpr auto stride_rows = DWC::stride_rows;
+  constexpr auto stride_cols = DWC::stride_cols;
+
+  // Extract parameters
+  const int in_pad_top = Specialize ? InPadTop : _in_pad_top;
+  const int in_pad_left = Specialize ? InPadLeft : _in_pad_left;
+  const int in_pad_bottom = Specialize ? InPadBottom : _in_pad_bottom;
+  const int in_pad_right = Specialize ? InPadRight : _in_pad_right;
+  const int out_pad_bottom = Specialize ? OutPadBottom : _out_pad_bottom;
+  const int out_pad_right = Specialize ? OutPadRight : _out_pad_right;
+
+  // Compute valid ranges of the tile
+  const int in_cells_i = inner_tile_rows - in_pad_bottom;
+  const int in_cells_j = inner_tile_cols - in_pad_right;
+  const int out_cells_i = output_tile_rows - out_pad_bottom;
+  const int out_cells_j = output_tile_cols - out_pad_right;
+
+  // Instantiate pointers
+  const float16_t* __restrict__ inptr_base = inptr;
+  const float16_t* __restrict__ wptr_base = weights;
+    float16_t* __restrict__ outptr_base = outptr;
+
+  // Perform the depthwise convolution
+  int channels_remaining = n_channels;
+#ifdef __aarch64__
+  for (; channels_remaining >= 8; channels_remaining -= 8)
+  {
+    // Load input tile
+    float16x8_t u[inner_tile_rows][inner_tile_cols];
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      const float16_t* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        if (i < in_pad_top || in_cells_i <= i ||
+            j < in_pad_left || in_cells_j <= j)
+        {
+          u[i][j] = vdupq_n_f16(0.0f);
+        }
+        else
+        {
+          u[i][j] = vld1q_f16(inptr_row + (j - in_pad_left)*in_col_stride);
+        }
+      }
+    }
+    inptr_base += 8;
+
+    // Load weights tile
+    float16x8_t w[kernel_rows][kernel_cols];
+    for (int i = 0; i < kernel_rows; i++)
+    {
+      const float16_t* const wptr_row = wptr_base + i*weight_row_stride;
+      for (int j = 0; j < kernel_cols; j++)
+      {
+        w[i][j] = vld1q_f16(wptr_row + j*weight_col_stride);
+      }
+    }
+    wptr_base += 8;
+
+    // Perform the convolution
+    float16x8_t v[output_tile_rows][output_tile_cols];
+    for (int out_i = 0; out_i < out_cells_i; out_i++)
+    {
+      for (int out_j = 0; out_j < out_cells_j; out_j++)
+      {
+        // Base co-ordinate
+        const int base_i = out_i * stride_rows;
+        const int base_j = out_j * stride_cols;
+
+        // Fill the accumulator
+        for (int in_i = 0; in_i < kernel_rows; in_i++)
+        {
+          const int i = base_i + in_i;
+          for (int in_j = 0; in_j < kernel_cols; in_j++)
+          {
+            const int j = base_j + in_j;
+            if (in_i == 0 && in_j == 0)
+            {
+              // v[out_i][out_j] = w[in_i][in_j] * u[i][j];
+              v[out_i][out_j] = vmulq_f16(w[in_i][in_j], u[i][j]);
+            }
+            else
+            {
+              // v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+              v[out_i][out_j] = vaddq_f16(v[out_i][out_j], vmulq_f16(w[in_i][in_j], u[i][j]));
+            }
+          }
+        }
+      }
+    }
+
+    // Store the output tile
+    for (int i = 0; i < out_cells_i; i++)
+    {
+      float16_t* const outptr_row = outptr_base + i*out_row_stride;
+      for (int j = 0; j < out_cells_j; j++)
+      {
+        vst1q_f16(outptr_row + j*out_col_stride, v[i][j]);
+      }
+    }
+    outptr_base += 8;
+  }
+#endif  // __aarch64__
+  for (; channels_remaining; channels_remaining--)
+  {
+    // Load input tile
+    float16_t u[inner_tile_rows][inner_tile_cols];
+    for (int i = 0; i < inner_tile_rows; i++)
+    {
+      const float16_t* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+      for (int j = 0; j < inner_tile_cols; j++)
+      {
+        if (i < in_pad_top || in_cells_i <= i ||
+            j < in_pad_left || in_cells_j <= j)
+        {
+          u[i][j] = static_cast<float16_t>(0);
+        }
+        else
+        {
+          u[i][j] = *(inptr_row + (j - in_pad_left)*in_col_stride);
+        }
+      }
+    }
+    inptr_base++;
+
+    // Load weights tile
+    float16_t w[kernel_rows][kernel_cols];
+    for (int i = 0; i < kernel_rows; i++)
+    {
+      const float16_t* const wptr_row = wptr_base + i*weight_row_stride;
+      for (int j = 0; j < kernel_cols; j++)
+      {
+        w[i][j] = *(wptr_row + j*weight_col_stride);
+      }
+    }
+    wptr_base++;
+
+    // Perform the convolution
+    float16_t v[output_tile_rows][output_tile_cols];
+    for (int out_i = 0; out_i < out_cells_i; out_i++)
+    {
+      for (int out_j = 0; out_j < out_cells_j; out_j++)
+      {
+        // Clear the accumulator
+        v[out_i][out_j] = static_cast<float16_t>(0);
+
+        // Base co-ordinate
+        const int base_i = out_i * stride_rows;
+        const int base_j = out_j * stride_cols;
+
+        // Fill the accumulator
+        for (int in_i = 0; in_i < kernel_rows; in_i++)
+        {
+          const int i = base_i + in_i;
+          for (int in_j = 0; in_j < kernel_cols; in_j++)
+          {
+            const int j = base_j + in_j;
+            v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+          }
+        }
+      }
+    }
+
+    // Store the output tile
+    for (int i = 0; i < out_cells_i; i++)
+    {
+      float16_t* const outptr_row = outptr_base + i*out_row_stride;
+      for (int j = 0; j < out_cells_j; j++)
+      {
+        *(outptr_row + j*out_col_stride) = v[i][j];
+      }
+    }
+    outptr_base++;
+  }
+}
+}  // namespace depthwise
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp b/src/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
similarity index 100%
rename from arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
rename to src/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index ccbd01e..a46be2e 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -43,7 +43,7 @@
 
 void NEDepthwiseConvolutionLayer3x3::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
 
     PixelValue zero_value(0.f);
diff --git a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
index fe7bba3..54bce02 100644
--- a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
@@ -45,8 +45,9 @@
 
 namespace
 {
-constexpr RelativeTolerance<float>   tolerance_f32(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
-constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QASYMM8 */
+RelativeTolerance<half_float::half>  tolerance_f16(half_float::half(0.001)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr RelativeTolerance<float>   tolerance_f32(0.01f);                   /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);                   /**< Tolerance value for comparing reference's output against implementation's output for DataType::QASYMM8 */
 
 const auto depth_multipliers = framework::dataset::make("DepthMultiplier", { 1, 2, 3 });
 } // namespace
@@ -209,7 +210,7 @@
 {
     validate(Accessor(_target), _reference, tolerance_f32);
 }
-TEST_SUITE_END()
+TEST_SUITE_END() // Generic
 
 TEST_SUITE(W3x3)
 template <typename T>
@@ -238,10 +239,43 @@
 {
     validate(Accessor(_target), _reference, tolerance_f32);
 }
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // W3x3
+TEST_SUITE_END() // F32
 
-TEST_SUITE_END()
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+TEST_SUITE(F16)
+TEST_SUITE(W3x3)
+template <typename T>
+using NEDepthwiseConvolutionLayerFixture3x3 = DepthwiseConvolutionLayerValidationFixture<Tensor, Accessor, NEDepthwiseConvolutionLayer3x3, T>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthwiseConvolutionLayerFixture3x3<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
+                                                                                                                   depth_multipliers),
+                                                                                                                   framework::dataset::make("DataType",
+                                                                                                                           DataType::F16)),
+                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+{
+    validate(Accessor(_target), _reference, tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthwiseConvolutionLayerFixture3x3<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+                                                                                                                       depth_multipliers),
+                                                                                                                       framework::dataset::make("DataType",
+                                                                                                                               DataType::F16)),
+                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+{
+    validate(Accessor(_target), _reference, tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunOptimized, NEDepthwiseConvolutionLayerFixture3x3<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::OptimizedDepthwiseConvolutionLayerDataset3x3(),
+                                                                                                                       framework::dataset::make("DepthMultiplier", 1)),
+                                                                                                                       framework::dataset::make("DataType",
+                                                                                                                               DataType::F16)),
+                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+{
+    validate(Accessor(_target), _reference, tolerance_f16);
+}
+TEST_SUITE_END() // W3x3
+TEST_SUITE_END() // FP16
+#endif           // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+TEST_SUITE_END() // Float
 
 template <typename T>
 using NEDepthwiseConvolutionLayerQuantizedFixture3x3 = DepthwiseConvolutionLayerValidationQuantizedFixture<Tensor, Accessor, NEDepthwiseConvolutionLayer3x3, T>;