Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors
Resolves [COMPMID-5775]
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp
index 31a9a3b..5495023 100644
--- a/src/core/CL/kernels/CLGatherKernel.cpp
+++ b/src/core/CL/kernels/CLGatherKernel.cpp
@@ -38,8 +38,8 @@
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
const uint32_t actual_axis = wrap_around(axis, static_cast<int>(input->num_dimensions()));
- ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
+ ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() + indices->num_dimensions() - 1) > 4);
+
ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= input->num_dimensions());
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
@@ -102,7 +102,9 @@
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type())));
build_opts.add_option("-DOUTPUT_DIM_Z=" + support::cpp11::to_string(output->info()->dimension(2)));
+ build_opts.add_option("-DINDICES_DIM_Z=" + support::cpp11::to_string(indices->info()->dimension(2)));
build_opts.add_option("-DINPUT_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
+ build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->info()->num_dimensions()));
build_opts.add_option("-DAXIS=" + support::cpp11::to_string(_axis));
build_opts.add_option("-DINDEX_LIMIT=" + support::cpp11::to_string(input->info()->tensor_shape()[_axis]));
@@ -127,7 +129,7 @@
Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
unsigned int idx = 0;
add_4D_tensor_argument(idx, _input, window_collapsed);
- add_1D_tensor_argument(idx, _indices, window_collapsed);
+ add_4D_tensor_argument(idx, _indices, window_collapsed);
add_4D_tensor_argument(idx, _output, window_collapsed);
enqueue(queue, *this, window_collapsed, lws_hint());
}