COMPMID-421: Added FP16 support in Pooling Layer

Change-Id: I6b6119c8770051c1656da40aa073c539c15b493e
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78985
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index 2ef2b98..ce97714 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -48,17 +48,17 @@
 inline float calculate_avg_scale(const Coordinates &id, const int pool_size, const int upper_bound_w, const int upper_bound_h,
                                  const int pad_x, const int pad_y, const int stride_x, const int stride_y)
 {
-    int start_x = id.x() * stride_x - pad_x;
-    int start_y = id.y() * stride_y - pad_y;
-    int end_x   = std::min(start_x + pool_size, upper_bound_w);
-    int end_y   = std::min(start_y + pool_size, upper_bound_h);
+    const int start_x = id.x() * stride_x - pad_x;
+    const int start_y = id.y() * stride_y - pad_y;
+    const int end_x   = std::min(start_x + pool_size, upper_bound_w);
+    const int end_y   = std::min(start_y + pool_size, upper_bound_h);
     return 1.f / ((end_y - start_y) * (end_x - start_x));
 }
 
 inline qint8_t calculate_avg_scale_q8(const Coordinates &id, int pool_size, int upper_bound_w, int upper_bound_h,
                                       int pad_x, int pad_y, int stride_x, int stride_y, int fixed_point_position)
 {
-    static std::array<qint8_t, 10> scale_values_q8 =
+    static const std::array<qint8_t, 10> scale_values_q8 =
     { { 0x0, 0x0, 0x40, 0x2A, 0x20, 0x19, 0x15, 0x12, 0x10, 0xE } };
     const int start_x = id.x() * stride_x - pad_x;
     const int start_y = id.y() * stride_y - pad_y;
@@ -96,8 +96,11 @@
     static const std::set<int> supported_pool_sizes = { 2, 3, 7 };
     ARM_COMPUTE_UNUSED(supported_pool_sizes);
 
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
     ARM_COMPUTE_ERROR_ON(supported_pool_sizes.find(pool_size) == supported_pool_sizes.end());
     ARM_COMPUTE_ERROR_ON(7 == pool_size && input->info()->data_type() != DataType::F32);
     ARM_COMPUTE_ERROR_ON(pool_pad_x >= pool_size || pool_pad_y >= pool_size);
@@ -140,9 +143,30 @@
                     break;
                 default:
                     ARM_COMPUTE_ERROR("Pooling size not supported");
+                    break;
             }
             num_elems_horizontal_window = 8;
             break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+        case DataType::F16:
+            switch(pool_size)
+            {
+                case 2:
+                    num_elems_read_per_iteration      = 16;
+                    num_elems_processed_per_iteration = 8;
+                    num_elems_horizontal_window       = 8;
+                    break;
+                case 3:
+                    num_elems_read_per_iteration      = 4;
+                    num_elems_processed_per_iteration = 1;
+                    num_elems_horizontal_window       = 1;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Pooling size not supported");
+                    break;
+            }
+            break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
         case DataType::F32:
             switch(pool_size)
             {
@@ -157,6 +181,7 @@
                     break;
                 default:
                     ARM_COMPUTE_ERROR("Pooling size not supported");
+                    break;
             }
             num_elems_processed_per_iteration = 1;
             num_elems_horizontal_window       = 1;
@@ -188,6 +213,10 @@
             {
                 _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_q8<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_q8<PoolingType::MAX>;
             }
+            else if(input->info()->data_type() == DataType::F16)
+            {
+                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_f16<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_f16<PoolingType::MAX>;
+            }
             else if(input->info()->data_type() == DataType::F32)
             {
                 _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling2_f32<PoolingType::AVG> : &NEPoolingLayerKernel::pooling2_f32<PoolingType::MAX>;
@@ -198,6 +227,10 @@
             {
                 _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_q8<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_q8<PoolingType::MAX>;
             }
+            else if(input->info()->data_type() == DataType::F16)
+            {
+                _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_f16<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_f16<PoolingType::MAX>;
+            }
             else if(input->info()->data_type() == DataType::F32)
             {
                 _func = (PoolingType::AVG == pool_type) ? &NEPoolingLayerKernel::pooling3_f32<PoolingType::AVG> : &NEPoolingLayerKernel::pooling3_f32<PoolingType::MAX>;
@@ -266,6 +299,101 @@
 }
 
 template <PoolingType pooling_type>
+void NEPoolingLayerKernel::pooling3_f16(const Window &window_input, const Window &window)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+    Iterator input(_input, window_input);
+    Iterator output(_output, window);
+
+    constexpr const int pool_size     = 3;
+    int                 pool_pad_x    = 0;
+    int                 pool_pad_y    = 0;
+    int                 pool_stride_x = 0;
+    int                 pool_stride_y = 0;
+    std::tie(pool_pad_x, pool_pad_y)       = _pool_info.pad_stride_info().pad();
+    std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
+    const int upper_bound_w = _input->info()->dimension(0) + pool_pad_x;
+    const int upper_bound_h = _input->info()->dimension(1) + pool_pad_y;
+
+    const unsigned char *const input_top_ptr    = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_x), -static_cast<int>(pool_pad_y)));
+    const unsigned char *const input_middle_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_x), -static_cast<int>(pool_pad_y) + 1));
+    const unsigned char *const input_bottom_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_x), -static_cast<int>(pool_pad_y) + 2));
+
+    execute_window_loop(window, [&](const Coordinates & id)
+    {
+        const float16x4_t top_data    = vld1_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
+        const float16x4_t middle_data = vld1_f16(reinterpret_cast<const float16_t *>(input_middle_ptr + input.offset()));
+        const float16x4_t bottom_data = vld1_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
+        float16x4_t       res         = {};
+        if(pooling_type == PoolingType::AVG)
+        {
+            // Calculate scale
+            const float       scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
+            const float16x4_t scale_v = vdup_n_f16(scale);
+            // Perform pooling
+            const float16x4_t sum_data = vadd_f16(vadd_f16(top_data, bottom_data), middle_data);
+            res                        = vpadd_f16(vset_lane_f16(0.f, sum_data, 3), sum_data);
+            res                        = vmul_f16(vpadd_f16(res, res), scale_v);
+        }
+        else
+        {
+            const float16x4_t max_data = vmax_f16(vmax_f16(top_data, bottom_data), middle_data);
+            res                        = vpmax_f16(vset_lane_f16(-std::numeric_limits<float>::max(), max_data, 3), max_data);
+            res                        = vpmax_f16(res, res);
+        }
+        *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(res, 0);
+    },
+    input, output);
+#else  /* ARM_COMPUTE_ENABLE_FP16 */
+    ARM_COMPUTE_UNUSED(window_input);
+    ARM_COMPUTE_UNUSED(window);
+    ARM_COMPUTE_ERROR("FP16 Not supported! Recompile the library with arch=arm64-v8.2-a");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
+template <PoolingType pooling_type>
+void NEPoolingLayerKernel::pooling2_f16(const Window &window_input, const Window &window)
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+    Iterator      input(_input, window_input);
+    Iterator      output(_output, window);
+    constexpr int pool_size = 2;
+    int           pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y = 0;
+    std::tie(pool_pad_x, pool_pad_y)       = _pool_info.pad_stride_info().pad();
+    std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
+    const int upper_bound_w = _input->info()->dimension(0) + pool_pad_x;
+    const int upper_bound_h = _input->info()->dimension(1) + pool_pad_y;
+
+    const unsigned char *const input_top_ptr    = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_x), -static_cast<int>(pool_pad_y)));
+    const unsigned char *const input_bottom_ptr = _input->ptr_to_element(Coordinates(-static_cast<int>(pool_pad_x), -static_cast<int>(pool_pad_y) + 1));
+
+    execute_window_loop(window, [&](const Coordinates & id)
+    {
+        const auto  top_data    = vld2q_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
+        const auto  bottom_data = vld2q_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
+        float16x8_t res         = {};
+
+        if(pooling_type == PoolingType::AVG)
+        {
+            const float       scale   = calculate_avg_scale(id, pool_size, upper_bound_w, upper_bound_h, pool_pad_x, pool_pad_y, pool_stride_x, pool_stride_y);
+            const float16x8_t scale_v = vdupq_n_f16(scale);
+            res                       = vmulq_f16(scale_v, vaddq_f16(bottom_data.val[1], vaddq_f16(bottom_data.val[0], vaddq_f16(top_data.val[0], top_data.val[1]))));
+        }
+        else
+        {
+            res = vmaxq_f16(bottom_data.val[1], vmaxq_f16(bottom_data.val[0], vmaxq_f16(top_data.val[0], top_data.val[1])));
+        }
+        vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res);
+    },
+    input, output);
+#else  /* ARM_COMPUTE_ENABLE_FP16 */
+    ARM_COMPUTE_UNUSED(window_input);
+    ARM_COMPUTE_UNUSED(window);
+    ARM_COMPUTE_ERROR("FP16 Not supported! Recompile the library with arch=arm64-v8.2-a");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
+template <PoolingType pooling_type>
 void NEPoolingLayerKernel::pooling2_f32(const Window &window_input, const Window &window)
 {
     Iterator input(_input, window_input);
@@ -496,19 +624,29 @@
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
     ARM_COMPUTE_ERROR_ON(_func == nullptr);
 
-    unsigned int pool_stride_x, pool_stride_y = 0;
-    std::tie(pool_stride_x, pool_stride_y)    = _pool_info.pad_stride_info().stride();
+    const unsigned int pool_stride_x = _pool_info.pad_stride_info().stride().first;
+    const unsigned int pool_stride_y = _pool_info.pad_stride_info().stride().second;
 
     // Set step for input in x and y direction for the input
     Window       window_input(window);
     unsigned int window_x_inc = 0;
-    if(_input->info()->data_type() == DataType::QS8)
+    switch(_input->info()->data_type())
     {
-        window_x_inc = (pool_stride_x == 2) ? _num_elems_processed_per_iteration * 2 : _num_elems_processed_per_iteration;
-    }
-    else
-    {
-        window_x_inc = pool_stride_x;
+        case DataType::QS8:
+        case DataType::F16:
+        {
+            window_x_inc = (pool_stride_x == 2) ? _num_elems_processed_per_iteration * 2 : _num_elems_processed_per_iteration;
+            break;
+        }
+        case DataType::F32:
+        {
+            window_x_inc = pool_stride_x;
+            break;
+        }
+        default:
+        {
+            ARM_COMPUTE_ERROR("Not supported");
+        }
     }
     window_input.set(Window::DimX, Window::Dimension(window.x().start() * pool_stride_x, window.x().end() * pool_stride_x, window_x_inc));
     window_input.set(Window::DimY, Window::Dimension(window.y().start() * pool_stride_y, window.y().end() * pool_stride_y, pool_stride_y));