COMPMID-421: Added FP16 support to the NEON Direct Convolution function.

Change-Id: I3a1aa2ce985ecf95fc5f441a6e6d43b4935306ee
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79965
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp
index effc50e..fb16c8d 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp
@@ -100,6 +100,25 @@
     return vqaddq_qs16(x, y);
 }
 
+#ifdef ARM_COMPUTE_ENABLE_FP16
+inline float16x8_t internal_vld1q(const float16_t *in)
+{
+    return vld1q_f16(in);
+}
+inline void internal_vst1q(float16_t *p, const float16x8_t &v)
+{
+    vst1q_f16(p, v);
+}
+inline float16x8_t internal_vdupq_n(float16_t v)
+{
+    return vdupq_n_f16(v);
+}
+inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
+{
+    return vaddq_f16(x, y);
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 template <typename T1, typename T2, bool in_place>
 void accumulate_bias(ITensor *input, const ITensor *bias, const Window window, ITensor *output)
 {
@@ -143,8 +162,8 @@
 
 void NEDirectConvolutionLayerBiasAccumulateKernel::configure(ITensor *input, const ITensor *bias, ITensor *output)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON(input->info()->fixed_point_position() != bias->info()->fixed_point_position());
     if(output != nullptr)
     {
@@ -183,6 +202,12 @@
     {
         _func = (output == nullptr) ? &accumulate_bias<float, float, true> : &accumulate_bias<float, float, false>;
     }
+#ifdef ARM_COMPUTE_ENABLE_FP16
+    else if(input->info()->data_type() == DataType::F16)
+    {
+        _func = (output == nullptr) ? &accumulate_bias<float16_t, float16_t, true> : &accumulate_bias<float16_t, float16_t, false>;
+    }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
     else if(input->info()->data_type() == DataType::QS8)
     {
         _func = (output == nullptr) ? &accumulate_bias<qint8_t, qint8_t, true> : &accumulate_bias<qint8_t, qint8_t, false>;
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index d608898..09d8dd5 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -39,6 +39,53 @@
 
 namespace
 {
+#ifdef ARM_COMPUTE_ENABLE_FP16
+template <unsigned int stridex>
+float16x8_t internal_vld1q(const float16_t *in);
+
+template <>
+float16x8_t internal_vld1q<1>(const float16_t *in)
+{
+    return vld1q_f16(in);
+}
+
+template <>
+float16x8_t internal_vld1q<2>(const float16_t *in)
+{
+    const float16x8x2_t tmp = vld2q_f16(in);
+    return tmp.val[0];
+}
+
+template <>
+float16x8_t internal_vld1q<3>(const float16_t *in)
+{
+    const float16x8x3_t tmp = vld3q_f16(in);
+    return tmp.val[0];
+}
+
+inline float16x8_t internal_vdupq_n(float16_t v)
+{
+    return vdupq_n_f16(v);
+}
+
+inline void internal_vst1q(float16_t *p, const float16x8_t &v)
+{
+    vst1q_f16(p, v);
+}
+
+float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
+{
+    ARM_COMPUTE_UNUSED(fixed_point_position);
+    return vmulq_f16(x, y);
+}
+
+inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
+{
+    ARM_COMPUTE_UNUSED(fixed_point_position);
+    return vaddq_f16(x, vmulq_f16(y, z));
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 template <unsigned int stridex>
 float32x4_t internal_vld1q(const float *in);
 
@@ -226,6 +273,148 @@
     }
 };
 
+#ifdef ARM_COMPUTE_ENABLE_FP16
+inline float16x8x3_t load_matrix_row(const float16_t *ptr)
+{
+    /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
+       r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
+    const float16x8x3_t r =
+    {
+        {
+            vld1q_dup_f16(ptr),
+            vld1q_dup_f16(1 + ptr),
+            vld1q_dup_f16(2 + ptr)
+        }
+    };
+    return r;
+}
+
+template <unsigned int stridex>
+float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                           int fixed_point_position);
+
+template <>
+float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                              int fixed_point_position)
+{
+    ARM_COMPUTE_UNUSED(fixed_point_position);
+
+    const float16x8x3_t vtop =
+    {
+        {
+            vld1q_f16(in_top),
+            vld1q_f16(in_top + 8),
+            vld1q_f16(in_top + 16)
+        }
+    };
+    const float16x8x3_t vmid =
+    {
+        {
+            vld1q_f16(in_mid),
+            vld1q_f16(in_mid + 8),
+            vld1q_f16(in_mid + 16)
+        }
+    };
+    const float16x8x3_t vlow =
+    {
+        {
+            vld1q_f16(in_low),
+            vld1q_f16(in_low + 8),
+            vld1q_f16(in_low + 16)
+        }
+    };
+    float16x8x2_t out =
+    {
+        {
+            vmulq_f16(vtop.val[0], m0.val[0]),
+            vmulq_f16(vtop.val[1], m0.val[0])
+        }
+    };
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
+    out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
+    out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
+    return out;
+}
+
+template <>
+inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                                     int fixed_point_position)
+{
+    float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
+    out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
+    out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
+    out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
+    return out;
+}
+
+template <>
+inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+                                     int fixed_point_position)
+{
+    float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
+    out.val[0]        = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
+    return out;
+}
+
+template <unsigned int stridex>
+void store_results(float16_t *buffer, const float16x8x2_t &values);
+
+template <>
+void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
+{
+    vst1q_f16(buffer, values.val[0]);
+    vst1q_f16(buffer + 8, values.val[1]);
+}
+
+template <>
+void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
+{
+    vst1q_f16(buffer, values.val[0]);
+}
+
+template <>
+void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
+{
+    vst1_f16(buffer, vget_low_f16(values.val[0]));
+}
+
+template <unsigned int stridex>
+void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
+
+template <>
+void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
+{
+    vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
+    vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
+}
+
+template <>
+void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
+{
+    vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
+}
+
+template <>
+void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
+{
+    vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
+}
+
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 inline float32x4x3_t load_matrix_row(const float *ptr)
 {
     const float32x4x3_t r =
@@ -590,12 +779,13 @@
 
             for(int oz = 0; oz < num_planes_z; ++oz)
             {
+                const int zoffset    = id.z() + oz;
                 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
                 // Step 1
                 {
-                    const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
-                    const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
-                    const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
+                    const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
+                    const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
+                    const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
                     const auto vk_r0    = load_matrix_row(ptr_k_r0);
                     const auto vk_r1    = load_matrix_row(ptr_k_r1);
                     const auto vk_r2    = load_matrix_row(ptr_k_r2);
@@ -616,9 +806,9 @@
                 // Step 2
                 for(int p = 1; p < kernel_depth; ++p)
                 {
-                    const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
-                    const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
-                    const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
+                    const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
+                    const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
+                    const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
                     const auto vk_r0    = load_matrix_row(ptr_k_r0);
                     const auto vk_r1    = load_matrix_row(ptr_k_r1);
                     const auto vk_r2    = load_matrix_row(ptr_k_r2);
@@ -697,9 +887,9 @@
 
 void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
                              "Pad > 0 not supported for 1x1 weights");
     ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
@@ -723,10 +913,24 @@
     {
         case 1:
         {
-            _num_elems_written_per_iteration = (input->info()->data_type() == DataType::QS8) ? 8 : 4;
-            _num_elems_read_per_iteration    = conv_stride_x * _num_elems_written_per_iteration;
+            switch(input->info()->data_type())
+            {
+#ifdef ARM_COMPUTE_ENABLE_FP16
+                case DataType::F16:
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+                case DataType::QS8:
+                    _num_elems_written_per_iteration = 8;
+                    break;
+                case DataType::F32:
+                    _num_elems_written_per_iteration = 4;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Data type not supported.");
+                    break;
+            }
 
-            win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
+            _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
+            win                           = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
             AccessWindowHorizontal input_access(input->info(), 0, _num_elems_read_per_iteration);
             AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
             update_window_and_padding(win, input_access, output_access);
@@ -786,25 +990,43 @@
     {
         case 1:
         {
-            if(_input->info()->data_type() == DataType::QS8)
+            switch(_input->info()->data_type())
             {
-                convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-            }
-            else
-            {
-                convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                case DataType::QS8:
+                    convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                    break;
+                case DataType::F32:
+                    convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                    break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+                case DataType::F16:
+                    convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                    break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+                default:
+                    ARM_COMPUTE_ERROR("Data type not supported");
+                    break;
             }
             break;
         }
         case 3:
         {
-            if(_input->info()->data_type() == DataType::QS8)
+            switch(_input->info()->data_type())
             {
-                convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-            }
-            else
-            {
-                convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                case DataType::QS8:
+                    convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                    break;
+                case DataType::F32:
+                    convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                    break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+                case DataType::F16:
+                    convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                    break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+                default:
+                    ARM_COMPUTE_ERROR("Data type not supported");
+                    break;
             }
             break;
         }