Add support for 2d and 3d indices for axis 1

* Resolves COMPMID-5055

Change-Id: I2d14de29d3ec913d20c971bc8bbc9ad71e2d998f
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7547
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp
index 55c4525..085ab7c 100644
--- a/src/core/NEON/kernels/NEGatherKernel.cpp
+++ b/src/core/NEON/kernels/NEGatherKernel.cpp
@@ -44,19 +44,23 @@
  *
  * @param[in] indices Indices tensor info.
  */
+
 template <typename U>
 void validate_indices(const ITensor *indices)
 {
-    for(size_t i = 0; i < indices->info()->tensor_shape()[0]; ++i)
+    Window window;
+    window.use_tensor_dimensions(indices->info()->tensor_shape());
+    execute_window_loop(window, [&](const Coordinates & id)
     {
-        ARM_COMPUTE_ERROR_ON(*(reinterpret_cast<U *>(indices->ptr_to_element(Coordinates(i)))) < 0);
-    }
+        const auto i = *(reinterpret_cast<int32_t *>(indices->ptr_to_element(id)));
+        ARM_COMPUTE_UNUSED(i);
+        ARM_COMPUTE_ERROR_ON(i < 0);
+    });
 }
 
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
-    ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1);
     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
 
     if(axis < 0)
@@ -65,6 +69,7 @@
     }
 
     ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions()));
+    ARM_COMPUTE_RETURN_ERROR_ON(axis != 1 && indices->num_dimensions() > 1);
     ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
 
     if(output->total_size() != 0)
@@ -87,6 +92,37 @@
 }
 
 template <typename U>
+inline void NEGatherKernel::gather_multiindices_1_axis(const Window &window, const ThreadInfo &info)
+{
+    ARM_COMPUTE_UNUSED(info);
+    ARM_COMPUTE_ERROR_ON(_indices->info()->num_dimensions() < 2 || _indices->info()->num_dimensions() > 3);
+    validate_indices<U>(_indices);
+    Window win = window;
+    win.set(Window::DimX, Window::Dimension(0, 1, 1));
+    execute_window_loop(win, [&](const Coordinates & id)
+    {
+        auto       *dst_ptr = _output->ptr_to_element(id);
+        Coordinates index_offset;
+        for(uint32_t k = 0; k < _indices->info()->num_dimensions(); ++k)
+        {
+            index_offset.set(k, id[k + 1]);
+        }
+        const uint32_t row = *(reinterpret_cast<uint32_t *>(_indices->ptr_to_element(index_offset)));
+        Coordinates    src_offset;
+        // Set up input coords to read the row specified by the current index
+        src_offset.set(0, 0);
+        src_offset.set(1, row);
+        for(uint32_t j = 2; j < _input->info()->num_dimensions(); ++j)
+        {
+            src_offset.set(j, id[1 + _indices->info()->num_dimensions() + (j - 2)]);
+        }
+        const auto in_ptr_row = _input->ptr_to_element(src_offset);
+        // Copy a row from input to output
+        memcpy(dst_ptr, in_ptr_row, _input->info()->tensor_shape()[0] * _input->info()->element_size());
+    });
+}
+
+template <typename U>
 inline void NEGatherKernel::gather_0_axis(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
@@ -147,38 +183,64 @@
     }
     ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions()));
 
-    if(0 == _axis)
+    if(indices->info()->num_dimensions() == 1u)
     {
-        switch(_indices->info()->data_type())
+        if(_axis == 0)
         {
-            case DataType::U32:
-                _func = &NEGatherKernel::gather_0_axis<uint32_t>;
-                break;
-            case DataType::S32:
-                _func = &NEGatherKernel::gather_0_axis<int32_t>;
-                break;
-            default:
-                ARM_COMPUTE_ERROR("Not supported");
-                break;
+            switch(_indices->info()->data_type())
+            {
+                case DataType::U32:
+                    _func = &NEGatherKernel::gather_0_axis<uint32_t>;
+                    break;
+                case DataType::S32:
+                    _func = &NEGatherKernel::gather_0_axis<int32_t>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
+        }
+        else
+        {
+            switch(_indices->info()->data_type())
+            {
+                case DataType::U32:
+                    _func = &NEGatherKernel::gather_n_axis<uint32_t>;
+                    break;
+                case DataType::S32:
+                    _func = &NEGatherKernel::gather_n_axis<int32_t>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
         }
     }
     else
     {
-        switch(_indices->info()->data_type())
+        if(_axis == 1)
         {
-            case DataType::U32:
-                _func = &NEGatherKernel::gather_n_axis<uint32_t>;
-                break;
-            case DataType::S32:
-                _func = &NEGatherKernel::gather_n_axis<int32_t>;
-                break;
-            default:
-                ARM_COMPUTE_ERROR("Not supported");
-                break;
+            switch(_indices->info()->data_type())
+            {
+                case DataType::U32:
+                    _func = &NEGatherKernel::gather_multiindices_1_axis<uint32_t>;
+                    break;
+                case DataType::S32:
+                    _func = &NEGatherKernel::gather_multiindices_1_axis<int32_t>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
+        }
+        else
+        {
+            ARM_COMPUTE_ERROR("Not supported");
         }
     }
+
     // Output auto initialization if not yet initialized
-    TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
+    const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
     auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
 
     // Create window
diff --git a/src/core/NEON/kernels/NEGatherKernel.h b/src/core/NEON/kernels/NEGatherKernel.h
index 6f00ddb..3dc0cad 100644
--- a/src/core/NEON/kernels/NEGatherKernel.h
+++ b/src/core/NEON/kernels/NEGatherKernel.h
@@ -61,17 +61,17 @@
     /** Initialise the kernel's inputs and outputs
      *
      * @param[in]  input   Source tensor. Supported tensor rank: up to 4. Data type supported: All
-     * @param[in]  indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
+     * @param[in]  indices Indices tensor. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
+     *                     @note 2D or 3D indices are only supported for the axis 1.
      * @param[out] output  Destination tensor. Data type supported: Same as @p input
-     * @param[in]  axis    (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0
+     * @param[in]  axis    (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0.
+     *
      */
     void configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis = 0);
-    /** Static function to check if given info will lead to a valid configuration of @ref NEGatherKernel
+
+    /** Static function to check if given info will lead to a valid configuration
      *
-     * @param[in] input   Source tensor info. Supported tensor rank: up to 4. Data type supported: All
-     * @param[in] indices Indices tensor info. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
-     * @param[in] output  Destination tensor info. Data type supported: Same as @p input
-     * @param[in] axis    (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0
+     * Similar to @ref NEGatherKernel::configure()
      *
      * @return a status
      */
@@ -85,18 +85,20 @@
      *
      * For gather on the 0 axis an element by element copy is performed.
      *
-     * @param[in] window Region on which to execute the kernel. (Must be a region of the window returned by window())
-     * @param[in] info   Info about executing thread and CPU.
+     * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
+     * @param[in] info   Info about running thread and CPU.
      */
     template <typename U>
     void gather_0_axis(const Window &window, const ThreadInfo &info);
 
+    template <typename U>
+    void gather_multiindices_1_axis(const Window &window, const ThreadInfo &info);
     /** Implementation of the gather operation.
      *
      * For 1<=axis a row-wise copy is taking place.
      *
-     * @param[in] window Region on which to execute the kernel. (Must be a region of the window returned by window())
-     * @param[in] info   Info about executing thread and CPU.
+     * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
+     * @param[in] info   Info about running thread and CPU.
      */
     template <typename U>
     void gather_n_axis(const Window &window, const ThreadInfo &info);