blob: 74fe30454e4f8994af18511337728484fa9a2204 [file] [log] [blame]
Matthew Benthamf1aeab92023-05-30 13:35:34 +00001/*
Radu Salavatf1f1f872024-02-27 18:32:26 +00002 * Copyright (c) 2016-2024 Arm Limited.
Matthew Benthamf1aeab92023-05-30 13:35:34 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Jakub Sujak0d27b2e2023-08-24 14:01:20 +010024#ifndef ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
25#define ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
Matthew Benthamf1aeab92023-05-30 13:35:34 +000026
SiCong Li91295492023-07-21 18:16:13 +010027#include "arm_compute/core/CoreTypes.h"
SiCong Li91295492023-07-21 18:16:13 +010028#include "arm_compute/function_info/ActivationLayerInfo.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010029
SiCong Li91295492023-07-21 18:16:13 +010030#include <vector>
Matthew Benthamf1aeab92023-05-30 13:35:34 +000031
32namespace arm_compute
33{
SiCong Li91295492023-07-21 18:16:13 +010034class ITensorInfo;
35/** GEMMLowp output stage type */
36enum class GEMMLowpOutputStageType
37{
38 NONE, /**< No quantization */
39 QUANTIZE_DOWN, /**< Quantize using an integer multiplication */
40 QUANTIZE_DOWN_FIXEDPOINT, /**< Quantize using a fixed point multiplication */
41 QUANTIZE_DOWN_FLOAT /**< Quantize using a floating point multiplication */
42};
43
44/** GEMMLowp output stage info */
45struct GEMMLowpOutputStageInfo
46{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010047 GEMMLowpOutputStageType type{GEMMLowpOutputStageType::NONE}; /**< GEMMLowp output stage type */
48 int32_t gemmlowp_offset{0}; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */
49 int32_t gemmlowp_multiplier{0}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
50 int32_t gemmlowp_shift{0}; /**< GEMMLowp output stage shift used for quantizing to uint8 */
51 int32_t gemmlowp_min_bound{
52 std::numeric_limits<int32_t>::
53 lowest()}; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */
54 int32_t gemmlowp_max_bound{
55 std::numeric_limits<int32_t>::
56 max()}; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */
57 std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
58 std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
59 float gemmlowp_real_multiplier{0}; /**< GEMMLowp output stage real multiplier used for quantizing to QASYMM8 */
60 bool is_quantized_per_channel{false}; /**< GEMMLowp quantized per-channel flag */
61 DataType output_data_type{
62 DataType::UNKNOWN}; /**< Output tensor data type to use if the output is not initialized */
SiCong Li91295492023-07-21 18:16:13 +010063};
Matthew Benthamf1aeab92023-05-30 13:35:34 +000064/** GEMM information class. This class stores the necessary information to compute GEMM functions
65 *
66 * This object also contains the information about how matrix A and matrix B have been reshaped
67 *
68 */
69class GEMMInfo
70{
71public:
72 /** Default constructor */
73 GEMMInfo() noexcept
74 : _is_a_reshaped(false),
75 _is_b_reshaped(false),
76 _reshape_b_only_on_first_run(true),
77 _depth_output_gemm3d(0),
78 _reinterpret_input_as_3d(false),
79 _retain_internal_weights(false),
80 _gemmlowp_output_stage(),
81 _fast_math(false),
82 _fp_mixed_precision(false),
83 _broadcast_bias(false),
84 _pretranspose_A(false),
85 _pretranspose_B(false),
86 _activation_info(),
Matthew Benthamf1aeab92023-05-30 13:35:34 +000087 _fixed_format(false),
Radu Salavatf1f1f872024-02-27 18:32:26 +000088 _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
89 _accumulate(false)
Matthew Benthamf1aeab92023-05-30 13:35:34 +000090 {
91 }
92 /** Constructor
93 *
94 * @param[in] is_a_reshaped True if the matrix A has been reshaped
95 * @param[in] is_b_reshaped True if the matrix B has been reshaped
96 * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
97 * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
98 * If 0 the output will not be reinterpreted as 3D. Default 0
99 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
100 * to perform 1x1 convolutions with the NHWC data layout)
101 * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run
102 * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info
103 * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
104 * @param[in] fast_math (Optional) Use a data type of shorter width to improve performance
105 * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
106 * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000107 * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
108 * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
SiCong Lic5ab4df2023-10-17 17:38:57 +0100109 * @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
Radu Salavatf1f1f872024-02-27 18:32:26 +0000110 * @param[in] accumulate (Optional) Whether to accumulate in destination or not
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000111 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100112 GEMMInfo(bool is_a_reshaped,
113 bool is_b_reshaped,
114 bool reshape_b_only_on_first_run,
115 int depth_output_gemm3d = 0,
116 bool reinterpret_input_as_3d = false,
117 bool retain_internal_weights = false,
118 GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(),
119 bool fp_mixed_precision = false,
120 bool fast_math = false,
121 bool broadcast_bias = false,
122 const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
123 bool fixed_format = false,
SiCong Lic5ab4df2023-10-17 17:38:57 +0100124 arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
Radu Salavatf1f1f872024-02-27 18:32:26 +0000125 bool pretranspose_B = false,
126 bool accumulate = false) noexcept
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000127 : _is_a_reshaped(is_a_reshaped),
128 _is_b_reshaped(is_b_reshaped),
129 _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
130 _depth_output_gemm3d(depth_output_gemm3d),
131 _reinterpret_input_as_3d(reinterpret_input_as_3d),
132 _retain_internal_weights(retain_internal_weights),
133 _gemmlowp_output_stage(gemmlowp_output_stage),
134 _fast_math(fast_math),
135 _fp_mixed_precision(fp_mixed_precision),
136 _broadcast_bias(broadcast_bias),
137 _pretranspose_A(false),
SiCong Lic5ab4df2023-10-17 17:38:57 +0100138 _pretranspose_B(pretranspose_B),
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000139 _activation_info(activation_info),
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000140 _fixed_format(fixed_format),
Radu Salavatf1f1f872024-02-27 18:32:26 +0000141 _weight_format(weight_format),
142 _accumulate(accumulate)
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000143 {
144 }
145 /** Flag which specifies if the matrix A has been reshaped
146 *
147 * @return True if the matrix A has been reshaped
148 */
149 bool is_a_reshaped() const
150 {
151 return _is_a_reshaped;
152 };
153 /** Flag which specifies if the matrix B has been reshaped
154 *
155 * @return True if the matrix B has been reshaped
156 */
157 bool is_b_reshaped() const
158 {
159 return _is_b_reshaped;
160 };
161 /** Flag which specifies if the reshape of matrix B should executed only for the first
162 *
163 * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
164 *
165 * @return True if the reshaped of matrix B happens only for the first run
166 */
167 bool reshape_b_only_on_first_run() const
168 {
169 return _reshape_b_only_on_first_run;
170 };
171 /** Depth of the output when GEMM output is reinterpreted as 3D tensor
172 *
173 * @return the depth of the output tensor
174 */
175 int depth_output_gemm3d() const
176 {
177 return _depth_output_gemm3d;
178 };
179 /** Flag which specifies if the input tensor has to be reinterpreted as 3D
180 *
181 * @return True if the input tensor has to be reinterpreted as 3D tensor
182 */
183 bool reinterpret_input_as_3d() const
184 {
185 return _reinterpret_input_as_3d;
186 };
187 /** Flag which specifies if the weights tensor has to be retained from previous run
188 *
189 * @return True if the weights tensor has to be retained
190 */
191 bool retain_internal_weights() const
192 {
193 return _retain_internal_weights;
194 };
195 /** GEMMLowp output stage
196 *
197 * @return the GEMMLowp output stage info
198 */
199 GEMMLowpOutputStageInfo gemmlowp_output_stage() const
200 {
201 return _gemmlowp_output_stage;
202 };
203 /** Sets GEMMLowp output stage
204 *
205 * @param[in] output_stage Output stage to set
206 */
207 void set_gemmlowp_output_stage(GEMMLowpOutputStageInfo &output_stage)
208 {
209 _gemmlowp_output_stage = output_stage;
210 };
211 /** Flag which specifies if a wider accumulator should be used.
212 *
213 * @return True if a wider accumulator has to be used
214 */
215 bool fp_mixed_precision() const
216 {
217 return _fp_mixed_precision;
218 };
219 /** Flag which specifies if a shorter accumulator to be used.
220 *
221 * @return True if a shorter accumulator has to be used
222 */
223 bool fast_math() const
224 {
225 return _fast_math;
226 };
227 /** Set fast math flag
228 *
229 * @param[in] fast_math Flag to set
230 */
231 void set_fast_math(bool fast_math)
232 {
233 _fast_math = fast_math;
234 }
235 /** Flag which specifies whether to broadcast the shape of the bias tensor.
236 *
237 * @return True if the shape of the bias tensor is to be broadcasted.
238 */
239 bool broadcast_bias() const
240 {
241 return _broadcast_bias;
242 };
243 /** Flag which specifies whether A should be pre-transposed if supported.
244 *
245 * @return True if A should be pre-transposed else false.
246 */
247 bool pretranspose_A() const
248 {
249 return _pretranspose_A;
250 };
251 /** Set pre-transpose A flag
252 *
253 * @param[in] flag Flag to set
254 */
255 void set_pretranspose_A(bool flag)
256 {
257 _pretranspose_A = flag;
258 }
259 /** Flag which specifies whether b should be pre-transposed if supported.
SiCong Lic5ab4df2023-10-17 17:38:57 +0100260 * More concretely, the "pre-transpose" is the transposition of the b tensor's lowest 2 dimensions
261 * If specified true, this pre-transpose will occur in addition to and before, any further transformations of the b matrix
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000262 *
263 * @return True if b should be pre-transposed else false.
264 */
265 bool pretranspose_B() const
266 {
267 return _pretranspose_B;
268 };
269 /** Set pre-transpose b flag
270 *
271 * @param[in] flag Flag to set
272 */
273 void set_pretranspose_B(bool flag)
274 {
275 _pretranspose_B = flag;
276 }
277 /** Activation layer to apply after the matrix multiplication
278 *
279 * @return ActivationLayerInfo object
280 */
281 ActivationLayerInfo activation_info() const
282 {
283 return _activation_info;
284 }
285 /** Set activation layer info
286 *
287 * @param[in] activation_info ActivationLayerInfo object to set
288 */
289 void set_activation_info(const ActivationLayerInfo &activation_info)
290 {
291 _activation_info = activation_info;
292 }
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000293 /** Flag which specifies if the GEMM operation is running fixed-format kernels.
294 *
295 * @return True if the GEMM operation is running fixed-format kernel else false.
296 */
297 bool fixed_format() const
298 {
299 return _fixed_format;
300 }
Radu Salavatf1f1f872024-02-27 18:32:26 +0000301 /** Flag which specifies if GEMM should accumulate the result in destination or not.
302 *
303 * @return True if GEMM is accumulating the result.
304 */
305 bool accumulate() const
306 {
307 return _accumulate;
308 }
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000309 /** Set fixed-format flag
310 *
311 * @param[in] fixed_format sets whether or not to use fixed-format kernels
312 */
313 void set_fixed_format(bool fixed_format)
314 {
315 _fixed_format = fixed_format;
316 }
Radu Salavatf1f1f872024-02-27 18:32:26 +0000317 /** Set accumulate flag
318 *
319 * @param[in] accumulate sets whether or not to use accumulation
320 */
321 void set_accumulate(bool accumulate)
322 {
323 _accumulate = accumulate;
324 }
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000325
326 arm_compute::WeightFormat weight_format() const
327 {
328 return _weight_format;
329 }
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000330 /** Set weight format to be used
331 *
332 * @param[in] weight_format arm_compute::WeightFormat enumeration
333 */
334 void set_weight_format(arm_compute::WeightFormat weight_format)
335 {
336 _weight_format = weight_format;
337 }
338
339private:
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100340 bool _is_a_reshaped;
341 bool _is_b_reshaped;
342 bool _reshape_b_only_on_first_run;
343 int _depth_output_gemm3d;
344 bool _reinterpret_input_as_3d;
345 bool _retain_internal_weights;
346 GEMMLowpOutputStageInfo _gemmlowp_output_stage;
347 bool _fast_math;
348 bool _fp_mixed_precision;
349 bool _broadcast_bias;
350 bool _pretranspose_A;
351 bool _pretranspose_B;
352 ActivationLayerInfo _activation_info;
353 bool _fixed_format;
354 arm_compute::WeightFormat _weight_format;
Radu Salavatf1f1f872024-02-27 18:32:26 +0000355 bool _accumulate;
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000356};
357} //namespace arm_compute
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100358#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H