blob: 0785af1151c7e67393c12c1cb9bbbfa7715bdc25 [file] [log] [blame]
Georgios Pinitascbf39c62018-09-10 15:07:45 +01001/*
Matthew Bentham945b8da2023-07-12 11:54:59 +00002 * Copyright (c) 2017-2021, 2023 Arm Limited.
Georgios Pinitascbf39c62018-09-10 15:07:45 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
25#define ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29#include "tests/AssetsLibrary.h"
30#include "tests/Globals.h"
31#include "tests/IAccessor.h"
32#include "tests/framework/Asserts.h"
33#include "tests/framework/Fixture.h"
34#include "tests/validation/Helpers.h"
Giorgio Arena8b2a7d32020-02-11 17:21:31 +000035#include "tests/validation/reference/ActivationLayer.h"
Georgios Pinitascbf39c62018-09-10 15:07:45 +010036#include "tests/validation/reference/ArithmeticOperations.h"
37
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45class ArithmeticOperationGenericFixture : public framework::Fixture
46{
47public:
Georgios Pinitasda816752021-07-02 09:22:14 +010048 void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +010049 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +010050 {
Sheri Zhanga387e272021-06-29 17:34:06 +010051 _op = op;
52 _act_info = act_info;
53 _is_inplace = is_inplace;
54 _target = compute_target(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
55 _reference = compute_reference(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
Georgios Pinitascbf39c62018-09-10 15:07:45 +010056 }
57
58protected:
59 template <typename U>
60 void fill(U &&tensor, int i)
61 {
62 library->fill_tensor_uniform(tensor, i);
63 }
64
Georgios Pinitasda816752021-07-02 09:22:14 +010065 TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
Georgios Pinitascbf39c62018-09-10 15:07:45 +010066 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
67 {
68 // Create tensors
Sheri Zhanga387e272021-06-29 17:34:06 +010069 const TensorShape out_shape = TensorShape::broadcast_shape(shape0, shape1);
70 TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type, 1, qinfo0);
71 TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type, 1, qinfo1);
72 TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, qinfo_out);
73
74 // Check whether do in-place computation and whether inputs are broadcast compatible
75 TensorType *actual_dst = &dst;
76 if(_is_inplace)
77 {
78 bool src1_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape0, 0) && (qinfo0 == qinfo_out);
79 bool src2_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape1, 0) && (qinfo1 == qinfo_out);
80 bool do_in_place = out_shape.total_size() != 0 && (src1_is_inplace || src2_is_inplace);
81 ARM_COMPUTE_ASSERT(do_in_place);
82
83 if(src1_is_inplace)
84 {
85 actual_dst = &ref_src1;
86 }
87 else
88 {
89 actual_dst = &ref_src2;
90 }
91 }
Georgios Pinitascbf39c62018-09-10 15:07:45 +010092
93 // Create and configure function
94 FunctionType arith_op;
Sheri Zhanga387e272021-06-29 17:34:06 +010095 arith_op.configure(&ref_src1, &ref_src2, actual_dst, convert_policy, _act_info);
Georgios Pinitascbf39c62018-09-10 15:07:45 +010096
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +010097 ARM_COMPUTE_ASSERT(ref_src1.info()->is_resizable());
98 ARM_COMPUTE_ASSERT(ref_src2.info()->is_resizable());
Georgios Pinitascbf39c62018-09-10 15:07:45 +010099
100 // Allocate tensors
101 ref_src1.allocator()->allocate();
102 ref_src2.allocator()->allocate();
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100103
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100104 ARM_COMPUTE_ASSERT(!ref_src1.info()->is_resizable());
105 ARM_COMPUTE_ASSERT(!ref_src2.info()->is_resizable());
Sheri Zhanga387e272021-06-29 17:34:06 +0100106
107 // If don't do in-place computation, still need to allocate original dst
108 if(!_is_inplace)
109 {
110 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
111 dst.allocator()->allocate();
112 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
113 }
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100114
115 // Fill tensors
116 fill(AccessorType(ref_src1), 0);
117 fill(AccessorType(ref_src2), 1);
118
119 // Compute function
120 arith_op.run();
121
Sheri Zhanga387e272021-06-29 17:34:06 +0100122 return std::move(*actual_dst);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100123 }
124
Georgios Pinitasda816752021-07-02 09:22:14 +0100125 SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100126 QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
127 {
128 // Create reference
Georgios Pinitasda816752021-07-02 09:22:14 +0100129 SimpleTensor<T> ref_src1{ shape0, data_type, 1, qinfo0 };
130 SimpleTensor<T> ref_src2{ shape1, data_type, 1, qinfo1 };
Sheri Zhanga387e272021-06-29 17:34:06 +0100131 SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), data_type, 1, qinfo_out };
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100132
133 // Fill reference
134 fill(ref_src1, 0);
135 fill(ref_src2, 1);
136
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000137 auto result = reference::arithmetic_operation<T>(_op, ref_src1, ref_src2, ref_dst, convert_policy);
Sheri Zhanga387e272021-06-29 17:34:06 +0100138 return _act_info.enabled() ? reference::activation_layer(result, _act_info, qinfo_out) : result;
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100139 }
140
141 TensorType _target{};
142 SimpleTensor<T> _reference{};
143 reference::ArithmeticOperation _op{ reference::ArithmeticOperation::ADD };
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000144 ActivationLayerInfo _act_info{};
Sheri Zhanga387e272021-06-29 17:34:06 +0100145 bool _is_inplace{};
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100146};
147
148template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
149class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
150{
151public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100152 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100153 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100154 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100155 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100156 }
157};
158
159template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
160class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
161{
162public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100163 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100164 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100165 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100166 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000167 }
168};
169
170template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
171class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
172{
173public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100174 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000175 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100176 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100177 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000178 }
179};
180
181template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
182class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
183{
184public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100185 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000186 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100187 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100188 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100189 }
190};
191
192template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
193class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
194{
195public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100196 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100197
198 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100199 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100200 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100201 }
202};
203
204template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Michalis Spyrou2232a202020-07-13 15:15:33 +0100205class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
206{
207public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100208 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
209 bool is_inplace)
Michalis Spyrou2232a202020-07-13 15:15:33 +0100210 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100211 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100212 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Michalis Spyrou2232a202020-07-13 15:15:33 +0100213 }
214};
215
216template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100217class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
218{
219public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100220 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100221 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100222 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100223 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000224 }
225};
226
227template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
228class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
229{
230public:
Georgios Pinitasda816752021-07-02 09:22:14 +0100231 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
Sheri Zhanga387e272021-06-29 17:34:06 +0100232 bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000233 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100234 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100235 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100236 }
237};
238
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100239template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
240class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
241{
242public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100243 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100244 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100245 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100246 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000247 }
248};
249
250template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
251class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
252{
253public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100254 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000255 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100256 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100257 QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100258 }
259};
260
261template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
262class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
263{
264public:
Sheri Zhanga387e272021-06-29 17:34:06 +0100265 void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100266
267 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100268 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100269 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100270 }
271};
Michalis Spyroueae65842020-06-15 20:23:59 +0100272
273template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
274class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
275{
276public:
Georgios Pinitasda816752021-07-02 09:22:14 +0100277 void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
Sheri Zhanga387e272021-06-29 17:34:06 +0100278 bool is_inplace)
Michalis Spyroueae65842020-06-15 20:23:59 +0100279 {
Georgios Pinitasda816752021-07-02 09:22:14 +0100280 ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
Sheri Zhanga387e272021-06-29 17:34:06 +0100281 qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
Michalis Spyroueae65842020-06-15 20:23:59 +0100282 }
283};
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100284} // namespace validation
285} // namespace test
286} // namespace arm_compute
287#endif /* ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE */