COMPMID-1546 Optimize PoolingLayer NHWC on NEON for all data types

Change-Id: I4920e43059a713126f15493f38fe50f07d0a8c7f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/151087
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index 1fa8f47..fdd3410 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -200,7 +200,7 @@
             case DataType::QASYMM8:
                 if(is_nhwc)
                 {
-                    num_elems_processed_per_iteration = 8;
+                    num_elems_processed_per_iteration = 16;
                     break;
                 }
                 switch(pool_size_x)
@@ -271,8 +271,7 @@
     {
         if(is_nhwc)
         {
-            const unsigned int vector_size    = 16 / input->element_size();
-            num_elems_processed_per_iteration = (input->data_type() == DataType::QASYMM8) ? 8 : vector_size;
+            num_elems_processed_per_iteration = 16 / input->element_size();
         }
     }
 
@@ -1552,8 +1551,16 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const int idx_width  = id.y() * pool_stride_x;
-        const int idx_height = id.z() * pool_stride_y;
+        const int idx_width    = id.y() * pool_stride_x;
+        const int idx_height   = id.z() * pool_stride_y;
+        const int pool_limit_y = pool_pad_top - idx_height;
+        const int pool_limit_x = pool_pad_left - idx_width;
+
+        const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
+        const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
+        const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
+        const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+
         if(pooling_type != PoolingType::MAX)
         {
             // Calculate scale
@@ -1563,21 +1570,10 @@
 
             // Perform pooling
             vres = vdupq_n_f16(0.0f);
-
-            for(int y = 0; y < pool_size_y; ++y)
+            for(int y = pool_start_y; y < pool_end_y; ++y)
             {
-                if(y + idx_height - pool_pad_top >= window_input.z().end() || y + idx_height - pool_pad_top < window_input.z().start())
+                for(int x = pool_start_x; x < pool_end_x; ++x)
                 {
-                    continue;
-                }
-
-                for(int x = 0; x < pool_size_x; ++x)
-                {
-                    if(x + idx_width - pool_pad_left >= window_input.y().end() || x + idx_width - pool_pad_left < window_input.y().start())
-                    {
-                        continue;
-                    }
-
                     const float16x8_t data = vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
                                                                                            (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
 
@@ -1598,20 +1594,11 @@
         else
         {
             vres = vdupq_n_f16(std::numeric_limits<float>::lowest());
-            for(int y = 0; y < pool_size_y; ++y)
+
+            for(int y = pool_start_y; y < pool_end_y; ++y)
             {
-                if(y + idx_height > window_input.z().end() || y + idx_height - pool_pad_top < window_input.z().start())
+                for(int x = pool_start_x; x < pool_end_x; ++x)
                 {
-                    continue;
-                }
-
-                for(int x = 0; x < pool_size_x; ++x)
-                {
-                    if(x + idx_width > window_input.y().end() || x + idx_width - pool_pad_left < window_input.y().start())
-                    {
-                        continue;
-                    }
-
                     const float16x8_t data = vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
                                                                                            (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
                     vres                   = vmaxq_f16(vres, data);
@@ -1783,8 +1770,16 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const int idx_width  = id.y() * pool_stride_x;
-        const int idx_height = id.z() * pool_stride_y;
+        const int idx_width    = id.y() * pool_stride_x;
+        const int idx_height   = id.z() * pool_stride_y;
+        const int pool_limit_y = pool_pad_top - idx_height;
+        const int pool_limit_x = pool_pad_left - idx_width;
+
+        const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
+        const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
+        const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
+        const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+
         if(pooling_type != PoolingType::MAX)
         {
             // Calculate scale
@@ -1795,20 +1790,10 @@
             // Perform pooling
             vres = vdupq_n_f32(0.0f);
 
-            for(int y = 0; y < pool_size_y; ++y)
+            for(int y = pool_start_y; y < pool_end_y; ++y)
             {
-                if(y + idx_height - pool_pad_top >= window_input.z().end() || y + idx_height - pool_pad_top < window_input.z().start())
+                for(int x = pool_start_x; x < pool_end_x; ++x)
                 {
-                    continue;
-                }
-
-                for(int x = 0; x < pool_size_x; ++x)
-                {
-                    if(x + idx_width - pool_pad_left >= window_input.y().end() || x + idx_width - pool_pad_left < window_input.y().start())
-                    {
-                        continue;
-                    }
-
                     const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
                                                                                        (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
 
@@ -1829,20 +1814,10 @@
         else
         {
             vres = vdupq_n_f32(std::numeric_limits<float>::lowest());
-            for(int y = 0; y < pool_size_y; ++y)
+            for(int y = pool_start_y; y < pool_end_y; ++y)
             {
-                if(y + idx_height - pool_pad_top >= window_input.z().end() || y + idx_height - pool_pad_top < window_input.z().start())
+                for(int x = pool_start_x; x < pool_end_x; ++x)
                 {
-                    continue;
-                }
-
-                for(int x = 0; x < pool_size_x; ++x)
-                {
-                    if(x + idx_width - pool_pad_left >= window_input.y().end() || x + idx_width - pool_pad_left < window_input.y().start())
-                    {
-                        continue;
-                    }
-
                     const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
                                                                                        (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
                     vres                   = vmaxq_f32(vres, data);
@@ -1979,12 +1954,22 @@
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const int idx_width  = id.y() * pool_stride_x;
-        const int idx_height = id.z() * pool_stride_y;
+        const int idx_width    = id.y() * pool_stride_x;
+        const int idx_height   = id.z() * pool_stride_y;
+        const int pool_limit_y = pool_pad_top - idx_height;
+        const int pool_limit_x = pool_pad_left - idx_width;
+
+        const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
+        const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
+        const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
+        const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+
         if(pooling_type != PoolingType::MAX)
         {
             uint32x4_t vres1 = vdupq_n_u32(0);
             uint32x4_t vres2 = vdupq_n_u32(0);
+            uint32x4_t vres3 = vdupq_n_u32(0);
+            uint32x4_t vres4 = vdupq_n_u32(0);
 
             // Calculate scale
             const float scale = calculate_avg_scale<exclude_padding, DataLayout::NHWC>(id, pool_size_x, pool_size_y, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x,
@@ -1992,63 +1977,50 @@
             const float32x4_t scale_v = vdupq_n_f32(scale);
 
             // Perform pooling
-            for(int y = 0; y < pool_size_y; ++y)
+            for(int y = pool_start_y; y < pool_end_y; ++y)
             {
-                if(y + idx_height - pool_pad_top >= window_input.z().end() || y + idx_height - pool_pad_top < window_input.z().start())
+                for(int x = pool_start_x; x < pool_end_x; ++x)
                 {
-                    continue;
-                }
+                    const uint8x16_t data = vld1q_u8(reinterpret_cast<const uint8_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
+                                                                                       (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
 
-                for(int x = 0; x < pool_size_x; ++x)
-                {
-                    if(x + idx_width - pool_pad_left >= window_input.y().end() || x + idx_width - pool_pad_left < window_input.y().start())
-                    {
-                        continue;
-                    }
-
-                    const uint8x8_t data = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
-                                                                                     (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
-
-                    const uint16x8_t data_u16 = vmovl_u8(data);
-                    vres1                     = vaddq_u32(vres1, vmovl_u16(vget_low_u16(data_u16)));
-                    vres2                     = vaddq_u32(vres2, vmovl_u16(vget_high_u16(data_u16)));
+                    const uint16x8_t data_u16  = vmovl_u8(vget_low_u8(data));
+                    const uint16x8_t data2_u16 = vmovl_u8(vget_high_u8(data));
+                    vres1                      = vaddq_u32(vres1, vmovl_u16(vget_low_u16(data_u16)));
+                    vres2                      = vaddq_u32(vres2, vmovl_u16(vget_high_u16(data_u16)));
+                    vres3                      = vaddq_u32(vres3, vmovl_u16(vget_low_u16(data2_u16)));
+                    vres4                      = vaddq_u32(vres4, vmovl_u16(vget_high_u16(data2_u16)));
                 }
             }
             // Divide by scale
             vres1 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres1), scale_v));
             vres2 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres2), scale_v));
+            vres3 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres3), scale_v));
+            vres4 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres4), scale_v));
 
-            uint8x8_t res = vmovn_u16(vcombine_u16(vmovn_u32(vres1), vmovn_u32(vres2)));
+            uint8x8_t res1 = vmovn_u16(vcombine_u16(vmovn_u32(vres1), vmovn_u32(vres2)));
+            uint8x8_t res2 = vmovn_u16(vcombine_u16(vmovn_u32(vres3), vmovn_u32(vres4)));
 
             // Store result
-            vst1_u8(output.ptr(), res);
+            vst1_u8(output.ptr(), res1);
+            vst1_u8(output.ptr() + 8, res2);
         }
         else
         {
-            uint8x8_t vres = vdup_n_u8(0);
+            uint8x16_t vres = vdupq_n_u8(0);
 
-            for(int y = 0; y < pool_size_y; ++y)
+            for(int y = pool_start_y; y < pool_end_y; ++y)
             {
-                if(y + idx_height - pool_pad_top >= window_input.z().end() || y + idx_height - pool_pad_top < window_input.z().start())
+                for(int x = pool_start_x; x < pool_end_x; ++x)
                 {
-                    continue;
-                }
-
-                for(int x = 0; x < pool_size_x; ++x)
-                {
-                    if(x + idx_width - pool_pad_left >= window_input.y().end() || x + idx_width - pool_pad_left < window_input.y().start())
-                    {
-                        continue;
-                    }
-
-                    const uint8x8_t data = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
-                                                                                     (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
-                    vres                 = vmax_u8(vres, data);
+                    const uint8x16_t data = vld1q_u8(reinterpret_cast<const uint8_t *>(input.ptr() + (x - pool_pad_left) * _input->info()->strides_in_bytes().y() +
+                                                                                       (y - pool_pad_top) * _input->info()->strides_in_bytes().z()));
+                    vres                  = vmaxq_u8(vres, data);
                 }
             }
 
             // Store result
-            vst1_u8(output.ptr(), vres);
+            vst1q_u8(output.ptr(), vres);
         }
     },
     input, output);