blob: aedb9c0d46d96f6ea0b77dfc245f4fb20b6ef6eb [file] [log] [blame]
Michalis Spyrou25f45a42018-08-08 12:53:05 +01001/*
Georgios Pinitasc6f95102021-03-30 10:03:01 +01002 * Copyright (c) 2018-2021 Arm Limited.
Michalis Spyrou25f45a42018-08-08 12:53:05 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_LSTMPARAMS_H
25#define ARM_COMPUTE_LSTMPARAMS_H
Michalis Spyrou25f45a42018-08-08 12:53:05 +010026
Michalis Spyrou25f45a42018-08-08 12:53:05 +010027#include "arm_compute/core/Types.h"
28#include "arm_compute/runtime/Tensor.h"
29
30#include <cstddef>
31#include <memory>
32
33namespace arm_compute
34{
35template <typename T>
36class LSTMParams
37{
38public:
39 /** Constructor */
40 LSTMParams()
Michele Di Giorgio25d97752020-03-04 18:08:47 +000041 : _input_to_input_weights(nullptr),
42 _recurrent_to_input_weights(nullptr),
43 _cell_to_input_weights(nullptr),
44 _input_gate_bias(nullptr),
45 _cell_to_forget_weights(nullptr),
46 _cell_to_output_weights(nullptr),
47 _projection_weights(nullptr),
48 _projection_bias(nullptr),
49 _input_layer_norm_weights(nullptr),
50 _forget_layer_norm_weights(nullptr),
51 _cell_layer_norm_weights(nullptr),
52 _output_layer_norm_weights(nullptr),
53 _cell_clip(0.f),
54 _projection_clip(0.0f),
Michele Di Giorgio47a89902020-03-09 19:32:33 +000055 _input_intermediate_scale(0.0f),
56 _forget_intermediate_scale(0.0f),
57 _cell_intermediate_scale(0.0f),
58 _output_intermediate_scale(0.0f),
Sang-Hoon Park30b46a62020-04-18 01:40:57 +010059 _hidden_state_zero(0),
60 _hidden_state_scale(0.0f),
Michele Di Giorgio25d97752020-03-04 18:08:47 +000061 _has_peephole_opt(false),
62 _has_projection(false),
63 _has_cifg_opt(true),
64 _use_layer_norm(false)
Michalis Spyrou25f45a42018-08-08 12:53:05 +010065 {
66 }
67 /** Prevent instances of this class from being copied (As this class contains pointers) */
68 LSTMParams(const LSTMParams &) = delete;
69 /** Prevent instances of this class from being copied (As this class contains pointers) */
70 LSTMParams &operator=(const LSTMParams &) = delete;
71 /** Default destructor */
72 ~LSTMParams() = default;
73 /** Set CIFG tensor parameters.
74 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +000075 * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
Michalis Spyrou25f45a42018-08-08 12:53:05 +010076 * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
77 * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
Michele Di Giorgio47a89902020-03-09 19:32:33 +000078 * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
Michalis Spyrou25f45a42018-08-08 12:53:05 +010079 *
80 * @return Reference to this LSTMParams object
81 */
Michalis Spyrou1009e872020-07-27 12:48:34 +010082 LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
Michalis Spyrou25f45a42018-08-08 12:53:05 +010083 {
84 _input_to_input_weights = input_to_input_weights;
85 _recurrent_to_input_weights = recurrent_to_input_weights;
86 _cell_to_input_weights = cell_to_input_weights;
87 _input_gate_bias = input_gate_bias;
88 _has_cifg_opt = false;
89 return *this;
90 }
91 /** Set projection tensor parameters.
92 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +000093 * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
94 * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
Michalis Spyrou25f45a42018-08-08 12:53:05 +010095 *
96 * @return Reference to this LSTMParams object
97 */
98 LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias)
99 {
100 _projection_weights = projection_weights;
101 _projection_bias = projection_bias;
102 _has_projection = true;
103 return *this;
104 }
105 /** Set peephole tensor parameters.
106 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000107 * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
108 * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100109 *
110 * @return Reference to this LSTMParams object
111 */
Michalis Spyrou1009e872020-07-27 12:48:34 +0100112 LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100113 {
114 _cell_to_forget_weights = cell_to_forget_weights;
115 _cell_to_output_weights = cell_to_output_weights;
116 _has_peephole_opt = true;
117 return *this;
118 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100119 /** Set layer normalization tensor parameters.
120 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000121 * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100122 * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
123 * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
124 * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125 *
126 * @return Reference to this LSTMParams object
127 */
Michalis Spyrou1009e872020-07-27 12:48:34 +0100128 LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
129 T *cell_layer_norm_weights, T *output_layer_norm_weights)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100130 {
131 _input_layer_norm_weights = input_layer_norm_weights;
132 _forget_layer_norm_weights = forget_layer_norm_weights;
133 _cell_layer_norm_weights = cell_layer_norm_weights;
134 _output_layer_norm_weights = output_layer_norm_weights;
135 _use_layer_norm = true;
136 return *this;
137 }
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100138
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000139 /** Set cell clip value.
140 *
141 * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
142 *
143 * @return Reference to this LSTMParams object
144 */
145 LSTMParams &set_cell_clip_params(float cell_clip)
146 {
147 _cell_clip = cell_clip;
148 return *this;
149 }
150
151 /** Set projection clip value.
152 *
153 * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
154 *
155 * @return Reference to this LSTMParams object
156 */
157 LSTMParams &set_projection_clip_params(float projection_clip)
158 {
159 _projection_clip = projection_clip;
160 return *this;
161 }
162
163 /** Set scale of the intermediate results of matmul of each layer parameters.
164 *
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000165 * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
166 * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
167 * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
168 * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000169 *
170 * @return Reference to this LSTMParams object
171 */
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000172 LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000173 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000174 _input_intermediate_scale = input_intermediate_scale;
175 _forget_intermediate_scale = forget_intermediate_scale;
176 _cell_intermediate_scale = cell_intermediate_scale;
177 _output_intermediate_scale = output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000178 return *this;
179 }
180
181 /** Set hidden state zero and scale parameters.
182 *
183 * @param[in] hidden_state_zero The zero point of the hidden state.
184 * @param[in] hidden_state_scale The scale of the hidden state.
185 *
186 * @return Reference to this LSTMParams object
187 */
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000188 LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000189 {
190 _hidden_state_zero = hidden_state_zero;
191 _hidden_state_scale = hidden_state_scale;
192 return *this;
193 }
194
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100195 const T *input_to_input_weights() const
196 {
197 return _input_to_input_weights;
198 }
199
200 const T *recurrent_to_input_weights() const
201 {
202 return _recurrent_to_input_weights;
203 }
204
Michalis Spyrou1009e872020-07-27 12:48:34 +0100205 T *cell_to_input_weights() const
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100206 {
207 return _cell_to_input_weights;
208 }
209
210 const T *input_gate_bias() const
211 {
212 return _input_gate_bias;
213 }
214
Michalis Spyrou1009e872020-07-27 12:48:34 +0100215 T *cell_to_forget_weights() const
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100216 {
217 return _cell_to_forget_weights;
218 }
219
Michalis Spyrou1009e872020-07-27 12:48:34 +0100220 T *cell_to_output_weights() const
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100221 {
222 return _cell_to_output_weights;
223 }
224
225 const T *projection_weights() const
226 {
227 return _projection_weights;
228 }
229
230 const T *projection_bias() const
231 {
232 return _projection_bias;
233 }
234
Michalis Spyrou1009e872020-07-27 12:48:34 +0100235 T *input_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100236 {
237 return _input_layer_norm_weights;
238 }
239
Michalis Spyrou1009e872020-07-27 12:48:34 +0100240 T *forget_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100241 {
242 return _forget_layer_norm_weights;
243 }
244
Michalis Spyrou1009e872020-07-27 12:48:34 +0100245 T *cell_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100246 {
247 return _cell_layer_norm_weights;
248 }
249
Michalis Spyrou1009e872020-07-27 12:48:34 +0100250 T *output_layer_norm_weights() const
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100251 {
252 return _output_layer_norm_weights;
253 }
254
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000255 float cell_clip() const
256 {
257 return _cell_clip;
258 }
259
260 float projection_clip() const
261 {
262 return _projection_clip;
263 }
264
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000265 float input_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000266 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000267 return _input_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000268 }
269
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000270 float forget_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000271 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000272 return _forget_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000273 }
274
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000275 float cell_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000276 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000277 return _cell_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000278 }
279
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000280 float output_intermediate_scale() const
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000281 {
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000282 return _output_intermediate_scale;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000283 }
284
285 int32_t hidden_state_zero() const
286 {
287 return _hidden_state_zero;
288 }
289
290 float hidden_state_scale() const
291 {
292 return _hidden_state_scale;
293 }
294
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100295 bool has_peephole_opt() const
296 {
297 return _has_peephole_opt;
298 }
299
300 bool has_projection() const
301 {
302 return _has_projection;
303 }
304
305 bool has_cifg_opt() const
306 {
307 return _has_cifg_opt;
308 }
309
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100310 bool use_layer_norm() const
311 {
312 return _use_layer_norm;
313 }
314
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100315private:
316 const T *_input_to_input_weights;
317 const T *_recurrent_to_input_weights;
Michalis Spyrou1009e872020-07-27 12:48:34 +0100318 T *_cell_to_input_weights;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100319 const T *_input_gate_bias;
Michalis Spyrou1009e872020-07-27 12:48:34 +0100320 T *_cell_to_forget_weights;
321 T *_cell_to_output_weights;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100322 const T *_projection_weights;
323 const T *_projection_bias;
Michalis Spyrou1009e872020-07-27 12:48:34 +0100324 T *_input_layer_norm_weights;
325 T *_forget_layer_norm_weights;
326 T *_cell_layer_norm_weights;
327 T *_output_layer_norm_weights;
Michele Di Giorgio25d97752020-03-04 18:08:47 +0000328 float _cell_clip;
329 float _projection_clip;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000330 float _input_intermediate_scale;
331 float _forget_intermediate_scale;
332 float _cell_intermediate_scale;
333 float _output_intermediate_scale;
Sang-Hoon Park30b46a62020-04-18 01:40:57 +0100334 int32_t _hidden_state_zero;
335 float _hidden_state_scale;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100336 bool _has_peephole_opt;
337 bool _has_projection;
338 bool _has_cifg_opt;
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100339 bool _use_layer_norm;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100340};
341}
Michalis Spyrouf4643372019-11-29 16:17:13 +0000342#endif /*ARM_COMPUTE_LSTMPARAMS_H */