blob: e3159caa93d97ea8356fcd81d32dac1c31d8a246 [file] [log] [blame]
Georgios Pinitasaaa27182018-11-21 16:32:15 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/core/Types.h"
26#include "arm_compute/runtime/CL/CLTensor.h"
27#include "arm_compute/runtime/CL/CLTensorAllocator.h"
28#include "arm_compute/runtime/CL/functions/CLSelect.h"
29#include "tests/CL/CLAccessor.h"
30#include "tests/PaddingCalculator.h"
31#include "tests/datasets/ShapeDatasets.h"
32#include "tests/framework/Asserts.h"
33#include "tests/framework/Macros.h"
34#include "tests/framework/datasets/Datasets.h"
35#include "tests/validation/Validation.h"
36#include "tests/validation/fixtures/SelectFixture.h"
37
38namespace arm_compute
39{
40namespace test
41{
42namespace validation
43{
44namespace
45{
46auto configuration_dataset = combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
47 framework::dataset::make("has_same_rank", { false, true }));
48auto run_small_dataset = combine(datasets::SmallShapes(), framework::dataset::make("has_same_rank", { false, true }));
49auto run_large_dataset = combine(datasets::LargeShapes(), framework::dataset::make("has_same_rank", { false, true }));
50
51} // namespace
52TEST_SUITE(CL)
53TEST_SUITE(Select)
54
55// *INDENT-OFF*
56// clang-format off
57DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
58 framework::dataset::make("CInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8), // Invalid condition datatype
59 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid output datatype
60 TensorInfo(TensorShape(13U), 1, DataType::U8), // Invalid c shape
61 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Mismatching shapes
62 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
63 TensorInfo(TensorShape(2U), 1, DataType::U8),
64 }),
65 framework::dataset::make("XInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
66 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
67 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
68 TensorInfo(TensorShape(32U, 10U, 2U), 1, DataType::F32),
69 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
70 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
71 })),
72 framework::dataset::make("YInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
73 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
74 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
75 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
76 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
77 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
78 })),
79 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
80 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S8),
81 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
82 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
83 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
84 TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
85 })),
86 framework::dataset::make("Expected", { false, false, false, false, true, true})),
87 c_info, x_info, y_info, output_info, expected)
88{
89 Status s = CLSelect::validate(&c_info.clone()->set_is_resizable(false),
90 &x_info.clone()->set_is_resizable(false),
91 &y_info.clone()->set_is_resizable(false),
92 &output_info.clone()->set_is_resizable(false));
93 ARM_COMPUTE_EXPECT(bool(s) == expected, framework::LogLevel::ERRORS);
94}
95// clang-format on
96// *INDENT-ON*
97
98template <typename T>
99using CLSelectFixture = SelectValidationFixture<CLTensor, CLAccessor, CLSelect, T>;
100
101TEST_SUITE(Float)
102TEST_SUITE(F16)
103DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, configuration_dataset,
104 shape, same_rank)
105{
106 const DataType dt = DataType::F16;
107
108 // Create tensors
109 CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8);
110 CLTensor ref_x = create_tensor<CLTensor>(shape, dt);
111 CLTensor ref_y = create_tensor<CLTensor>(shape, dt);
112 CLTensor dst = create_tensor<CLTensor>(shape, dt);
113
114 // Create and Configure function
115 CLSelect select;
116 select.configure(&ref_c, &ref_x, &ref_y, &dst);
117
118 // Validate valid region
119 const ValidRegion valid_region = shape_to_valid_region(shape);
120 validate(dst.info()->valid_region(), valid_region);
121
122 // Validate padding
123 const int step = 16 / arm_compute::data_size_from_type(dt);
124 const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding();
125 if(same_rank)
126 {
127 validate(ref_c.info()->padding(), padding);
128 }
129 validate(ref_x.info()->padding(), padding);
130 validate(ref_y.info()->padding(), padding);
131 validate(dst.info()->padding(), padding);
132}
133
134FIXTURE_DATA_TEST_CASE(RunSmall,
135 CLSelectFixture<half>,
136 framework::DatasetMode::PRECOMMIT,
137 combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16)))
138{
139 // Validate output
140 validate(CLAccessor(_target), _reference);
141}
142
143FIXTURE_DATA_TEST_CASE(RunLarge,
144 CLSelectFixture<half>,
145 framework::DatasetMode::NIGHTLY,
146 combine(run_large_dataset, framework::dataset::make("DataType", DataType::F16)))
147{
148 // Validate output
149 validate(CLAccessor(_target), _reference);
150}
151TEST_SUITE_END() // F16
152
153TEST_SUITE(FP32)
154DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, configuration_dataset,
155 shape, same_rank)
156{
157 const DataType dt = DataType::F32;
158
159 // Create tensors
160 CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8);
161 CLTensor ref_x = create_tensor<CLTensor>(shape, dt);
162 CLTensor ref_y = create_tensor<CLTensor>(shape, dt);
163 CLTensor dst = create_tensor<CLTensor>(shape, dt);
164
165 // Create and Configure function
166 CLSelect select;
167 select.configure(&ref_c, &ref_x, &ref_y, &dst);
168
169 // Validate valid region
170 const ValidRegion valid_region = shape_to_valid_region(shape);
171 validate(dst.info()->valid_region(), valid_region);
172
173 // Validate padding
174 const int step = 16 / arm_compute::data_size_from_type(dt);
175 const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding();
176 if(same_rank)
177 {
178 validate(ref_c.info()->padding(), padding);
179 }
180 validate(ref_x.info()->padding(), padding);
181 validate(ref_y.info()->padding(), padding);
182 validate(dst.info()->padding(), padding);
183}
184
185FIXTURE_DATA_TEST_CASE(RunSmall,
186 CLSelectFixture<float>,
187 framework::DatasetMode::PRECOMMIT,
188 combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32)))
189{
190 // Validate output
191 validate(CLAccessor(_target), _reference);
192}
193
194FIXTURE_DATA_TEST_CASE(RunLarge,
195 CLSelectFixture<float>,
196 framework::DatasetMode::NIGHTLY,
197 combine(run_large_dataset, framework::dataset::make("DataType", DataType::F32)))
198{
199 // Validate output
200 validate(CLAccessor(_target), _reference);
201}
202TEST_SUITE_END() // F32
203TEST_SUITE_END() // Float
204
205TEST_SUITE(Quantized)
206TEST_SUITE(QASYMM8)
207DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, configuration_dataset,
208 shape, same_rank)
209{
210 const DataType dt = DataType::QASYMM8;
211
212 // Create tensors
213 CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8);
214 CLTensor ref_x = create_tensor<CLTensor>(shape, dt);
215 CLTensor ref_y = create_tensor<CLTensor>(shape, dt);
216 CLTensor dst = create_tensor<CLTensor>(shape, dt);
217
218 // Create and Configure function
219 CLSelect select;
220 select.configure(&ref_c, &ref_x, &ref_y, &dst);
221
222 // Validate valid region
223 const ValidRegion valid_region = shape_to_valid_region(shape);
224 validate(dst.info()->valid_region(), valid_region);
225
226 // Validate padding
227 const int step = 16 / arm_compute::data_size_from_type(dt);
228 const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding();
229 if(same_rank)
230 {
231 validate(ref_c.info()->padding(), padding);
232 }
233 validate(ref_x.info()->padding(), padding);
234 validate(ref_y.info()->padding(), padding);
235 validate(dst.info()->padding(), padding);
236}
237
238FIXTURE_DATA_TEST_CASE(RunSmall,
239 CLSelectFixture<uint8_t>,
240 framework::DatasetMode::PRECOMMIT,
241 combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
242{
243 // Validate output
244 validate(CLAccessor(_target), _reference);
245}
246
247FIXTURE_DATA_TEST_CASE(RunLarge,
248 CLSelectFixture<uint8_t>,
249 framework::DatasetMode::NIGHTLY,
250 combine(run_large_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
251{
252 // Validate output
253 validate(CLAccessor(_target), _reference);
254}
255TEST_SUITE_END() // QASYMM8
256TEST_SUITE_END() // Quantized
257
258TEST_SUITE_END() // Select
259TEST_SUITE_END() // CL
260} // namespace validation
261} // namespace test
262} // namespace arm_compute