blob: 4ba5387a6115cb9555aa18fb5df04a86caf1a31a [file] [log] [blame]
Sanghoon Lee72898fe2017-09-01 11:42:16 +01001/*
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +00002 * Copyright (c) 2017-2018 ARM Limited.
Sanghoon Lee72898fe2017-09-01 11:42:16 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONCLCTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/Types.h"
25#include "arm_compute/runtime/CL/CLTensor.h"
26#include "arm_compute/runtime/CL/CLTensorAllocator.h"
27#include "arm_compute/runtime/CL/functions/CLArithmeticSubtraction.h"
28#include "tests/CL/CLAccessor.h"
29#include "tests/PaddingCalculator.h"
30#include "tests/datasets/ConvertPolicyDataset.h"
31#include "tests/datasets/ShapeDatasets.h"
32#include "tests/framework/Asserts.h"
33#include "tests/framework/Macros.h"
34#include "tests/framework/datasets/Datasets.h"
35#include "tests/validation/Validation.h"
36#include "tests/validation/fixtures/ArithmeticSubtractionFixture.h"
Sanghoon Lee72898fe2017-09-01 11:42:16 +010037
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44namespace
45{
46/** Input data sets **/
47const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
48 framework::dataset::make("DataType",
49 DataType::U8));
Isabella Gottardib5908c22017-10-30 15:28:13 +000050const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::S16)),
Sanghoon Lee72898fe2017-09-01 11:42:16 +010051 framework::dataset::make("DataType", DataType::S16));
Isabella Gottardib5908c22017-10-30 15:28:13 +000052const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
53 framework::dataset::make("DataType", DataType::S16));
54const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
55 framework::dataset::make("DataType", DataType::S16));
56const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
57 framework::dataset::make("DataType", DataType::S16));
Sanghoon Lee72898fe2017-09-01 11:42:16 +010058const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
59 framework::dataset::make("DataType", DataType::F16));
60const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
61 framework::dataset::make("DataType", DataType::F32));
62} // namespace
63
64TEST_SUITE(CL)
65TEST_SUITE(ArithmeticSubtraction)
66
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000067// *INDENT-OFF*
68// clang-format off
69DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
Giorgio Arena70623822017-11-27 15:50:10 +000070 framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
71 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
72 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink
73 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
74 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000075 }),
Giorgio Arena70623822017-11-27 15:50:10 +000076 framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
77 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000078 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
Giorgio Arena70623822017-11-27 15:50:10 +000079 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
80 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000081 })),
Giorgio Arena70623822017-11-27 15:50:10 +000082 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
83 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000084 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
Giorgio Arena70623822017-11-27 15:50:10 +000085 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
86 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000087 })),
Georgios Pinitas631c41a2017-12-06 11:53:03 +000088 framework::dataset::make("Expected", { true, true, false, false, false, false, true })),
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000089 input1_info, input2_info, output_info, expected)
90{
Giorgio Arena70623822017-11-27 15:50:10 +000091 ARM_COMPUTE_EXPECT(bool(CLArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS);
Georgios Pinitasf9d3a0a2017-11-03 19:01:44 +000092}
93// clang-format on
94// *INDENT-ON*
95
Isabella Gottardib5908c22017-10-30 15:28:13 +000096template <typename T1, typename T2 = T1, typename T3 = T1>
97using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
Sanghoon Lee72898fe2017-09-01 11:42:16 +010098
99TEST_SUITE(U8)
100DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
101 shape, policy)
102{
103 // Create tensors
104 CLTensor ref_src1 = create_tensor<CLTensor>(shape, DataType::U8);
105 CLTensor ref_src2 = create_tensor<CLTensor>(shape, DataType::U8);
106 CLTensor dst = create_tensor<CLTensor>(shape, DataType::U8);
107
108 // Create and Configure function
109 CLArithmeticSubtraction sub;
110 sub.configure(&ref_src1, &ref_src2, &dst, policy);
111
112 // Validate valid region
113 const ValidRegion valid_region = shape_to_valid_region(shape);
114 validate(dst.info()->valid_region(), valid_region);
115
116 // Validate padding
117 const PaddingSize padding = PaddingCalculator(shape.x(), 16).required_padding();
118 validate(ref_src1.info()->padding(), padding);
119 validate(ref_src2.info()->padding(), padding);
120 validate(dst.info()->padding(), padding);
121}
122
Isabella Gottardib5908c22017-10-30 15:28:13 +0000123FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
124 ArithmeticSubtractionU8Dataset),
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100125 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
126{
127 // Validate output
128 validate(CLAccessor(_target), _reference);
129}
130TEST_SUITE_END()
131
Isabella Gottardib5908c22017-10-30 15:28:13 +0000132template <typename T1, typename T2 = T1>
133using CLArithmeticSubtractionToS16Fixture = CLArithmeticSubtractionFixture<T1, T2, int16_t>;
134
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100135TEST_SUITE(S16)
Isabella Gottardib5908c22017-10-30 15:28:13 +0000136DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
137 framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
138 framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100139 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
Isabella Gottardib5908c22017-10-30 15:28:13 +0000140 shape, data_type1, data_type2, policy)
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100141{
142 // Create tensors
Isabella Gottardib5908c22017-10-30 15:28:13 +0000143 CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type1);
144 CLTensor ref_src2 = create_tensor<CLTensor>(shape, data_type2);
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100145 CLTensor dst = create_tensor<CLTensor>(shape, DataType::S16);
146
147 // Create and Configure function
148 CLArithmeticSubtraction sub;
149 sub.configure(&ref_src1, &ref_src2, &dst, policy);
150
151 // Validate valid region
152 const ValidRegion valid_region = shape_to_valid_region(shape);
153 validate(dst.info()->valid_region(), valid_region);
154
155 // Validate padding
156 const PaddingSize padding = PaddingCalculator(shape.x(), 16).required_padding();
157 validate(ref_src1.info()->padding(), padding);
158 validate(ref_src2.info()->padding(), padding);
159 validate(dst.info()->padding(), padding);
160}
Isabella Gottardib5908c22017-10-30 15:28:13 +0000161TEST_SUITE(S16_S16_S16)
162FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
163 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100164{
165 // Validate output
166 validate(CLAccessor(_target), _reference);
167}
168
Isabella Gottardib5908c22017-10-30 15:28:13 +0000169FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
170 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100171{
172 // Validate output
173 validate(CLAccessor(_target), _reference);
174}
175TEST_SUITE_END()
176
Isabella Gottardib5908c22017-10-30 15:28:13 +0000177TEST_SUITE(U8_U8_S16)
178FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
179 ArithmeticSubtractionU8U8S16Dataset),
180 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
181{
182 // Validate output
183 validate(CLAccessor(_target), _reference);
184}
185
186FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
187 ArithmeticSubtractionU8U8S16Dataset),
188 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
189{
190 // Validate output
191 validate(CLAccessor(_target), _reference);
192}
193TEST_SUITE_END()
194
195TEST_SUITE(S16_U8_S16)
196using CLAriSubS16U8ToS16Fixture = CLArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
197FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
198 ArithmeticSubtractionS16U8S16Dataset),
199 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
200{
201 // Validate output
202 validate(CLAccessor(_target), _reference);
203}
204
205FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
206 ArithmeticSubtractionS16U8S16Dataset),
207 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
208{
209 // Validate output
210 validate(CLAccessor(_target), _reference);
211}
212TEST_SUITE_END()
213
214TEST_SUITE(U8_S16_S16)
215using CLAriSubU8S16ToS16Fixture = CLArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
216FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
217 ArithmeticSubtractionU8S16S16Dataset),
218 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
219{
220 // Validate output
221 validate(CLAccessor(_target), _reference);
222}
223
224FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
225 ArithmeticSubtractionU8S16S16Dataset),
226 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
227{
228 // Validate output
229 validate(CLAccessor(_target), _reference);
230}
231TEST_SUITE_END()
232TEST_SUITE_END()
233
234template <typename T1, typename T2 = T1, typename T3 = T1>
235using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100236
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100237TEST_SUITE(Float)
238TEST_SUITE(FP16)
Georgios Pinitas583137c2017-08-31 18:12:42 +0100239FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset),
240 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
Sanghoon Lee72898fe2017-09-01 11:42:16 +0100241{
242 // Validate output
243 validate(CLAccessor(_target), _reference);
244}
245TEST_SUITE_END()
246
247TEST_SUITE(FP32)
248DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
249 shape, policy)
250{
251 // Create tensors
252 CLTensor ref_src1 = create_tensor<CLTensor>(shape, DataType::F32);
253 CLTensor ref_src2 = create_tensor<CLTensor>(shape, DataType::F32);
254 CLTensor dst = create_tensor<CLTensor>(shape, DataType::F32);
255
256 // Create and Configure function
257 CLArithmeticSubtraction sub;
258 sub.configure(&ref_src1, &ref_src2, &dst, policy);
259
260 // Validate valid region
261 const ValidRegion valid_region = shape_to_valid_region(shape);
262 validate(dst.info()->valid_region(), valid_region);
263
264 // Validate padding
265 const PaddingSize padding = PaddingCalculator(shape.x(), 16).required_padding();
266 validate(ref_src1.info()->padding(), padding);
267 validate(ref_src2.info()->padding(), padding);
268 validate(dst.info()->padding(), padding);
269}
270
271FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP32Dataset),
272 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
273{
274 // Validate output
275 validate(CLAccessor(_target), _reference);
276}
277
278FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionFP32Dataset),
279 framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
280{
281 // Validate output
282 validate(CLAccessor(_target), _reference);
283}
284TEST_SUITE_END()
285TEST_SUITE_END()
286
287TEST_SUITE_END()
288TEST_SUITE_END()
289} // namespace validation
290} // namespace test
291} // namespace arm_compute