blob: ffb4ddd9d3ed5bb9f316676bfd6b946cd2390437 [file] [log] [blame]
Michalis Spyrou25f45a42018-08-08 12:53:05 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2020 Arm Limited.
Michalis Spyrou25f45a42018-08-08 12:53:05 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_LSTMPARAMS_H
25#define ARM_COMPUTE_LSTMPARAMS_H
Michalis Spyrou25f45a42018-08-08 12:53:05 +010026
27#include "arm_compute/core/IPyramid.h"
28#include "arm_compute/core/PyramidInfo.h"
29#include "arm_compute/core/Types.h"
30#include "arm_compute/runtime/Tensor.h"
31
32#include <cstddef>
33#include <memory>
34
35namespace arm_compute
36{
37template <typename T>
38class LSTMParams
39{
40public:
41 /** Constructor */
42 LSTMParams()
Michele Di Giorgio25d97752020-03-04 18:08:47 +000043 : _input_to_input_weights(nullptr),
44 _recurrent_to_input_weights(nullptr),
45 _cell_to_input_weights(nullptr),
46 _input_gate_bias(nullptr),
47 _cell_to_forget_weights(nullptr),
48 _cell_to_output_weights(nullptr),
49 _projection_weights(nullptr),
50 _projection_bias(nullptr),
51 _input_layer_norm_weights(nullptr),
52 _forget_layer_norm_weights(nullptr),
53 _cell_layer_norm_weights(nullptr),
54 _output_layer_norm_weights(nullptr),
55 _cell_clip(0.f),
56 _projection_clip(0.0f),
Michele Di Giorgio47a89902020-03-09 19:32:33 +000057 _input_intermediate_scale(0.0f),
58 _forget_intermediate_scale(0.0f),
59 _cell_intermediate_scale(0.0f),
60 _output_intermediate_scale(0.0f),
Sang-Hoon Park30b46a62020-04-18 01:40:57 +010061 _hidden_state_zero(0),
62 _hidden_state_scale(0.0f),
Michele Di Giorgio25d97752020-03-04 18:08:47 +000063 _has_peephole_opt(false),
64 _has_projection(false),
65 _has_cifg_opt(true),
66 _use_layer_norm(false)
Michalis Spyrou25f45a42018-08-08 12:53:05 +010067 {
68 }
69 /** Prevent instances of this class from being copied (As this class contains pointers) */
70 LSTMParams(const LSTMParams &) = delete;
71 /** Prevent instances of this class from being copied (As this class contains pointers) */
72 LSTMParams &operator=(const LSTMParams &) = delete;
73 /** Default destructor */
74 ~LSTMParams() = default;
75 /** Set CIFG tensor parameters.
76 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +000077 * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
Michalis Spyrou25f45a42018-08-08 12:53:05 +010078 * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
79 * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
Michele Di Giorgio47a89902020-03-09 19:32:33 +000080 * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
Michalis Spyrou25f45a42018-08-08 12:53:05 +010081 *
82 * @return Reference to this LSTMParams object
83 */
Michalis Spyrou1009e872020-07-27 12:48:34 +010084 LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
Michalis Spyrou25f45a42018-08-08 12:53:05 +010085 {
86 _input_to_input_weights = input_to_input_weights;
87 _recurrent_to_input_weights = recurrent_to_input_weights;
88 _cell_to_input_weights = cell_to_input_weights;
89 _input_gate_bias = input_gate_bias;
90 _has_cifg_opt = false;
91 return *this;
92 }
93 /** Set projection tensor parameters.
94 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +000095 * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
96 * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
Michalis Spyrou25f45a42018-08-08 12:53:05 +010097 *
98 * @return Reference to this LSTMParams object
99 */
100 LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias)
101 {
102 _projection_weights = projection_weights;
103 _projection_bias = projection_bias;
104 _has_projection = true;
105 return *this;
106 }
107 /** Set peephole tensor parameters.
108 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000109 * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
110 * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100111 *
112 * @return Reference to this LSTMParams object
113 */
Michalis Spyrou1009e872020-07-27 12:48:34 +0100114 LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100115 {
116 _cell_to_forget_weights = cell_to_forget_weights;
117 _cell_to_output_weights = cell_to_output_weights;
118 _has_peephole_opt = true;
119 return *this;
120 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100121 /** Set layer normalization tensor parameters.
122 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000123 * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100124 * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125 * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
126 * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
127 *
128 * @return Reference to this LSTMParams object
129 */
Michalis Spyrou1009e872020-07-27 12:48:34 +0100130 LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
131 T *cell_layer_norm_weights, T *output_layer_norm_weights)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100132 {
133 _input_layer_norm_weights = input_layer_norm_weights;
134 _forget_layer_norm_weights = forget_layer_norm_weights;
135 _cell_layer_norm_weights = cell_layer_norm_weights;
136 _output_layer_norm_weights = output_layer_norm_weights;
137 _use_layer_norm = true;
138 return *this;
139 }
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100140
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000141 /** Set cell clip value.
142 *
143 * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
144 *
145 * @return Reference to this LSTMParams object
146 */
147 LSTMParams &set_cell_clip_params(float cell_clip)
148 {
149 _cell_clip = cell_clip;
150 return *this;
151 }
152
153 /** Set projection clip value.
154 *
155 * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
156 *
157 * @return Reference to this LSTMParams object
158 */
159 LSTMParams &set_projection_clip_params(float projection_clip)
160 {
161 _projection_clip = projection_clip;
162 return *this;
163 }
164
165 /** Set scale of the intermediate results of matmul of each layer parameters.
166 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000167 * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
168 * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
169 * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
170 * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000171 *
172 * @return Reference to this LSTMParams object
173 */
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000174 LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000175 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000176 _input_intermediate_scale = input_intermediate_scale;
177 _forget_intermediate_scale = forget_intermediate_scale;
178 _cell_intermediate_scale = cell_intermediate_scale;
179 _output_intermediate_scale = output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000180 return *this;
181 }
182
183 /** Set hidden state zero and scale parameters.
184 *
185 * @param[in] hidden_state_zero The zero point of the hidden state.
186 * @param[in] hidden_state_scale The scale of the hidden state.
187 *
188 * @return Reference to this LSTMParams object
189 */
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000190 LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000191 {
192 _hidden_state_zero = hidden_state_zero;
193 _hidden_state_scale = hidden_state_scale;
194 return *this;
195 }
196
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100197 const T *input_to_input_weights() const
198 {
199 return _input_to_input_weights;
200 }
201
202 const T *recurrent_to_input_weights() const
203 {
204 return _recurrent_to_input_weights;
205 }
206
Michalis Spyrou1009e872020-07-27 12:48:34 +0100207 T *cell_to_input_weights() const
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100208 {
209 return _cell_to_input_weights;
210 }
211
212 const T *input_gate_bias() const
213 {
214 return _input_gate_bias;
215 }
216
Michalis Spyrou1009e872020-07-27 12:48:34 +0100217 T *cell_to_forget_weights() const
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100218 {
219 return _cell_to_forget_weights;
220 }
221
Michalis Spyrou1009e872020-07-27 12:48:34 +0100222 T *cell_to_output_weights() const
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100223 {
224 return _cell_to_output_weights;
225 }
226
227 const T *projection_weights() const
228 {
229 return _projection_weights;
230 }
231
232 const T *projection_bias() const
233 {
234 return _projection_bias;
235 }
236
Michalis Spyrou1009e872020-07-27 12:48:34 +0100237 T *input_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100238 {
239 return _input_layer_norm_weights;
240 }
241
Michalis Spyrou1009e872020-07-27 12:48:34 +0100242 T *forget_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100243 {
244 return _forget_layer_norm_weights;
245 }
246
Michalis Spyrou1009e872020-07-27 12:48:34 +0100247 T *cell_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100248 {
249 return _cell_layer_norm_weights;
250 }
251
Michalis Spyrou1009e872020-07-27 12:48:34 +0100252 T *output_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100253 {
254 return _output_layer_norm_weights;
255 }
256
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000257 float cell_clip() const
258 {
259 return _cell_clip;
260 }
261
262 float projection_clip() const
263 {
264 return _projection_clip;
265 }
266
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000267 float input_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000268 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000269 return _input_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000270 }
271
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000272 float forget_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000273 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000274 return _forget_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000275 }
276
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000277 float cell_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000278 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000279 return _cell_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000280 }
281
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000282 float output_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000283 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000284 return _output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000285 }
286
287 int32_t hidden_state_zero() const
288 {
289 return _hidden_state_zero;
290 }
291
292 float hidden_state_scale() const
293 {
294 return _hidden_state_scale;
295 }
296
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100297 bool has_peephole_opt() const
298 {
299 return _has_peephole_opt;
300 }
301
302 bool has_projection() const
303 {
304 return _has_projection;
305 }
306
307 bool has_cifg_opt() const
308 {
309 return _has_cifg_opt;
310 }
311
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100312 bool use_layer_norm() const
313 {
314 return _use_layer_norm;
315 }
316
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100317private:
318 const T *_input_to_input_weights;
319 const T *_recurrent_to_input_weights;
Michalis Spyrou1009e872020-07-27 12:48:34 +0100320 T *_cell_to_input_weights;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100321 const T *_input_gate_bias;
Michalis Spyrou1009e872020-07-27 12:48:34 +0100322 T *_cell_to_forget_weights;
323 T *_cell_to_output_weights;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100324 const T *_projection_weights;
325 const T *_projection_bias;
Michalis Spyrou1009e872020-07-27 12:48:34 +0100326 T *_input_layer_norm_weights;
327 T *_forget_layer_norm_weights;
328 T *_cell_layer_norm_weights;
329 T *_output_layer_norm_weights;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000330 float _cell_clip;
331 float _projection_clip;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000332 float _input_intermediate_scale;
333 float _forget_intermediate_scale;
334 float _cell_intermediate_scale;
335 float _output_intermediate_scale;
Sang-Hoon Park30b46a62020-04-18 01:40:57 +0100336 int32_t _hidden_state_zero;
337 float _hidden_state_scale;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100338 bool _has_peephole_opt;
339 bool _has_projection;
340 bool _has_cifg_opt;
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100341 bool _use_layer_norm;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100342};
343}
Michalis Spyrouf4643372019-11-29 16:17:13 +0000344#endif /*ARM_COMPUTE_LSTMPARAMS_H */