blob: ea6589df22d1ae7049ff9ddf1c363306b62499bc [file] [log] [blame]
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
25#include "arm_compute/core/Types.h"
26#include "arm_compute/core/utils/misc/ShapeCalculator.h"
27#include "arm_compute/runtime/CL/CLTensor.h"
28#include "arm_compute/runtime/CL/CLTensorAllocator.h"
29#include "tests/CL/CLAccessor.h"
30#include "tests/CL/Helper.h"
31#include "tests/PaddingCalculator.h"
32#include "tests/datasets/ShapeDatasets.h"
33#include "tests/framework/Asserts.h"
34#include "tests/framework/Macros.h"
35#include "tests/framework/datasets/Datasets.h"
36#include "tests/validation/Validation.h"
37#include "tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h"
38
39namespace arm_compute
40{
41namespace test
42{
43namespace validation
44{
45namespace
46{
47// *INDENT-OFF*
48// clang-format off
49/** Data types */
50const auto data_types = framework::dataset::make("DataType", { DataType::QASYMM8, DataType::F16, DataType::F32 });
51
52/** Batch size values to test */
53const auto b_values = framework::dataset::make("batchsize", 1, 3);
54
55/** M0 values to test */
56const auto m0_values = framework::dataset::make("M0", 2, 9);
57
58/** K0 values to test */
59const auto k0_values = framework::dataset::make("K0", { 2, 4, 8, 16 });
60
61/** V0 values to test */
62const auto v0_values = framework::dataset::make("V0", 1, 4);
63
64/** Interleave values to test */
65const auto i_values = framework::dataset::make("interleave", { true, false });
66
67/** Transpose values to test */
68const auto t_values = framework::dataset::make("transpose", { false });
69} // namespace
70
71using namespace arm_compute::misc::shape_calculator;
72
73// Initialize the output tensor with zero and fill the border with zero
74using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeLHSMatrixKernel, 16>;
75
76template <typename T>
77using CLGEMMReshapeLHSMatrixFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, false>;
78
79// Fixture to use when the input has to be reinterpreted as 3D
80template <typename T>
81using CLGEMMReshapeLHSMatrix3DFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, true>;
82
83TEST_SUITE(CL)
84TEST_SUITE(GEMMReshapeLHSMatrix)
85
86DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
87 b_values),
88 data_types),
89 m0_values),
90 k0_values),
91 v0_values),
92 i_values),
93 t_values),
94shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value)
95{
96 GEMMLHSMatrixInfo lhs_info;
97 lhs_info.m0 = m0_value;
98 lhs_info.k0 = k0_value;
99 lhs_info.v0 = v0_value;
100 lhs_info.interleave = i_value;
101 lhs_info.transpose = t_value;
102
103 const TensorShape shape_src(shape_in[0], shape_in[1], b_value);
104 const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, false);
105
106 // Create tensors
107 CLTensor src = create_tensor<CLTensor>(shape_src, data_type);
108 CLTensor dst = create_tensor<CLTensor>(shape_dst, data_type);
109
110 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
111 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
112
113 // Create and configure function
114 CLGEMMReshapeLHSMatrixKernel reshape_lhs;
115 reshape_lhs.configure(&src, &dst, lhs_info, false);
116}
117
118TEST_SUITE(S32)
119FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture<int>, framework::DatasetMode::ALL,
120 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
121 b_values),
122 framework::dataset::make("DataType", DataType::S32)),
123 m0_values),
124 k0_values),
125 v0_values),
126 i_values),
127 t_values))
128{
129 // Validate output
130 validate(CLAccessor(_target), _reference);
131}
132
133FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<int>, framework::DatasetMode::NIGHTLY,
134 combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(),
135 b_values),
136 framework::dataset::make("DataType", DataType::S32)),
137 m0_values),
138 k0_values),
139 v0_values),
140 i_values),
141 t_values))
142{
143 // Validate output
144 validate(CLAccessor(_target), _reference);
145}
146TEST_SUITE_END() // S32
147
148TEST_SUITE(S16)
149FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture<short>, framework::DatasetMode::ALL,
150 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
151 b_values),
152 framework::dataset::make("DataType", DataType::S16)),
153 m0_values),
154 k0_values),
155 v0_values),
156 i_values),
157 t_values))
158{
159 // Validate output
160 validate(CLAccessor(_target), _reference);
161}
162
163FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<short>, framework::DatasetMode::NIGHTLY,
164 combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(),
165 b_values),
166 framework::dataset::make("DataType", DataType::S16)),
167 m0_values),
168 k0_values),
169 v0_values),
170 i_values),
171 t_values))
172{
173 // Validate output
174 validate(CLAccessor(_target), _reference);
175}
176TEST_SUITE_END() // S16
177
178TEST_SUITE(S8)
179FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrixFixture<char>, framework::DatasetMode::ALL,
180 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
181 b_values),
182 framework::dataset::make("DataType", DataType::S8)),
183 m0_values),
184 k0_values),
185 v0_values),
186 i_values),
187 t_values))
188{
189 // Validate output
190 validate(CLAccessor(_target), _reference);
191}
192
193FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrixFixture<char>, framework::DatasetMode::NIGHTLY,
194 combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape2DShapes(),
195 b_values),
196 framework::dataset::make("DataType", DataType::S8)),
197 m0_values),
198 k0_values),
199 v0_values),
200 i_values),
201 t_values))
202{
203 // Validate output
204 validate(CLAccessor(_target), _reference);
205}
206TEST_SUITE_END() // S8
207
208TEST_SUITE(REINTERPRET_INPUT_AS_3D)
209DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(),
210 b_values),
211 data_types),
212 m0_values),
213 k0_values),
214 v0_values),
215 i_values),
216 t_values),
217shape_in, b_value, data_type, m0_value, k0_value, v0_value, i_value, t_value)
218{
219 GEMMLHSMatrixInfo lhs_info;
220 lhs_info.m0 = m0_value;
221 lhs_info.k0 = k0_value;
222 lhs_info.v0 = v0_value;
223 lhs_info.interleave = i_value;
224 lhs_info.transpose = t_value;
225
226 const TensorShape shape_src(shape_in[0], shape_in[1], shape_in[2], b_value);
227 const TensorShape shape_dst = compute_lhs_reshaped_shape(TensorInfo(shape_src, 1, data_type), lhs_info, true);
228
229 // Create tensors
230 CLTensor src = create_tensor<CLTensor>(shape_src, data_type);
231 CLTensor dst = create_tensor<CLTensor>(shape_dst, data_type);
232
233 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
234 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
235
236 // Create and configure function
237 CLGEMMReshapeLHSMatrixKernel reshape_lhs;
238 reshape_lhs.configure(&src, &dst, lhs_info, true);
239}
240
241TEST_SUITE(S32)
242FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture<int>, framework::DatasetMode::ALL,
243 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(),
244 b_values),
245 framework::dataset::make("DataType", DataType::S32)),
246 m0_values),
247 k0_values),
248 v0_values),
249 i_values),
250 t_values))
251{
252 // Validate output
253 validate(CLAccessor(_target), _reference);
254}
255
256FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<int>, framework::DatasetMode::NIGHTLY,
257 combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(),
258 b_values),
259 framework::dataset::make("DataType", DataType::S32)),
260 m0_values),
261 k0_values),
262 v0_values),
263 i_values),
264 t_values))
265{
266 // Validate output
267 validate(CLAccessor(_target), _reference);
268}
269TEST_SUITE_END() // S32
270
271TEST_SUITE(S16)
272FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture<short>, framework::DatasetMode::ALL,
273 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(),
274 b_values),
275 framework::dataset::make("DataType", DataType::S16)),
276 m0_values),
277 k0_values),
278 v0_values),
279 i_values),
280 t_values))
281{
282 // Validate output
283 validate(CLAccessor(_target), _reference);
284}
285
286FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<short>, framework::DatasetMode::NIGHTLY,
287 combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(),
288 b_values),
289 framework::dataset::make("DataType", DataType::S16)),
290 m0_values),
291 k0_values),
292 v0_values),
293 i_values),
294 t_values))
295{
296 // Validate output
297 validate(CLAccessor(_target), _reference);
298}
299TEST_SUITE_END() // S16
300
301TEST_SUITE(S8)
302FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMReshapeLHSMatrix3DFixture<char>, framework::DatasetMode::ALL,
303 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(),
304 b_values),
305 framework::dataset::make("DataType", DataType::S8)),
306 m0_values),
307 k0_values),
308 v0_values),
309 i_values),
310 t_values))
311{
312 // Validate output
313 validate(CLAccessor(_target), _reference);
314}
315
316FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMReshapeLHSMatrix3DFixture<char>, framework::DatasetMode::NIGHTLY,
317 combine(combine(combine(combine(combine(combine(combine(datasets::LargeGEMMReshape3DShapes(),
318 b_values),
319 framework::dataset::make("DataType", DataType::S8)),
320 m0_values),
321 k0_values),
322 v0_values),
323 i_values),
324 t_values))
325{
326 // Validate output
327 validate(CLAccessor(_target), _reference);
328}
329TEST_SUITE_END() // S8
330TEST_SUITE_END() // REINTERPRET_INPUT_AS_3D
331TEST_SUITE_END() // GEMMReshapeLHSMatrix
332TEST_SUITE_END() // CL
333} // namespace validation
334} // namespace test
335} // namespace arm_compute