blob: 203925329c5e723c22246395083af28c14d7f9e1 [file] [log] [blame]
Georgios Pinitase1a352c2018-09-03 12:42:19 +01001/*
Matthew Bentham945b8da2023-07-12 11:54:59 +00002 * Copyright (c) 2018-2021, 2023 Arm Limited.
Georgios Pinitase1a352c2018-09-03 12:42:19 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_SPLIT_FIXTURE
25#define ARM_COMPUTE_TEST_SPLIT_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29
30#include "tests/AssetsLibrary.h"
31#include "tests/Globals.h"
32#include "tests/IAccessor.h"
Georgios Pinitase1a352c2018-09-03 12:42:19 +010033#include "tests/framework/Asserts.h"
34#include "tests/framework/Fixture.h"
35#include "tests/validation/Helpers.h"
36#include "tests/validation/reference/SliceOperations.h"
37
38#include <algorithm>
39
40namespace arm_compute
41{
42namespace test
43{
44namespace validation
45{
46template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T>
47class SplitFixture : public framework::Fixture
48{
49public:
Georgios Pinitase1a352c2018-09-03 12:42:19 +010050 void setup(TensorShape shape, unsigned int axis, unsigned int splits, DataType data_type)
51 {
52 _target = compute_target(shape, axis, splits, data_type);
53 _reference = compute_reference(shape, axis, splits, data_type);
54 }
55
56protected:
57 template <typename U>
58 void fill(U &&tensor, int i)
59 {
60 library->fill_tensor_uniform(tensor, i);
61 }
62
63 std::vector<TensorType> compute_target(const TensorShape &shape, unsigned int axis, unsigned int splits, DataType data_type)
64 {
65 // Create tensors
66 TensorType src = create_tensor<TensorType>(shape, data_type);
67 std::vector<TensorType> dsts(splits);
68 std::vector<ITensorType *> dsts_ptr;
69 for(auto &dst : dsts)
70 {
71 dsts_ptr.emplace_back(&dst);
72 }
73
74 // Create and configure function
75 FunctionType split;
76 split.configure(&src, dsts_ptr, axis);
77
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +010078 ARM_COMPUTE_ASSERT(src.info()->is_resizable());
Georgios Pinitase1a352c2018-09-03 12:42:19 +010079 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
80 {
81 return t.info()->is_resizable();
82 }),
83 framework::LogLevel::ERRORS);
84
85 // Allocate tensors
86 src.allocator()->allocate();
87 for(unsigned int i = 0; i < splits; ++i)
88 {
89 dsts[i].allocator()->allocate();
90 }
91
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +010092 ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
Georgios Pinitase1a352c2018-09-03 12:42:19 +010093 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
94 {
95 return !t.info()->is_resizable();
96 }),
97 framework::LogLevel::ERRORS);
98
99 // Fill tensors
100 fill(AccessorType(src), 0);
101
102 // Compute function
103 split.run();
104
105 return dsts;
106 }
107
108 std::vector<SimpleTensor<T>> compute_reference(const TensorShape &shape, unsigned int axis, unsigned int splits, DataType data_type)
109 {
110 // Create reference
111 SimpleTensor<T> src{ shape, data_type };
112 std::vector<SimpleTensor<T>> dsts;
113
114 // Fill reference
115 fill(src, 0);
116
117 // Calculate splice for each split
118 const size_t axis_split_step = shape[axis] / splits;
119 unsigned int axis_offset = 0;
120
121 // Start/End coordinates
122 Coordinates start_coords;
123 Coordinates end_coords;
124 for(unsigned int d = 0; d < shape.num_dimensions(); ++d)
125 {
126 end_coords.set(d, -1);
127 }
128
129 for(unsigned int i = 0; i < splits; ++i)
130 {
131 // Update coordinate on axis
132 start_coords.set(axis, axis_offset);
133 end_coords.set(axis, axis_offset + axis_split_step);
134
135 dsts.emplace_back(std::move(reference::slice(src, start_coords, end_coords)));
136
137 axis_offset += axis_split_step;
138 }
139
140 return dsts;
141 }
142
143 std::vector<TensorType> _target{};
144 std::vector<SimpleTensor<T>> _reference{};
145};
Kurtis Charnockec00da12019-11-29 11:42:30 +0000146
147template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T>
148class SplitShapesFixture : public framework::Fixture
149{
150public:
Kurtis Charnockec00da12019-11-29 11:42:30 +0000151 void setup(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type)
152 {
153 _target = compute_target(shape, axis, split_shapes, data_type);
154 _reference = compute_reference(shape, axis, split_shapes, data_type);
155 }
156
157protected:
158 template <typename U>
159 void fill(U &&tensor, int i)
160 {
161 library->fill_tensor_uniform(tensor, i);
162 }
163
164 std::vector<TensorType> compute_target(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type)
165 {
166 // Create tensors
167 TensorType src = create_tensor<TensorType>(shape, data_type);
168 std::vector<TensorType> dsts{};
169 std::vector<ITensorType *> dsts_ptr;
170
171 for(const auto &split_shape : split_shapes)
172 {
173 TensorType dst = create_tensor<TensorType>(split_shape, data_type);
174 dsts.push_back(std::move(dst));
175 }
176
177 for(auto &dst : dsts)
178 {
179 dsts_ptr.emplace_back(&dst);
180 }
181
182 // Create and configure function
183 FunctionType split;
184 split.configure(&src, dsts_ptr, axis);
185
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100186 ARM_COMPUTE_ASSERT(src.info()->is_resizable());
Kurtis Charnockec00da12019-11-29 11:42:30 +0000187 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
188 {
189 return t.info()->is_resizable();
190 }),
191 framework::LogLevel::ERRORS);
192
193 // Allocate tensors
194 src.allocator()->allocate();
195 for(unsigned int i = 0; i < dsts.size(); ++i)
196 {
197 dsts[i].allocator()->allocate();
198 }
199
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100200 ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
Kurtis Charnockec00da12019-11-29 11:42:30 +0000201 ARM_COMPUTE_EXPECT(std::all_of(dsts.cbegin(), dsts.cend(), [](const TensorType & t)
202 {
203 return !t.info()->is_resizable();
204 }),
205 framework::LogLevel::ERRORS);
206
207 // Fill tensors
208 fill(AccessorType(src), 0);
209
210 // Compute function
211 split.run();
212
213 return dsts;
214 }
215
216 std::vector<SimpleTensor<T>> compute_reference(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes, DataType data_type)
217 {
218 // Create reference
219 SimpleTensor<T> src{ shape, data_type };
220 std::vector<SimpleTensor<T>> dsts;
221
222 // Fill reference
223 fill(src, 0);
224
225 unsigned int axis_offset{ 0 };
226 for(const auto &split_shape : split_shapes)
227 {
228 // Calculate splice for each split
229 const size_t axis_split_step = split_shape[axis];
230
231 // Start/End coordinates
232 Coordinates start_coords;
233 Coordinates end_coords;
234 for(unsigned int d = 0; d < shape.num_dimensions(); ++d)
235 {
236 end_coords.set(d, -1);
237 }
238
239 // Update coordinate on axis
240 start_coords.set(axis, axis_offset);
241 end_coords.set(axis, axis_offset + axis_split_step);
242
243 dsts.emplace_back(std::move(reference::slice(src, start_coords, end_coords)));
244
245 axis_offset += axis_split_step;
246 }
247
248 return dsts;
249 }
250
251 std::vector<TensorType> _target{};
252 std::vector<SimpleTensor<T>> _reference{};
253};
Georgios Pinitase1a352c2018-09-03 12:42:19 +0100254} // namespace validation
255} // namespace test
256} // namespace arm_compute
257#endif /* ARM_COMPUTE_TEST_SPLIT_FIXTURE */