blob: bfa2db443ac1746fbe9850988e57864b8d018bee [file] [log] [blame]
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +01001/*
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +00002 * Copyright (c) 2017-2018 ARM Limited.
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Pablo Tello2fdc4092017-11-23 15:50:08 +000024#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
Pablo Tello088cc7f2017-12-07 15:20:55 +000025#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010026#include "arm_compute/core/Types.h"
27#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
28#include "arm_compute/runtime/Tensor.h"
29#include "arm_compute/runtime/TensorAllocator.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010030#include "tests/NEON/Accessor.h"
Pablo Tello2fdc4092017-11-23 15:50:08 +000031#include "tests/NEON/Helper.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010032#include "tests/PaddingCalculator.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010033#include "tests/datasets/LargeGEMMDataset.h"
34#include "tests/datasets/SmallGEMMDataset.h"
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +000035#include "tests/datasets/TinyGEMMDataset.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010036#include "tests/framework/Asserts.h"
37#include "tests/framework/Macros.h"
38#include "tests/framework/datasets/Datasets.h"
39#include "tests/validation/Validation.h"
40#include "tests/validation/fixtures/GEMMFixture.h"
Pablo Tello2fdc4092017-11-23 15:50:08 +000041#include "tests/validation/fixtures/GEMMInterleave4x4Fixture.h"
Pablo Tello088cc7f2017-12-07 15:20:55 +000042#include "tests/validation/fixtures/GEMMTranspose1xWFixture.h"
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010043
44namespace arm_compute
45{
46namespace test
47{
48namespace validation
49{
50namespace
51{
52constexpr AbsoluteTolerance<float> tolerance_f(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
53constexpr AbsoluteTolerance<float> tolerance_q(1.0f); /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */
54
55/** CNN data types */
56const auto CNNDataTypes = framework::dataset::make("DataType",
57{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000058#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010059 DataType::F16,
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000060#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010061 DataType::F32,
62 DataType::QS8,
63 DataType::QS16,
64});
Pablo Tello2fdc4092017-11-23 15:50:08 +000065
66const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::dataset::make("N", 8, 12);
Pablo Tello088cc7f2017-12-07 15:20:55 +000067const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
68
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +010069} // namespace
70
71TEST_SUITE(NEON)
72TEST_SUITE(GEMM)
73
Pablo Tello088cc7f2017-12-07 15:20:55 +000074TEST_SUITE(TRANSPOSE_1XW)
75using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 4>;
76using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, NEGEMMTranspose1xW, float>;
77TEST_SUITE(FP32)
78FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::F32))
79{
80 // Validate output
81 validate(Accessor(_target), _reference);
82}
83TEST_SUITE_END() // FP32
84
85TEST_SUITE(Quantized)
86TEST_SUITE(QS8)
87using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 16>;
88using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixedPointFixture<Tensor, Accessor, NEGEMMTranspose1xW, int8_t>;
89FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose *
90 framework::dataset::make("DataType", DataType::QS8)
91 * framework::dataset::make("FractionalBits", 1, 7))
92{
93 // Validate output
94 validate(Accessor(_target), _reference);
95}
96TEST_SUITE_END()
97
98TEST_SUITE(QS16)
99using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 8>;
100using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixedPointFixture<Tensor, Accessor, NEGEMMTranspose1xW, int16_t>;
101FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose *
102 framework::dataset::make("DataType", DataType::QS16)
103 * framework::dataset::make("FractionalBits", 1, 14))
104{
105 // Validate output
106 validate(Accessor(_target), _reference);
107}
108TEST_SUITE_END()
109
110TEST_SUITE_END()
111
112TEST_SUITE_END() // TRANSPOSE_1XW
113
Pablo Tello2fdc4092017-11-23 15:50:08 +0000114TEST_SUITE(INTERLEAVE_4X4)
115using NEGEMMInterleave4x4 = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMInterleave4x4Kernel, 4>;
116
117TEST_SUITE(FP32)
118using NEGEMMInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, NEGEMMInterleave4x4, float>;
119FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::F32))
120{
121 // Validate output
122 validate(Accessor(_target), _reference);
123}
124TEST_SUITE_END() // FP32
125
126TEST_SUITE(Quantized)
127TEST_SUITE(QS8)
128using NEGEMMInterleave4x4Fixture = GEMMInterleave4x4ValidationFixedPointFixture<Tensor, Accessor, NEGEMMInterleave4x4, int8_t>;
129FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave *
130 framework::dataset::make("DataType", DataType::QS8)
131 * framework::dataset::make("FractionalBits", 1, 7))
132{
133 // Validate output
134 validate(Accessor(_target), _reference);
135}
136TEST_SUITE_END()
137
138TEST_SUITE(QS16)
139using NEGEMMInterleave4x4Fixture = GEMMInterleave4x4ValidationFixedPointFixture<Tensor, Accessor, NEGEMMInterleave4x4, int16_t>;
140FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave *
141 framework::dataset::make("DataType", DataType::QS16)
142 * framework::dataset::make("FractionalBits", 1, 14))
143{
144 // Validate output
145 validate(Accessor(_target), _reference);
146}
147TEST_SUITE_END()
148
149TEST_SUITE_END()
150
151TEST_SUITE_END() // INTERLEAVE_4X4
152
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100153DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallGEMMDataset(), datasets::LargeGEMMDataset()), CNNDataTypes),
154 shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type)
155{
156 // Set fixed point position data type allowed
157 const int fixed_point_position = is_data_type_fixed_point(data_type) ? 3 : 0;
158
159 // Create tensors
160 Tensor a = create_tensor<Tensor>(shape_a, data_type, 1, fixed_point_position);
161 Tensor b = create_tensor<Tensor>(shape_b, data_type, 1, fixed_point_position);
162 Tensor c = create_tensor<Tensor>(shape_c, data_type, 1, fixed_point_position);
163 Tensor dst = create_tensor<Tensor>(output_shape, data_type, 1, fixed_point_position);
164
165 ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
166 ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
167 ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
168 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
169
170 // Create and configure function
171 NEGEMM gemm;
172 gemm.configure(&a, &b, &c, &dst, alpha, beta);
173
174 //TODO(COMPMID-415): Validate valid region
175}
176
177template <typename T>
178using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
179
180TEST_SUITE(Float)
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000181#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100182TEST_SUITE(FP16)
Georgios Pinitas583137c2017-08-31 18:12:42 +0100183FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F16)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100184{
185 // Validate output
186 validate(Accessor(_target), _reference, tolerance_f);
187}
Georgios Pinitas583137c2017-08-31 18:12:42 +0100188FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType",
189 DataType::F16)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100190{
191 // Validate output
192 validate(Accessor(_target), _reference, tolerance_f);
193}
194TEST_SUITE_END()
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000195#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100196
197TEST_SUITE(FP32)
198FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
199{
200 // Validate output
201 validate(Accessor(_target), _reference, tolerance_f);
202}
203FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
204{
205 // Validate output
206 validate(Accessor(_target), _reference, tolerance_f);
207}
208TEST_SUITE_END()
209TEST_SUITE_END()
210
211template <typename T>
212using NEGEMMFixedPointFixture = GEMMValidationFixedPointFixture<Tensor, Accessor, NEGEMM, T>;
213
214TEST_SUITE(Quantized)
215TEST_SUITE(QS8)
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000216FIXTURE_DATA_TEST_CASE(RunTiny, NEGEMMFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::TinyGEMMDataset(),
217 framework::dataset::make("DataType",
218 DataType::QS8)),
219 framework::dataset::make("FractionalBits", 1, 7)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100220{
221 // Validate output
222 validate(Accessor(_target), _reference, tolerance_q);
223}
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000224FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SmallGEMMDataset(),
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100225 framework::dataset::make("DataType",
226 DataType::QS8)),
227 framework::dataset::make("FractionalBits", 1, 7)))
228{
229 // Validate output
230 validate(Accessor(_target), _reference, tolerance_q);
231}
232TEST_SUITE_END()
233
234TEST_SUITE(QS16)
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000235FIXTURE_DATA_TEST_CASE(RunTiny, NEGEMMFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::TinyGEMMDataset(),
236 framework::dataset::make("DataType",
237 DataType::QS16)),
238 framework::dataset::make("FractionalBits", 1, 14)))
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100239{
240 // Validate output
241 validate(Accessor(_target), _reference, tolerance_q);
242}
Anthony Barbier1c0d0ff2018-01-31 13:05:09 +0000243FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::SmallGEMMDataset(),
Moritz Pflanzer4dfc2352017-08-02 14:51:36 +0100244 framework::dataset::make("DataType",
245 DataType::QS16)),
246 framework::dataset::make("FractionalBits", 1, 14)))
247{
248 // Validate output
249 validate(Accessor(_target), _reference, tolerance_q);
250}
251TEST_SUITE_END()
252TEST_SUITE_END()
253
254TEST_SUITE_END()
255TEST_SUITE_END()
256} // namespace validation
257} // namespace test
258} // namespace arm_compute