blob: 217edf4438bf7f381c253a9259e0a1a7a87d063b [file] [log] [blame]
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +01001/*
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +00002 * Copyright (c) 2017-2018 ARM Limited.
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Pablo Tello1d1c0262017-12-08 16:02:38 +000024#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
Pablo Tello088cc7f2017-12-07 15:20:55 +000025#include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010026#include "arm_compute/core/Types.h"
27#include "arm_compute/runtime/CL/CLTensor.h"
28#include "arm_compute/runtime/CL/CLTensorAllocator.h"
29#include "arm_compute/runtime/CL/functions/CLGEMM.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010030#include "tests/CL/CLAccessor.h"
Pablo Tello1d1c0262017-12-08 16:02:38 +000031#include "tests/CL/Helper.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010032#include "tests/PaddingCalculator.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010033#include "tests/datasets/LargeGEMMDataset.h"
34#include "tests/datasets/SmallGEMMDataset.h"
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +000035#include "tests/datasets/TinyGEMMDataset.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010036#include "tests/framework/Asserts.h"
37#include "tests/framework/Macros.h"
38#include "tests/framework/datasets/Datasets.h"
39#include "tests/validation/Validation.h"
40#include "tests/validation/fixtures/GEMMFixture.h"
Pablo Tello1d1c0262017-12-08 16:02:38 +000041#include "tests/validation/fixtures/GEMMInterleave4x4Fixture.h"
Pablo Tello088cc7f2017-12-07 15:20:55 +000042#include "tests/validation/fixtures/GEMMTranspose1xWFixture.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010043
44namespace arm_compute
45{
46namespace test
47{
48namespace validation
49{
50namespace
51{
Michele Di Giorgioff6c2602018-02-26 15:22:16 +000052RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
53constexpr float abs_tolerance_f32(
54 0.0001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */
steniu01f81652d2017-09-11 15:29:12 +010055RelativeTolerance<half_float::half> tolerance_f16(half(0.2)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
56constexpr AbsoluteTolerance<float> tolerance_q(1.0f); /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */
Pablo Tello1d1c0262017-12-08 16:02:38 +000057constexpr float tolerance_num = 0.02f; /**< Tolerance number */
58const auto data_interleave = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010059
60/** CNN data types */
61const auto CNNDataTypes = framework::dataset::make("DataType",
62{
63 DataType::F16,
64 DataType::F32,
65 DataType::QS8,
66 DataType::QS16,
67});
68} // namespace
69
Pablo Tello088cc7f2017-12-07 15:20:55 +000070const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
71
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010072TEST_SUITE(CL)
73TEST_SUITE(GEMM)
74
Pablo Tello1d1c0262017-12-08 16:02:38 +000075TEST_SUITE(INTERLEAVE_4X4)
76using CLGEMMInterleave4x4 = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMInterleave4x4Kernel, 4>;
77
78TEST_SUITE(FP32)
79using CLGEMMInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<CLTensor, CLAccessor, CLGEMMInterleave4x4, float>;
80FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::F32))
81{
82 // Validate output
83 validate(CLAccessor(_target), _reference);
84}
85TEST_SUITE_END() // FP32
86
87TEST_SUITE(Quantized)
88TEST_SUITE(QS8)
89using CLGEMMInterleave4x4Fixture = GEMMInterleave4x4ValidationFixedPointFixture<CLTensor, CLAccessor, CLGEMMInterleave4x4, int8_t>;
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +000090FIXTURE_DATA_TEST_CASE(RunTiny, CLGEMMInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave *
Pablo Tello1d1c0262017-12-08 16:02:38 +000091 framework::dataset::make("DataType", DataType::QS8)
92 * framework::dataset::make("FractionalBits", 1, 7))
93{
94 // Validate output
95 validate(CLAccessor(_target), _reference);
96}
97TEST_SUITE_END()
98
99TEST_SUITE(QS16)
100using CLGEMMInterleave4x4Fixture = GEMMInterleave4x4ValidationFixedPointFixture<CLTensor, CLAccessor, CLGEMMInterleave4x4, int16_t>;
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000101FIXTURE_DATA_TEST_CASE(RunTiny, CLGEMMInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave *
Pablo Tello1d1c0262017-12-08 16:02:38 +0000102 framework::dataset::make("DataType", DataType::QS16)
103 * framework::dataset::make("FractionalBits", 1, 14))
104{
105 // Validate output
106 validate(CLAccessor(_target), _reference);
107}
108TEST_SUITE_END()
109
110TEST_SUITE_END()
111
112TEST_SUITE_END() // INTERLEAVE_4X4
113
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100114DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallGEMMDataset(), datasets::LargeGEMMDataset()), CNNDataTypes),
115 shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type)
116{
117 // Set fixed point position data type allowed
118 const int fixed_point_position = is_data_type_fixed_point(data_type) ? 3 : 0;
119
120 // Create tensors
121 CLTensor a = create_tensor<CLTensor>(shape_a, data_type, 1, fixed_point_position);
122 CLTensor b = create_tensor<CLTensor>(shape_b, data_type, 1, fixed_point_position);
123 CLTensor c = create_tensor<CLTensor>(shape_c, data_type, 1, fixed_point_position);
124 CLTensor dst = create_tensor<CLTensor>(output_shape, data_type, 1, fixed_point_position);
125
126 ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
127 ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
128 ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
129 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
130
131 // Create and configure function
132 CLGEMM gemm;
133 gemm.configure(&a, &b, &c, &dst, alpha, beta);
134
135 //TODO(COMPMID-415): Validate valid region
136}
137
138template <typename T>
139using CLGEMMFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T>;
140
Pablo Tello088cc7f2017-12-07 15:20:55 +0000141TEST_SUITE(TRANSPOSE_1XW)
142using CLGEMMTranspose1xW = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMTranspose1xWKernel, 4>;
143using CLGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<CLTensor, CLAccessor, CLGEMMTranspose1xW, float>;
144TEST_SUITE(FP32)
145FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::F32))
146{
147 // Validate output
148 validate(CLAccessor(_target), _reference);
149}
150TEST_SUITE_END() // FP32
151
152TEST_SUITE(Quantized)
153TEST_SUITE(QS8)
154using CLGEMMTranspose1xW = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMTranspose1xWKernel, 16>;
155using CLGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixedPointFixture<CLTensor, CLAccessor, CLGEMMTranspose1xW, int8_t>;
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000156FIXTURE_DATA_TEST_CASE(RunTiny, CLGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose *
Pablo Tello088cc7f2017-12-07 15:20:55 +0000157 framework::dataset::make("DataType", DataType::QS8)
158 * framework::dataset::make("FractionalBits", 1, 7))
159{
160 // Validate output
161 validate(CLAccessor(_target), _reference);
162}
163TEST_SUITE_END()
164
165TEST_SUITE(QS16)
166using CLGEMMTranspose1xW = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMTranspose1xWKernel, 8>;
167using CLGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixedPointFixture<CLTensor, CLAccessor, CLGEMMTranspose1xW, int16_t>;
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000168FIXTURE_DATA_TEST_CASE(RunTiny, CLGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose *
Pablo Tello088cc7f2017-12-07 15:20:55 +0000169 framework::dataset::make("DataType", DataType::QS16)
170 * framework::dataset::make("FractionalBits", 1, 14))
171{
172 // Validate output
173 validate(CLAccessor(_target), _reference);
174}
175TEST_SUITE_END()
176
177TEST_SUITE_END()
178
179TEST_SUITE_END() //TRANSPOSE_1XW
180
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100181TEST_SUITE(Float)
182TEST_SUITE(FP16)
Georgios Pinitas583137c2017-08-31 18:12:42 +0100183FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F16)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100184{
185 // Validate output
steniu01f81652d2017-09-11 15:29:12 +0100186 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100187}
Georgios Pinitas583137c2017-08-31 18:12:42 +0100188FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType",
189 DataType::F16)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100190{
191 // Validate output
steniu01f81652d2017-09-11 15:29:12 +0100192 validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100193}
194TEST_SUITE_END()
195
196TEST_SUITE(FP32)
197FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
198{
199 // Validate output
200 validate(CLAccessor(_target), _reference, tolerance_f32);
201}
202FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
203{
204 // Validate output
Michele Di Giorgioff6c2602018-02-26 15:22:16 +0000205 validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100206}
207TEST_SUITE_END()
208TEST_SUITE_END()
209
210template <typename T>
211using CLGEMMFixedPointFixture = GEMMValidationFixedPointFixture<CLTensor, CLAccessor, CLGEMM, T>;
212
213TEST_SUITE(Quantized)
214TEST_SUITE(QS8)
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000215FIXTURE_DATA_TEST_CASE(RunTiny, CLGEMMFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::TinyGEMMDataset(),
216 framework::dataset::make("DataType",
217 DataType::QS8)),
218 framework::dataset::make("FractionalBits", 1, 7)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100219{
220 // Validate output
221 validate(CLAccessor(_target), _reference, tolerance_q);
222}
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000223FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SmallGEMMDataset(),
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100224 framework::dataset::make("DataType",
225 DataType::QS8)),
226 framework::dataset::make("FractionalBits", 1, 7)))
227{
228 // Validate output
229 validate(CLAccessor(_target), _reference, tolerance_q);
230}
231TEST_SUITE_END()
232
233TEST_SUITE(QS16)
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000234FIXTURE_DATA_TEST_CASE(RunTiny, CLGEMMFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::TinyGEMMDataset(),
235 framework::dataset::make("DataType",
236 DataType::QS16)),
237 framework::dataset::make("FractionalBits", 1, 14)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100238{
239 // Validate output
240 validate(CLAccessor(_target), _reference, tolerance_q);
241}
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000242FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SmallGEMMDataset(),
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100243 framework::dataset::make("DataType",
244 DataType::QS16)),
245 framework::dataset::make("FractionalBits", 1, 14)))
246{
247 // Validate output
248 validate(CLAccessor(_target), _reference, tolerance_q);
249}
250TEST_SUITE_END()
251TEST_SUITE_END()
252
253TEST_SUITE_END()
254TEST_SUITE_END()
255} // namespace validation
256} // namespace test
257} // namespace arm_compute