COMPMID-421: Added FP16 support in the Neon Locally Connected Layer.

Change-Id: I4b52a209a5ce1a7e69494008538ed242b14b5593
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81520
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
index 895799c..2b7b391 100644
--- a/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.cpp
@@ -49,6 +49,126 @@
 
 namespace
 {
+void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+    const auto width_matrix_b  = static_cast<int>(output->info()->dimension(0));
+    const auto in_b_stride     = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+    const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
+
+    // The implementation computes 16 elements per iteration
+    const int window_start_x = 16 * window.thread_id();
+    const int window_step_x  = 16 * window.num_threads();
+    // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
+    const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+
+    Window win_out(window);
+    win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+
+    Window win_a(window);
+    win_a.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    Iterator ina(input0, win_a);
+    Iterator out(output, win_out);
+
+    execute_window_loop(win_out, [&](const Coordinates & id)
+    {
+        if(id.x() > width_matrix_b)
+        {
+            return;
+        }
+
+        float16x8_t acc0 = vdupq_n_f16(0.f);
+        float16x8_t acc1 = vdupq_n_f16(0.f);
+        float16x8_t acc2 = vdupq_n_f16(0.f);
+        float16x8_t acc3 = vdupq_n_f16(0.f);
+
+        auto vec_a    = reinterpret_cast<const float16_t *>(ina.ptr());
+        auto matrix_b = reinterpret_cast<const float16_t *>(input1->ptr_to_element(Coordinates(id[0], 0, id[1])));
+
+        const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
+
+        for(; vec_a <= (vec_a_end_addr - 4);)
+        {
+            const float16x4_t a0l = vld1_f16(vec_a);
+
+            float16x8_t b00 = vld1q_f16(matrix_b);
+            float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+            float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+            float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+            float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+            float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+            float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+            float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+            acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
+            acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
+            acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
+            acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
+            acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
+            acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
+            acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
+            acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
+
+            matrix_b += 2 * in_b_stride;
+
+            b00 = vld1q_f16(matrix_b);
+            b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+            b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+            b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+            b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
+            b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
+            b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
+            b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
+
+            acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
+            acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
+            acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
+            acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
+            acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
+            acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
+            acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
+            acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
+
+            vec_a += 4;
+            matrix_b += 2 * in_b_stride;
+        }
+
+        for(; vec_a < vec_a_end_addr;)
+        {
+            const float16_t   a0  = *vec_a;
+            const float16x8_t b00 = vld1q_f16(matrix_b);
+            const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
+            const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
+            const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
+
+            acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
+            acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
+            acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
+            acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
+
+            vec_a += 1;
+            matrix_b += in_b_stride;
+        }
+
+        const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
+
+        vst1q_f16(vec_out + 0, acc0);
+        vst1q_f16(vec_out + 8, acc1);
+        vst1q_f16(vec_out + 16, acc2);
+        vst1q_f16(vec_out + 24, acc3);
+    },
+    ina, out);
+#else  /* ARM_COMPUTE_ENABLE_FP16 */
+    ARM_COMPUTE_UNUSED(input0);
+    ARM_COMPUTE_UNUSED(input1);
+    ARM_COMPUTE_UNUSED(output);
+    ARM_COMPUTE_UNUSED(window);
+    ARM_COMPUTE_ERROR("Not supported, recompile with -march=armv8.2-a+fp16+simd.");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
 void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window)
 {
     const auto width_matrix_b  = static_cast<int>(output->info()->dimension(0));
@@ -190,17 +310,17 @@
 
 void NELocallyConnectedMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON(input0->info()->dimension(0) != input1->info()->dimension(1));
 
     _input0 = input0;
     _input1 = input1;
     _output = output;
 
-    unsigned int num_elems_processed_per_iteration_x = 16;
+    const unsigned int num_elems_processed_per_iteration_x = 16;
 
     // Configure kernel window
     Window win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
@@ -222,5 +342,22 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
 
-    vector_matrix_multiply_f32(_input0, _input1, _output, window);
+    switch(_input0->info()->data_type())
+    {
+        case DataType::F16:
+        {
+            vector_matrix_multiply_f16(_input0, _input1, _output, window);
+            break;
+        }
+        case DataType::F32:
+        {
+            vector_matrix_multiply_f32(_input0, _input1, _output, window);
+            break;
+        }
+        default:
+        {
+            ARM_COMPUTE_ERROR("Data type not supported");
+            break;
+        }
+    }
 }