blob: e21ddd7af1f54a6d2d7312fbede195202438081d [file] [log] [blame]
Michalis Spyrou25f45a42018-08-08 12:53:05 +01001/*
Michele Di Giorgio25d97752020-03-04 18:08:47 +00002 * Copyright (c) 2018-2020 ARM Limited.
Michalis Spyrou25f45a42018-08-08 12:53:05 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_LSTMPARAMS_H
25#define ARM_COMPUTE_LSTMPARAMS_H
Michalis Spyrou25f45a42018-08-08 12:53:05 +010026
27#include "arm_compute/core/IPyramid.h"
28#include "arm_compute/core/PyramidInfo.h"
29#include "arm_compute/core/Types.h"
30#include "arm_compute/runtime/Tensor.h"
31
32#include <cstddef>
33#include <memory>
34
35namespace arm_compute
36{
37template <typename T>
38class LSTMParams
39{
40public:
41 /** Constructor */
42 LSTMParams()
Michele Di Giorgio25d97752020-03-04 18:08:47 +000043 : _input_to_input_weights(nullptr),
44 _recurrent_to_input_weights(nullptr),
45 _cell_to_input_weights(nullptr),
46 _input_gate_bias(nullptr),
47 _cell_to_forget_weights(nullptr),
48 _cell_to_output_weights(nullptr),
49 _projection_weights(nullptr),
50 _projection_bias(nullptr),
51 _input_layer_norm_weights(nullptr),
52 _forget_layer_norm_weights(nullptr),
53 _cell_layer_norm_weights(nullptr),
54 _output_layer_norm_weights(nullptr),
55 _cell_clip(0.f),
56 _projection_clip(0.0f),
Michele Di Giorgio47a89902020-03-09 19:32:33 +000057 _input_intermediate_scale(0.0f),
58 _forget_intermediate_scale(0.0f),
59 _cell_intermediate_scale(0.0f),
60 _output_intermediate_scale(0.0f),
Michele Di Giorgio25d97752020-03-04 18:08:47 +000061 _hidden_state_zero(0.0f),
62 _hidden_state_scale(0),
63 _has_peephole_opt(false),
64 _has_projection(false),
65 _has_cifg_opt(true),
66 _use_layer_norm(false)
Michalis Spyrou25f45a42018-08-08 12:53:05 +010067 {
68 }
69 /** Prevent instances of this class from being copied (As this class contains pointers) */
70 LSTMParams(const LSTMParams &) = delete;
71 /** Prevent instances of this class from being copied (As this class contains pointers) */
72 LSTMParams &operator=(const LSTMParams &) = delete;
73 /** Default destructor */
74 ~LSTMParams() = default;
75 /** Set CIFG tensor parameters.
76 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +000077 * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
Michalis Spyrou25f45a42018-08-08 12:53:05 +010078 * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
79 * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
Michele Di Giorgio47a89902020-03-09 19:32:33 +000080 * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
Michalis Spyrou25f45a42018-08-08 12:53:05 +010081 *
82 * @return Reference to this LSTMParams object
83 */
84 LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, const T *cell_to_input_weights, const T *input_gate_bias)
85 {
86 _input_to_input_weights = input_to_input_weights;
87 _recurrent_to_input_weights = recurrent_to_input_weights;
88 _cell_to_input_weights = cell_to_input_weights;
89 _input_gate_bias = input_gate_bias;
90 _has_cifg_opt = false;
91 return *this;
92 }
93 /** Set projection tensor parameters.
94 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +000095 * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
96 * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
Michalis Spyrou25f45a42018-08-08 12:53:05 +010097 *
98 * @return Reference to this LSTMParams object
99 */
100 LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias)
101 {
102 _projection_weights = projection_weights;
103 _projection_bias = projection_bias;
104 _has_projection = true;
105 return *this;
106 }
107 /** Set peephole tensor parameters.
108 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000109 * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
110 * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100111 *
112 * @return Reference to this LSTMParams object
113 */
114 LSTMParams &set_peephole_params(const T *cell_to_forget_weights, const T *cell_to_output_weights)
115 {
116 _cell_to_forget_weights = cell_to_forget_weights;
117 _cell_to_output_weights = cell_to_output_weights;
118 _has_peephole_opt = true;
119 return *this;
120 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100121 /** Set layer normalization tensor parameters.
122 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000123 * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100124 * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125 * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
126 * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
127 *
128 * @return Reference to this LSTMParams object
129 */
130 LSTMParams &set_layer_normalization_params(const T *input_layer_norm_weights, const T *forget_layer_norm_weights,
131 const T *cell_layer_norm_weights, const T *output_layer_norm_weights)
132 {
133 _input_layer_norm_weights = input_layer_norm_weights;
134 _forget_layer_norm_weights = forget_layer_norm_weights;
135 _cell_layer_norm_weights = cell_layer_norm_weights;
136 _output_layer_norm_weights = output_layer_norm_weights;
137 _use_layer_norm = true;
138 return *this;
139 }
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100140
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000141 /** Set cell clip value.
142 *
143 * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
144 *
145 * @return Reference to this LSTMParams object
146 */
147 LSTMParams &set_cell_clip_params(float cell_clip)
148 {
149 _cell_clip = cell_clip;
150 return *this;
151 }
152
153 /** Set projection clip value.
154 *
155 * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
156 *
157 * @return Reference to this LSTMParams object
158 */
159 LSTMParams &set_projection_clip_params(float projection_clip)
160 {
161 _projection_clip = projection_clip;
162 return *this;
163 }
164
165 /** Set scale of the intermediate results of matmul of each layer parameters.
166 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000167 * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
168 * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
169 * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
170 * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000171 *
172 * @return Reference to this LSTMParams object
173 */
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000174 LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000175 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000176 _input_intermediate_scale = input_intermediate_scale;
177 _forget_intermediate_scale = forget_intermediate_scale;
178 _cell_intermediate_scale = cell_intermediate_scale;
179 _output_intermediate_scale = output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000180 return *this;
181 }
182
183 /** Set hidden state zero and scale parameters.
184 *
185 * @param[in] hidden_state_zero The zero point of the hidden state.
186 * @param[in] hidden_state_scale The scale of the hidden state.
187 *
188 * @return Reference to this LSTMParams object
189 */
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000190 LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000191 {
192 _hidden_state_zero = hidden_state_zero;
193 _hidden_state_scale = hidden_state_scale;
194 return *this;
195 }
196
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100197 const T *input_to_input_weights() const
198 {
199 return _input_to_input_weights;
200 }
201
202 const T *recurrent_to_input_weights() const
203 {
204 return _recurrent_to_input_weights;
205 }
206
207 const T *cell_to_input_weights() const
208 {
209 return _cell_to_input_weights;
210 }
211
212 const T *input_gate_bias() const
213 {
214 return _input_gate_bias;
215 }
216
217 const T *cell_to_forget_weights() const
218 {
219 return _cell_to_forget_weights;
220 }
221
222 const T *cell_to_output_weights() const
223 {
224 return _cell_to_output_weights;
225 }
226
227 const T *projection_weights() const
228 {
229 return _projection_weights;
230 }
231
232 const T *projection_bias() const
233 {
234 return _projection_bias;
235 }
236
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100237 const T *input_layer_norm_weights() const
238 {
239 return _input_layer_norm_weights;
240 }
241
242 const T *forget_layer_norm_weights() const
243 {
244 return _forget_layer_norm_weights;
245 }
246
247 const T *cell_layer_norm_weights() const
248 {
249 return _cell_layer_norm_weights;
250 }
251
252 const T *output_layer_norm_weights() const
253 {
254 return _output_layer_norm_weights;
255 }
256
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000257 float cell_clip() const
258 {
259 return _cell_clip;
260 }
261
262 float projection_clip() const
263 {
264 return _projection_clip;
265 }
266
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000267 float input_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000268 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000269 return _input_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000270 }
271
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000272 float forget_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000273 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000274 return _forget_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000275 }
276
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000277 float cell_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000278 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000279 return _cell_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000280 }
281
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000282 float output_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000283 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000284 return _output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000285 }
286
287 int32_t hidden_state_zero() const
288 {
289 return _hidden_state_zero;
290 }
291
292 float hidden_state_scale() const
293 {
294 return _hidden_state_scale;
295 }
296
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100297 bool has_peephole_opt() const
298 {
299 return _has_peephole_opt;
300 }
301
302 bool has_projection() const
303 {
304 return _has_projection;
305 }
306
307 bool has_cifg_opt() const
308 {
309 return _has_cifg_opt;
310 }
311
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100312 bool use_layer_norm() const
313 {
314 return _use_layer_norm;
315 }
316
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100317private:
318 const T *_input_to_input_weights;
319 const T *_recurrent_to_input_weights;
320 const T *_cell_to_input_weights;
321 const T *_input_gate_bias;
322 const T *_cell_to_forget_weights;
323 const T *_cell_to_output_weights;
324 const T *_projection_weights;
325 const T *_projection_bias;
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100326 const T *_input_layer_norm_weights;
327 const T *_forget_layer_norm_weights;
328 const T *_cell_layer_norm_weights;
329 const T *_output_layer_norm_weights;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000330 float _cell_clip;
331 float _projection_clip;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000332 float _input_intermediate_scale;
333 float _forget_intermediate_scale;
334 float _cell_intermediate_scale;
335 float _output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000336 float _hidden_state_zero;
337 int32_t _hidden_state_scale;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100338 bool _has_peephole_opt;
339 bool _has_projection;
340 bool _has_cifg_opt;
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100341 bool _use_layer_norm;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100342};
343}
Michalis Spyrouf4643372019-11-29 16:17:13 +0000344#endif /*ARM_COMPUTE_LSTMPARAMS_H */