blob: a827c79fda13ad6ebc8fc9cb9e3d11eaa0cf7e9f [file] [log] [blame]
Matthew Benthamf1aeab92023-05-30 13:35:34 +00001/*
2 * Copyright (c) 2016-2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Jakub Sujak0d27b2e2023-08-24 14:01:20 +010024#ifndef ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
25#define ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
Matthew Benthamf1aeab92023-05-30 13:35:34 +000026
SiCong Li91295492023-07-21 18:16:13 +010027#include "arm_compute/core/CoreTypes.h"
SiCong Li91295492023-07-21 18:16:13 +010028#include "arm_compute/function_info/ActivationLayerInfo.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010029
SiCong Li91295492023-07-21 18:16:13 +010030#include <vector>
Matthew Benthamf1aeab92023-05-30 13:35:34 +000031
32namespace arm_compute
33{
SiCong Li91295492023-07-21 18:16:13 +010034class ITensorInfo;
35/** GEMMLowp output stage type */
36enum class GEMMLowpOutputStageType
37{
38 NONE, /**< No quantization */
39 QUANTIZE_DOWN, /**< Quantize using an integer multiplication */
40 QUANTIZE_DOWN_FIXEDPOINT, /**< Quantize using a fixed point multiplication */
41 QUANTIZE_DOWN_FLOAT /**< Quantize using a floating point multiplication */
42};
43
44/** GEMMLowp output stage info */
45struct GEMMLowpOutputStageInfo
46{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010047 GEMMLowpOutputStageType type{GEMMLowpOutputStageType::NONE}; /**< GEMMLowp output stage type */
48 int32_t gemmlowp_offset{0}; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */
49 int32_t gemmlowp_multiplier{0}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
50 int32_t gemmlowp_shift{0}; /**< GEMMLowp output stage shift used for quantizing to uint8 */
51 int32_t gemmlowp_min_bound{
52 std::numeric_limits<int32_t>::
53 lowest()}; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */
54 int32_t gemmlowp_max_bound{
55 std::numeric_limits<int32_t>::
56 max()}; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */
57 std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
58 std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
59 float gemmlowp_real_multiplier{0}; /**< GEMMLowp output stage real multiplier used for quantizing to QASYMM8 */
60 bool is_quantized_per_channel{false}; /**< GEMMLowp quantized per-channel flag */
61 DataType output_data_type{
62 DataType::UNKNOWN}; /**< Output tensor data type to use if the output is not initialized */
SiCong Li91295492023-07-21 18:16:13 +010063};
Matthew Benthamf1aeab92023-05-30 13:35:34 +000064/** GEMM information class. This class stores the necessary information to compute GEMM functions
65 *
66 * This object also contains the information about how matrix A and matrix B have been reshaped
67 *
68 */
69class GEMMInfo
70{
71public:
72 /** Default constructor */
73 GEMMInfo() noexcept
74 : _is_a_reshaped(false),
75 _is_b_reshaped(false),
76 _reshape_b_only_on_first_run(true),
77 _depth_output_gemm3d(0),
78 _reinterpret_input_as_3d(false),
79 _retain_internal_weights(false),
80 _gemmlowp_output_stage(),
81 _fast_math(false),
82 _fp_mixed_precision(false),
83 _broadcast_bias(false),
84 _pretranspose_A(false),
85 _pretranspose_B(false),
86 _activation_info(),
Matthew Benthamf1aeab92023-05-30 13:35:34 +000087 _fixed_format(false),
88 _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
89 {
90 }
91 /** Constructor
92 *
93 * @param[in] is_a_reshaped True if the matrix A has been reshaped
94 * @param[in] is_b_reshaped True if the matrix B has been reshaped
95 * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
96 * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
97 * If 0 the output will not be reinterpreted as 3D. Default 0
98 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
99 * to perform 1x1 convolutions with the NHWC data layout)
100 * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run
101 * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info
102 * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
103 * @param[in] fast_math (Optional) Use a data type of shorter width to improve performance
104 * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
105 * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000106 * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
107 * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
SiCong Lic5ab4df2023-10-17 17:38:57 +0100108 * @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000109 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100110 GEMMInfo(bool is_a_reshaped,
111 bool is_b_reshaped,
112 bool reshape_b_only_on_first_run,
113 int depth_output_gemm3d = 0,
114 bool reinterpret_input_as_3d = false,
115 bool retain_internal_weights = false,
116 GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(),
117 bool fp_mixed_precision = false,
118 bool fast_math = false,
119 bool broadcast_bias = false,
120 const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
121 bool fixed_format = false,
SiCong Lic5ab4df2023-10-17 17:38:57 +0100122 arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
123 bool pretranspose_B = false) noexcept
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000124 : _is_a_reshaped(is_a_reshaped),
125 _is_b_reshaped(is_b_reshaped),
126 _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
127 _depth_output_gemm3d(depth_output_gemm3d),
128 _reinterpret_input_as_3d(reinterpret_input_as_3d),
129 _retain_internal_weights(retain_internal_weights),
130 _gemmlowp_output_stage(gemmlowp_output_stage),
131 _fast_math(fast_math),
132 _fp_mixed_precision(fp_mixed_precision),
133 _broadcast_bias(broadcast_bias),
134 _pretranspose_A(false),
SiCong Lic5ab4df2023-10-17 17:38:57 +0100135 _pretranspose_B(pretranspose_B),
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000136 _activation_info(activation_info),
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000137 _fixed_format(fixed_format),
138 _weight_format(weight_format)
139 {
140 }
141 /** Flag which specifies if the matrix A has been reshaped
142 *
143 * @return True if the matrix A has been reshaped
144 */
145 bool is_a_reshaped() const
146 {
147 return _is_a_reshaped;
148 };
149 /** Flag which specifies if the matrix B has been reshaped
150 *
151 * @return True if the matrix B has been reshaped
152 */
153 bool is_b_reshaped() const
154 {
155 return _is_b_reshaped;
156 };
157 /** Flag which specifies if the reshape of matrix B should executed only for the first
158 *
159 * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
160 *
161 * @return True if the reshaped of matrix B happens only for the first run
162 */
163 bool reshape_b_only_on_first_run() const
164 {
165 return _reshape_b_only_on_first_run;
166 };
167 /** Depth of the output when GEMM output is reinterpreted as 3D tensor
168 *
169 * @return the depth of the output tensor
170 */
171 int depth_output_gemm3d() const
172 {
173 return _depth_output_gemm3d;
174 };
175 /** Flag which specifies if the input tensor has to be reinterpreted as 3D
176 *
177 * @return True if the input tensor has to be reinterpreted as 3D tensor
178 */
179 bool reinterpret_input_as_3d() const
180 {
181 return _reinterpret_input_as_3d;
182 };
183 /** Flag which specifies if the weights tensor has to be retained from previous run
184 *
185 * @return True if the weights tensor has to be retained
186 */
187 bool retain_internal_weights() const
188 {
189 return _retain_internal_weights;
190 };
191 /** GEMMLowp output stage
192 *
193 * @return the GEMMLowp output stage info
194 */
195 GEMMLowpOutputStageInfo gemmlowp_output_stage() const
196 {
197 return _gemmlowp_output_stage;
198 };
199 /** Sets GEMMLowp output stage
200 *
201 * @param[in] output_stage Output stage to set
202 */
203 void set_gemmlowp_output_stage(GEMMLowpOutputStageInfo &output_stage)
204 {
205 _gemmlowp_output_stage = output_stage;
206 };
207 /** Flag which specifies if a wider accumulator should be used.
208 *
209 * @return True if a wider accumulator has to be used
210 */
211 bool fp_mixed_precision() const
212 {
213 return _fp_mixed_precision;
214 };
215 /** Flag which specifies if a shorter accumulator to be used.
216 *
217 * @return True if a shorter accumulator has to be used
218 */
219 bool fast_math() const
220 {
221 return _fast_math;
222 };
223 /** Set fast math flag
224 *
225 * @param[in] fast_math Flag to set
226 */
227 void set_fast_math(bool fast_math)
228 {
229 _fast_math = fast_math;
230 }
231 /** Flag which specifies whether to broadcast the shape of the bias tensor.
232 *
233 * @return True if the shape of the bias tensor is to be broadcasted.
234 */
235 bool broadcast_bias() const
236 {
237 return _broadcast_bias;
238 };
239 /** Flag which specifies whether A should be pre-transposed if supported.
240 *
241 * @return True if A should be pre-transposed else false.
242 */
243 bool pretranspose_A() const
244 {
245 return _pretranspose_A;
246 };
247 /** Set pre-transpose A flag
248 *
249 * @param[in] flag Flag to set
250 */
251 void set_pretranspose_A(bool flag)
252 {
253 _pretranspose_A = flag;
254 }
255 /** Flag which specifies whether b should be pre-transposed if supported.
SiCong Lic5ab4df2023-10-17 17:38:57 +0100256 * More concretely, the "pre-transpose" is the transposition of the b tensor's lowest 2 dimensions
257 * If specified true, this pre-transpose will occur in addition to and before, any further transformations of the b matrix
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000258 *
259 * @return True if b should be pre-transposed else false.
260 */
261 bool pretranspose_B() const
262 {
263 return _pretranspose_B;
264 };
265 /** Set pre-transpose b flag
266 *
267 * @param[in] flag Flag to set
268 */
269 void set_pretranspose_B(bool flag)
270 {
271 _pretranspose_B = flag;
272 }
273 /** Activation layer to apply after the matrix multiplication
274 *
275 * @return ActivationLayerInfo object
276 */
277 ActivationLayerInfo activation_info() const
278 {
279 return _activation_info;
280 }
281 /** Set activation layer info
282 *
283 * @param[in] activation_info ActivationLayerInfo object to set
284 */
285 void set_activation_info(const ActivationLayerInfo &activation_info)
286 {
287 _activation_info = activation_info;
288 }
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000289 /** Flag which specifies if the GEMM operation is running fixed-format kernels.
290 *
291 * @return True if the GEMM operation is running fixed-format kernel else false.
292 */
293 bool fixed_format() const
294 {
295 return _fixed_format;
296 }
297
298 /** Set fixed-format flag
299 *
300 * @param[in] fixed_format sets whether or not to use fixed-format kernels
301 */
302 void set_fixed_format(bool fixed_format)
303 {
304 _fixed_format = fixed_format;
305 }
306
307 arm_compute::WeightFormat weight_format() const
308 {
309 return _weight_format;
310 }
311
312 /** Set weight format to be used
313 *
314 * @param[in] weight_format arm_compute::WeightFormat enumeration
315 */
316 void set_weight_format(arm_compute::WeightFormat weight_format)
317 {
318 _weight_format = weight_format;
319 }
320
321private:
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100322 bool _is_a_reshaped;
323 bool _is_b_reshaped;
324 bool _reshape_b_only_on_first_run;
325 int _depth_output_gemm3d;
326 bool _reinterpret_input_as_3d;
327 bool _retain_internal_weights;
328 GEMMLowpOutputStageInfo _gemmlowp_output_stage;
329 bool _fast_math;
330 bool _fp_mixed_precision;
331 bool _broadcast_bias;
332 bool _pretranspose_A;
333 bool _pretranspose_B;
334 ActivationLayerInfo _activation_info;
335 bool _fixed_format;
336 arm_compute::WeightFormat _weight_format;
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000337};
338} //namespace arm_compute
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100339#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H