blob: 0172ddeb76ec4b8b2cd1acc548ea2bc13beb672d [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "Globals.h"
25#include "NEON/Helper.h"
26#include "NEON/NEAccessor.h"
27#include "TensorLibrary.h"
28#include "TypePrinter.h"
29#include "Utils.h"
30#include "dataset/GEMMDataset.h"
31#include "validation/Datasets.h"
32#include "validation/Reference.h"
33#include "validation/Validation.h"
34
35#include "arm_compute/core/Helpers.h"
36#include "arm_compute/core/Types.h"
37#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
38#include "arm_compute/runtime/Tensor.h"
39#include "arm_compute/runtime/TensorAllocator.h"
40
41#include "boost_wrapper.h"
42
43#include <random>
44#include <string>
45
46using namespace arm_compute;
47using namespace arm_compute::test;
48using namespace arm_compute::test::neon;
49using namespace arm_compute::test::validation;
50
51namespace
52{
53const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
54const float tolerance_qs8 = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */
55
56Tensor compute_gemm(const TensorShape &src_shape1, const TensorShape &src_shape2, const TensorShape &src_shape3,
57 const TensorShape &out_shape, float alpha, float beta, DataType dt, int fixed_point_position = 0)
58{
59 // Create tensors
60 Tensor src1 = create_tensor(src_shape1, dt, 1, fixed_point_position);
61 Tensor src2 = create_tensor(src_shape2, dt, 1, fixed_point_position);
62 Tensor src3 = create_tensor(src_shape3, dt, 1, fixed_point_position);
63 Tensor dst = create_tensor(out_shape, dt, 1, fixed_point_position);
64
65 // Create and configure function
66 NEGEMM gemm;
67 gemm.configure(&src1, &src2, &src3, &dst, alpha, beta);
68
69 // Allocate tensors
70 src1.allocator()->allocate();
71 src2.allocator()->allocate();
72 src3.allocator()->allocate();
73 dst.allocator()->allocate();
74
75 BOOST_TEST(!src1.info()->is_resizable());
76 BOOST_TEST(!src2.info()->is_resizable());
77 BOOST_TEST(!src3.info()->is_resizable());
78 BOOST_TEST(!dst.info()->is_resizable());
79
80 // Fill tensors
81 if(dt == DataType::F32)
82 {
83 std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
84 library->fill(NEAccessor(src1), distribution, 0);
85 library->fill(NEAccessor(src2), distribution, 1);
86 library->fill(NEAccessor(src3), distribution, 2);
87 }
88 else
89 {
90 library->fill_tensor_uniform(NEAccessor(src1), 0);
91 library->fill_tensor_uniform(NEAccessor(src2), 1);
92 library->fill_tensor_uniform(NEAccessor(src3), 2);
93 }
94
95 // Compute function
96 gemm.run();
97
98 return dst;
99}
100} // namespace
101
102#ifndef DOXYGEN_SKIP_THIS
103BOOST_AUTO_TEST_SUITE(NEON)
104BOOST_AUTO_TEST_SUITE(GEMM)
105
106BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
107BOOST_DATA_TEST_CASE(Configuration,
108 SmallGEMMDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8 }),
109 gemm_set, dt)
110{
111 // Set fixed point position data type allowed
112 int fixed_point_position = (dt == DataType::F32) ? 0 : 3;
113
114 // Create tensors
115 Tensor src1 = create_tensor(gemm_set.shape_a, dt, 1, fixed_point_position);
116 Tensor src2 = create_tensor(gemm_set.shape_b, dt, 1, fixed_point_position);
117 Tensor src3 = create_tensor(gemm_set.shape_c, dt, 1, fixed_point_position);
118 Tensor dst = create_tensor(gemm_set.shape_d, dt, 1, fixed_point_position);
119
120 BOOST_TEST(src1.info()->is_resizable());
121 BOOST_TEST(src2.info()->is_resizable());
122 BOOST_TEST(src3.info()->is_resizable());
123 BOOST_TEST(dst.info()->is_resizable());
124
125 // Create and configure function
126 NEGEMM gemm;
127 gemm.configure(&src1, &src2, &src3, &dst, gemm_set.alpha, gemm_set.beta);
128
129 // Validate valid region
130 const ValidRegion src1_valid_region = shape_to_valid_region(gemm_set.shape_a);
131 const ValidRegion src2_valid_region = shape_to_valid_region(gemm_set.shape_b);
132 const ValidRegion src3_valid_region = shape_to_valid_region(gemm_set.shape_c);
133 const ValidRegion dst_valid_region = shape_to_valid_region(gemm_set.shape_d);
134
135 validate(src1.info()->valid_region(), src1_valid_region);
136 validate(src2.info()->valid_region(), src2_valid_region);
137 validate(src3.info()->valid_region(), src3_valid_region);
138 validate(dst.info()->valid_region(), dst_valid_region);
139}
140
141BOOST_AUTO_TEST_SUITE(Float)
142BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
143BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::F32),
144 gemm_set, dt)
145{
146 // Compute reference
147 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
148
149 // Compute function
150 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
151
152 // Validate output
153 validate(NEAccessor(dst), ref_dst, tolerance_f32);
154}
155
156BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly"))
157BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::make(DataType::F32),
158 gemm_set, dt)
159{
160 // Compute reference
161 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
162
163 // Compute function
164 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt);
165
166 // Validate output
167 validate(NEAccessor(dst), ref_dst, tolerance_f32);
168}
169BOOST_AUTO_TEST_SUITE_END()
170
171BOOST_AUTO_TEST_SUITE(Quantized)
172BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
173BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::QS8) * boost::unit_test::data::xrange(1, 7),
174 gemm_set, dt, fixed_point_position)
175{
176 // Compute reference
177 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
178
179 // Compute function
180 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
181
182 // Validate output
183 validate(NEAccessor(dst), ref_dst, tolerance_qs8);
184}
185
186BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly"))
187BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::make(DataType::QS8) * boost::unit_test::data::xrange(1, 7),
188 gemm_set, dt, fixed_point_position)
189{
190 // Compute reference
191 RawTensor ref_dst = Reference::compute_reference_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
192
193 // Compute function
194 Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
195
196 // Validate output
197 validate(NEAccessor(dst), ref_dst, tolerance_qs8);
198}
199BOOST_AUTO_TEST_SUITE_END()
200
201BOOST_AUTO_TEST_SUITE_END()
202BOOST_AUTO_TEST_SUITE_END()
203#endif