Add support for 2d and 3d indices for axis 1

* Resolves COMPMID-5055

Change-Id: I2d14de29d3ec913d20c971bc8bbc9ad71e2d998f
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7547
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index df907c1..9f9f53e 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1494,15 +1494,53 @@
     return output_shape;
 }
 
+/** Calculate the gather output shape of a tensor
+ *
+ * @param[in] input_shape   Input tensor shape
+ * @param[in] indices_shape Indices tensor shape. Only supports for 2d and 3d indices
+ * @param[in] actual_axis   Axis to be used in the computation
+ *
+ * @note Let input_shape be (X,Y,Z) and indices shape (W,O,P) and axis 1
+ *       the new shape is computed by replacing the axis in the input shape with
+ *       the indice shape so the output shape will be (X,W,O,P,Z)
+ *
+ * @return the calculated shape
+ */
 inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
 {
-    ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 1);
     ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
     ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
-
-    TensorShape output_shape  = input_shape;
-    output_shape[actual_axis] = indices_shape[0];
-
+    ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
+    TensorShape output_shape = input_shape;
+    if(indices_shape.num_dimensions() == 1u)
+    {
+        output_shape[actual_axis] = indices_shape[0];
+    }
+    else
+    {
+        const auto ind_num_dims
+        {
+            indices_shape.num_dimensions()
+        };
+        output_shape.shift_right(ind_num_dims - 1);
+        switch(actual_axis)
+        {
+            case 1:
+            {
+                output_shape[0] = input_shape[0];
+                for(size_t idx = 0; idx < ind_num_dims; ++idx)
+                {
+                    output_shape.set(actual_axis + idx, indices_shape[idx], false);
+                }
+                break;
+            }
+            default:
+            {
+                // 2d and 3d indices are only supported for axis == 1
+                ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1);
+            }
+        }
+    }
     return output_shape;
 }
 } // namespace shape_calculator
diff --git a/arm_compute/runtime/NEON/functions/NEGather.h b/arm_compute/runtime/NEON/functions/NEGather.h
index 393a38e..8253e98 100644
--- a/arm_compute/runtime/NEON/functions/NEGather.h
+++ b/arm_compute/runtime/NEON/functions/NEGather.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -49,18 +49,17 @@
      * |All            |All            |
      *
      * @param[in]  input   Source tensor. Supported tensor rank: up to 4. Data type supported: All
-     * @param[in]  indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
+     * @param[in]  indices Indices tensor. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
+     *                     @note The "axis" must be in the range [0, input.rank -1] when indices is a vector, and must be 1 when indices is a 2D or 3D tensor.
      * @param[out] output  Destination tensor. Data type supported: Same as @p input
      * @param[in]  axis    (Optional) The axis in @p input to gather @p indices from. Defaults to 0
+     *
      */
     void configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis = 0);
 
-    /** Static function to check if given info will lead to a valid configuration of @ref NEGatherKernel
+    /** Static function to check if given info will lead to a valid configuration
      *
-     * @param[in] input   Source tensor info. Supported tensor rank: up to 4. Data type supported: All
-     * @param[in] indices Indices tensor info. Supported tensor rank: up to 1. Must be one of the following types: U32/S32. Each value Must be in range [0, input.shape[@p axis])
-     * @param[in] output  Destination tensor info. Data type supported: Same as @p input
-     * @param[in] axis    (Optional) The axis in @p input to gather @p indices from. Defaults to 0
+     * Similar to @ref NEGather::configure()
      *
      * @return a status
      */
diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp
index 55c4525..085ab7c 100644
--- a/src/core/NEON/kernels/NEGatherKernel.cpp
+++ b/src/core/NEON/kernels/NEGatherKernel.cpp
@@ -44,19 +44,23 @@
  *
  * @param[in] indices Indices tensor info.
  */
+
 template <typename U>
 void validate_indices(const ITensor *indices)
 {
-    for(size_t i = 0; i < indices->info()->tensor_shape()[0]; ++i)
+    Window window;
+    window.use_tensor_dimensions(indices->info()->tensor_shape());
+    execute_window_loop(window, [&](const Coordinates & id)
     {
-        ARM_COMPUTE_ERROR_ON(*(reinterpret_cast<U *>(indices->ptr_to_element(Coordinates(i)))) < 0);
-    }
+        const auto i = *(reinterpret_cast<int32_t *>(indices->ptr_to_element(id)));
+        ARM_COMPUTE_UNUSED(i);
+        ARM_COMPUTE_ERROR_ON(i < 0);
+    });
 }
 
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
-    ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1);
     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
 
     if(axis < 0)
@@ -65,6 +69,7 @@
     }
 
     ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions()));
+    ARM_COMPUTE_RETURN_ERROR_ON(axis != 1 && indices->num_dimensions() > 1);
     ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
 
     if(output->total_size() != 0)
@@ -87,6 +92,37 @@
 }
 
 template <typename U>
+inline void NEGatherKernel::gather_multiindices_1_axis(const Window &window, const ThreadInfo &info)
+{
+    ARM_COMPUTE_UNUSED(info);
+    ARM_COMPUTE_ERROR_ON(_indices->info()->num_dimensions() < 2 || _indices->info()->num_dimensions() > 3);
+    validate_indices<U>(_indices);
+    Window win = window;
+    win.set(Window::DimX, Window::Dimension(0, 1, 1));
+    execute_window_loop(win, [&](const Coordinates & id)
+    {
+        auto       *dst_ptr = _output->ptr_to_element(id);
+        Coordinates index_offset;
+        for(uint32_t k = 0; k < _indices->info()->num_dimensions(); ++k)
+        {
+            index_offset.set(k, id[k + 1]);
+        }
+        const uint32_t row = *(reinterpret_cast<uint32_t *>(_indices->ptr_to_element(index_offset)));
+        Coordinates    src_offset;
+        // Set up input coords to read the row specified by the current index
+        src_offset.set(0, 0);
+        src_offset.set(1, row);
+        for(uint32_t j = 2; j < _input->info()->num_dimensions(); ++j)
+        {
+            src_offset.set(j, id[1 + _indices->info()->num_dimensions() + (j - 2)]);
+        }
+        const auto in_ptr_row = _input->ptr_to_element(src_offset);
+        // Copy a row from input to output
+        memcpy(dst_ptr, in_ptr_row, _input->info()->tensor_shape()[0] * _input->info()->element_size());
+    });
+}
+
+template <typename U>
 inline void NEGatherKernel::gather_0_axis(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
@@ -147,38 +183,64 @@
     }
     ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions()));
 
-    if(0 == _axis)
+    if(indices->info()->num_dimensions() == 1u)
     {
-        switch(_indices->info()->data_type())
+        if(_axis == 0)
         {
-            case DataType::U32:
-                _func = &NEGatherKernel::gather_0_axis<uint32_t>;
-                break;
-            case DataType::S32:
-                _func = &NEGatherKernel::gather_0_axis<int32_t>;
-                break;
-            default:
-                ARM_COMPUTE_ERROR("Not supported");
-                break;
+            switch(_indices->info()->data_type())
+            {
+                case DataType::U32:
+                    _func = &NEGatherKernel::gather_0_axis<uint32_t>;
+                    break;
+                case DataType::S32:
+                    _func = &NEGatherKernel::gather_0_axis<int32_t>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
+        }
+        else
+        {
+            switch(_indices->info()->data_type())
+            {
+                case DataType::U32:
+                    _func = &NEGatherKernel::gather_n_axis<uint32_t>;
+                    break;
+                case DataType::S32:
+                    _func = &NEGatherKernel::gather_n_axis<int32_t>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
         }
     }
     else
     {
-        switch(_indices->info()->data_type())
+        if(_axis == 1)
         {
-            case DataType::U32:
-                _func = &NEGatherKernel::gather_n_axis<uint32_t>;
-                break;
-            case DataType::S32:
-                _func = &NEGatherKernel::gather_n_axis<int32_t>;
-                break;
-            default:
-                ARM_COMPUTE_ERROR("Not supported");
-                break;
+            switch(_indices->info()->data_type())
+            {
+                case DataType::U32:
+                    _func = &NEGatherKernel::gather_multiindices_1_axis<uint32_t>;
+                    break;
+                case DataType::S32:
+                    _func = &NEGatherKernel::gather_multiindices_1_axis<int32_t>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
+        }
+        else
+        {
+            ARM_COMPUTE_ERROR("Not supported");
         }
     }
+
     // Output auto initialization if not yet initialized
-    TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
+    const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
     auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
 
     // Create window
diff --git a/src/core/NEON/kernels/NEGatherKernel.h b/src/core/NEON/kernels/NEGatherKernel.h
index 6f00ddb..3dc0cad 100644
--- a/src/core/NEON/kernels/NEGatherKernel.h
+++ b/src/core/NEON/kernels/NEGatherKernel.h
@@ -61,17 +61,17 @@
     /** Initialise the kernel's inputs and outputs
      *
      * @param[in]  input   Source tensor. Supported tensor rank: up to 4. Data type supported: All
-     * @param[in]  indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
+     * @param[in]  indices Indices tensor. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
+     *                     @note 2D or 3D indices are only supported for the axis 1.
      * @param[out] output  Destination tensor. Data type supported: Same as @p input
-     * @param[in]  axis    (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0
+     * @param[in]  axis    (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0.
+     *
      */
     void configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis = 0);
-    /** Static function to check if given info will lead to a valid configuration of @ref NEGatherKernel
+
+    /** Static function to check if given info will lead to a valid configuration
      *
-     * @param[in] input   Source tensor info. Supported tensor rank: up to 4. Data type supported: All
-     * @param[in] indices Indices tensor info. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis])
-     * @param[in] output  Destination tensor info. Data type supported: Same as @p input
-     * @param[in] axis    (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0
+     * Similar to @ref NEGatherKernel::configure()
      *
      * @return a status
      */
@@ -85,18 +85,20 @@
      *
      * For gather on the 0 axis an element by element copy is performed.
      *
-     * @param[in] window Region on which to execute the kernel. (Must be a region of the window returned by window())
-     * @param[in] info   Info about executing thread and CPU.
+     * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
+     * @param[in] info   Info about running thread and CPU.
      */
     template <typename U>
     void gather_0_axis(const Window &window, const ThreadInfo &info);
 
+    template <typename U>
+    void gather_multiindices_1_axis(const Window &window, const ThreadInfo &info);
     /** Implementation of the gather operation.
      *
      * For 1<=axis a row-wise copy is taking place.
      *
-     * @param[in] window Region on which to execute the kernel. (Must be a region of the window returned by window())
-     * @param[in] info   Info about executing thread and CPU.
+     * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
+     * @param[in] info   Info about running thread and CPU.
      */
     template <typename U>
     void gather_n_axis(const Window &window, const ThreadInfo &info);
diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h
index 444b62c..8fec544 100644
--- a/tests/datasets/GatherDataset.h
+++ b/tests/datasets/GatherDataset.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2019, 2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -106,6 +106,19 @@
     std::vector<int>         _axis{};
 };
 
+class SmallGatherMultiDimIndicesDataset final : public GatherDataset
+{
+public:
+    SmallGatherMultiDimIndicesDataset()
+    {
+        add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1);
+        add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1);
+        add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1);
+        add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1);
+        add_config(TensorShape(1U, 5U, 3U), TensorShape(1U, 7U, 3U), 1);
+    }
+};
+
 class SmallGatherDataset final : public GatherDataset
 {
 public:
diff --git a/tests/validation/NEON/Gather.cpp b/tests/validation/NEON/Gather.cpp
index 71f98ea..0aea199 100644
--- a/tests/validation/NEON/Gather.cpp
+++ b/tests/validation/NEON/Gather.cpp
@@ -100,12 +100,14 @@
 template <typename T>
 using NEGatherFixture = GatherFixture<Tensor, Accessor, NEGather, T>;
 
+const auto gather_small_shapes = arm_compute::test::framework::dataset::concat(datasets::SmallGatherDataset(), datasets::SmallGatherMultiDimIndicesDataset());
+
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        NEGatherFixture<half>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::F16)))
+                       combine(gather_small_shapes, framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -125,7 +127,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        NEGatherFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::F32)))
+                       combine(gather_small_shapes, framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -146,7 +148,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        NEGatherFixture<uint8_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::U8)))
+                       combine(gather_small_shapes, framework::dataset::make("DataType", DataType::U8)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -166,7 +168,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        NEGatherFixture<uint16_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::U16)))
+                       combine(gather_small_shapes, framework::dataset::make("DataType", DataType::U16)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp
index c264388..8de1a47 100644
--- a/tests/validation/reference/Gather.cpp
+++ b/tests/validation/reference/Gather.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2019, 2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -45,22 +45,55 @@
 
     Window win;
     win.use_tensor_dimensions(dst_shape);
-    execute_window_loop(win, [&](const Coordinates & id)
+    if(indices.shape().num_dimensions() == 1u)
     {
-        Coordinates offset;
-        for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim)
+        execute_window_loop(win, [&](const Coordinates & id)
         {
-            if(dim == actual_axis)
+            Coordinates offset;
+            for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim)
             {
-                offset.set(dim, indices_ptr[id[dim]]);
+                if(dim == actual_axis)
+                {
+                    offset.set(dim, indices_ptr[id[dim]]);
+                }
+                else
+                {
+                    offset.set(dim, id[dim]);
+                }
             }
-            else
+            *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset));
+        });
+    }
+    else
+    {
+        if(actual_axis == 1)
+        {
+            win.set(Window::DimX, Window::Dimension(0, 1, 1));
+            execute_window_loop(win, [&](const Coordinates & id)
             {
-                offset.set(dim, id[dim]);
-            }
+                auto       *dst_ptr = dst(id);
+                Coordinates index_offset;
+                for(uint32_t k = 0; k < indices.shape().num_dimensions(); ++k)
+                {
+                    index_offset.set(k, id[k + 1]);
+                }
+                const uint32_t row = *reinterpret_cast<const uint32_t *>(indices(index_offset));
+                Coordinates    src_offset;
+                src_offset.set(0, 0);
+                src_offset.set(1, row);
+                for(uint32_t j = 2; j < src.shape().num_dimensions(); ++j)
+                {
+                    src_offset.set(j, id[1 + indices.shape().num_dimensions() + (j - 2)]);
+                }
+                const auto in_ptr_row = src(src_offset);
+                memcpy(dst_ptr, in_ptr_row, src.shape()[0] * src.element_size());
+            });
         }
-        *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset));
-    });
+        else
+        {
+            ARM_COMPUTE_ERROR("Not implemented.");
+        }
+    }
 
     return dst;
 }
@@ -72,4 +105,4 @@
 } // namespace reference
 } // namespace validation
 } // namespace test
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute