MLCE-166: Add support for extracting indices in NEPoolingLayer 2x2 NHWC

     * Added support for pooling indices in NHWC Poolsize 2x2

Change-Id: Ib2a3468e794f58bbf2c03aba9f6b184b9d76b183
Signed-off-by: morgolock <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2997
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index fdbba81..6d61f51 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -156,7 +156,7 @@
         if(indices)
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MSG((pool_size != Size2D(2, 2)), "Pooling indices only supported for pool size 2x2");
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC, "Pool indices only supported in NCHW");
+
             ARM_COMPUTE_RETURN_ERROR_ON((indices->dimension(get_data_layout_dimension_index(indices->data_layout(), DataLayoutDimension::WIDTH)) != pooled_w)
                                         || (indices->dimension(get_data_layout_dimension_index(indices->data_layout(), DataLayoutDimension::HEIGHT)) != pooled_h));
         }
@@ -183,7 +183,9 @@
     if(indices)
     {
         // Indices auto inizialitation if not yet initialized
-        auto_init_if_empty(*indices, (input->clone()->set_tensor_shape(compute_pool_shape(*input, pool_info))).set_data_type(DataType::U32) /* we store the offset to the element */);
+        auto_init_if_empty(*indices, (input->clone()->set_tensor_shape(compute_pool_shape(*input,
+                                                                                          pool_info)))
+                           .set_data_type(DataType::U32) /* we store the offset to the element */);
     }
     const auto          data_layout                  = pool_info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : pool_info.data_layout;
     unsigned int        num_elems_read_per_iteration = 0;
@@ -1751,23 +1753,125 @@
 
 void NEPoolingLayerKernel::poolingMxN_f32_nhwc(const Window &window_input, const Window &window, PoolingType pooling_type, bool exclude_padding)
 {
+    if(_pool_info.pool_size == Size2D(2, 2) && pooling_type == PoolingType::MAX && _indices)
+    {
+        pooling2_f32_nhwc_maxpool_indices(window_input, window);
+    }
+    else
+    {
+        Iterator input(_input, window_input);
+        Iterator output(_output, window);
+
+        const int pool_size_x     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().y() : _pool_info.pool_size.width;
+        const int pool_size_y     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().z() : _pool_info.pool_size.height;
+        const int pool_pad_right  = _pool_info.pad_stride_info.pad_right();
+        const int pool_pad_top    = _pool_info.pad_stride_info.pad_top();
+        const int pool_pad_left   = _pool_info.pad_stride_info.pad_left();
+        const int pool_pad_bottom = _pool_info.pad_stride_info.pad_bottom();
+        int       pool_stride_x   = 0;
+        int       pool_stride_y   = 0;
+        std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info.stride();
+        const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
+        const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
+
+        float32x4_t vres;
+
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            const int idx_width    = id.y() * pool_stride_x;
+            const int idx_height   = id.z() * pool_stride_y;
+            const int pool_limit_y = pool_pad_top - idx_height;
+            const int pool_limit_x = pool_pad_left - idx_width;
+
+            const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
+            const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
+            const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
+            const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+
+            if(pooling_type != PoolingType::MAX)
+            {
+                // Calculate scale
+                const float scale = calculate_avg_scale(exclude_padding, DataLayout::NHWC, id, pool_size_x, pool_size_y, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x,
+                                                        pool_stride_y);
+                const float32x4_t scale_v = vdupq_n_f32(scale);
+
+                // Perform pooling
+                vres = vdupq_n_f32(0.0f);
+
+                for(int y = pool_start_y; y < pool_end_y; ++y)
+                {
+                    for(int x = pool_start_x; x < pool_end_x; ++x)
+                    {
+                        const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
+                                                                                           (_input->info()->strides_in_bytes().z())));
+
+                        // Get power of 2 in case of l2 pooling and accumulate
+                        if(pooling_type == PoolingType::L2)
+                        {
+                            vres = vmlaq_f32(vres, data, data);
+                        }
+                        else
+                        {
+                            vres = vaddq_f32(vres, data);
+                        }
+                    }
+                }
+                // Divide by scale
+                vres = vmulq_f32(vres, scale_v);
+            }
+            else
+            {
+                vres = vdupq_n_f32(std::numeric_limits<float>::lowest());
+                for(int y = pool_start_y; y < pool_end_y; ++y)
+                {
+                    for(int x = pool_start_x; x < pool_end_x; ++x)
+                    {
+                        const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
+                                                                                           (_input->info()->strides_in_bytes().z())));
+                        vres                   = vmaxq_f32(vres, data);
+                    }
+                }
+            }
+
+            // Calculate square-root in case of l2 pooling
+            if(pooling_type == PoolingType::L2)
+            {
+                float32x4_t l2_res = { static_cast<float>(sqrt(vgetq_lane_f32(vres, 0))),
+                                       static_cast<float>(sqrt(vgetq_lane_f32(vres, 1))),
+                                       static_cast<float>(sqrt(vgetq_lane_f32(vres, 2))),
+                                       static_cast<float>(sqrt(vgetq_lane_f32(vres, 3)))
+                                     };
+                vres = l2_res;
+            }
+
+            // Store result
+            vst1q_f32(reinterpret_cast<float *>(output.ptr()), vres);
+        },
+        input, output);
+    }
+}
+
+void NEPoolingLayerKernel::pooling2_f32_nhwc_maxpool_indices(const Window &window_input, const Window &window)
+{
     Iterator input(_input, window_input);
     Iterator output(_output, window);
+    Iterator indices(_indices, window);
 
-    const int pool_size_x     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().y() : _pool_info.pool_size.width;
-    const int pool_size_y     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().z() : _pool_info.pool_size.height;
-    const int pool_pad_right  = _pool_info.pad_stride_info.pad_right();
-    const int pool_pad_top    = _pool_info.pad_stride_info.pad_top();
-    const int pool_pad_left   = _pool_info.pad_stride_info.pad_left();
-    const int pool_pad_bottom = _pool_info.pad_stride_info.pad_bottom();
-    int       pool_stride_x   = 0;
-    int       pool_stride_y   = 0;
+    const int pool_pad_top  = _pool_info.pad_stride_info.pad_top();
+    const int pool_pad_left = _pool_info.pad_stride_info.pad_left();
+
+    int pool_stride_x = 0;
+    int pool_stride_y = 0;
     std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info.stride();
-    const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
-    const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
 
     float32x4_t vres;
 
+    const int pad_right   = _input->info()->padding().right;
+    const int pad_top     = _input->info()->padding().top;
+    const int in_stride_y = static_cast<int>(_input->info()->strides_in_bytes().y());
+    const int in_stride_z = static_cast<int>(_input->info()->strides_in_bytes().z());
+    const int in_stride_w = static_cast<int>(_input->info()->strides_in_bytes()[3]);
+
     execute_window_loop(window, [&](const Coordinates & id)
     {
         const int idx_width    = id.y() * pool_stride_x;
@@ -1776,70 +1880,53 @@
         const int pool_limit_x = pool_pad_left - idx_width;
 
         const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
-        const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
         const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
-        const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+        const int in_x0_offset = (pool_start_x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
+        const int in_x1_offset = (pool_start_x + 1 - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
 
-        if(pooling_type != PoolingType::MAX)
-        {
-            // Calculate scale
-            const float scale = calculate_avg_scale(exclude_padding, DataLayout::NHWC, id, pool_size_x, pool_size_y, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x,
-                                                    pool_stride_y);
-            const float32x4_t scale_v = vdupq_n_f32(scale);
+        const int in_x2_offset = (pool_start_x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y + 1 - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
 
-            // Perform pooling
-            vres = vdupq_n_f32(0.0f);
+        const int in_x3_offset = (pool_start_x + 1 - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y + 1 - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
 
-            for(int y = pool_start_y; y < pool_end_y; ++y)
-            {
-                for(int x = pool_start_x; x < pool_end_x; ++x)
-                {
-                    const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
-                                                                                       (_input->info()->strides_in_bytes().z())));
-
-                    // Get power of 2 in case of l2 pooling and accumulate
-                    if(pooling_type == PoolingType::L2)
-                    {
-                        vres = vmlaq_f32(vres, data, data);
-                    }
-                    else
-                    {
-                        vres = vaddq_f32(vres, data);
-                    }
-                }
-            }
-            // Divide by scale
-            vres = vmulq_f32(vres, scale_v);
-        }
-        else
-        {
-            vres = vdupq_n_f32(std::numeric_limits<float>::lowest());
-            for(int y = pool_start_y; y < pool_end_y; ++y)
-            {
-                for(int x = pool_start_x; x < pool_end_x; ++x)
-                {
-                    const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
-                                                                                       (_input->info()->strides_in_bytes().z())));
-                    vres                   = vmaxq_f32(vres, data);
-                }
-            }
-        }
-
-        // Calculate square-root in case of l2 pooling
-        if(pooling_type == PoolingType::L2)
-        {
-            float32x4_t l2_res = { static_cast<float>(sqrt(vgetq_lane_f32(vres, 0))),
-                                   static_cast<float>(sqrt(vgetq_lane_f32(vres, 1))),
-                                   static_cast<float>(sqrt(vgetq_lane_f32(vres, 2))),
-                                   static_cast<float>(sqrt(vgetq_lane_f32(vres, 3)))
-                                 };
-            vres = l2_res;
-        }
-
+        const auto in_x0_ptr = reinterpret_cast<const float *>(input.ptr() + in_x0_offset);
+        const auto in_x1_ptr = reinterpret_cast<const float *>(input.ptr() + in_x1_offset);
+        const auto in_x2_ptr = reinterpret_cast<const float *>(input.ptr() + in_x2_offset);
+        const auto in_x3_ptr = reinterpret_cast<const float *>(input.ptr() + in_x3_offset);
+        const auto v_x0      = vld1q_f32(in_x0_ptr);
+        const auto v_x1      = vld1q_f32(in_x1_ptr);
+        const auto v_x2      = vld1q_f32(in_x2_ptr);
+        const auto v_x3      = vld1q_f32(in_x3_ptr);
+        vres                 = vmaxq_f32(vmaxq_f32(v_x2, v_x3), vmaxq_f32(v_x0, v_x1));
         // Store result
         vst1q_f32(reinterpret_cast<float *>(output.ptr()), vres);
+
+        const uint32_t offset_base = input.offset()
+                                     - sizeof(float) * pad_right * id.y() * pool_stride_x                                     /* subtract padding elems per row */
+                                     - pad_top * sizeof(float)                                                                /* top padding */
+                                     - sizeof(float) * pad_right * _input->info()->tensor_shape()[1] * id.z() * pool_stride_y /* for each Z plane there are width*pad_right padding elems */
+                                     - in_stride_w * id[3] + _input->info()->tensor_shape()[0] * sizeof(float) * id[3];
+
+        const uint32_t offset_x0 = (uint32_t)offset_base / sizeof(float);
+        const uint32_t offset_x1 = (uint32_t)offset_x0 + in_stride_y / sizeof(float) - pad_right;
+        const uint32_t offset_x2 = (uint32_t)offset_x0 + in_stride_z / sizeof(float) - pad_right * _input->info()->tensor_shape()[1];
+        const uint32_t offset_x3 = (uint32_t)offset_x2 + in_stride_y / sizeof(float) - pad_right;
+
+        const uint32x4_t voffset_x0   = { offset_x0, offset_x0 + 1, offset_x0 + 2, offset_x0 + 3 };
+        const uint32x4_t voffset_x1   = { offset_x1, offset_x1 + 1, offset_x1 + 2, offset_x1 + 3 };
+        const uint32x4_t voffset_x2   = { offset_x2, offset_x2 + 1, offset_x2 + 2, offset_x2 + 3 };
+        const uint32x4_t voffset_x3   = { offset_x3, offset_x3 + 1, offset_x3 + 2, offset_x3 + 3 };
+        const uint32x4_t tmp_indices0 = vbslq_u32(vcgtq_f32(v_x0, v_x1), voffset_x0, voffset_x1);
+        const uint32x4_t tmp_indices1 = vbslq_u32(vcgtq_f32(v_x2, v_x3), voffset_x2, voffset_x3);
+        const uint32x4_t tmp_indices2 = vbslq_u32(vcgtq_f32(vmaxq_f32(v_x0, v_x1), vmaxq_f32(v_x2, v_x3)), tmp_indices0, tmp_indices1);
+
+        vst1q_u32(reinterpret_cast<uint32_t *>(indices.ptr()), tmp_indices2);
+
     },
-    input, output);
+    input, output, indices);
 }
 
 template <typename T>