Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors
Resolves [COMPMID-5775]
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h
index 487ce19..74ea3b4 100644
--- a/tests/datasets/GatherDataset.h
+++ b/tests/datasets/GatherDataset.h
@@ -126,6 +126,44 @@
}
};
+class CLSmallGatherMultiDimIndicesDataset final : public GatherDataset
+{
+public:
+ CLSmallGatherMultiDimIndicesDataset()
+ {
+ add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 0);
+ add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 0);
+ add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 0);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 0);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0);
+ add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),0);
+
+ add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1);
+ add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1);
+ add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 1);
+ add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 1);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 1);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),1);
+
+ add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 2);
+ add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 2);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 2);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 2);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 2);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),2);
+ }
+};
+
class SmallGatherDataset final : public GatherDataset
{
public: