blob: c7b0752cc8a23155a63583dd7a4f767a4cd47bf6 [file] [log] [blame]
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2020 Arm Limited.
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
25#include "arm_compute/core/Types.h"
26#include "arm_compute/core/utils/misc/ShapeCalculator.h"
27#include "arm_compute/runtime/CL/CLTensor.h"
28#include "arm_compute/runtime/CL/CLTensorAllocator.h"
29#include "tests/CL/CLAccessor.h"
30#include "tests/CL/Helper.h"
31#include "tests/PaddingCalculator.h"
32#include "tests/datasets/ShapeDatasets.h"
33#include "tests/framework/Asserts.h"
34#include "tests/framework/Macros.h"
35#include "tests/framework/datasets/Datasets.h"
36#include "tests/validation/Validation.h"
37#include "tests/validation/fixtures/GEMMReshapeRHSMatrixFixture.h"
38
39namespace arm_compute
40{
41namespace test
42{
43namespace validation
44{
45namespace
46{
47// *INDENT-OFF*
48// clang-format off
49/** Data types */
50const auto data_types = framework::dataset::make("DataType", { DataType::QASYMM8, DataType::F16, DataType::F32 });
51
52/** Batch size values to test */
53const auto b_values = framework::dataset::make("batchsize", 1, 3);
54
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +010055/** N0 values to test */
56const auto n0_values_nt_s32 = framework::dataset::make("N0", { 1, 2, 3 });
57const auto n0_values_nt_s16 = framework::dataset::make("N0", { 4, 8 });
58const auto n0_values_nt_s8 = framework::dataset::make("N0", { 16 });
59const auto n0_values_t_s32 = framework::dataset::make("N0", { 4, 8 });
60const auto n0_values_t_s16 = framework::dataset::make("N0", { 16 });
61const auto n0_values_t_s8 = framework::dataset::make("N0", { 2, 3 });
Gian Marco Iodice89124342018-12-19 14:17:22 +000062
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +010063/** K0 values to test */
64const auto k0_values_nt_s32 = framework::dataset::make("K0", { 1, 2 });
65const auto k0_values_nt_s16 = framework::dataset::make("K0", { 16 });
66const auto k0_values_nt_s8 = framework::dataset::make("K0", { 3,4 });
67const auto k0_values_t_s32 = framework::dataset::make("K0", { 2, 3 });
68const auto k0_values_t_s16 = framework::dataset::make("K0", { 4, 8 });
69const auto k0_values_t_s8 = framework::dataset::make("K0", { 16 });
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +000070
71/** H0 values to test */
72const auto h0_values = framework::dataset::make("H0", 1, 4);
73
74/** Interleave values to test */
75const auto i_values = framework::dataset::make("interleave", { true, false });
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +000076} // namespace
77
78using namespace arm_compute::misc::shape_calculator;
79
80// Initialize the output tensor with zero and fill the border with zero
81using CLGEMMReshapeRHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeRHSMatrixKernel, 16>;
82
83template <typename T>
84using CLGEMMReshapeRHSMatrixFixture = GEMMReshapeRHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, T>;
85
86TEST_SUITE(CL)
87TEST_SUITE(GEMMReshapeRHSMatrix)
88
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +010089// *INDENT-OFF*
90// clang-format off
91DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
92 framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32),
93 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // Mismatching data types
94 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // Wrong n0 value
95 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // Wrong k0 value
96 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // Wrong h0 value
97 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // n0 > 16
98 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // k0 > 16
99 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32), // k0 == 1 && transpose
100 }),
101 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 2U, 2U), 1, DataType::F32),
102 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F16),
103 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F32),
104 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F32),
105 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F32),
106 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F32),
107 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F32),
108 TensorInfo(TensorShape(32U, 2U, 2U), 1, DataType::F32),
109 })),
110 framework::dataset::make("N0",{ 4, 0, 4, 4, 4, 17, 4, 4 })),
111 framework::dataset::make("K0",{ 4, 4, 0, 4, 4, 4, 17, 1 })),
112 framework::dataset::make("H0",{ 4, 4, 4, 0, 4, 4, 4, 4 })),
113 framework::dataset::make("Expected", { false, false, false, false, false, false, false})),
114 input_info, output_info, n0, k0, h0, expected)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000115{
116 GEMMRHSMatrixInfo rhs_info;
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100117 rhs_info.n0 = n0;
118 rhs_info.k0 = k0;
119 rhs_info.h0 = h0;
120 rhs_info.transpose = true;
121 rhs_info.interleave = true;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000122
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100123 bool has_error = bool(CLGEMMReshapeRHSMatrixKernel::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), rhs_info));
124 ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000125}
Gian Marco Iodicea98dee22020-06-02 12:12:35 +0100126
127DATA_TEST_CASE(ValidatePadding, framework::DatasetMode::ALL, combine(combine(combine(
128 framework::dataset::make("InputShape", { TensorShape(32U, 16U, 1U),
129 TensorShape(32U, 16U, 2U)
130 }),
131 framework::dataset::make("N0",{ 4 })),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100132 framework::dataset::make("K0",{ 4, 8, 16 })),
Gian Marco Iodicea98dee22020-06-02 12:12:35 +0100133 framework::dataset::make("H0",{ 1, 2, 4 })),
134 input_shape, n0, k0, h0)
135{
136 CLTensor input;
137 CLTensor output;
138
139 input.info()->init(input_shape, 1, DataType::F32);
140
141 unsigned int padding = 0;
142
143 GEMMRHSMatrixInfo rhs_info;
144 rhs_info.n0 = n0;
145 rhs_info.k0 = k0;
146 rhs_info.h0 = h0;
147 rhs_info.transpose = true;
148 rhs_info.interleave = true;
149 rhs_info.export_to_cl_image = image2d_from_buffer_supported(CLKernelLibrary::get().get_device()) && (get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) != 0);
150
151 if(rhs_info.export_to_cl_image)
152 {
153 TensorShape output_shape = compute_rhs_reshaped_shape(*input.info(), rhs_info);
154 constexpr unsigned int num_floats_per_pixel = 4;
155
156 const unsigned int pixel_aligment = get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device());
157 const unsigned int row_pitch_alignment = pixel_aligment * num_floats_per_pixel;
158 const unsigned int round_up_width = ((output_shape[0] + row_pitch_alignment - 1) / row_pitch_alignment) * row_pitch_alignment;
159
160 padding = round_up_width - output_shape[0];
161 }
162
163 CLGEMMReshapeRHSMatrixKernel kernel;
164
165 kernel.configure(&input, &output, rhs_info);
166
167 ARM_COMPUTE_EXPECT((output.info()->padding().right == padding), framework::LogLevel::ERRORS);
168}
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100169// clang-format on
170// *INDENT-ON*
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000171
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100172// Run S32 tests only for transpose = false
173FIXTURE_DATA_TEST_CASE(S32_NT, CLGEMMReshapeRHSMatrixFixture<int>, framework::DatasetMode::ALL,
174 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
175 b_values),
176 framework::dataset::make("DataType", DataType::S32)),
177 n0_values_nt_s32),
178 k0_values_nt_s32),
179 h0_values),
180 i_values),
181 framework::dataset::make("transpose", false)))
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000182{
183 // Validate output
184 validate(CLAccessor(_target), _reference);
185}
186
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100187// Run S32 tests only for transpose = true
188FIXTURE_DATA_TEST_CASE(S32_T, CLGEMMReshapeRHSMatrixFixture<int>, framework::DatasetMode::ALL,
189 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
190 b_values),
191 framework::dataset::make("DataType", DataType::S32)),
192 n0_values_t_s32),
193 k0_values_t_s32),
194 h0_values),
195 i_values),
196 framework::dataset::make("transpose", true)))
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000197{
198 // Validate output
199 validate(CLAccessor(_target), _reference);
200}
201
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100202// Run S16 tests only for transpose = false
203FIXTURE_DATA_TEST_CASE(S16_NT, CLGEMMReshapeRHSMatrixFixture<short>, framework::DatasetMode::ALL,
204 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
205 b_values),
206 framework::dataset::make("DataType", DataType::S16)),
207 n0_values_nt_s16),
208 k0_values_nt_s16),
209 h0_values),
210 i_values),
211 framework::dataset::make("transpose", false)))
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000212{
213 // Validate output
214 validate(CLAccessor(_target), _reference);
215}
216
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100217// Run S16 tests only for transpose = true
218FIXTURE_DATA_TEST_CASE(S16_T, CLGEMMReshapeRHSMatrixFixture<short>, framework::DatasetMode::ALL,
219 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
220 b_values),
221 framework::dataset::make("DataType", DataType::S16)),
222 n0_values_t_s16),
223 k0_values_t_s16),
224 h0_values),
225 i_values),
226 framework::dataset::make("transpose", true)))
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000227{
228 // Validate output
229 validate(CLAccessor(_target), _reference);
230}
231
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100232// Run S8 tests only for transpose = false
233FIXTURE_DATA_TEST_CASE(S8_NT, CLGEMMReshapeRHSMatrixFixture<char>, framework::DatasetMode::ALL,
234 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
235 b_values),
236 framework::dataset::make("DataType", DataType::S8)),
237 n0_values_nt_s8),
238 k0_values_nt_s8),
239 h0_values),
240 i_values),
241 framework::dataset::make("transpose", false)))
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000242{
243 // Validate output
244 validate(CLAccessor(_target), _reference);
245}
246
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100247// Run S8 tests only for transpose = true
248FIXTURE_DATA_TEST_CASE(S8_T, CLGEMMReshapeRHSMatrixFixture<char>, framework::DatasetMode::ALL,
249 combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
250 b_values),
251 framework::dataset::make("DataType", DataType::S8)),
252 n0_values_t_s8),
253 k0_values_t_s8),
254 h0_values),
255 i_values),
256 framework::dataset::make("transpose", true)))
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000257{
258 // Validate output
259 validate(CLAccessor(_target), _reference);
260}
261
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000262TEST_SUITE_END() // GEMMReshapeRHSMatrix
263TEST_SUITE_END() // CL
264} // namespace validation
265} // namespace test
Michalis Spyrou5c2df4e2020-04-27 18:10:58 +0100266} // namespace arm_compute