blob: 8b0972f99afd27aeb1bc0e747c9081c24aaac639 [file] [log] [blame]
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +00001/*
2 * Copyright (c) 2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ACL_TESTS_DATASETS_SCATTERDATASET_H
25#define ACL_TESTS_DATASETS_SCATTERDATASET_H
26
27#include "arm_compute/core/TensorShape.h"
28#include "utils/TypePrinter.h"
29
30namespace arm_compute
31{
32namespace test
33{
34namespace datasets
35{
36
37class ScatterDataset
38{
39public:
40 using type = std::tuple<TensorShape, TensorShape, TensorShape, TensorShape>;
41
42 struct iterator
43 {
44 iterator(std::vector<TensorShape>::const_iterator src_it,
45 std::vector<TensorShape>::const_iterator updates_it,
46 std::vector<TensorShape>::const_iterator indices_it,
47 std::vector<TensorShape>::const_iterator dst_it)
48 : _src_it{ std::move(src_it) },
49 _updates_it{ std::move(updates_it) },
50 _indices_it{std::move(indices_it)},
51 _dst_it{ std::move(dst_it) }
52 {
53 }
54
55 std::string description() const
56 {
57 std::stringstream description;
58 description << "A=" << *_src_it << ":";
59 description << "B=" << *_updates_it << ":";
60 description << "C=" << *_indices_it << ":";
61 description << "Out=" << *_dst_it << ":";
62 return description.str();
63 }
64
65 ScatterDataset::type operator*() const
66 {
67 return std::make_tuple(*_src_it, *_updates_it, *_indices_it, *_dst_it);
68 }
69
70 iterator &operator++()
71 {
72 ++_src_it;
73 ++_updates_it;
74 ++_indices_it;
75 ++_dst_it;
76
77 return *this;
78 }
79
80 private:
81 std::vector<TensorShape>::const_iterator _src_it;
82 std::vector<TensorShape>::const_iterator _updates_it;
83 std::vector<TensorShape>::const_iterator _indices_it;
84 std::vector<TensorShape>::const_iterator _dst_it;
85 };
86
87 iterator begin() const
88 {
89 return iterator(_src_shapes.begin(), _update_shapes.begin(), _indices_shapes.begin(), _dst_shapes.begin());
90 }
91
92 int size() const
93 {
94 return std::min(_src_shapes.size(), std::min(_indices_shapes.size(), std::min(_update_shapes.size(), _dst_shapes.size())));
95 }
96
97 void add_config(TensorShape a, TensorShape b, TensorShape c, TensorShape dst)
98 {
99 _src_shapes.emplace_back(std::move(a));
100 _update_shapes.emplace_back(std::move(b));
101 _indices_shapes.emplace_back(std::move(c));
102 _dst_shapes.emplace_back(std::move(dst));
103 }
104
105protected:
106 ScatterDataset() = default;
107 ScatterDataset(ScatterDataset &&) = default;
108
109private:
110 std::vector<TensorShape> _src_shapes{};
111 std::vector<TensorShape> _update_shapes{};
112 std::vector<TensorShape> _indices_shapes{};
113 std::vector<TensorShape> _dst_shapes{};
114};
115
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100116
117// 1D dataset for simple scatter tests.
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000118class Small1DScatterDataset final : public ScatterDataset
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000119{
120public:
Mohammed Suhail Munshi473b8292024-03-18 12:13:30 +0000121 Small1DScatterDataset()
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000122 {
Mohammed Suhail Munshi73771072024-03-25 15:55:42 +0000123 add_config(TensorShape(6U), TensorShape(6U), TensorShape(1U, 6U), TensorShape(6U));
124 add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000125 }
126};
Mohammed Suhail Munshi0e212362024-04-08 14:38:31 +0100127
128// This dataset represents the (m+1)-D updates/dst case.
129class SmallScatterMultiDimDataset final : public ScatterDataset
130{
131public:
132 SmallScatterMultiDimDataset()
133 {
134 // NOTE: Config is src, updates, indices, output.
135 // - In this config, the dim replaced is the final number (largest tensor dimension)
136 // - Largest "updates" dim should match y-dim of indices.
137 // - src/updates/dst should all have same number of dims. Indices should be 2D.
138 add_config(TensorShape(6U, 5U), TensorShape(6U, 2U), TensorShape(1U, 2U), TensorShape(6U, 5U));
139 add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
140 add_config(TensorShape(3U, 2U, 4U, 2U), TensorShape(3U, 2U, 4U, 2U), TensorShape(1U, 2U), TensorShape(3U, 2U, 4U, 2U));
141 }
142};
143
144// This dataset represents the (m+1)-D updates tensor, (m+n)-d output tensor cases
145class SmallScatterMultiIndicesDataset final : public ScatterDataset
146{
147public:
148 SmallScatterMultiIndicesDataset()
149 {
150 // NOTE: Config is src, updates, indices, output.
151 // NOTE: indices.shape.x = src.num_dimensions - updates.num_dimensions + 1
152
153 // index length is 2
154 add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 4U), TensorShape(2U, 4U), TensorShape(6U, 5U, 2U));
155 add_config(TensorShape(17U, 3U, 3U, 2U), TensorShape(17U, 3U, 2U), TensorShape(2U, 2U), TensorShape(17U, 3U, 3U, 2U));
156 add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
157 add_config(TensorShape(5U, 4U, 3U, 3U, 2U, 4U), TensorShape(5U, 4U, 3U, 3U, 5U), TensorShape(2U, 5U), TensorShape(5U, 4U, 3U, 3U, 2U, 4U));
158
159 // index length is 3
160 add_config(TensorShape(4U, 3U, 2U, 2U), TensorShape(4U, 2U), TensorShape(3U, 2U), TensorShape(4U, 3U, 2U, 2U));
161 add_config(TensorShape(17U, 4U, 3U, 2U, 2U), TensorShape(17U, 4U, 4U), TensorShape(3U, 4U), TensorShape(17U, 4U, 3U, 2U, 2U));
162 add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 4U, 5U, 3U), TensorShape(3U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
163
164 // index length is 4
165 add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
166 add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 4U, 3U), TensorShape(4U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
167
168 // index length is 5
169 add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 3U), TensorShape(5U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
170 }
171};
172
173// This dataset represents the (m+k)-D updates tensor, (k+1)-d indices tensor and (m+n)-d output tensor cases
174class SmallScatterBatchedDataset final : public ScatterDataset
175{
176public:
177 SmallScatterBatchedDataset()
178 {
179 // NOTE: Config is src, updates, indices, output.
180 // NOTE: Updates/Indices tensors are now batched.
181 // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1
182 add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U));
183 add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));
184 add_config(TensorShape(6U, 5U, 2U, 2U), TensorShape(3U, 2U), TensorShape(4U, 3U, 2U), TensorShape(6U, 5U, 2U, 2U));
185 add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(6U, 2U), TensorShape(5U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
186 }
187};
Mohammed Suhail Munshi8609ca02024-02-29 17:00:07 +0000188} // namespace datasets
189} // namespace test
190} // namespace arm_compute
191#endif // ACL_TESTS_DATASETS_SCATTERDATASET_H