Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors

Resolves [COMPMID-5775]

Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/CL/Gather.cpp b/tests/validation/CL/Gather.cpp
index f0b87d7..7619baa 100644
--- a/tests/validation/CL/Gather.cpp
+++ b/tests/validation/CL/Gather.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -48,19 +48,21 @@
         framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),
                                                 TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
                                                 TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
-                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),     // Invalid Indices data type
-                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),     // Invalid Indices dimensionality
-                                                TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32),    // Invalid Input dimensionality
-                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),     // Mismatching data type input/output
-                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),     // Invalid positive axis value
-                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),     // Invalid negative axis value
+                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),                // Invalid Output shape
+                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),                // Invalid Indices data type
+                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),                // Invalid Indices dimensionality
+                                                TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32),      // Invalid Input dimensionality
+                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),                // Mismatching data type input/output
+                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),                // Invalid positive axis value
+                                                TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),                // Invalid negative axis value
         }),
         framework::dataset::make("IndicesInfo", {
                                                 TensorInfo(TensorShape(10U), 1, DataType::U32),
                                                 TensorInfo(TensorShape(10U), 1, DataType::U32),
                                                 TensorInfo(TensorShape(10U), 1, DataType::U32),
-                                                TensorInfo(TensorShape(10U), 1, DataType::U8),
                                                 TensorInfo(TensorShape(10U, 10U), 1, DataType::U32),
+                                                TensorInfo(TensorShape(10U), 1, DataType::U8),
+                                                TensorInfo(TensorShape(10U, 10U, 10U, 10U), 1, DataType::U32),
                                                 TensorInfo(TensorShape(10U), 1, DataType::U32),
                                                 TensorInfo(TensorShape(10U), 1, DataType::U32),
                                                 TensorInfo(TensorShape(10U), 1, DataType::U32),
@@ -71,7 +73,8 @@
                                                 TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
                                                 TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
                                                 TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
-                                                TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
+                                                TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
+                                                TensorInfo(TensorShape(27U, 10U, 10U, 10U, 10U), 1, DataType::F32),
                                                 TensorInfo(TensorShape(10U, 5U, 5U, 5U, 5U), 1, DataType::F32),
                                                 TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
                                                 TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
@@ -82,13 +85,14 @@
                                             1,
                                             -2,
                                             0,
+                                            0,
                                             1,
                                             0,
                                             1,
                                             2,
                                             -3,
         })),
-        framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false })),
+        framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false, false })),
         input_info, indices_info, output_info, axis, expected)
 {
     const Status status = CLGather::validate(&input_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), axis);
@@ -111,6 +115,15 @@
     validate(CLAccessor(_target), _reference);
 }
 
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+                       CLGatherFixture<half>,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F16)))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        CLGatherFixture<half>,
                        framework::DatasetMode::NIGHTLY,
@@ -131,6 +144,15 @@
     validate(CLAccessor(_target), _reference);
 }
 
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+                       CLGatherFixture<float>,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F32)))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        CLGatherFixture<float>,
                        framework::DatasetMode::NIGHTLY,
@@ -152,6 +174,15 @@
     validate(CLAccessor(_target), _reference);
 }
 
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+                       CLGatherFixture<uint8_t>,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U8)))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        CLGatherFixture<uint8_t>,
                        framework::DatasetMode::NIGHTLY,
@@ -172,6 +203,16 @@
     validate(CLAccessor(_target), _reference);
 }
 
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+                       CLGatherFixture<uint16_t>,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U16)))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        CLGatherFixture<uint16_t>,
                        framework::DatasetMode::NIGHTLY,