blob: 0f7e44e58852ffc1213f90b0f905d0ae11805c50 [file] [log] [blame]
Georgios Pinitascbf39c62018-09-10 15:07:45 +01001/*
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +01002 * Copyright (c) 2017-2021 Arm Limited.
Georgios Pinitascbf39c62018-09-10 15:07:45 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
25#define ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29#include "tests/AssetsLibrary.h"
30#include "tests/Globals.h"
31#include "tests/IAccessor.h"
32#include "tests/framework/Asserts.h"
33#include "tests/framework/Fixture.h"
34#include "tests/validation/Helpers.h"
Giorgio Arena8b2a7d32020-02-11 17:21:31 +000035#include "tests/validation/reference/ActivationLayer.h"
Georgios Pinitascbf39c62018-09-10 15:07:45 +010036#include "tests/validation/reference/ArithmeticOperations.h"
37
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45class ArithmeticOperationGenericFixture : public framework::Fixture
46{
47public:
48 template <typename...>
Georgios Pinitasda816752021-07-02 09:22:14 +010049 void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +010050 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +010051 {
Sheri Zhanga387e272021-06-29 17:34:06 +010052 _op = op;
53 _act_info = act_info;
54 _is_inplace = is_inplace;
55 _target = compute_target(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
56 _reference = compute_reference(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
Georgios Pinitascbf39c62018-09-10 15:07:45 +010057 }
58
59protected:
60 template <typename U>
61 void fill(U &&tensor, int i)
62 {
63 library->fill_tensor_uniform(tensor, i);
64 }
65
Georgios Pinitasda816752021-07-02 09:22:14 +010066 TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
Georgios Pinitascbf39c62018-09-10 15:07:45 +010067 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
68 {
69 // Create tensors
Sheri Zhanga387e272021-06-29 17:34:06 +010070 const TensorShape out_shape = TensorShape::broadcast_shape(shape0, shape1);
71 TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type, 1, qinfo0);
72 TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type, 1, qinfo1);
73 TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, qinfo_out);
74
75 // Check whether do in-place computation and whether inputs are broadcast compatible
76 TensorType *actual_dst = &dst;
77 if(_is_inplace)
78 {
79 bool src1_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape0, 0) && (qinfo0 == qinfo_out);
80 bool src2_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape1, 0) && (qinfo1 == qinfo_out);
81 bool do_in_place = out_shape.total_size() != 0 && (src1_is_inplace || src2_is_inplace);
82 ARM_COMPUTE_ASSERT(do_in_place);
83
84 if(src1_is_inplace)
85 {
86 actual_dst = &ref_src1;
87 }
88 else
89 {
90 actual_dst = &ref_src2;
91 }
92 }
Georgios Pinitascbf39c62018-09-10 15:07:45 +010093
94 // Create and configure function
95 FunctionType arith_op;
Sheri Zhanga387e272021-06-29 17:34:06 +010096 arith_op.configure(&ref_src1, &ref_src2, actual_dst, convert_policy, _act_info);
Georgios Pinitascbf39c62018-09-10 15:07:45 +010097
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +010098 ARM_COMPUTE_ASSERT(ref_src1.info()->is_resizable());
99 ARM_COMPUTE_ASSERT(ref_src2.info()->is_resizable());
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100100
101 // Allocate tensors
102 ref_src1.allocator()->allocate();
103 ref_src2.allocator()->allocate();
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100104
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100105 ARM_COMPUTE_ASSERT(!ref_src1.info()->is_resizable());
106 ARM_COMPUTE_ASSERT(!ref_src2.info()->is_resizable());
Sheri Zhanga387e272021-06-29 17:34:06 +0100107
108 // If don't do in-place computation, still need to allocate original dst
109 if(!_is_inplace)
110 {
111 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
112 dst.allocator()->allocate();
113 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
114 }
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100115
116 // Fill tensors
117 fill(AccessorType(ref_src1), 0);
118 fill(AccessorType(ref_src2), 1);
119
120 // Compute function
121 arith_op.run();
122
Sheri Zhanga387e272021-06-29 17:34:06 +0100123 return std::move(*actual_dst);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100124 }
125
Georgios Pinitasda816752021-07-02 09:22:14 +0100126 SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100127 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
128 {
129 // Create reference
Georgios Pinitasda816752021-07-02 09:22:14 +0100130 SimpleTensor<T> ref_src1{ shape0, data_type, 1, qinfo0 };
131 SimpleTensor<T> ref_src2{ shape1, data_type, 1, qinfo1 };
Sheri Zhanga387e272021-06-29 17:34:06 +0100132 SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), data_type, 1, qinfo_out };
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100133
134 // Fill reference
135 fill(ref_src1, 0);
136 fill(ref_src2, 1);
137
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000138 auto result = reference::arithmetic_operation<T>(_op, ref_src1, ref_src2, ref_dst, convert_policy);
Sheri Zhanga387e272021-06-29 17:34:06 +0100139 return _act_info.enabled() ? reference::activation_layer(result, _act_info, qinfo_out) : result;
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100140 }
141
142 TensorType _target{};
143 SimpleTensor<T> _reference{};
144 reference::ArithmeticOperation _op{ reference::ArithmeticOperation::ADD };
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000145 ActivationLayerInfo _act_info{};
Sheri Zhanga387e272021-06-29 17:34:06 +0100146 bool _is_inplace{};
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100147};
148
149template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
150class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
151{
152public:
153 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100154 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100155 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100156 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100157 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100158 }
159};
160
161template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
162class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
163{
164public:
165 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100166 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100167 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100168 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100169 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000170 }
171};
172
173template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
174class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
175{
176public:
177 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100178 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000179 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100180 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100181 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000182 }
183};
184
185template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
186class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
187{
188public:
189 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100190 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000191 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100192 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100193 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100194 }
195};
196
197template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
198class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
199{
200public:
201 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100202 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100203
204 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100205 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100206 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100207 }
208};
209
210template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Michalis Spyrou2232a202020-07-13 15:15:33 +0100211class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
212{
213public:
214 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100215 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
216 bool is_inplace)
Michalis Spyrou2232a202020-07-13 15:15:33 +0100217 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100218 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100219 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Michalis Spyrou2232a202020-07-13 15:15:33 +0100220 }
221};
222
223template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100224class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
225{
226public:
227 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100228 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100229 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100230 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100231 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000232 }
233};
234
235template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
236class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
237{
238public:
239 template <typename...>
Georgios Pinitasda816752021-07-02 09:22:14 +0100240 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
Sheri Zhanga387e272021-06-29 17:34:06 +0100241 bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000242 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100243 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100244 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100245 }
246};
247
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100248template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
249class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
250{
251public:
252 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100253 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100254 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100255 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100256 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000257 }
258};
259
260template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
261class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
262{
263public:
264 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100265 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000266 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100267 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100268 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100269 }
270};
271
272template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
273class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
274{
275public:
276 template <typename...>
Sheri Zhanga387e272021-06-29 17:34:06 +0100277 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100278
279 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100280 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100281 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100282 }
283};
Michalis Spyroueae65842020-06-15 20:23:59 +0100284
285template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
286class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
287{
288public:
289 template <typename...>
Georgios Pinitasda816752021-07-02 09:22:14 +0100290 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
Sheri Zhanga387e272021-06-29 17:34:06 +0100291 bool is_inplace)
Michalis Spyroueae65842020-06-15 20:23:59 +0100292 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100293 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100294 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Michalis Spyroueae65842020-06-15 20:23:59 +0100295 }
296};
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100297} // namespace validation
298} // namespace test
299} // namespace arm_compute
300#endif /* ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE */