blob: cf55c6a66b65c4b311e894876367389a72351df4 [file] [log] [blame]
giuros0192fd9432018-12-03 17:30:00 +00001/*
morgolock6427c822020-01-13 11:53:20 +00002 * Copyright (c) 2018-2020 ARM Limited.
giuros0192fd9432018-12-03 17:30:00 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
Georgios Pinitas2d6cb172018-12-24 15:00:43 +000021 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
giuros0192fd9432018-12-03 17:30:00 +000022 * SOFTWARE.
23 */
24#include "arm_compute/core/Types.h"
25#include "arm_compute/runtime/NEON/functions/NEElementwiseOperations.h"
26#include "arm_compute/runtime/Tensor.h"
27#include "arm_compute/runtime/TensorAllocator.h"
28#include "tests/NEON/Accessor.h"
29#include "tests/PaddingCalculator.h"
30#include "tests/datasets/ShapeDatasets.h"
31#include "tests/framework/Asserts.h"
32#include "tests/framework/Macros.h"
33#include "tests/framework/datasets/Datasets.h"
34#include "tests/validation/Validation.h"
35#include "tests/validation/fixtures/ElementwiseOperationsFixture.h"
36
37namespace arm_compute
38{
39namespace test
40{
41namespace validation
42{
43namespace
44{
45RelativeTolerance<float> tolerance_fp32(0.000001f);
Georgios Pinitas2d6cb172018-12-24 15:00:43 +000046#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
47RelativeTolerance<float> tolerance_fp16(0.01f);
48#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
49
giuros0192fd9432018-12-03 17:30:00 +000050/** Input data sets **/
51const auto ElementwiseSquaredDiffQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)),
52 framework::dataset::make("DataType",
53 DataType::QASYMM8));
morgolock6427c822020-01-13 11:53:20 +000054
55const auto ElementwiseSquaredDiffQASYMM8SignedDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
56 framework::dataset::make("DataType",
57 DataType::QASYMM8_SIGNED));
58
giuros0192fd9432018-12-03 17:30:00 +000059/** Input data sets **/
60const auto ElementwiseSquaredDiffS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32), framework::dataset::make("DataType", DataType::S32)),
61 framework::dataset::make("DataType",
62 DataType::S32));
63const auto ElementwiseSquaredDiffS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
64 framework::dataset::make("DataType", DataType::S16));
65#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
66const auto ElementwiseSquaredDiffFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
67 framework::dataset::make("DataType", DataType::F16));
68#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
69const auto ElementwiseSquaredDiffFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
70 framework::dataset::make("DataType", DataType::F32));
71} // namespace
72
73TEST_SUITE(NEON)
74TEST_SUITE(ElementwiseSquaredDiff)
75
76template <typename T>
77using NEElementwiseSquaredDiffFixture = ElementwiseSquaredDiffValidationFixture<Tensor, Accessor, NEElementwiseSquaredDiff, T>;
78
79// *INDENT-OFF*
80// clang-format off
81DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
82 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
83 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
84 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S32),
85 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32), // Invalid data type combination
86 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
morgolock6427c822020-01-13 11:53:20 +000087 TensorInfo(TensorShape(1U, 1U, 2U), 1, DataType::QASYMM8_SIGNED), // Mismatching types
giuros0192fd9432018-12-03 17:30:00 +000088 }),
89 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
90 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
91 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S32),
92 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
93 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
morgolock6427c822020-01-13 11:53:20 +000094 TensorInfo(TensorShape(1U, 1U, 2U), 1, DataType::QASYMM8_SIGNED),
giuros0192fd9432018-12-03 17:30:00 +000095 })),
96 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
97 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
98 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S32),
99 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S32),
100 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
morgolock6427c822020-01-13 11:53:20 +0000101 TensorInfo(TensorShape(1U, 1U, 2U), 1, DataType::QASYMM8, QuantizationInfo(0.3f,1)),
giuros0192fd9432018-12-03 17:30:00 +0000102 })),
morgolock6427c822020-01-13 11:53:20 +0000103 framework::dataset::make("Expected", { true, true, true, false, false, false})),
giuros0192fd9432018-12-03 17:30:00 +0000104 input1_info, input2_info, output_info, expected)
105{
106 ARM_COMPUTE_EXPECT(bool(NEElementwiseSquaredDiff::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
107}
108// clang-format on
109// *INDENT-ON*
110
111TEST_SUITE(S32)
Michalis Spyrou5c9f0c42019-01-16 14:48:48 +0000112DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, datasets::SmallShapes(),
giuros0192fd9432018-12-03 17:30:00 +0000113 shape)
114{
115 // Create tensors
116 Tensor ref_src1 = create_tensor<Tensor>(shape, DataType::S32);
117 Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::S32);
118 Tensor dst = create_tensor<Tensor>(shape, DataType::S32);
119
120 // Create and Configure function
121 NEElementwiseSquaredDiff add;
122 add.configure(&ref_src1, &ref_src2, &dst);
123
124 // Validate valid region
125 const ValidRegion valid_region = shape_to_valid_region(shape);
126 validate(dst.info()->valid_region(), valid_region);
127}
128
129FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwiseSquaredDiffFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapes(), ElementwiseSquaredDiffS32Dataset))
130{
131 // Validate output
132 validate(Accessor(_target), _reference);
133}
134TEST_SUITE_END() // S32
135
136TEST_SUITE(S16)
Michalis Spyrou5c9f0c42019-01-16 14:48:48 +0000137DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), framework::dataset::make("DataType", { DataType::S16 })),
giuros0192fd9432018-12-03 17:30:00 +0000138 shape, data_type)
139{
140 // Create tensors
141 Tensor ref_src1 = create_tensor<Tensor>(shape, data_type);
142 Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::S16);
143 Tensor dst = create_tensor<Tensor>(shape, DataType::S16);
144
145 // Create and Configure function
146 NEElementwiseSquaredDiff add;
147 add.configure(&ref_src1, &ref_src2, &dst);
148
149 // Validate valid region
150 const ValidRegion valid_region = shape_to_valid_region(shape);
151 validate(dst.info()->valid_region(), valid_region);
152}
153
Michalis Spyrou5ce99a22019-01-25 14:17:49 +0000154FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwiseSquaredDiffFixture<int16_t>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), ElementwiseSquaredDiffS16Dataset))
giuros0192fd9432018-12-03 17:30:00 +0000155{
156 // Validate output
157 validate(Accessor(_target), _reference);
158}
159TEST_SUITE_END() // S16
160
161template <typename T>
162using NEElementwiseSquaredDiffQuantizedFixture = ElementwiseSquaredDiffValidationQuantizedFixture<Tensor, Accessor, NEElementwiseSquaredDiff, T>;
163
164TEST_SUITE(Quantized)
165TEST_SUITE(QASYMM8)
Michalis Spyrou5c9f0c42019-01-16 14:48:48 +0000166DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, datasets::SmallShapes(),
giuros0192fd9432018-12-03 17:30:00 +0000167 shape)
168{
169 // Create tensors
170 Tensor ref_src1 = create_tensor<Tensor>(shape, DataType::QASYMM8);
171 Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::QASYMM8);
172 Tensor dst = create_tensor<Tensor>(shape, DataType::QASYMM8);
173
174 // Create and Configure function
175 NEElementwiseMin add;
176 add.configure(&ref_src1, &ref_src2, &dst);
177
178 // Validate valid region
179 const ValidRegion valid_region = shape_to_valid_region(shape);
180 validate(dst.info()->valid_region(), valid_region);
181}
182
183FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwiseSquaredDiffQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
184 ElementwiseSquaredDiffQASYMM8Dataset),
185 framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })),
186 framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
187 framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255.f, 5) }))
188
189 )
190{
191 // Validate output
192 validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
193}
194template <typename T>
195using NEElementwiseSquaredDiffQuantizedBroadcastFixture = ElementwiseSquaredDiffQuantizedBroadcastValidationFixture<Tensor, Accessor, NEElementwiseSquaredDiff, T>;
196
197FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEElementwiseSquaredDiffQuantizedBroadcastFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
198 combine(combine(combine(combine(datasets::SmallShapesBroadcast(),
199 ElementwiseSquaredDiffQASYMM8Dataset),
200 framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })),
201 framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
202 framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255.f, 5) })))
203{
204 // Validate output
205 validate(Accessor(_target), _reference);
206}
207TEST_SUITE_END()
morgolock6427c822020-01-13 11:53:20 +0000208
209TEST_SUITE(QASYMM8_SIGNED)
210FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwiseSquaredDiffQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
211 ElementwiseSquaredDiffQASYMM8SignedDataset),
212 framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f, 5) })),
213 framework::dataset::make("QuantizationInfo", { QuantizationInfo(.5f, 5) })),
214 framework::dataset::make("QuantizationInfo", { QuantizationInfo(.2f, 5) })))
215{
216 // Validate output
217 validate(Accessor(_target), _reference);
218}
219TEST_SUITE_END()
giuros0192fd9432018-12-03 17:30:00 +0000220TEST_SUITE_END()
221
222TEST_SUITE(Float)
223#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
224TEST_SUITE(F16)
225FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwiseSquaredDiffFixture<half>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), ElementwiseSquaredDiffFP16Dataset))
226{
227 // Validate output
Georgios Pinitas2d6cb172018-12-24 15:00:43 +0000228 validate(Accessor(_target), _reference, tolerance_fp16, 0.01);
giuros0192fd9432018-12-03 17:30:00 +0000229}
230TEST_SUITE_END() // F16
231#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
232
233TEST_SUITE(F32)
Michalis Spyrou5c9f0c42019-01-16 14:48:48 +0000234DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, datasets::SmallShapes(),
giuros0192fd9432018-12-03 17:30:00 +0000235 shape)
236{
237 // Create tensors
238 Tensor ref_src1 = create_tensor<Tensor>(shape, DataType::F32);
239 Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::F32);
240 Tensor dst = create_tensor<Tensor>(shape, DataType::F32);
241
242 // Create and Configure function
243 NEElementwiseSquaredDiff add;
244 add.configure(&ref_src1, &ref_src2, &dst);
245
246 // Validate valid region
247 const ValidRegion valid_region = shape_to_valid_region(shape);
248 validate(dst.info()->valid_region(), valid_region);
249}
250
Michalis Spyrou5ce99a22019-01-25 14:17:49 +0000251FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwiseSquaredDiffFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), ElementwiseSquaredDiffFP32Dataset))
giuros0192fd9432018-12-03 17:30:00 +0000252{
253 // Validate output
254 validate(Accessor(_target), _reference);
255}
giuros0192fd9432018-12-03 17:30:00 +0000256template <typename T>
257using NEElementwiseSquaredDiffBroadcastFixture = ElementwiseSquaredDiffBroadcastValidationFixture<Tensor, Accessor, NEElementwiseSquaredDiff, T>;
258
259FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEElementwiseSquaredDiffBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallShapesBroadcast(),
260 ElementwiseSquaredDiffFP32Dataset))
261{
262 // Validate output
263 validate(Accessor(_target), _reference);
264}
265
266FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEElementwiseSquaredDiffBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapesBroadcast(),
267 ElementwiseSquaredDiffFP32Dataset))
268{
269 // Validate output
270 validate(Accessor(_target), _reference);
271}
272TEST_SUITE_END() // F32
273TEST_SUITE_END() // Float
274
275TEST_SUITE_END() // ElementwiseSquaredDiff
276TEST_SUITE_END() // NEON
277} // namespace validation
278} // namespace test
279} // namespace arm_compute