blob: 0b608902a32c534cf308c0c2760e3dc4aab926b3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "Globals.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025#include "NEON/NEAccessor.h"
26#include "TensorLibrary.h"
27#include "TypePrinter.h"
28#include "Utils.h"
29#include "dataset/GEMMDataset.h"
30#include "validation/Datasets.h"
31#include "validation/Reference.h"
32#include "validation/Validation.h"
33
34#include "arm_compute/core/Helpers.h"
35#include "arm_compute/core/Types.h"
36#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
37#include "arm_compute/runtime/Tensor.h"
38#include "arm_compute/runtime/TensorAllocator.h"
39
40#include "boost_wrapper.h"
41
42#include <random>
43#include <string>
44
45using namespace arm_compute;
46using namespace arm_compute::test;
47using namespace arm_compute::test::neon;
48using namespace arm_compute::test::validation;
49
50namespace
51{
52const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +010053const float tolerance_q = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */
Anthony Barbier6ff3b192017-09-04 18:44:23 +010054
55Tensor compute_gemm(const TensorShape &src_shape1, const TensorShape &src_shape2, const TensorShape &src_shape3,
56 const TensorShape &out_shape, float alpha, float beta, DataType dt, int fixed_point_position = 0)
57{
58 // Create tensors
Moritz Pflanzer94450f12017-06-30 12:48:43 +010059 Tensor src1 = create_tensor<Tensor>(src_shape1, dt, 1, fixed_point_position);
60 Tensor src2 = create_tensor<Tensor>(src_shape2, dt, 1, fixed_point_position);
61 Tensor src3 = create_tensor<Tensor>(src_shape3, dt, 1, fixed_point_position);
62 Tensor dst = create_tensor<Tensor>(out_shape, dt, 1, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010063
64 // Create and configure function
65 NEGEMM gemm;
66 gemm.configure(&src1, &src2, &src3, &dst, alpha, beta);
67
68 // Allocate tensors
69 src1.allocator()->allocate();
70 src2.allocator()->allocate();
71 src3.allocator()->allocate();
72 dst.allocator()->allocate();
73
74 BOOST_TEST(!src1.info()->is_resizable());
75 BOOST_TEST(!src2.info()->is_resizable());
76 BOOST_TEST(!src3.info()->is_resizable());
77 BOOST_TEST(!dst.info()->is_resizable());
78
79 // Fill tensors
Pablo Tello221f3812017-06-28 17:27:56 +010080 if(dt == DataType::F16 || dt == DataType::F32)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010081 {
82 std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
83 library->fill(NEAccessor(src1), distribution, 0);
84 library->fill(NEAccessor(src2), distribution, 1);
85 library->fill(NEAccessor(src3), distribution, 2);
86 }
87 else
88 {
89 library->fill_tensor_uniform(NEAccessor(src1), 0);
90 library->fill_tensor_uniform(NEAccessor(src2), 1);
91 library->fill_tensor_uniform(NEAccessor(src3), 2);
92 }
93
94 // Compute function
95 gemm.run();
96
97 return dst;
98}
99} // namespace
100
101#ifndef DOXYGEN_SKIP_THIS
102BOOST_AUTO_TEST_SUITE(NEON)
103BOOST_AUTO_TEST_SUITE(GEMM)
104
105BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
106BOOST_DATA_TEST_CASE(Configuration,
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100107 SmallGEMMDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8, DataType::QS16 }),
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100108 gemm_set, dt)
109{
110 // Set fixed point position data type allowed
111 int fixed_point_position = (dt == DataType::F32) ? 0 : 3;
112
113 // Create tensors
Moritz Pflanzer94450f12017-06-30 12:48:43 +0100114 Tensor src1 = create_tensor<Tensor>(gemm_set.shape_a, dt, 1, fixed_point_position);
115 Tensor src2 = create_tensor<Tensor>(gemm_set.shape_b, dt, 1, fixed_point_position);
116 Tensor src3 = create_tensor<Tensor>(gemm_set.shape_c, dt, 1, fixed_point_position);
117 Tensor dst = create_tensor<Tensor>(gemm_set.shape_d, dt, 1, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118
119 BOOST_TEST(src1.info()->is_resizable());
120 BOOST_TEST(src2.info()->is_resizable());
121 BOOST_TEST(src3.info()->is_resizable());
122 BOOST_TEST(dst.info()->is_resizable());
123
124 // Create and configure function
125 NEGEMM gemm;
126 gemm.configure(&src1, &src2, &src3, &dst, gemm_set.alpha, gemm_set.beta);
127
128 // Validate valid region
129 const ValidRegion src1_valid_region = shape_to_valid_region(gemm_set.shape_a);
130 const ValidRegion src2_valid_region = shape_to_valid_region(gemm_set.shape_b);
131 const ValidRegion src3_valid_region = shape_to_valid_region(gemm_set.shape_c);
132 const ValidRegion dst_valid_region = shape_to_valid_region(gemm_set.shape_d);
133
134 validate(src1.info()->valid_region(), src1_valid_region);
135 validate(src2.info()->valid_region(), src2_valid_region);
136 validate(src3.info()->valid_region(), src3_valid_region);
137 validate(dst.info()->valid_region(), dst_valid_region);
138}
139
Pablo Tello221f3812017-06-28 17:27:56 +0100140#ifdef ARM_COMPUTE_ENABLE_FP16
141BOOST_AUTO_TEST_SUITE(Float16)
142BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
143BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::F16),
144 gemm_set, dt)
145{
146 // Compute reference
147 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
148
149 // Compute function
150 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
151
152 // Validate output
153 validate(NEAccessor(dst), ref_dst, tolerance_f32);
154}
155BOOST_AUTO_TEST_SUITE_END()
156#endif /* ARM_COMPUTE_ENABLE_FP16 */
157
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100158BOOST_AUTO_TEST_SUITE(Float)
159BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
160BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::F32),
161 gemm_set, dt)
162{
163 // Compute reference
164 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
165
166 // Compute function
167 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
168
169 // Validate output
170 validate(NEAccessor(dst), ref_dst, tolerance_f32);
171}
172
173BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly"))
174BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::make(DataType::F32),
175 gemm_set, dt)
176{
177 // Compute reference
178 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
179
180 // Compute function
181 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
182
183 // Validate output
184 validate(NEAccessor(dst), ref_dst, tolerance_f32);
185}
186BOOST_AUTO_TEST_SUITE_END()
187
188BOOST_AUTO_TEST_SUITE(Quantized)
189BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100190BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(1, 7),
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100191 gemm_set, dt, fixed_point_position)
192{
193 // Compute reference
194 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
195
196 // Compute function
197 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
198
199 // Validate output
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100200 validate(NEAccessor(dst), ref_dst, tolerance_q);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201}
202
203BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly"))
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100204BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(1, 7),
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100205 gemm_set, dt, fixed_point_position)
206{
207 // Compute reference
208 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
209
210 // Compute function
211 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
212
213 // Validate output
Gian Marco Iodicebdb6b0b2017-06-30 12:21:00 +0100214 validate(NEAccessor(dst), ref_dst, tolerance_q);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100215}
216BOOST_AUTO_TEST_SUITE_END()
217
218BOOST_AUTO_TEST_SUITE_END()
219BOOST_AUTO_TEST_SUITE_END()
Anthony Barbierac69aa12017-07-03 17:39:37 +0100220#endif /* DOXYGEN_SKIP_THIS */