blob: 44c096c52125fc271e194093a5a13afc7c587c2b [file] [log] [blame]
giuros01164a2722018-11-20 18:34:46 +00001/*
Giorgio Arena8b2a7d32020-02-11 17:21:31 +00002 * Copyright (c) 2018-2020 ARM Limited.
giuros01164a2722018-11-20 18:34:46 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_ELEMENTWISE_OPERATIONS_FIXTURE
25#define ARM_COMPUTE_TEST_ELEMENTWISE_OPERATIONS_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29#include "tests/AssetsLibrary.h"
30#include "tests/Globals.h"
31#include "tests/IAccessor.h"
32#include "tests/framework/Asserts.h"
33#include "tests/framework/Fixture.h"
34#include "tests/validation/Helpers.h"
Giorgio Arena8b2a7d32020-02-11 17:21:31 +000035#include "tests/validation/reference/ActivationLayer.h"
giuros01164a2722018-11-20 18:34:46 +000036#include "tests/validation/reference/ElementwiseOperations.h"
37
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45class ArithmeticOperationsGenericFixture : public framework::Fixture
46{
47public:
48 template <typename...>
49 void setup(ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1,
50 DataType data_type0, DataType data_type1, DataType output_data_type,
51 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
52 {
53 _op = op;
54 _target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, qinfo0, qinfo1, qinfo_out);
55 _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, qinfo0, qinfo1, qinfo_out);
56 }
57
58protected:
59 template <typename U>
60 void fill(U &&tensor, int i)
61 {
Usama Arif81e671e2019-05-13 13:33:14 +010062 switch(_op)
63 {
64 case ArithmeticOperation::DIV:
65 library->fill_tensor_uniform_ranged(tensor, i, { std::pair<float, float>(-0.001f, 0.001f) });
66 break;
67 case ArithmeticOperation::POWER:
68 library->fill_tensor_uniform(tensor, i, 0.0f, 5.0f);
69 break;
70 default:
71 library->fill_tensor_uniform(tensor, i);
72 }
giuros01164a2722018-11-20 18:34:46 +000073 }
74
75 TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
76 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
77 {
78 // Create tensors
79 TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0);
80 TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
81 TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
82
83 // Create and configure function
84 FunctionType elem_op;
85 elem_op.configure(&ref_src1, &ref_src2, &dst);
86
87 ARM_COMPUTE_EXPECT(ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
88 ARM_COMPUTE_EXPECT(ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
89 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
90
91 // Allocate tensors
92 ref_src1.allocator()->allocate();
93 ref_src2.allocator()->allocate();
94 dst.allocator()->allocate();
95
96 ARM_COMPUTE_EXPECT(!ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
97 ARM_COMPUTE_EXPECT(!ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
98 ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
99
100 // Fill tensors
101 fill(AccessorType(ref_src1), 0);
102 fill(AccessorType(ref_src2), 1);
103
104 // Compute function
105 elem_op.run();
106
107 return dst;
108 }
109
110 SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1,
111 DataType data_type0, DataType data_type1, DataType output_data_type,
112 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
113 {
114 // Create reference
115 SimpleTensor<T> ref_src1{ shape0, data_type0, 1, qinfo0 };
116 SimpleTensor<T> ref_src2{ shape1, data_type1, 1, qinfo1 };
117 SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out };
118
119 // Fill reference
120 fill(ref_src1, 0);
121 fill(ref_src2, 1);
122
123 return reference::arithmetic_operation<T>(_op, ref_src1, ref_src2, ref_dst);
124 }
125
126 TensorType _target{};
127 SimpleTensor<T> _reference{};
128 ArithmeticOperation _op{ ArithmeticOperation::ADD };
129};
130
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000131// Arithmetic operation fused with activation function
132template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
133class ArithmeticOperationsFuseActivationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
134{
135public:
136 template <typename...>
137 void setup(ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1,
138 DataType data_type0, DataType data_type1, DataType output_data_type,
139 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info)
140 {
141 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(op, shape0, shape1,
142 data_type0, data_type1, output_data_type,
143 qinfo0, qinfo1, qinfo_out);
144 _act_info = act_info;
145 }
146
147protected:
148 TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
149 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
150 {
151 // Create tensors
152 TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0);
153 TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
154 TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
155
156 // Create and configure function
157 FunctionType elem_op;
158 elem_op.configure(&ref_src1, &ref_src2, &dst, _act_info);
159
160 ARM_COMPUTE_EXPECT(ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
161 ARM_COMPUTE_EXPECT(ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
162 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
163
164 // Allocate tensors
165 ref_src1.allocator()->allocate();
166 ref_src2.allocator()->allocate();
167 dst.allocator()->allocate();
168
169 ARM_COMPUTE_EXPECT(!ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
170 ARM_COMPUTE_EXPECT(!ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
171 ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
172
173 // Fill tensors
174 fill(AccessorType(ref_src1), 0);
175 fill(AccessorType(ref_src2), 1);
176
177 // Compute function
178 elem_op.run();
179
180 return dst;
181 }
182
183 SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1,
184 DataType data_type0, DataType data_type1, DataType output_data_type,
185 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
186 {
187 auto result = ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::compute_reference(shape0, shape1, data_type0,
188 data_type1, output_data_type, qinfo0, qinfo1, qinfo_out);
189 return _act_info.enabled() ? reference::activation_layer(result, _act_info, qinfo_out) : result;
190 }
191
192 ActivationLayerInfo _act_info{};
193};
194
giuros01164a2722018-11-20 18:34:46 +0000195template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
196class ArithmeticDivisionBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
197{
198public:
199 template <typename...>
200 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type)
201 {
202 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::DIV, shape0, shape1,
203 data_type0, data_type1, output_data_type,
204 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
205 }
206};
207
208template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
209class ArithmeticDivisionValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
210{
211public:
212 template <typename...>
213 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type)
214 {
215 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::DIV, shape, shape,
216 data_type0, data_type1, output_data_type,
217 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
218 }
219};
220
221template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000222class ArithmeticDivisionBroadcastValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
223{
224public:
225 template <typename...>
226 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
227 {
228 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::DIV, shape0, shape1,
229 data_type0, data_type1, output_data_type,
230 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
231 }
232};
233
234template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
235class ArithmeticDivisionValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
236{
237public:
238 template <typename...>
239 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
240 {
241 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::DIV, shape, shape,
242 data_type0, data_type1, output_data_type,
243 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
244 }
245};
246
247template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros01164a2722018-11-20 18:34:46 +0000248class ArithmeticDivisionValidationQuantizedFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
249{
250public:
251 template <typename...>
252 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type,
253 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
254
255 {
256 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::DIV, shape, shape,
257 data_type0, data_type1, output_data_type,
258 qinfo0, qinfo1, qinfo_out);
259 }
260};
261
262template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
263class ElementwiseMaxBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
264{
265public:
266 template <typename...>
267 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type)
268 {
269 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MAX, shape0, shape1,
270 data_type0, data_type1, output_data_type,
271 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
272 }
273};
274
275template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
276class ElementwiseMaxValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
277{
278public:
279 template <typename...>
280 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type)
281 {
282 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MAX, shape, shape,
283 data_type0, data_type1, output_data_type,
284 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
285 }
286};
287
288template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000289class ElementwiseMaxBroadcastValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
290{
291public:
292 template <typename...>
293 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
294 {
295 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MAX, shape0, shape1,
296 data_type0, data_type1, output_data_type,
297 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
298 }
299};
300
301template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
302class ElementwiseMaxValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
303{
304public:
305 template <typename...>
306 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
307 {
308 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MAX, shape, shape,
309 data_type0, data_type1, output_data_type,
310 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
311 }
312};
313
314template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros01164a2722018-11-20 18:34:46 +0000315class ElementwiseMaxValidationQuantizedFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
316{
317public:
318 template <typename...>
319 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type,
320 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
321
322 {
323 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MAX, shape, shape,
324 data_type0, data_type1, output_data_type,
325 qinfo0, qinfo1, qinfo_out);
326 }
327};
328
329template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros0192fd9432018-12-03 17:30:00 +0000330class ElementwiseMaxQuantizedBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
331{
332public:
333 template <typename...>
334 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
335 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
336
337 {
338 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MAX, shape0, shape1,
339 data_type0, data_type1, output_data_type,
340 qinfo0, qinfo1, qinfo_out);
341 }
342};
343
344template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros01164a2722018-11-20 18:34:46 +0000345class ElementwiseMinBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
346{
347public:
348 template <typename...>
349 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type)
350 {
351 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MIN, shape0, shape1,
352 data_type0, data_type1, output_data_type,
353 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
354 }
355};
356
357template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
358class ElementwiseMinValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
359{
360public:
361 template <typename...>
362 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type)
363 {
364 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MIN, shape, shape,
365 data_type0, data_type1, output_data_type,
366 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
367 }
368};
369
370template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000371class ElementwiseMinBroadcastValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
372{
373public:
374 template <typename...>
375 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
376 {
377 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MIN, shape0, shape1,
378 data_type0, data_type1, output_data_type,
379 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
380 }
381};
382
383template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
384class ElementwiseMinValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
385{
386public:
387 template <typename...>
388 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
389 {
390 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MIN, shape, shape,
391 data_type0, data_type1, output_data_type,
392 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
393 }
394};
395
396template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros01164a2722018-11-20 18:34:46 +0000397class ElementwiseMinValidationQuantizedFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
398{
399public:
400 template <typename...>
401 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type,
402 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
403
404 {
405 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MIN, shape, shape,
406 data_type0, data_type1, output_data_type,
407 qinfo0, qinfo1, qinfo_out);
408 }
409};
410
411template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros0192fd9432018-12-03 17:30:00 +0000412class ElementwiseMinQuantizedBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
413{
414public:
415 template <typename...>
416 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
417 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
418
419 {
420 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::MIN, shape0, shape1,
421 data_type0, data_type1, output_data_type,
422 qinfo0, qinfo1, qinfo_out);
423 }
424};
425
426template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros01164a2722018-11-20 18:34:46 +0000427class ElementwiseSquaredDiffBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
428{
429public:
430 template <typename...>
431 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type)
432 {
433 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::SQUARED_DIFF, shape0, shape1,
434 data_type0, data_type1, output_data_type,
435 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
436 }
437};
438
439template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
440class ElementwiseSquaredDiffValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
441{
442public:
443 template <typename...>
444 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type)
445 {
446 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::SQUARED_DIFF, shape, shape,
447 data_type0, data_type1, output_data_type,
448 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
449 }
450};
451
452template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000453class ElementwiseSquaredDiffBroadcastValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
454{
455public:
456 template <typename...>
457 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
458 {
459 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::SQUARED_DIFF, shape0, shape1,
460 data_type0, data_type1, output_data_type,
461 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
462 }
463};
464
465template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
466class ElementwiseSquaredDiffValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
467{
468public:
469 template <typename...>
470 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
471 {
472 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::SQUARED_DIFF, shape, shape,
473 data_type0, data_type1, output_data_type,
474 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
475 }
476};
477
478template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros01164a2722018-11-20 18:34:46 +0000479class ElementwiseSquaredDiffValidationQuantizedFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
480{
481public:
482 template <typename...>
483 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type,
484 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
485
486 {
487 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::SQUARED_DIFF, shape, shape,
488 data_type0, data_type1, output_data_type,
489 qinfo0, qinfo1, qinfo_out);
490 }
491};
giuros0192fd9432018-12-03 17:30:00 +0000492
493template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
494class ElementwiseSquaredDiffQuantizedBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
495{
496public:
497 template <typename...>
498 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
499 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
500
501 {
502 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::SQUARED_DIFF, shape0, shape1,
503 data_type0, data_type1, output_data_type,
504 qinfo0, qinfo1, qinfo_out);
505 }
506};
George Worta1e7e282019-01-15 11:00:29 +0000507
508template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
giuros011e6e1b82019-05-14 16:12:53 +0100509class PReluLayerBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
510{
511public:
512 template <typename...>
513 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type)
514 {
515 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::PRELU, shape0, shape1,
516 data_type0, data_type1, output_data_type,
517 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
518 }
519};
520
521template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
522class PReluLayerValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
523{
524public:
525 template <typename...>
526 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type)
527 {
528 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::PRELU, shape, shape,
529 data_type0, data_type1, output_data_type,
530 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
531 }
532};
533
534template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
535class PReluLayerValidationQuantizedFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
536{
537public:
538 template <typename...>
539 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type,
540 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
541
542 {
543 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::PRELU, shape, shape,
544 data_type0, data_type1, output_data_type,
545 qinfo0, qinfo1, qinfo_out);
546 }
547};
548
549template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
550class PReluLayerQuantizedBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
551{
552public:
553 template <typename...>
554 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
555 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
556
557 {
558 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::PRELU, shape0, shape1,
559 data_type0, data_type1, output_data_type,
560 qinfo0, qinfo1, qinfo_out);
561 }
562};
563
564template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Usama Arif81e671e2019-05-13 13:33:14 +0100565class ElementwisePowerBroadcastValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
566{
567public:
568 template <typename...>
569 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type)
570 {
571 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::POWER, shape0, shape1,
572 data_type0, data_type1, output_data_type,
573 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
574 }
575};
576
577template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
578class ElementwisePowerValidationFixture : public ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>
579{
580public:
581 template <typename...>
582 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type)
583 {
584 ArithmeticOperationsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::POWER, shape, shape,
585 data_type0, data_type1, output_data_type,
586 QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
587 }
588};
589
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000590template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
591class ElementwisePowerBroadcastValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
592{
593public:
594 template <typename...>
595 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
596 {
597 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::POWER, shape0, shape1,
598 data_type0, data_type1, output_data_type,
599 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
600 }
601};
602
603template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
604class ElementwisePowerValidationFloatFixture : public ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>
605{
606public:
607 template <typename...>
608 void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ActivationLayerInfo act_info)
609 {
610 ArithmeticOperationsFuseActivationFixture<TensorType, AccessorType, FunctionType, T>::setup(ArithmeticOperation::POWER, shape, shape,
611 data_type0, data_type1, output_data_type,
612 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
613 }
614};
615
giuros01164a2722018-11-20 18:34:46 +0000616} // namespace validation
617} // namespace test
618} // namespace arm_compute
619#endif /* ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE */