blob: 29a57a00c28221cc3d79921da2954f369d7e555b [file] [log] [blame]
Matthew Benthamf1aeab92023-05-30 13:35:34 +00001/*
2 * Copyright (c) 2016-2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Jakub Sujak0d27b2e2023-08-24 14:01:20 +010024#ifndef ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
25#define ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
Matthew Benthamf1aeab92023-05-30 13:35:34 +000026
SiCong Li91295492023-07-21 18:16:13 +010027#include "arm_compute/core/CoreTypes.h"
SiCong Li91295492023-07-21 18:16:13 +010028#include "arm_compute/function_info/ActivationLayerInfo.h"
29#include <vector>
Matthew Benthamf1aeab92023-05-30 13:35:34 +000030
31namespace arm_compute
32{
SiCong Li91295492023-07-21 18:16:13 +010033class ITensorInfo;
34/** GEMMLowp output stage type */
35enum class GEMMLowpOutputStageType
36{
37 NONE, /**< No quantization */
38 QUANTIZE_DOWN, /**< Quantize using an integer multiplication */
39 QUANTIZE_DOWN_FIXEDPOINT, /**< Quantize using a fixed point multiplication */
40 QUANTIZE_DOWN_FLOAT /**< Quantize using a floating point multiplication */
41};
42
43/** GEMMLowp output stage info */
44struct GEMMLowpOutputStageInfo
45{
46 GEMMLowpOutputStageType type{ GEMMLowpOutputStageType::NONE }; /**< GEMMLowp output stage type */
47 int32_t gemmlowp_offset{ 0 }; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */
48 int32_t gemmlowp_multiplier{ 0 }; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
49 int32_t gemmlowp_shift{ 0 }; /**< GEMMLowp output stage shift used for quantizing to uint8 */
50 int32_t gemmlowp_min_bound{ std::numeric_limits<int32_t>::lowest() }; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */
51 int32_t gemmlowp_max_bound{ std::numeric_limits<int32_t>::max() }; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */
52 std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
53 std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
54 float gemmlowp_real_multiplier{ 0 }; /**< GEMMLowp output stage real multiplier used for quantizing to QASYMM8 */
55 bool is_quantized_per_channel{ false }; /**< GEMMLowp quantized per-channel flag */
56 DataType output_data_type{ DataType::UNKNOWN }; /**< Output tensor data type to use if the output is not initialized */
57};
Matthew Benthamf1aeab92023-05-30 13:35:34 +000058/** GEMM information class. This class stores the necessary information to compute GEMM functions
59 *
60 * This object also contains the information about how matrix A and matrix B have been reshaped
61 *
62 */
63class GEMMInfo
64{
65public:
66 /** Default constructor */
67 GEMMInfo() noexcept
68 : _is_a_reshaped(false),
69 _is_b_reshaped(false),
70 _reshape_b_only_on_first_run(true),
71 _depth_output_gemm3d(0),
72 _reinterpret_input_as_3d(false),
73 _retain_internal_weights(false),
74 _gemmlowp_output_stage(),
75 _fast_math(false),
76 _fp_mixed_precision(false),
77 _broadcast_bias(false),
78 _pretranspose_A(false),
79 _pretranspose_B(false),
80 _activation_info(),
Matthew Benthamf1aeab92023-05-30 13:35:34 +000081 _fixed_format(false),
82 _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
83 {
84 }
85 /** Constructor
86 *
87 * @param[in] is_a_reshaped True if the matrix A has been reshaped
88 * @param[in] is_b_reshaped True if the matrix B has been reshaped
89 * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
90 * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
91 * If 0 the output will not be reinterpreted as 3D. Default 0
92 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
93 * to perform 1x1 convolutions with the NHWC data layout)
94 * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run
95 * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info
96 * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
97 * @param[in] fast_math (Optional) Use a data type of shorter width to improve performance
98 * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
99 * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000100 * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
101 * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
102 */
103 GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
104 GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100105 const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000106 : _is_a_reshaped(is_a_reshaped),
107 _is_b_reshaped(is_b_reshaped),
108 _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
109 _depth_output_gemm3d(depth_output_gemm3d),
110 _reinterpret_input_as_3d(reinterpret_input_as_3d),
111 _retain_internal_weights(retain_internal_weights),
112 _gemmlowp_output_stage(gemmlowp_output_stage),
113 _fast_math(fast_math),
114 _fp_mixed_precision(fp_mixed_precision),
115 _broadcast_bias(broadcast_bias),
116 _pretranspose_A(false),
117 _pretranspose_B(false),
118 _activation_info(activation_info),
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000119 _fixed_format(fixed_format),
120 _weight_format(weight_format)
121 {
122 }
123 /** Flag which specifies if the matrix A has been reshaped
124 *
125 * @return True if the matrix A has been reshaped
126 */
127 bool is_a_reshaped() const
128 {
129 return _is_a_reshaped;
130 };
131 /** Flag which specifies if the matrix B has been reshaped
132 *
133 * @return True if the matrix B has been reshaped
134 */
135 bool is_b_reshaped() const
136 {
137 return _is_b_reshaped;
138 };
139 /** Flag which specifies if the reshape of matrix B should executed only for the first
140 *
141 * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
142 *
143 * @return True if the reshaped of matrix B happens only for the first run
144 */
145 bool reshape_b_only_on_first_run() const
146 {
147 return _reshape_b_only_on_first_run;
148 };
149 /** Depth of the output when GEMM output is reinterpreted as 3D tensor
150 *
151 * @return the depth of the output tensor
152 */
153 int depth_output_gemm3d() const
154 {
155 return _depth_output_gemm3d;
156 };
157 /** Flag which specifies if the input tensor has to be reinterpreted as 3D
158 *
159 * @return True if the input tensor has to be reinterpreted as 3D tensor
160 */
161 bool reinterpret_input_as_3d() const
162 {
163 return _reinterpret_input_as_3d;
164 };
165 /** Flag which specifies if the weights tensor has to be retained from previous run
166 *
167 * @return True if the weights tensor has to be retained
168 */
169 bool retain_internal_weights() const
170 {
171 return _retain_internal_weights;
172 };
173 /** GEMMLowp output stage
174 *
175 * @return the GEMMLowp output stage info
176 */
177 GEMMLowpOutputStageInfo gemmlowp_output_stage() const
178 {
179 return _gemmlowp_output_stage;
180 };
181 /** Sets GEMMLowp output stage
182 *
183 * @param[in] output_stage Output stage to set
184 */
185 void set_gemmlowp_output_stage(GEMMLowpOutputStageInfo &output_stage)
186 {
187 _gemmlowp_output_stage = output_stage;
188 };
189 /** Flag which specifies if a wider accumulator should be used.
190 *
191 * @return True if a wider accumulator has to be used
192 */
193 bool fp_mixed_precision() const
194 {
195 return _fp_mixed_precision;
196 };
197 /** Flag which specifies if a shorter accumulator to be used.
198 *
199 * @return True if a shorter accumulator has to be used
200 */
201 bool fast_math() const
202 {
203 return _fast_math;
204 };
205 /** Set fast math flag
206 *
207 * @param[in] fast_math Flag to set
208 */
209 void set_fast_math(bool fast_math)
210 {
211 _fast_math = fast_math;
212 }
213 /** Flag which specifies whether to broadcast the shape of the bias tensor.
214 *
215 * @return True if the shape of the bias tensor is to be broadcasted.
216 */
217 bool broadcast_bias() const
218 {
219 return _broadcast_bias;
220 };
221 /** Flag which specifies whether A should be pre-transposed if supported.
222 *
223 * @return True if A should be pre-transposed else false.
224 */
225 bool pretranspose_A() const
226 {
227 return _pretranspose_A;
228 };
229 /** Set pre-transpose A flag
230 *
231 * @param[in] flag Flag to set
232 */
233 void set_pretranspose_A(bool flag)
234 {
235 _pretranspose_A = flag;
236 }
237 /** Flag which specifies whether b should be pre-transposed if supported.
238 *
239 * @return True if b should be pre-transposed else false.
240 */
241 bool pretranspose_B() const
242 {
243 return _pretranspose_B;
244 };
245 /** Set pre-transpose b flag
246 *
247 * @param[in] flag Flag to set
248 */
249 void set_pretranspose_B(bool flag)
250 {
251 _pretranspose_B = flag;
252 }
253 /** Activation layer to apply after the matrix multiplication
254 *
255 * @return ActivationLayerInfo object
256 */
257 ActivationLayerInfo activation_info() const
258 {
259 return _activation_info;
260 }
261 /** Set activation layer info
262 *
263 * @param[in] activation_info ActivationLayerInfo object to set
264 */
265 void set_activation_info(const ActivationLayerInfo &activation_info)
266 {
267 _activation_info = activation_info;
268 }
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000269 /** Flag which specifies if the GEMM operation is running fixed-format kernels.
270 *
271 * @return True if the GEMM operation is running fixed-format kernel else false.
272 */
273 bool fixed_format() const
274 {
275 return _fixed_format;
276 }
277
278 /** Set fixed-format flag
279 *
280 * @param[in] fixed_format sets whether or not to use fixed-format kernels
281 */
282 void set_fixed_format(bool fixed_format)
283 {
284 _fixed_format = fixed_format;
285 }
286
287 arm_compute::WeightFormat weight_format() const
288 {
289 return _weight_format;
290 }
291
292 /** Set weight format to be used
293 *
294 * @param[in] weight_format arm_compute::WeightFormat enumeration
295 */
296 void set_weight_format(arm_compute::WeightFormat weight_format)
297 {
298 _weight_format = weight_format;
299 }
300
301private:
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100302 bool _is_a_reshaped;
303 bool _is_b_reshaped;
304 bool _reshape_b_only_on_first_run;
305 int _depth_output_gemm3d;
306 bool _reinterpret_input_as_3d;
307 bool _retain_internal_weights;
308 GEMMLowpOutputStageInfo _gemmlowp_output_stage;
309 bool _fast_math;
310 bool _fp_mixed_precision;
311 bool _broadcast_bias;
312 bool _pretranspose_A;
313 bool _pretranspose_B;
314 ActivationLayerInfo _activation_info;
315 bool _fixed_format;
316 arm_compute::WeightFormat _weight_format;
Matthew Benthamf1aeab92023-05-30 13:35:34 +0000317};
318} //namespace arm_compute
Jakub Sujak0d27b2e2023-08-24 14:01:20 +0100319#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H