COMPMID-1534: Fix 2x2 NEPoolingLayer for FP16

Change-Id: Icaf45cad826bb0966a6c663ecb7e828f5fe5e5db
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145336
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index 2ca6090..ad4b8f7 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -129,7 +129,7 @@
     v = vsetq_lane_u16(elems[7], v, 7);
 }
 
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, unsigned int &pooled_w, unsigned int pooled_h, int pool_size_x)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, unsigned int &pooled_w, unsigned int pooled_h)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
 
@@ -138,15 +138,11 @@
     PoolingType         pool_type       = pool_info.pool_type();
     const PadStrideInfo pad_stride_info = pool_info.pad_stride_info();
     std::tie(pool_stride_x, pool_stride_y) = pad_stride_info.stride();
-    static const std::set<int> supported_pool_sizes = { 2, 3 };
 
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON(pool_type == PoolingType::L2 && is_data_type_quantized(input->data_type()));
 
-    ARM_COMPUTE_RETURN_ERROR_ON((supported_pool_sizes.find(pool_size_x) == supported_pool_sizes.end()) && ((input->data_type() != DataType::F32) && (input->data_type() != DataType::QASYMM8))
-                                && (pool_type != PoolingType::MAX));
-
     if(output->total_size() != 0)
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
@@ -239,10 +235,6 @@
                 switch(pool_size_x)
                 {
                     case 2:
-                        num_elems_read_per_iteration      = 16;
-                        num_elems_processed_per_iteration = 8;
-                        num_elems_horizontal_window       = 8;
-                        break;
                     case 3:
                         num_elems_read_per_iteration      = 4;
                         num_elems_processed_per_iteration = 1;
@@ -285,14 +277,8 @@
     {
         if(is_nhwc)
         {
-            if(DataType::QASYMM8 == input->data_type())
-            {
-                num_elems_processed_per_iteration = 8;
-            }
-            else
-            {
-                num_elems_processed_per_iteration = 4;
-            }
+            const unsigned int vector_size    = 16 / input->element_size();
+            num_elems_processed_per_iteration = (input->data_type() == DataType::QASYMM8) ? 8 : vector_size;
         }
     }
 
@@ -389,7 +375,7 @@
     auto_init(input->info(), output->info(), pooled_w, pooled_h);
 
     // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info, pooled_w, pooled_h, pool_size_x));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info, pooled_w, pooled_h));
 
     // Set instance variables
     _input     = input;
@@ -1053,38 +1039,39 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        auto        top_data    = vld2q_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
-        auto        bottom_data = vld2q_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
-        float16x8_t res         = {};
+        float16x4_t top_data    = vld1_f16(reinterpret_cast<const float16_t *>(input_top_ptr + input.offset()));
+        float16x4_t bottom_data = vld1_f16(reinterpret_cast<const float16_t *>(input_bottom_ptr + input.offset()));
+        float16x4_t res         = {};
 
         // Get power of 2 in case of l2 pooling
         if(pooling_type == PoolingType::L2)
         {
-            top_data.val[0]    = vmulq_f16(top_data.val[0], top_data.val[0]);
-            top_data.val[1]    = vmulq_f16(top_data.val[1], top_data.val[1]);
-            bottom_data.val[0] = vmulq_f16(bottom_data.val[0], bottom_data.val[0]);
-            bottom_data.val[1] = vmulq_f16(bottom_data.val[1], bottom_data.val[1]);
+            top_data    = vmul_f16(top_data, top_data);
+            bottom_data = vmul_f16(bottom_data, bottom_data);
         }
 
         if(pooling_type != PoolingType::MAX)
         {
             const float       scale   = calculate_avg_scale<exclude_padding, DataLayout::NCHW>(id, pool_size, pool_size, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x, pool_stride_y);
-            const float16x8_t scale_v = vdupq_n_f16(scale);
-            res                       = vmulq_f16(scale_v, vaddq_f16(bottom_data.val[1], vaddq_f16(bottom_data.val[0], vaddq_f16(top_data.val[0], top_data.val[1]))));
+            const float16x4_t scale_v = vdup_n_f16(scale);
+
+            const float16x4_t sum_data = vadd_f16(top_data, bottom_data);
+            res                        = vmul_f16(vpadd_f16(sum_data, sum_data), scale_v);
         }
         else
         {
-            res = vmaxq_f16(bottom_data.val[1], vmaxq_f16(bottom_data.val[0], vmaxq_f16(top_data.val[0], top_data.val[1])));
+            const float16x4_t max_data = vmax_f16(top_data, bottom_data);
+            res                        = vpmax_f16(max_data, max_data);
         }
 
         // Calculate square-root in case of l2 pooling
         if(pooling_type == PoolingType::L2)
         {
-            res = vinvq_f16(vinvsqrtq_f16(res));
+            res = vinv_f16(vinvsqrt_f16(res));
         }
 
         // Store result
-        vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), res);
+        *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(res, 0);
     },
     input, output);
 #else  /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -2107,7 +2094,7 @@
                                                      pool_size_y,
                                                      pool_info.pad_stride_info());
 
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, pool_info, pooled_w, pooled_h, pool_size_x));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, pool_info, pooled_w, pooled_h));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), pool_info, num_elems_processed_per_iteration, border_size, pooled_w, pooled_h,
                                                               pool_size_x, pool_size_y)
                                 .first);
@@ -2133,11 +2120,6 @@
         unsigned int window_x_inc = 0;
         switch(_input->info()->data_type())
         {
-            case DataType::F16:
-            {
-                window_x_inc = (pool_stride_x == 2) ? _num_elems_processed_per_iteration * 2 : _num_elems_processed_per_iteration;
-                break;
-            }
             case DataType::QASYMM8:
             {
                 window_x_inc = pool_stride_x;
@@ -2147,6 +2129,7 @@
                 }
                 break;
             }
+            case DataType::F16:
             case DataType::F32:
             {
                 window_x_inc = pool_stride_x;