Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors

Resolves [COMPMID-5775]

Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/common/gather.cl b/src/core/CL/cl_kernels/common/gather.cl
index a47c8a7..5d180f3 100644
--- a/src/core/CL/cl_kernels/common/gather.cl
+++ b/src/core/CL/cl_kernels/common/gather.cl
@@ -59,34 +59,70 @@
  */
 __kernel void gather(
     TENSOR4D_DECLARATION(input),
-    VECTOR_DECLARATION(indices),
+    TENSOR4D_DECLARATION(indices),
     TENSOR4D_DECLARATION(output))
 {
     const int px = get_global_id(0);
     const int py = get_global_id(1);
     const int pz = get_global_id(2) % OUTPUT_DIM_Z;
-    const int pw = get_global_id(2) / OUTPUT_DIM_Z;
+    const int pw = (get_global_id(2) / OUTPUT_DIM_Z );
 
     const Tensor4D input   = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(input, INPUT_DIM_Z);
-    const Vector   indices = CONVERT_TO_VECTOR_STRUCT_NO_STEP(indices);
+    const Tensor4D indices = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(indices, INDICES_DIM_Z);
     Tensor4D       output  = CONVERT_TO_TENSOR4D_STRUCT(output, OUTPUT_DIM_Z);
 
 #if AXIS == 0
-    const uint index                 = *(__global const uint *)vector_offset(&indices, px);
+#if INDICES_DIMS == 1
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, px, 0, 0, 0);
     const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
     __global const uchar *input_addr = tensor4D_offset(&input, safe_index, py, pz, pw);
+#elif INDICES_DIMS == 2
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, px, py, 0, 0);
+    const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
+    __global const uchar *input_addr = tensor4D_offset(&input, safe_index, pz, pw, 0);
+#elif INDICES_DIMS == 3
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, px, py, pz, 0);
+    const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
+    __global const uchar *input_addr = tensor4D_offset(&input, safe_index, pw, 0, 0);
+#elif INDICES_DIMS == 4
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, px, py, pz, pw);
+    const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
+    __global const uchar *input_addr = tensor4D_offset(&input, safe_index, 0, 0, 0);
+#endif //INDICES_DIMS
+
 #elif AXIS == 1
-    const uint index                 = *(__global const uint *)vector_offset(&indices, py);
+#if INDICES_DIMS == 1
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, py, 0, 0, 0);
     const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
-    __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pz, pw);
+     __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pz, pw);
+#elif INDICES_DIMS == 2
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, py, pz, 0, 0);
+    const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
+    __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pw, 0);
+#elif INDICES_DIMS == 3
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, py, pz, pw, 0);
+    const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
+    __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, 0, 0);
+#endif //INDICES_DIMS
+
 #elif AXIS == 2
-    const uint index                 = *(__global const uint *)vector_offset(&indices, pz);
+#if INDICES_DIMS == 1
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, pz, 0, 0, 0);
     const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
-    __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, pw);
+     __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, pw);
+#elif INDICES_DIMS == 2
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, pz, pw, 0, 0);
+    const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
+    __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, 0);
+#endif //INDICES_DIMS
+
 #elif AXIS == 3
-    const uint index                 = *(__global const uint *)vector_offset(&indices, pw);
+#if INDICES_DIMS == 1
+    const uint index                 = *(__global const uint *)tensor4D_offset(&indices, pw, 0, 0, 0);
     const uint safe_index            = select((uint)0, index, index < INDEX_LIMIT);
-    __global const uchar *input_addr = tensor4D_offset(&input, px, py, pz, safe_index);
+     __global const uchar *input_addr = tensor4D_offset(&input, px, py, pz, safe_index);
+#endif //INDICES_DIMS
+
 #endif //AXIS
 
     *(__global DATA_TYPE *)output.ptr = select((DATA_TYPE)0, *((__global const DATA_TYPE *)input_addr), (DATA_TYPE)(index < INDEX_LIMIT));
diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp
index 31a9a3b..5495023 100644
--- a/src/core/CL/kernels/CLGatherKernel.cpp
+++ b/src/core/CL/kernels/CLGatherKernel.cpp
@@ -38,8 +38,8 @@
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
     const uint32_t actual_axis = wrap_around(axis, static_cast<int>(input->num_dimensions()));
-    ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1);
-    ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
+    ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() + indices->num_dimensions() - 1) > 4);
+
     ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= input->num_dimensions());
     ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
 
@@ -102,7 +102,9 @@
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type())));
     build_opts.add_option("-DOUTPUT_DIM_Z=" + support::cpp11::to_string(output->info()->dimension(2)));
+    build_opts.add_option("-DINDICES_DIM_Z=" + support::cpp11::to_string(indices->info()->dimension(2)));
     build_opts.add_option("-DINPUT_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
+    build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->info()->num_dimensions()));
     build_opts.add_option("-DAXIS=" + support::cpp11::to_string(_axis));
     build_opts.add_option("-DINDEX_LIMIT=" + support::cpp11::to_string(input->info()->tensor_shape()[_axis]));
 
@@ -127,7 +129,7 @@
     Window       window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
     unsigned int idx              = 0;
     add_4D_tensor_argument(idx, _input, window_collapsed);
-    add_1D_tensor_argument(idx, _indices, window_collapsed);
+    add_4D_tensor_argument(idx, _indices, window_collapsed);
     add_4D_tensor_argument(idx, _output, window_collapsed);
     enqueue(queue, *this, window_collapsed, lws_hint());
 }