blob: 98149ce1493dd3a3405ea337070ade95ebc58800 [file] [log] [blame]
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2020 Arm Limited.
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
25#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
26#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010027#include "arm_compute/core/KernelDescriptors.h"
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000028#include "arm_compute/core/Types.h"
29#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/runtime/CL/CLTensor.h"
31#include "arm_compute/runtime/CL/CLTensorAllocator.h"
32#include "tests/CL/CLAccessor.h"
33#include "tests/CL/Helper.h"
34#include "tests/PaddingCalculator.h"
35#include "tests/datasets/ShapeDatasets.h"
36#include "tests/framework/Asserts.h"
37#include "tests/framework/Macros.h"
38#include "tests/framework/datasets/Datasets.h"
39#include "tests/validation/Validation.h"
40#include "tests/validation/fixtures/GEMMFixture.h"
41
42namespace arm_compute
43{
44namespace test
45{
46namespace validation
47{
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000048using namespace arm_compute::misc::shape_calculator;
49
50// Create function for CLGEMMReshapeLHSMatrixKernel
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000051using CLGEMMReshapeLHSMatrix = CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000052
53// Create function for CLGEMMReshapeRHSMatrixKernel
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000054using CLGEMMReshapeRHSMatrix = CLSynthetizeFunction<CLGEMMReshapeRHSMatrixKernel>;
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000055
56// Create function for CLGEMMMatrixMultiplyReshapedKernel
57using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
58
59// Fixture for CLGEMMMatrixMultiplyReshaped
60template <typename T>
61using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
62
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +010063// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
64template <typename T>
65using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
66 GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
67
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000068// Fixture for CLGEMMMatrixMultiplyReshaped3D
69template <typename T>
70using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
71
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +010072// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
73template <typename T>
74using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
75 GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
76
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000077namespace
78{
79// *INDENT-OFF*
80// clang-format off
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000081RelativeTolerance<float> rel_tolerance_f32(0.001f);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000082constexpr float abs_tolerance_f32(0.0001f);
83
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +010084RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
85constexpr float abs_tolerance_f16_mixed_precision(0.01f);
86
Gian Marco Iodice05639f62019-09-24 12:05:06 +010087RelativeTolerance<float> rel_tolerance_f16(0.001f);
88constexpr float abs_tolerance_f16(0.01f);
89
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090/** M values to test */
morgolockaba2f912020-05-05 16:28:19 +010091const auto m_values = framework::dataset::make("M", 17);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000092
Gian Marco Iodice9382ab32018-12-17 15:12:07 +000093/** M_W values to test */
94const auto m_w_values = framework::dataset::make("M_W", 5);
95
96/** M_H values to test */
97const auto m_h_values = framework::dataset::make("M_H", 7);
98
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000099/** N values to test */
morgolockaba2f912020-05-05 16:28:19 +0100100const auto n_values = framework::dataset::make("N", 21);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000101
102/** K values to test */
morgolockaba2f912020-05-05 16:28:19 +0100103const auto k_values = framework::dataset::make("K", 13);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000104
105/** Batch size values to test */
morgolockaba2f912020-05-05 16:28:19 +0100106const auto b_values = framework::dataset::make("batch_size", 2, 3);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000107
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100108/** Activation values to test */
109const auto act_values = framework::dataset::make("Activation",
110{
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100111 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
112});
113
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100114/** Alpha values to test - Precommit */
115const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
116
117/** Beta values to test - Precommit */
118const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
119
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000120/** M0 values to test - Precommit */
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100121const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000122
123/** N0 values to test - Precommit */
Gian Marco Iodiced820db62019-08-05 14:23:23 +0100124const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000125
126/** K0 values to test - Precommit */
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000127const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000128
129/** V0 values to test - Precommit */
130const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
131
132/** H0 values to test - Precommit */
133const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
134
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100135/** Alpha values to test - Nightly */
136const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
137
138/** Beta values to test - Nightly */
139const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
140
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000141/** M0 values to test - Nightly */
Gian Marco Iodice6f931342020-09-15 14:17:41 +0100142const auto m0_values_nightly = framework::dataset::make("M0", { 8 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000143
144/** N0 values to test - Nightly */
Gian Marco Iodice6f931342020-09-15 14:17:41 +0100145const auto n0_values_nightly = framework::dataset::make("N0", { 8 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000146
147/** K0 values to test - Nightly */
Gian Marco Iodice6f931342020-09-15 14:17:41 +0100148const auto k0_values_nightly = framework::dataset::make("K0", { 4 });
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000149
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100150/** N0 values to test with export to OpenCL image object - Nightly */
151const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0", { 4, 8, 16 });
152
153/** K0 values to test with export to OpenCL image object - Nightly */
154const auto k0_export_to_cl_image_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
155
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000156/** V0 values to test - Nightly */
Gian Marco Iodice6f931342020-09-15 14:17:41 +0100157const auto v0_values_nightly = framework::dataset::make("V0", 1, 3);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000158
159/** H0 values to test - Nightly */
Gian Marco Iodice6f931342020-09-15 14:17:41 +0100160const auto h0_values_nightly = framework::dataset::make("H0", 1, 3);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000161
162/** Interleave values to test with LHS matrix */
163const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
164
165/** Interleave values to test with RHS matrix */
166const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
167
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100168/** Broadcast bias from vector to matrix */
Gian Marco Iodiced820db62019-08-05 14:23:23 +0100169const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100170
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100171/** LHS transposed values */
172const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
Gian Marco Iodice088d63a2020-08-11 14:14:06 +0100173
174/** Zero padding test */
175bool validate_zero_padding(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value,
176 unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value,
177 bool i_value_rhs, bool t_value_rhs, bool export_to_cl_image, bool broadcast_bias, unsigned int depth_output_gemm3d, const ActivationLayerInfo &act_info,
178 DataType dt_input0, DataType dt_input1, DataType dt_input2, DataType dt_output, float alpha, float beta)
179{
180 const unsigned int M = m_value;
181 const unsigned int N = n_value;
182 const unsigned int K = k_value;
183
184 GEMMLHSMatrixInfo lhs_info;
185 lhs_info.m0 = m0_value;
186 lhs_info.k0 = k0_value;
187
188 GEMMRHSMatrixInfo rhs_info;
189 rhs_info.n0 = n0_value;
190 rhs_info.k0 = k0_value;
191 rhs_info.h0 = h0_value;
192 rhs_info.interleave = i_value_rhs;
193 rhs_info.transpose = t_value_rhs;
194 rhs_info.export_to_cl_image = export_to_cl_image;
195
196 GEMMKernelInfo kernel_info;
197 kernel_info.m = M;
198 kernel_info.n = N;
199 kernel_info.k = K;
200 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
201 kernel_info.reinterpret_input_as_3d = false;
202 kernel_info.broadcast_bias = broadcast_bias;
203 kernel_info.activation_info = act_info;
204
205 const TensorShape lhs_shape(K, M, b_value);
206 const TensorShape rhs_shape(N, K, b_value);
207 const TensorShape lhs_shape_reshaped = compute_lhs_reshaped_shape(TensorInfo(lhs_shape, 1, dt_input0),
208 lhs_info);
209 const TensorShape rhs_shape_reshaped = compute_rhs_reshaped_shape(TensorInfo(rhs_shape, 1, dt_input1),
210 rhs_info);
211
212 const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape_reshaped, 1, dt_input0),
213 TensorInfo(rhs_shape_reshaped, 1, dt_input1),
214 kernel_info);
215
216 const TensorShape bias_shape(N,
217 M, // Correct calculation should be: broadcast_bias? 1 : M, it's wrong here on purpose just for validation test
218 broadcast_bias? 1 : b_value);
219
220 // Create tensors
221 CLTensor lhs_reshaped = create_tensor<CLTensor>(lhs_shape_reshaped, dt_input0);
222 CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, dt_input1);
223 CLTensor bias = create_tensor<CLTensor>(bias_shape, dt_input2);
224 CLTensor dst = create_tensor<CLTensor>(dst_shape, dt_output);
225
226 ARM_COMPUTE_EXPECT(lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
227 ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
228 ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
229 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
230
231 // Validate zero-padding
232 CLGEMMMatrixMultiplyReshaped gemm;
233
234 gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
235
236 // Padding can be added along rhs and bias's X/Y dimension
237 return dst.info()->padding().empty() && lhs_reshaped.info()->padding().empty();
238}
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000239} // namespace
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000240
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000241TEST_SUITE(CL)
242TEST_SUITE(GEMMMatrixMultiplyReshaped)
morgolockaba2f912020-05-05 16:28:19 +0100243
Gian Marco Iodice088d63a2020-08-11 14:14:06 +0100244/** Validate zero padding tests
245 *
246 * A series of validation tests to check the zero padding requirement
247 *
248 * Checks performed in order:
249 * - No partial blocks in both x and y dimensions
250 * - Partial blocks in x dimension
251 * - Partial blocks in y dimension
252 * - Partial blocks in both x and y dimensions
253 * - Special case: partial_n0 == 9 (vstore1 should be invoked instead of vstore_partial_1)
254 */
255DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(zip(zip(
256framework::dataset::make("M", { 24, 64, 101, 1, 103 }),
257framework::dataset::make("N", { 48, 29, 16, 121, 41 })),
258framework::dataset::make("M0", { 4, 8, 4, 2, 4 })),
259framework::dataset::make("N0", { 4, 4, 16, 2, 16 })),
260m_value, n_value, m0_value, n0_value)
261{
262 constexpr DataType dt = DataType::F32;
263
264 bool status = validate_zero_padding(m_value, n_value, 23, 1, m0_value, n0_value, 4, 1, false, false, false, 0, 0, ActivationLayerInfo(), dt, dt, dt, dt, 1.0f, 1.0f);
265 ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
266}
267
morgolockaba2f912020-05-05 16:28:19 +0100268// *INDENT-OFF*
269// clang-format off
270DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
271 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK
272 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK
273 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8), // Data type not supported
274 TensorInfo(TensorShape(10U, 5U, 2U), 1, DataType::F32), // Incorrect dimension bias
275 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // Mismatching shapes
276 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, do not broadcast bias
277 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, wider accummulation
278 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, RHS 4,4,2
279
280 }),
281 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
282 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
283 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),
284 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
285 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
286 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
287 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
288 TensorInfo(TensorShape(128U, 3U, 2U), 1, DataType::F16),
289
290 })),
291 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32),
292 TensorInfo(TensorShape(21U), 1, DataType::F16),
293 TensorInfo(TensorShape(21U), 1, DataType::QASYMM8),
294 TensorInfo(TensorShape(21U), 1, DataType::F32),
295 TensorInfo(TensorShape(21U), 1, DataType::F32),
296 TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
297 TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
298 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
299
300 })),
301 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
302 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
303 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::QASYMM8),
304 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
305 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
306 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
307 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
308 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
309
310 })),
311 framework::dataset::make("LHSMInfo",{
312 GEMMLHSMatrixInfo(4,4,1,false,true),
313 GEMMLHSMatrixInfo(4,4,1,false,true),
314 GEMMLHSMatrixInfo(4,4,1,false,true),
315 GEMMLHSMatrixInfo(4,2,4,false,false),
316 GEMMLHSMatrixInfo(4,2,4,false,false),
317 GEMMLHSMatrixInfo(4,4,1,false,true),
318 GEMMLHSMatrixInfo(4,4,1,false,true),
319 GEMMLHSMatrixInfo(4,4,1,false,true),
320
321 })),
322 framework::dataset::make("RHSMInfo",{
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100323 GEMMRHSMatrixInfo(4,4,1,true,true,false),
324 GEMMRHSMatrixInfo(4,4,1,true,true,false),
325 GEMMRHSMatrixInfo(4,4,1,true,true,false),
326 GEMMRHSMatrixInfo(2,2,1,true,false,false),
327 GEMMRHSMatrixInfo(2,2,1,true,false,false),
328 GEMMRHSMatrixInfo(4,4,1,true,true,false),
329 GEMMRHSMatrixInfo(4,4,1,true,true,false),
330 GEMMRHSMatrixInfo(4,4,2,true,false,false),
morgolockaba2f912020-05-05 16:28:19 +0100331
332
333 })),
334
335
336 framework::dataset::make("GEMMInfo",{
337 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
338 21 /**<N Number of RHS columns*/,
339 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
340 false /**< reinterpret the input as 3D */,
341 true /**< Flag used to broadcast the bias addition */,
342 false /**< wider accumm */,
343 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
344 1 /**< Multiplication factor for the width of the 1xW transposed block */,
345 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
346 GEMMLHSMatrixInfo(4,4,1,false,true),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100347 GEMMRHSMatrixInfo(4,4,1,true,true,false),
morgolockaba2f912020-05-05 16:28:19 +0100348 0 /**< Offset to be added to each element of the matrix A */,
349 0 /**< Offset to be added to each element of the matrix B */),
350
351 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
352 21 /**<N Number of RHS columns*/,
353 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
354 false /**< reinterpret the input as 3D */,
355 true /**< Flag used to broadcast the bias addition */,
356 false /**< wider accumm */,
357 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
358 1 /**< Multiplication factor for the width of the 1xW transposed block */,
359 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
360 GEMMLHSMatrixInfo(4,4,1,false,true),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100361 GEMMRHSMatrixInfo(4,4,1,true,true,false),
morgolockaba2f912020-05-05 16:28:19 +0100362 0 /**< Offset to be added to each element of the matrix A */,
363 0 /**< Offset to be added to each element of the matrix B */),
364 GEMMKernelInfo(),
365 GEMMKernelInfo(),
366 GEMMKernelInfo(),
367
368 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
369 21 /**<N Number of RHS columns*/,
370 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
371 false /**< reinterpret the input as 3D */,
372 false /**< Flag used to broadcast the bias addition */,
373 false /**< wider accumm */,
374 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
375 1 /**< Multiplication factor for the width of the 1xW transposed block */,
376 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
377 GEMMLHSMatrixInfo(4,4,1,false,true),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100378 GEMMRHSMatrixInfo(4,4,1,true,true,false),
morgolockaba2f912020-05-05 16:28:19 +0100379 0 /**< Offset to be added to each element of the matrix A */,
380 0 /**< Offset to be added to each element of the matrix B */),
381
382
383 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
384 21 /**<N Number of RHS columns*/,
385 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
386 false /**< reinterpret the input as 3D */,
387 false /**< Flag used to broadcast the bias addition */,
388 true /**< wider accumm */,
389 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
390 1 /**< Multiplication factor for the width of the 1xW transposed block */,
391 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
392 GEMMLHSMatrixInfo(4,4,1,false,true),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100393 GEMMRHSMatrixInfo(4,4,1,true,true,false),
morgolockaba2f912020-05-05 16:28:19 +0100394 0 /**< Offset to be added to each element of the matrix A */,
395 0 /**< Offset to be added to each element of the matrix B */),
396
397 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
398 21 /**<N Number of RHS columns*/,
399 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
400 false /**< reinterpret the input as 3D */,
401 false /**< Flag used to broadcast the bias addition */,
402 false /**< wider accumm */,
403 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
404 1 /**< Multiplication factor for the width of the 1xW transposed block */,
405 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
406 GEMMLHSMatrixInfo(4,4,1,false,true),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100407 GEMMRHSMatrixInfo(4,4,2,true,false,false),
morgolockaba2f912020-05-05 16:28:19 +0100408 0 /**< Offset to be added to each element of the matrix A */,
409 0 /**< Offset to be added to each element of the matrix B */),
410 })),
411 framework::dataset::make("Expected", { true, true, false, false, false, true, true,true})),
412 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
413{
414 ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
415 &input1_info.clone()->set_is_resizable(true),
416 &input2_info.clone()->set_is_resizable(true),
417 &output_info.clone()->set_is_resizable(true),1.f,1.f,
418 lhs_info,
419 rhs_info,
420 gemm_info)) == expected, framework::LogLevel::ERRORS);
421}
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000422TEST_SUITE(Float)
423TEST_SUITE(FP32)
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000424
425FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100426 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000427 m_values,
428 n_values),
429 k_values),
430 b_values),
431 m0_values_precommit),
432 n0_values_precommit),
433 k0_values_precommit),
434 v0_values_precommit),
435 h0_values_precommit),
436 i_values_lhs),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000437 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100438 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000439 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100440 a_values_precommit),
441 beta_values_precommit),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100442 broadcast_bias_values),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100443 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100444 act_values))
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000445{
446 // Validate output
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000447 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000448}
449
Michalis Spyrou1d897772019-12-09 18:47:29 +0000450FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::DISABLED,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100451 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000452 m_values,
453 n_values),
454 k_values),
455 b_values),
456 m0_values_nightly),
457 n0_values_nightly),
458 k0_values_nightly),
459 v0_values_nightly),
460 h0_values_nightly),
461 i_values_lhs),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000462 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100463 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000464 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100465 a_values_nightly),
466 beta_values_nightly),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100467 broadcast_bias_values),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100468 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100469 act_values))
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000470{
471 // Validate output
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000472 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000473}
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000474
475FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100476 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000477 m_w_values,
478 m_h_values),
479 n_values),
480 k_values),
481 b_values),
482 m0_values_precommit),
483 n0_values_precommit),
484 k0_values_precommit),
485 v0_values_precommit),
486 h0_values_precommit),
487 i_values_lhs),
488 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100489 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000490 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100491 a_values_precommit),
492 beta_values_precommit),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100493 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100494 act_values))
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000495{
496 // Validate output
497 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
498}
499
Michalis Spyrou1d897772019-12-09 18:47:29 +0000500FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::DISABLED,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100501 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000502 m_w_values,
503 m_h_values),
504 n_values),
505 k_values),
506 b_values),
507 m0_values_nightly),
508 n0_values_nightly),
509 k0_values_nightly),
510 v0_values_nightly),
511 h0_values_nightly),
512 i_values_lhs),
513 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100514 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000515 framework::dataset::make("DataType", DataType::F32)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100516 a_values_nightly),
517 beta_values_nightly),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +0100518 lhs_transpose_values),
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +0100519 act_values))
Gian Marco Iodice9382ab32018-12-17 15:12:07 +0000520{
521 // Validate output
522 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
523}
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100524TEST_SUITE(ExportToCLImage)
525DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
526 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
527 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
528 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
529 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // Incorrect k0
530 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // Incorrect n0
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100531
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100532 }),
533 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
534 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
535 TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F32),
536 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
537 TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F32),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100538
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100539 })),
540 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F32),
541 TensorInfo(TensorShape(64U), 1, DataType::F32),
542 TensorInfo(TensorShape(64U), 1, DataType::F32),
543 TensorInfo(TensorShape(64U), 1, DataType::F32),
544 TensorInfo(TensorShape(64U), 1, DataType::F32),
545
546 })),
547 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
548 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
549 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
550 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
551 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
552 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
553
554 })),
555 framework::dataset::make("LHSMInfo",{
556 GEMMLHSMatrixInfo(4, 4, 1, false, true),
557 GEMMLHSMatrixInfo(4, 8, 1, false, true),
558 GEMMLHSMatrixInfo(4, 4, 1, false, true),
559 GEMMLHSMatrixInfo(4, 2, 1, false, false),
560 GEMMLHSMatrixInfo(4, 4, 1, false, false),
561
562 })),
563 framework::dataset::make("RHSMInfo",{
564 GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
565 GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
566 GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
567 GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
568 GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
569 })),
570 framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
571 64 /**<N Number of RHS columns*/,
572 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
573 false /**< reinterpret the input as 3D */,
574 true /**< Flag used to broadcast the bias addition */,
575 false /**< wider accumm */,
576 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
577 1 /**< Multiplication factor for the width of the 1xW transposed block */,
578 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
579 GEMMLHSMatrixInfo(),
580 GEMMRHSMatrixInfo(),
581 0 /**< Offset to be added to each element of the matrix A */,
582 0 /**< Offset to be added to each element of the matrix B */),
583 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
584 64 /**<N Number of RHS columns*/,
585 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
586 false /**< reinterpret the input as 3D */,
587 true /**< Flag used to broadcast the bias addition */,
588 false /**< wider accumm */,
589 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
590 1 /**< Multiplication factor for the width of the 1xW transposed block */,
591 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
592 GEMMLHSMatrixInfo(),
593 GEMMRHSMatrixInfo(),
594 0 /**< Offset to be added to each element of the matrix A */,
595 0 /**< Offset to be added to each element of the matrix B */),
596 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
597 64 /**<N Number of RHS columns*/,
598 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
599 false /**< reinterpret the input as 3D */,
600 true /**< Flag used to broadcast the bias addition */,
601 false /**< wider accumm */,
602 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
603 1 /**< Multiplication factor for the width of the 1xW transposed block */,
604 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
605 GEMMLHSMatrixInfo(),
606 GEMMRHSMatrixInfo(),
607 0 /**< Offset to be added to each element of the matrix A */,
608 0 /**< Offset to be added to each element of the matrix B */),
609
610 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
611 64 /**<N Number of RHS columns*/,
612 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
613 false /**< reinterpret the input as 3D */,
614 true /**< Flag used to broadcast the bias addition */,
615 false /**< wider accumm */,
616 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
617 1 /**< Multiplication factor for the width of the 1xW transposed block */,
618 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
619 GEMMLHSMatrixInfo(),
620 GEMMRHSMatrixInfo(),
621 0 /**< Offset to be added to each element of the matrix A */,
622 0 /**< Offset to be added to each element of the matrix B */),
623 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
624 64 /**<N Number of RHS columns*/,
625 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
626 false /**< reinterpret the input as 3D */,
627 true /**< Flag used to broadcast the bias addition */,
628 false /**< wider accumm */,
629 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
630 1 /**< Multiplication factor for the width of the 1xW transposed block */,
631 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
632 GEMMLHSMatrixInfo(),
633 GEMMRHSMatrixInfo(),
634 0 /**< Offset to be added to each element of the matrix A */,
635 0 /**< Offset to be added to each element of the matrix B */)
636 })),
637 framework::dataset::make("Expected", { true,
638 true,
639 true,
640 false,
641 false})),
642 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
643{
644 ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
645 &input1_info.clone()->set_is_resizable(true),
646 &input2_info.clone()->set_is_resizable(true),
647 &output_info.clone()->set_is_resizable(true),1.f,1.f,
648 lhs_info,
649 rhs_info,
650 gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
651}
652
653FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
654 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100655 m_values,
656 n_values),
657 k_values),
658 b_values),
659 m0_values_precommit),
660 n0_values_precommit),
661 k0_values_precommit),
662 v0_values_precommit),
663 h0_values_precommit),
664 i_values_lhs),
665 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100666 framework::dataset::make("export_to_cl_image_rhs", true)),
667 framework::dataset::make("DataType", DataType::F32)),
668 a_values_precommit),
669 beta_values_precommit),
670 broadcast_bias_values),
671 lhs_transpose_values),
672 act_values))
673{
674 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
675 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
676 {
677 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
678 }
679 else
680 {
681 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
682 framework::ARM_COMPUTE_PRINT_INFO();
683 }
684
685}
686
687FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
688 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
689 m_values,
690 n_values),
691 k_values),
692 b_values),
693 m0_values_nightly),
694 n0_export_to_cl_image_values_nightly),
695 k0_export_to_cl_image_values_nightly),
696 v0_values_nightly),
697 h0_values_nightly),
698 i_values_lhs),
699 i_values_rhs),
700 framework::dataset::make("export_to_cl_image_rhs", true)),
701 framework::dataset::make("DataType", DataType::F32)),
702 a_values_nightly),
703 beta_values_nightly),
704 broadcast_bias_values),
705 lhs_transpose_values),
706 act_values))
707{
708 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
709 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
710 {
711 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
712 }
713 else
714 {
715 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
716 framework::ARM_COMPUTE_PRINT_INFO();
717 }
718}
719
720FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
721 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
722 m_w_values,
723 m_h_values),
724 n_values),
725 k_values),
726 b_values),
727 m0_values_precommit),
728 n0_values_precommit),
729 k0_values_precommit),
730 v0_values_precommit),
731 h0_values_precommit),
732 i_values_lhs),
733 i_values_rhs),
734 framework::dataset::make("export_to_cl_image_rhs", true)),
735 framework::dataset::make("DataType", DataType::F32)),
736 a_values_precommit),
737 beta_values_precommit),
738 lhs_transpose_values),
739 act_values))
740{
741 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
742 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
743 {
744 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
745 }
746 else
747 {
748 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
749 framework::ARM_COMPUTE_PRINT_INFO();
750 }
751}
752
753FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
754 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
755 m_w_values,
756 m_h_values),
757 n_values),
758 k_values),
759 b_values),
760 m0_values_nightly),
761 n0_export_to_cl_image_values_nightly),
762 k0_export_to_cl_image_values_nightly),
763 v0_values_nightly),
764 h0_values_nightly),
765 i_values_lhs),
766 i_values_rhs),
767 framework::dataset::make("export_to_cl_image_rhs", true)),
768 framework::dataset::make("DataType", DataType::F32)),
769 a_values_nightly),
770 beta_values_nightly),
771 lhs_transpose_values),
772 act_values))
773{
774 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
775 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
776 {
777 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
778 }
779 else
780 {
781 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
782 framework::ARM_COMPUTE_PRINT_INFO();
783 }
784}
785TEST_SUITE_END() // ExportToCLImage
786TEST_SUITE_END() // FP32
787
788TEST_SUITE(FP16)
789
790FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
791 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
792 m_values,
793 n_values),
794 k_values),
795 b_values),
796 m0_values_precommit),
797 n0_values_precommit),
798 k0_values_precommit),
799 v0_values_precommit),
800 h0_values_precommit),
801 i_values_lhs),
802 i_values_rhs),
803 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100804 framework::dataset::make("DataType", DataType::F16)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100805 a_values_precommit),
806 beta_values_precommit),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100807 broadcast_bias_values),
808 lhs_transpose_values),
809 act_values))
810{
811 // Validate output
812 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
813}
814
Michalis Spyrou1d897772019-12-09 18:47:29 +0000815FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::DISABLED,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100816 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100817 m_values,
818 n_values),
819 k_values),
820 b_values),
821 m0_values_nightly),
822 n0_values_nightly),
823 k0_values_nightly),
824 v0_values_nightly),
825 h0_values_nightly),
826 i_values_lhs),
827 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100828 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100829 framework::dataset::make("DataType", DataType::F16)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100830 a_values_nightly),
831 beta_values_nightly),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100832 broadcast_bias_values),
833 lhs_transpose_values),
834 act_values))
835{
836 // Validate output
837 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
838}
839
840FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100841 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100842 m_w_values,
843 m_h_values),
844 n_values),
845 k_values),
846 b_values),
847 m0_values_precommit),
848 n0_values_precommit),
849 k0_values_precommit),
850 v0_values_precommit),
851 h0_values_precommit),
852 i_values_lhs),
853 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100854 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100855 framework::dataset::make("DataType", DataType::F16)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100856 a_values_precommit),
857 beta_values_precommit),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100858 lhs_transpose_values),
859 act_values))
860{
861 // Validate output
862 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
863}
864
Michalis Spyrou1d897772019-12-09 18:47:29 +0000865FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::DISABLED,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100866 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100867 m_w_values,
868 m_h_values),
869 n_values),
870 k_values),
871 b_values),
872 m0_values_nightly),
873 n0_values_nightly),
874 k0_values_nightly),
875 v0_values_nightly),
876 h0_values_nightly),
877 i_values_lhs),
878 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +0100879 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100880 framework::dataset::make("DataType", DataType::F16)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +0100881 a_values_nightly),
882 beta_values_nightly),
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100883 lhs_transpose_values),
884 act_values))
885{
886 // Validate output
887 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
888}
Gian Marco Iodice6f931342020-09-15 14:17:41 +0100889
890TEST_SUITE(ExportToCLImage)
891DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
892 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
893 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
894 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
895 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect k0
896 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect n0
897
898 }),
899 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
900 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
901 TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F16),
902 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
903 TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F16),
904
905 })),
906 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F16),
907 TensorInfo(TensorShape(64U), 1, DataType::F16),
908 TensorInfo(TensorShape(64U), 1, DataType::F16),
909 TensorInfo(TensorShape(64U), 1, DataType::F16),
910 TensorInfo(TensorShape(64U), 1, DataType::F16),
911
912 })),
913 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
914 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
915 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
916 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
917 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
918 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
919
920 })),
921 framework::dataset::make("LHSMInfo",{
922 GEMMLHSMatrixInfo(4, 4, 1, false, true),
923 GEMMLHSMatrixInfo(4, 8, 1, false, true),
924 GEMMLHSMatrixInfo(4, 4, 1, false, true),
925 GEMMLHSMatrixInfo(4, 2, 1, false, false),
926 GEMMLHSMatrixInfo(4, 4, 1, false, false),
927
928 })),
929 framework::dataset::make("RHSMInfo",{
930 GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
931 GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
932 GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
933 GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
934 GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
935 })),
936 framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
937 64 /**<N Number of RHS columns*/,
938 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
939 false /**< reinterpret the input as 3D */,
940 true /**< Flag used to broadcast the bias addition */,
941 false /**< wider accumm */,
942 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
943 1 /**< Multiplication factor for the width of the 1xW transposed block */,
944 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
945 GEMMLHSMatrixInfo(),
946 GEMMRHSMatrixInfo(),
947 0 /**< Offset to be added to each element of the matrix A */,
948 0 /**< Offset to be added to each element of the matrix B */),
949 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
950 64 /**<N Number of RHS columns*/,
951 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
952 false /**< reinterpret the input as 3D */,
953 true /**< Flag used to broadcast the bias addition */,
954 false /**< wider accumm */,
955 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
956 1 /**< Multiplication factor for the width of the 1xW transposed block */,
957 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
958 GEMMLHSMatrixInfo(),
959 GEMMRHSMatrixInfo(),
960 0 /**< Offset to be added to each element of the matrix A */,
961 0 /**< Offset to be added to each element of the matrix B */),
962 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
963 64 /**<N Number of RHS columns*/,
964 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
965 false /**< reinterpret the input as 3D */,
966 true /**< Flag used to broadcast the bias addition */,
967 false /**< wider accumm */,
968 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
969 1 /**< Multiplication factor for the width of the 1xW transposed block */,
970 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
971 GEMMLHSMatrixInfo(),
972 GEMMRHSMatrixInfo(),
973 0 /**< Offset to be added to each element of the matrix A */,
974 0 /**< Offset to be added to each element of the matrix B */),
975
976 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
977 64 /**<N Number of RHS columns*/,
978 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
979 false /**< reinterpret the input as 3D */,
980 true /**< Flag used to broadcast the bias addition */,
981 false /**< wider accumm */,
982 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
983 1 /**< Multiplication factor for the width of the 1xW transposed block */,
984 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
985 GEMMLHSMatrixInfo(),
986 GEMMRHSMatrixInfo(),
987 0 /**< Offset to be added to each element of the matrix A */,
988 0 /**< Offset to be added to each element of the matrix B */),
989 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
990 64 /**<N Number of RHS columns*/,
991 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
992 false /**< reinterpret the input as 3D */,
993 true /**< Flag used to broadcast the bias addition */,
994 false /**< wider accumm */,
995 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
996 1 /**< Multiplication factor for the width of the 1xW transposed block */,
997 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
998 GEMMLHSMatrixInfo(),
999 GEMMRHSMatrixInfo(),
1000 0 /**< Offset to be added to each element of the matrix A */,
1001 0 /**< Offset to be added to each element of the matrix B */)
1002 })),
1003 framework::dataset::make("Expected", { true,
1004 true,
1005 true,
1006 false,
1007 false})),
1008 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
1009{
1010 ARM_COMPUTE_EXPECT(bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
1011 &input1_info.clone()->set_is_resizable(true),
1012 &input2_info.clone()->set_is_resizable(true),
1013 &output_info.clone()->set_is_resizable(true),1.f,1.f,
1014 lhs_info,
1015 rhs_info,
1016 gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
1017}
1018
1019FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
1020 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1021 m_values,
1022 n_values),
1023 k_values),
1024 b_values),
1025 m0_values_precommit),
1026 n0_values_precommit),
1027 k0_values_precommit),
1028 v0_values_precommit),
1029 h0_values_precommit),
1030 i_values_lhs),
1031 i_values_rhs),
1032 framework::dataset::make("export_to_cl_image_rhs", true)),
1033 framework::dataset::make("DataType", DataType::F16)),
1034 a_values_precommit),
1035 beta_values_precommit),
1036 broadcast_bias_values),
1037 lhs_transpose_values),
1038 act_values))
1039{
1040 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
1041 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
1042 {
1043 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1044 }
1045 else
1046 {
1047 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1048 framework::ARM_COMPUTE_PRINT_INFO();
1049 }
1050
1051}
1052
1053FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
1054 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1055 m_values,
1056 n_values),
1057 k_values),
1058 b_values),
1059 m0_values_nightly),
1060 n0_export_to_cl_image_values_nightly),
1061 k0_export_to_cl_image_values_nightly),
1062 v0_values_nightly),
1063 h0_values_nightly),
1064 i_values_lhs),
1065 i_values_rhs),
1066 framework::dataset::make("export_to_cl_image_rhs", true)),
1067 framework::dataset::make("DataType", DataType::F16)),
1068 a_values_nightly),
1069 beta_values_nightly),
1070 broadcast_bias_values),
1071 lhs_transpose_values),
1072 act_values))
1073{
1074 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
1075 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
1076 {
1077 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1078 }
1079 else
1080 {
1081 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1082 framework::ARM_COMPUTE_PRINT_INFO();
1083 }
1084}
1085
1086FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1087 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1088 m_w_values,
1089 m_h_values),
1090 n_values),
1091 k_values),
1092 b_values),
1093 m0_values_precommit),
1094 n0_values_precommit),
1095 k0_values_precommit),
1096 v0_values_precommit),
1097 h0_values_precommit),
1098 i_values_lhs),
1099 i_values_rhs),
1100 framework::dataset::make("export_to_cl_image_rhs", true)),
1101 framework::dataset::make("DataType", DataType::F16)),
1102 a_values_precommit),
1103 beta_values_precommit),
1104 lhs_transpose_values),
1105 act_values))
1106{
1107 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
1108 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
1109 {
1110 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1111 }
1112 else
1113 {
1114 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1115 framework::ARM_COMPUTE_PRINT_INFO();
1116 }
1117}
1118
1119FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
1120 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1121 m_w_values,
1122 m_h_values),
1123 n_values),
1124 k_values),
1125 b_values),
1126 m0_values_nightly),
1127 n0_export_to_cl_image_values_nightly),
1128 k0_export_to_cl_image_values_nightly),
1129 v0_values_nightly),
1130 h0_values_nightly),
1131 i_values_lhs),
1132 i_values_rhs),
1133 framework::dataset::make("export_to_cl_image_rhs", true)),
1134 framework::dataset::make("DataType", DataType::F16)),
1135 a_values_nightly),
1136 beta_values_nightly),
1137 lhs_transpose_values),
1138 act_values))
1139{
1140 // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
1141 if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
1142 {
1143 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1144 }
1145 else
1146 {
1147 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1148 framework::ARM_COMPUTE_PRINT_INFO();
1149 }
1150}
1151TEST_SUITE_END() // ExportToCLImage
Gian Marco Iodice05639f62019-09-24 12:05:06 +01001152TEST_SUITE_END() // FP16
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001153
1154TEST_SUITE(MixedPrecision)
1155
1156FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001157 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001158 m_values,
1159 n_values),
1160 k_values),
1161 b_values),
1162 m0_values_precommit),
1163 n0_values_precommit),
1164 k0_values_precommit),
1165 v0_values_precommit),
1166 h0_values_precommit),
1167 i_values_lhs),
1168 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001169 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001170 framework::dataset::make("DataType", DataType::F16)),
1171 a_values_precommit),
1172 beta_values_precommit),
1173 broadcast_bias_values),
1174 lhs_transpose_values),
1175 act_values))
1176{
1177 // Validate output
1178 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1179}
1180
Michalis Spyrou1d897772019-12-09 18:47:29 +00001181FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001182 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001183 m_values,
1184 n_values),
1185 k_values),
1186 b_values),
1187 m0_values_nightly),
1188 n0_values_nightly),
1189 k0_values_nightly),
1190 v0_values_nightly),
1191 h0_values_nightly),
1192 i_values_lhs),
1193 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001194 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001195 framework::dataset::make("DataType", DataType::F16)),
1196 a_values_nightly),
1197 beta_values_nightly),
1198 broadcast_bias_values),
1199 lhs_transpose_values),
1200 act_values))
1201{
1202 // Validate output
1203 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1204}
1205
1206FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001207 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001208 m_w_values,
1209 m_h_values),
1210 n_values),
1211 k_values),
1212 b_values),
1213 m0_values_precommit),
1214 n0_values_precommit),
1215 k0_values_precommit),
1216 v0_values_precommit),
1217 h0_values_precommit),
1218 i_values_lhs),
1219 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001220 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001221 framework::dataset::make("DataType", DataType::F16)),
1222 a_values_precommit),
1223 beta_values_precommit),
1224 lhs_transpose_values),
1225 act_values))
1226{
1227 // Validate output
1228 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1229}
1230
Michalis Spyrou1d897772019-12-09 18:47:29 +00001231FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001232 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001233 m_w_values,
1234 m_h_values),
1235 n_values),
1236 k_values),
1237 b_values),
1238 m0_values_nightly),
1239 n0_values_nightly),
1240 k0_values_nightly),
1241 v0_values_nightly),
1242 h0_values_nightly),
1243 i_values_lhs),
1244 i_values_rhs),
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001245 framework::dataset::make("export_to_cl_image_rhs", false)),
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001246 framework::dataset::make("DataType", DataType::F16)),
1247 a_values_nightly),
1248 beta_values_nightly),
1249 lhs_transpose_values),
1250 act_values))
1251{
1252 // Validate output
1253 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1254}
1255TEST_SUITE_END() // MixedPrecision
Gian Marco Iodice9382ab32018-12-17 15:12:07 +00001256TEST_SUITE_END() // Float
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001257TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001258TEST_SUITE_END() // CL
1259} // namespace validation
1260} // namespace test
Michele Di Giorgio2568c6b2019-09-17 12:08:46 +01001261} // namespace arm_compute