blob: 99f5ffe191e6888c4a65658c962bd672749d1951 [file] [log] [blame]
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001/*
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002 * Copyright (c) 2018-2019 ARM Limited.
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
25#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
26#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010027#include "arm_compute/core/KernelDescriptors.h"
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000028#include "arm_compute/core/Types.h"
29#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/runtime/CL/CLTensor.h"
31#include "arm_compute/runtime/CL/CLTensorAllocator.h"
32#include "tests/CL/CLAccessor.h"
33#include "tests/CL/Helper.h"
34#include "tests/PaddingCalculator.h"
35#include "tests/datasets/ShapeDatasets.h"
36#include "tests/framework/Asserts.h"
37#include "tests/framework/Macros.h"
38#include "tests/framework/datasets/Datasets.h"
39#include "tests/validation/Validation.h"
40#include "tests/validation/fixtures/GEMMFixture.h"
41
42namespace arm_compute
43{
44namespace test
45{
46namespace validation
47{
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000048using namespace arm_compute::misc::shape_calculator;
49
50// Create function for CLGEMMReshapeLHSMatrixKernel
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000051using CLGEMMReshapeLHSMatrix = CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000052
53// Create function for CLGEMMReshapeRHSMatrixKernel
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000054using CLGEMMReshapeRHSMatrix = CLSynthetizeFunction<CLGEMMReshapeRHSMatrixKernel>;
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000055
56// Create function for CLGEMMMatrixMultiplyReshapedKernel
57using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
58
59// Fixture for CLGEMMMatrixMultiplyReshaped
60template <typename T>
61using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
62
63// Fixture for CLGEMMMatrixMultiplyReshaped3D
64template <typename T>
65using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
66
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000067namespace
68{
69// *INDENT-OFF*
70// clang-format off
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000071RelativeTolerance<float> rel_tolerance_f32(0.001f);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000072constexpr float abs_tolerance_f32(0.0001f);
73
Gian Marco Iodice05639f62019-09-24 12:05:06 +010074RelativeTolerance<float> rel_tolerance_f16(0.001f);
75constexpr float abs_tolerance_f16(0.01f);
76
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000077/** Alpha values to test - Precommit */
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000078const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000079
Gian Marco Iodicee16c8902019-06-14 16:11:10 +010080/** Beta values to test - Precommit */
Gian Marco Iodiced820db62019-08-05 14:23:23 +010081const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
Gian Marco Iodicee16c8902019-06-14 16:11:10 +010082
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000083/** M values to test */
84const auto m_values = framework::dataset::make("M", 37);
85
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000086/** M_W values to test */
87const auto m_w_values = framework::dataset::make("M_W", 5);
88
89/** M_H values to test */
90const auto m_h_values = framework::dataset::make("M_H", 7);
91
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000092/** N values to test */
93const auto n_values = framework::dataset::make("N", 51);
94
95/** K values to test */
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000096const auto k_values = framework::dataset::make("K", 23);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000097
98/** Batch size values to test */
99const auto b_values = framework::dataset::make("batch_size", 1, 3);
100
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100101/** Activation values to test */
102const auto act_values = framework::dataset::make("Activation",
103{
104 ActivationLayerInfo(),
105 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
106});
107
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000108/** M0 values to test - Precommit */
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100109const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000110
111/** N0 values to test - Precommit */
Gian Marco Iodiced820db62019-08-05 14:23:23 +0100112const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000113
114/** K0 values to test - Precommit */
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000115const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000116
117/** V0 values to test - Precommit */
118const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
119
120/** H0 values to test - Precommit */
121const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
122
123/** M0 values to test - Nightly */
Michele Di Giorgio2568c6b2019-09-17 12:08:46 +0100124const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000125
126/** N0 values to test - Nightly */
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000127const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000128
129/** K0 values to test - Nightly */
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000130const auto k0_values_nightly = framework::dataset::make("K0", { 2, 3, 4, 8 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000131
132/** V0 values to test - Nightly */
133const auto v0_values_nightly = framework::dataset::make("V0", 1, 4);
134
135/** H0 values to test - Nightly */
136const auto h0_values_nightly = framework::dataset::make("H0", 1, 4);
137
138/** Interleave values to test with LHS matrix */
139const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
140
141/** Interleave values to test with RHS matrix */
142const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
143
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100144/** Broadcast bias from vector to matrix */
Gian Marco Iodiced820db62019-08-05 14:23:23 +0100145const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100146
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100147/** LHS transposed values */
148const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000149} // namespace
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000150
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000151TEST_SUITE(CL)
152TEST_SUITE(GEMMMatrixMultiplyReshaped)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000153TEST_SUITE(Float)
154TEST_SUITE(FP32)
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000155
156FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100157 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000158 m_values,
159 n_values),
160 k_values),
161 b_values),
162 m0_values_precommit),
163 n0_values_precommit),
164 k0_values_precommit),
165 v0_values_precommit),
166 h0_values_precommit),
167 i_values_lhs),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000168 i_values_rhs),
169 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100170 a_values),
171 beta_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100172 broadcast_bias_values),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100173 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100174 act_values))
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000175{
176 // Validate output
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000177 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000178}
179
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000180FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100181 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000182 m_values,
183 n_values),
184 k_values),
185 b_values),
186 m0_values_nightly),
187 n0_values_nightly),
188 k0_values_nightly),
189 v0_values_nightly),
190 h0_values_nightly),
191 i_values_lhs),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000192 i_values_rhs),
193 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100194 a_values),
195 beta_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100196 broadcast_bias_values),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100197 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100198 act_values))
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000199{
200 // Validate output
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000201 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000202}
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000203
204FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100205 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000206 m_w_values,
207 m_h_values),
208 n_values),
209 k_values),
210 b_values),
211 m0_values_precommit),
212 n0_values_precommit),
213 k0_values_precommit),
214 v0_values_precommit),
215 h0_values_precommit),
216 i_values_lhs),
217 i_values_rhs),
218 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100219 a_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100220 beta_values),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100221 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100222 act_values))
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000223{
224 // Validate output
225 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
226}
227
228FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100229 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000230 m_w_values,
231 m_h_values),
232 n_values),
233 k_values),
234 b_values),
235 m0_values_nightly),
236 n0_values_nightly),
237 k0_values_nightly),
238 v0_values_nightly),
239 h0_values_nightly),
240 i_values_lhs),
241 i_values_rhs),
242 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100243 a_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100244 beta_values),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100245 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100246 act_values))
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000247{
248 // Validate output
249 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
250}
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000251TEST_SUITE_END() // FP32
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100252
253TEST_SUITE(FP16)
254
255FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
256 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
257 m_values,
258 n_values),
259 k_values),
260 b_values),
261 m0_values_precommit),
262 n0_values_precommit),
263 k0_values_precommit),
264 v0_values_precommit),
265 h0_values_precommit),
266 i_values_lhs),
267 i_values_rhs),
268 framework::dataset::make("DataType", DataType::F16)),
269 a_values),
270 beta_values),
271 broadcast_bias_values),
272 lhs_transpose_values),
273 act_values))
274{
275 // Validate output
276 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
277}
278
279FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
280 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
281 m_values,
282 n_values),
283 k_values),
284 b_values),
285 m0_values_nightly),
286 n0_values_nightly),
287 k0_values_nightly),
288 v0_values_nightly),
289 h0_values_nightly),
290 i_values_lhs),
291 i_values_rhs),
292 framework::dataset::make("DataType", DataType::F16)),
293 a_values),
294 beta_values),
295 broadcast_bias_values),
296 lhs_transpose_values),
297 act_values))
298{
299 // Validate output
300 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
301}
302
303FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
304 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
305 m_w_values,
306 m_h_values),
307 n_values),
308 k_values),
309 b_values),
310 m0_values_precommit),
311 n0_values_precommit),
312 k0_values_precommit),
313 v0_values_precommit),
314 h0_values_precommit),
315 i_values_lhs),
316 i_values_rhs),
317 framework::dataset::make("DataType", DataType::F16)),
318 a_values),
319 beta_values),
320 lhs_transpose_values),
321 act_values))
322{
323 // Validate output
324 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
325}
326
327FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
328 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
329 m_w_values,
330 m_h_values),
331 n_values),
332 k_values),
333 b_values),
334 m0_values_nightly),
335 n0_values_nightly),
336 k0_values_nightly),
337 v0_values_nightly),
338 h0_values_nightly),
339 i_values_lhs),
340 i_values_rhs),
341 framework::dataset::make("DataType", DataType::F16)),
342 a_values),
343 beta_values),
344 lhs_transpose_values),
345 act_values))
346{
347 // Validate output
348 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
349}
350TEST_SUITE_END() // FP16
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000351TEST_SUITE_END() // Float
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100352TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000353TEST_SUITE_END() // CL
354} // namespace validation
355} // namespace test
Michele Di Giorgio2568c6b2019-09-17 12:08:46 +0100356} // namespace arm_compute