blob: e891419e9bb385a91b03d81516599ddc4ecf515d [file] [log] [blame]
Georgios Pinitas77589b52018-08-21 14:41:35 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET
25#define ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET
26
27#include "utils/TypePrinter.h"
28
29#include "arm_compute/core/Types.h"
30
31namespace arm_compute
32{
33namespace test
34{
35namespace datasets
36{
Georgios Pinitasc1a72452018-08-24 11:25:32 +010037class SliceDataset
38{
39public:
40 using type = std::tuple<TensorShape, Coordinates, Coordinates>;
41
42 struct iterator
43 {
44 iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it,
45 std::vector<Coordinates>::const_iterator starts_values_it,
46 std::vector<Coordinates>::const_iterator ends_values_it)
47 : _tensor_shapes_it{ std::move(tensor_shapes_it) },
48 _starts_values_it{ std::move(starts_values_it) },
49 _ends_values_it{ std::move(ends_values_it) }
50 {
51 }
52
53 std::string description() const
54 {
55 std::stringstream description;
56 description << "Shape=" << *_tensor_shapes_it << ":";
57 description << "Starts=" << *_starts_values_it << ":";
58 description << "Ends=" << *_ends_values_it << ":";
59 return description.str();
60 }
61
62 SliceDataset::type operator*() const
63 {
64 return std::make_tuple(*_tensor_shapes_it, *_starts_values_it, *_ends_values_it);
65 }
66
67 iterator &operator++()
68 {
69 ++_tensor_shapes_it;
70 ++_starts_values_it;
71 ++_ends_values_it;
72 return *this;
73 }
74
75 private:
76 std::vector<TensorShape>::const_iterator _tensor_shapes_it;
77 std::vector<Coordinates>::const_iterator _starts_values_it;
78 std::vector<Coordinates>::const_iterator _ends_values_it;
79 };
80
81 iterator begin() const
82 {
83 return iterator(_tensor_shapes.begin(), _starts_values.begin(), _ends_values.begin());
84 }
85
86 int size() const
87 {
88 return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), _ends_values.size()));
89 }
90
91 void add_config(TensorShape shape, Coordinates starts, Coordinates ends)
92 {
93 _tensor_shapes.emplace_back(std::move(shape));
94 _starts_values.emplace_back(std::move(starts));
95 _ends_values.emplace_back(std::move(ends));
96 }
97
98protected:
99 SliceDataset() = default;
100 SliceDataset(SliceDataset &&) = default;
101
102private:
103 std::vector<TensorShape> _tensor_shapes{};
104 std::vector<Coordinates> _starts_values{};
105 std::vector<Coordinates> _ends_values{};
106};
107
Georgios Pinitas77589b52018-08-21 14:41:35 +0100108class StridedSliceDataset
109{
110public:
111 using type = std::tuple<TensorShape, Coordinates, Coordinates, BiStrides, int32_t, int32_t, int32_t>;
112
113 struct iterator
114 {
115 iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it,
116 std::vector<Coordinates>::const_iterator starts_values_it,
117 std::vector<Coordinates>::const_iterator ends_values_it,
118 std::vector<BiStrides>::const_iterator strides_values_it,
119 std::vector<int32_t>::const_iterator begin_mask_values_it,
120 std::vector<int32_t>::const_iterator end_mask_values_it,
121 std::vector<int32_t>::const_iterator shrink_mask_values_it)
122 : _tensor_shapes_it{ std::move(tensor_shapes_it) },
123 _starts_values_it{ std::move(starts_values_it) },
124 _ends_values_it{ std::move(ends_values_it) },
125 _strides_values_it{ std::move(strides_values_it) },
126 _begin_mask_values_it{ std::move(begin_mask_values_it) },
127 _end_mask_values_it{ std::move(end_mask_values_it) },
128 _shrink_mask_values_it{ std::move(shrink_mask_values_it) }
129 {
130 }
131
132 std::string description() const
133 {
134 std::stringstream description;
135 description << "Shape=" << *_tensor_shapes_it << ":";
136 description << "Starts=" << *_starts_values_it << ":";
137 description << "Ends=" << *_ends_values_it << ":";
138 description << "Strides=" << *_strides_values_it << ":";
139 description << "BeginMask=" << *_begin_mask_values_it << ":";
140 description << "EndMask=" << *_end_mask_values_it << ":";
141 description << "ShrinkMask=" << *_shrink_mask_values_it << ":";
142 return description.str();
143 }
144
145 StridedSliceDataset::type operator*() const
146 {
147 return std::make_tuple(*_tensor_shapes_it,
148 *_starts_values_it, *_ends_values_it, *_strides_values_it,
149 *_begin_mask_values_it, *_end_mask_values_it, *_shrink_mask_values_it);
150 }
151
152 iterator &operator++()
153 {
154 ++_tensor_shapes_it;
155 ++_starts_values_it;
156 ++_ends_values_it;
157 ++_strides_values_it;
158 ++_begin_mask_values_it;
159 ++_end_mask_values_it;
160 ++_shrink_mask_values_it;
161
162 return *this;
163 }
164
165 private:
166 std::vector<TensorShape>::const_iterator _tensor_shapes_it;
167 std::vector<Coordinates>::const_iterator _starts_values_it;
168 std::vector<Coordinates>::const_iterator _ends_values_it;
169 std::vector<BiStrides>::const_iterator _strides_values_it;
170 std::vector<int32_t>::const_iterator _begin_mask_values_it;
171 std::vector<int32_t>::const_iterator _end_mask_values_it;
172 std::vector<int32_t>::const_iterator _shrink_mask_values_it;
173 };
174
175 iterator begin() const
176 {
177 return iterator(_tensor_shapes.begin(),
178 _starts_values.begin(), _ends_values.begin(), _strides_values.begin(),
179 _begin_mask_values.begin(), _end_mask_values.begin(), _shrink_mask_values.begin());
180 }
181
182 int size() const
183 {
184 return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), std::min(_ends_values.size(), _strides_values.size())));
185 }
186
187 void add_config(TensorShape shape,
188 Coordinates starts, Coordinates ends, BiStrides strides,
189 int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_mask = 0)
190 {
191 _tensor_shapes.emplace_back(std::move(shape));
192 _starts_values.emplace_back(std::move(starts));
193 _ends_values.emplace_back(std::move(ends));
194 _strides_values.emplace_back(std::move(strides));
195 _begin_mask_values.emplace_back(std::move(begin_mask));
196 _end_mask_values.emplace_back(std::move(end_mask));
197 _shrink_mask_values.emplace_back(std::move(shrink_mask));
198 }
199
200protected:
201 StridedSliceDataset() = default;
202 StridedSliceDataset(StridedSliceDataset &&) = default;
203
204private:
205 std::vector<TensorShape> _tensor_shapes{};
206 std::vector<Coordinates> _starts_values{};
207 std::vector<Coordinates> _ends_values{};
208 std::vector<BiStrides> _strides_values{};
209 std::vector<int32_t> _begin_mask_values{};
210 std::vector<int32_t> _end_mask_values{};
211 std::vector<int32_t> _shrink_mask_values{};
212};
213
Georgios Pinitasc1a72452018-08-24 11:25:32 +0100214class SmallSliceDataset final : public SliceDataset
215{
216public:
217 SmallSliceDataset()
218 {
219 // 1D
220 add_config(TensorShape(15U), Coordinates(4), Coordinates(9));
221 add_config(TensorShape(15U), Coordinates(0), Coordinates(-1));
222 // 2D
223 add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1));
224 add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1));
225 // 3D
226 add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4));
227 add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4));
228 // 4D
229 add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5));
230 }
231};
232
233class LargeSliceDataset final : public SliceDataset
234{
235public:
236 LargeSliceDataset()
237 {
238 // 1D
239 add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100));
240 // 2D
241 add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -1));
242 // 3D
243 add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, 2), Coordinates(368, -1, 4));
244 // 4D
245 add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, 17, 5));
246 }
247};
248
Georgios Pinitas77589b52018-08-21 14:41:35 +0100249class SmallStridedSliceDataset final : public StridedSliceDataset
250{
251public:
252 SmallStridedSliceDataset()
253 {
254 // 1D
255 add_config(TensorShape(15U), Coordinates(0), Coordinates(5), BiStrides(2));
256 add_config(TensorShape(15U), Coordinates(-1), Coordinates(-8), BiStrides(-2));
257 // 2D
258 add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1), BiStrides(2, 1));
259 add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1), BiStrides(2, 1), 1);
260 // 3D
261 add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2));
262 add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2), 0, 1);
263 // 4D
264 add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5), BiStrides(2, 1, 2, 3));
Georgios Pinitasb4af2c62018-12-10 18:45:35 +0000265
266 // Shrink axis
267 add_config(TensorShape(1U, 3U, 2U, 3U), Coordinates(0, 1, 0, 0), Coordinates(1, 1, 1, 1), BiStrides(1, 1, 1, 1), 0, 15, 6);
268 add_config(TensorShape(3U, 2U), Coordinates(0, 0), Coordinates(3U, 1U), BiStrides(1, 1), 0, 0, 2);
269 add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 0, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 6, 1);
270 add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 1, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 5, 3);
Georgios Pinitas77589b52018-08-21 14:41:35 +0100271 }
272};
273
274class LargeStridedSliceDataset final : public StridedSliceDataset
275{
276public:
277 LargeStridedSliceDataset()
278 {
279 // 1D
280 add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100), BiStrides(20));
281 // 2D
Georgios Pinitasc1a72452018-08-24 11:25:32 +0100282 add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -30), BiStrides(10, 7));
Georgios Pinitas77589b52018-08-21 14:41:35 +0100283 // 3D
Georgios Pinitasc1a72452018-08-24 11:25:32 +0100284 add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, -1), Coordinates(368, -30, -5), BiStrides(14, 7, -2));
Georgios Pinitas77589b52018-08-21 14:41:35 +0100285 // 4D
Georgios Pinitasc1a72452018-08-24 11:25:32 +0100286 add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, -30, 5), BiStrides(20, 7, 2), 1, 1);
Georgios Pinitas77589b52018-08-21 14:41:35 +0100287 }
288};
Georgios Pinitas77589b52018-08-21 14:41:35 +0100289} // namespace datasets
290} // namespace test
291} // namespace arm_compute
292#endif /* ARM_COMPUTE_TEST_STRIDED_SLICE_DATASET */