blob: 4c8e94a3156532ae4ea8e22363617f3a7a09bcb9 [file] [log] [blame]
Matthew Benthamf1aeab92023-05-30 13:35:34 +00001/*
2 * Copyright (c) 2016-2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_GEMMINFO_H
25#define ARM_COMPUTE_GEMMINFO_H
26
27#include "arm_compute/core/ActivationLayerInfo.h"
28#include "arm_compute/core/Types.h"
29
30namespace arm_compute
31{
32/** GEMM information class. This class stores the necessary information to compute GEMM functions
33 *
34 * This object also contains the information about how matrix A and matrix B have been reshaped
35 *
36 */
37class GEMMInfo
38{
39public:
40 /** Default constructor */
41 GEMMInfo() noexcept
42 : _is_a_reshaped(false),
43 _is_b_reshaped(false),
44 _reshape_b_only_on_first_run(true),
45 _depth_output_gemm3d(0),
46 _reinterpret_input_as_3d(false),
47 _retain_internal_weights(false),
48 _gemmlowp_output_stage(),
49 _fast_math(false),
50 _fp_mixed_precision(false),
51 _broadcast_bias(false),
52 _pretranspose_A(false),
53 _pretranspose_B(false),
54 _activation_info(),
55 _post_ops(),
56 _fixed_format(false),
57 _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
58 {
59 }
60 /** Constructor
61 *
62 * @param[in] is_a_reshaped True if the matrix A has been reshaped
63 * @param[in] is_b_reshaped True if the matrix B has been reshaped
64 * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
65 * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
66 * If 0 the output will not be reinterpreted as 3D. Default 0
67 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
68 * to perform 1x1 convolutions with the NHWC data layout)
69 * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run
70 * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info
71 * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
72 * @param[in] fast_math (Optional) Use a data type of shorter width to improve performance
73 * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
74 * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
75 * @param[in] post_ops (Optional) A sequence of post operations that are performed after the main operation.
76 * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
77 * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
78 */
79 GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
80 GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
81 const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *>(),
82 bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
83 : _is_a_reshaped(is_a_reshaped),
84 _is_b_reshaped(is_b_reshaped),
85 _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
86 _depth_output_gemm3d(depth_output_gemm3d),
87 _reinterpret_input_as_3d(reinterpret_input_as_3d),
88 _retain_internal_weights(retain_internal_weights),
89 _gemmlowp_output_stage(gemmlowp_output_stage),
90 _fast_math(fast_math),
91 _fp_mixed_precision(fp_mixed_precision),
92 _broadcast_bias(broadcast_bias),
93 _pretranspose_A(false),
94 _pretranspose_B(false),
95 _activation_info(activation_info),
96 _post_ops(post_ops),
97 _fixed_format(fixed_format),
98 _weight_format(weight_format)
99 {
100 }
101 /** Flag which specifies if the matrix A has been reshaped
102 *
103 * @return True if the matrix A has been reshaped
104 */
105 bool is_a_reshaped() const
106 {
107 return _is_a_reshaped;
108 };
109 /** Flag which specifies if the matrix B has been reshaped
110 *
111 * @return True if the matrix B has been reshaped
112 */
113 bool is_b_reshaped() const
114 {
115 return _is_b_reshaped;
116 };
117 /** Flag which specifies if the reshape of matrix B should executed only for the first
118 *
119 * @note This flag could be set to TRUE when GEMM is used to accelerate convolution layer
120 *
121 * @return True if the reshaped of matrix B happens only for the first run
122 */
123 bool reshape_b_only_on_first_run() const
124 {
125 return _reshape_b_only_on_first_run;
126 };
127 /** Depth of the output when GEMM output is reinterpreted as 3D tensor
128 *
129 * @return the depth of the output tensor
130 */
131 int depth_output_gemm3d() const
132 {
133 return _depth_output_gemm3d;
134 };
135 /** Flag which specifies if the input tensor has to be reinterpreted as 3D
136 *
137 * @return True if the input tensor has to be reinterpreted as 3D tensor
138 */
139 bool reinterpret_input_as_3d() const
140 {
141 return _reinterpret_input_as_3d;
142 };
143 /** Flag which specifies if the weights tensor has to be retained from previous run
144 *
145 * @return True if the weights tensor has to be retained
146 */
147 bool retain_internal_weights() const
148 {
149 return _retain_internal_weights;
150 };
151 /** GEMMLowp output stage
152 *
153 * @return the GEMMLowp output stage info
154 */
155 GEMMLowpOutputStageInfo gemmlowp_output_stage() const
156 {
157 return _gemmlowp_output_stage;
158 };
159 /** Sets GEMMLowp output stage
160 *
161 * @param[in] output_stage Output stage to set
162 */
163 void set_gemmlowp_output_stage(GEMMLowpOutputStageInfo &output_stage)
164 {
165 _gemmlowp_output_stage = output_stage;
166 };
167 /** Flag which specifies if a wider accumulator should be used.
168 *
169 * @return True if a wider accumulator has to be used
170 */
171 bool fp_mixed_precision() const
172 {
173 return _fp_mixed_precision;
174 };
175 /** Flag which specifies if a shorter accumulator to be used.
176 *
177 * @return True if a shorter accumulator has to be used
178 */
179 bool fast_math() const
180 {
181 return _fast_math;
182 };
183 /** Set fast math flag
184 *
185 * @param[in] fast_math Flag to set
186 */
187 void set_fast_math(bool fast_math)
188 {
189 _fast_math = fast_math;
190 }
191 /** Flag which specifies whether to broadcast the shape of the bias tensor.
192 *
193 * @return True if the shape of the bias tensor is to be broadcasted.
194 */
195 bool broadcast_bias() const
196 {
197 return _broadcast_bias;
198 };
199 /** Flag which specifies whether A should be pre-transposed if supported.
200 *
201 * @return True if A should be pre-transposed else false.
202 */
203 bool pretranspose_A() const
204 {
205 return _pretranspose_A;
206 };
207 /** Set pre-transpose A flag
208 *
209 * @param[in] flag Flag to set
210 */
211 void set_pretranspose_A(bool flag)
212 {
213 _pretranspose_A = flag;
214 }
215 /** Flag which specifies whether b should be pre-transposed if supported.
216 *
217 * @return True if b should be pre-transposed else false.
218 */
219 bool pretranspose_B() const
220 {
221 return _pretranspose_B;
222 };
223 /** Set pre-transpose b flag
224 *
225 * @param[in] flag Flag to set
226 */
227 void set_pretranspose_B(bool flag)
228 {
229 _pretranspose_B = flag;
230 }
231 /** Activation layer to apply after the matrix multiplication
232 *
233 * @return ActivationLayerInfo object
234 */
235 ActivationLayerInfo activation_info() const
236 {
237 return _activation_info;
238 }
239 /** Set activation layer info
240 *
241 * @param[in] activation_info ActivationLayerInfo object to set
242 */
243 void set_activation_info(const ActivationLayerInfo &activation_info)
244 {
245 _activation_info = activation_info;
246 }
247 /** Post operations to apply after the matrix multiplication
248 *
249 * @return experimental::PostOpList object
250 */
251 const experimental::PostOpList<ITensorInfo *> &post_ops() const
252 {
253 return _post_ops;
254 }
255 /** Set post ops
256 *
257 * @param[in] post_ops experimental::PostOpList object to set
258 */
259 void set_post_ops(const experimental::PostOpList<ITensorInfo *> &post_ops)
260 {
261 _post_ops = post_ops;
262 }
263 /** Flag which specifies if the GEMM operation is running fixed-format kernels.
264 *
265 * @return True if the GEMM operation is running fixed-format kernel else false.
266 */
267 bool fixed_format() const
268 {
269 return _fixed_format;
270 }
271
272 /** Set fixed-format flag
273 *
274 * @param[in] fixed_format sets whether or not to use fixed-format kernels
275 */
276 void set_fixed_format(bool fixed_format)
277 {
278 _fixed_format = fixed_format;
279 }
280
281 arm_compute::WeightFormat weight_format() const
282 {
283 return _weight_format;
284 }
285
286 /** Set weight format to be used
287 *
288 * @param[in] weight_format arm_compute::WeightFormat enumeration
289 */
290 void set_weight_format(arm_compute::WeightFormat weight_format)
291 {
292 _weight_format = weight_format;
293 }
294
295private:
296 bool _is_a_reshaped;
297 bool _is_b_reshaped;
298 bool _reshape_b_only_on_first_run;
299 int _depth_output_gemm3d;
300 bool _reinterpret_input_as_3d;
301 bool _retain_internal_weights;
302 GEMMLowpOutputStageInfo _gemmlowp_output_stage;
303 bool _fast_math;
304 bool _fp_mixed_precision;
305 bool _broadcast_bias;
306 bool _pretranspose_A;
307 bool _pretranspose_B;
308 ActivationLayerInfo _activation_info;
309 experimental::PostOpList<ITensorInfo *> _post_ops;
310 bool _fixed_format;
311 arm_compute::WeightFormat _weight_format;
312};
313} //namespace arm_compute
314#endif /* ARM_COMPUTE_GEMMINFO_H */