blob: 41f9c3d70071db3fd9e63d317a84ea88a9d696f1 [file] [log] [blame]
Michalis Spyrouba27e442019-05-28 10:04:57 +01001/*
Teresa Charlin562bee52021-04-13 17:44:15 +01002 * Copyright (c) 2019-2021 Arm Limited.
Michalis Spyrouba27e442019-05-28 10:04:57 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NELSTMLayerQuantized.h"
25
26#include "arm_compute/core/Utils.h"
Michalis Spyrouba27e442019-05-28 10:04:57 +010027#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010028#include "arm_compute/core/Validate.h"
29
ramelg01cbbb0382021-09-17 17:36:57 +010030#include "src/common/utils/Log.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010031#include "src/core/helpers/AutoConfiguration.h"
Michalis Spyrouba27e442019-05-28 10:04:57 +010032
33#include <cmath>
34#include <memory>
35#include <tuple>
36
37namespace arm_compute
38{
39namespace
40{
41// Quantization info structures used in the LSTMQuantize layer
42const QuantizationInfo qasymm(1.f / 128.f, 128);
43const QuantizationInfo qsymm_3(8.f / 32768.f, 0); // qsymm16 with 3 integer bit
44const QuantizationInfo qsymm_4(16.f / 32768.f, 0); // qsymm16 with 4 integer bit
45const QuantizationInfo qsymm_0(1.f / 32768.f, 0); // qsymm16 with 0 integer bit
46} // namespace
Michalis Spyrouebcebf12020-10-21 00:04:14 +010047NELSTMLayerQuantized::~NELSTMLayerQuantized() = default;
Michalis Spyrouba27e442019-05-28 10:04:57 +010048
49NELSTMLayerQuantized::NELSTMLayerQuantized(std::shared_ptr<IMemoryManager> memory_manager)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010050 : _memory_group(std::move(memory_manager)),
51 _gemmlowp(),
52 _output_stage(),
53 _transpose_weights(),
54 _concat_input_weights(),
55 _concat_recurrent_weights(),
56 _concat_weights(),
57 _concat_inputs(),
58 _concat_bias(),
59 _sigmoid_forget_gate(),
60 _sigmoid_input_gate(),
61 _sigmoid_output_gate(),
62 _tanh_modulation_gate(),
63 _tanh_output_state(),
64 _add1(),
65 _add2(),
66 _mul1(),
67 _mul2(),
68 _mul3(),
69 _slice_input_tensor(),
70 _slice_forget_tensor(),
71 _slice_cell_tensor(),
72 _slice_output_tensor(),
73 _dequantize(),
74 _quantize(),
75 _input_to_input_weights(nullptr),
76 _input_to_forget_weights(nullptr),
77 _input_to_cell_weights(nullptr),
78 _input_to_output_weights(nullptr),
79 _recurrent_to_input_weights(nullptr),
80 _recurrent_to_forget_weights(nullptr),
81 _recurrent_to_cell_weights(nullptr),
82 _recurrent_to_output_weights(nullptr),
83 _input_gate_bias(nullptr),
84 _forget_gate_bias(nullptr),
85 _cell_bias(nullptr),
86 _output_gate_bias(nullptr),
87 _recurrent_weights(),
88 _input_weights(),
89 _weights(),
90 _input(),
91 _weights_transposed(),
92 _output_highp(),
93 _output_lowp(),
94 _bias(),
95 _forget_gate_input(),
96 _input_gate_input(),
97 _output_gate_input(),
98 _input_modulation_gate_input(),
99 _forget_gate_output(),
100 _input_gate_output(),
101 _output_gate_output(),
102 _input_modulation_gate_output(),
103 _cell_state1(),
104 _cell_state2(),
105 _output_state_tmp(),
106 _output_state_out_symm(),
107 _output_state_out_f32(),
Michalis Spyrouba27e442019-05-28 10:04:57 +0100108 _is_prepared(false)
109{
110}
111
112void NELSTMLayerQuantized::configure(const ITensor *input,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100113 const ITensor *input_to_input_weights,
114 const ITensor *input_to_forget_weights,
115 const ITensor *input_to_cell_weights,
116 const ITensor *input_to_output_weights,
117 const ITensor *recurrent_to_input_weights,
118 const ITensor *recurrent_to_forget_weights,
119 const ITensor *recurrent_to_cell_weights,
120 const ITensor *recurrent_to_output_weights,
121 const ITensor *input_gate_bias,
122 const ITensor *forget_gate_bias,
123 const ITensor *cell_bias,
124 const ITensor *output_gate_bias,
125 ITensor *cell_state_in,
126 const ITensor *output_state_in,
127 ITensor *cell_state_out,
128 ITensor *output_state_out)
Michalis Spyrouba27e442019-05-28 10:04:57 +0100129{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100130 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights,
131 input_to_output_weights, recurrent_to_input_weights, recurrent_to_forget_weights,
132 recurrent_to_cell_weights, recurrent_to_output_weights, input_gate_bias,
133 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
134 cell_state_out, output_state_out);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100135
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100136 ARM_COMPUTE_ERROR_THROW_ON(NELSTMLayerQuantized::validate(
137 input->info(), input_to_input_weights->info(), input_to_forget_weights->info(), input_to_cell_weights->info(),
138 input_to_output_weights->info(), recurrent_to_input_weights->info(), recurrent_to_forget_weights->info(),
139 recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(), input_gate_bias->info(),
140 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), cell_state_in->info(),
141 output_state_in->info(), cell_state_out->info(), output_state_out->info()));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100142
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100143 ARM_COMPUTE_LOG_PARAMS(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights,
144 input_to_output_weights, recurrent_to_input_weights, recurrent_to_forget_weights,
145 recurrent_to_cell_weights, recurrent_to_output_weights, input_gate_bias, forget_gate_bias,
146 cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out,
147 output_state_out);
ramelg01cbbb0382021-09-17 17:36:57 +0100148
Michalis Spyrouba27e442019-05-28 10:04:57 +0100149 const int input_size = input->info()->dimension(0);
150 const int batch_size = input->info()->dimension(1);
151 const int output_size = input_to_input_weights->info()->dimension(1);
152
153 const QuantizationInfo qweights = input_to_input_weights->info()->quantization_info(); // Weights quantization
154
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100155 auto_init_if_empty(*cell_state_out->info(),
156 TensorInfo(TensorShape(batch_size, output_size), 1, DataType::QSYMM16, qsymm_4));
157 auto_init_if_empty(*output_state_out->info(),
158 TensorInfo(TensorShape(batch_size, output_size), 1, DataType::QASYMM8, qasymm));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100159
160 _input_to_input_weights = input_to_input_weights;
161 _input_to_forget_weights = input_to_forget_weights;
162 _input_to_cell_weights = input_to_cell_weights;
163 _input_to_output_weights = input_to_output_weights;
164 _recurrent_to_input_weights = recurrent_to_input_weights;
165 _recurrent_to_forget_weights = recurrent_to_forget_weights;
166 _recurrent_to_cell_weights = recurrent_to_cell_weights;
167 _recurrent_to_output_weights = recurrent_to_output_weights;
168 _input_gate_bias = input_gate_bias;
169 _forget_gate_bias = forget_gate_bias;
170 _cell_bias = cell_bias;
171 _output_gate_bias = output_gate_bias;
172
173 // Weights concatenation
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100174 std::vector<const ITensor *> inputs_weights_vector{input_to_input_weights, input_to_forget_weights,
175 input_to_cell_weights, input_to_output_weights};
176 std::vector<const ITensor *> recurrent_weights_vector{recurrent_to_input_weights, recurrent_to_forget_weights,
177 recurrent_to_cell_weights, recurrent_to_output_weights};
Michalis Spyrouba27e442019-05-28 10:04:57 +0100178
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100179 _input_weights.allocator()->init(
180 TensorInfo(TensorShape(input_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100181 _concat_input_weights.configure(inputs_weights_vector, &_input_weights, Window::DimY);
182
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100183 _recurrent_weights.allocator()->init(
184 TensorInfo(TensorShape(output_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100185 _concat_recurrent_weights.configure(recurrent_weights_vector, &_recurrent_weights, Window::DimY);
186
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100187 std::vector<const ITensor *> weights_vector{&_recurrent_weights, &_input_weights};
188 _weights.allocator()->init(
189 TensorInfo(TensorShape(output_size + input_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100190 _concat_weights.configure(weights_vector, &_weights, Window::DimX);
191 _transpose_weights.configure(&_weights, &_weights_transposed);
192
193 // Input concatenation
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100194 std::vector<const ITensor *> input_vector{input, output_state_in};
Michalis Spyrouba27e442019-05-28 10:04:57 +0100195 _memory_group.manage(&_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100196 _input.allocator()->init(
197 TensorInfo(TensorShape(output_size + input_size, batch_size), 1, DataType::QASYMM8, qasymm));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100198 _concat_inputs.configure(input_vector, &_input, Window::DimX);
199
200 // Bias concatenation
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100201 std::vector<const ITensor *> bias_vector{input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias};
Michalis Spyrouba27e442019-05-28 10:04:57 +0100202 _bias.allocator()->init(TensorInfo(TensorShape(4 * output_size), 1, DataType::S32));
203 _concat_bias.configure(bias_vector, &_bias, Window::DimX);
204
205 // Invert the offset for gemmlowp
206 _input.info()->set_quantization_info(QuantizationInfo(qasymm.uniform().scale, -qasymm.uniform().offset));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100207 _weights_transposed.info()->set_quantization_info(
208 QuantizationInfo(qweights.uniform().scale, -qweights.uniform().offset));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100209
210 // Run gemmlowp
211 _memory_group.manage(&_output_highp);
212 _output_highp.allocator()->init(TensorInfo(TensorShape(4 * output_size, batch_size), 1, DataType::S32));
213 _gemmlowp.configure(&_input, &_weights_transposed, nullptr, &_output_highp);
214 _input.allocator()->allocate();
215
216 // Set the offset back
217 _input.info()->set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100218 _weights_transposed.info()->set_quantization_info(
219 QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100220
221 // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
222 _output_lowp.allocator()->init(TensorInfo(_output_highp.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_3));
223
224 const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
Michalis Spyroue7be8a02019-12-12 16:16:09 +0000225 int32_t output_multiplier = 0;
226 int32_t output_shift = 0;
Manuel Bottini07263982019-10-17 18:37:26 +0100227 quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100228
229 _memory_group.manage(&_output_lowp);
Manuel Bottiniae58bdf2021-06-17 17:18:45 +0100230
231 GEMMLowpOutputStageInfo info;
232 info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
233 info.gemmlowp_multiplier = output_multiplier;
234 info.gemmlowp_shift = output_shift;
235 info.output_data_type = DataType::QSYMM16;
236 _output_stage.configure(&_output_highp, &_bias, &_output_lowp, info);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100237 _output_highp.allocator()->allocate();
238 _bias.allocator()->allocate();
239
240 // Get the gate tensors
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100241 if (batch_size > 1)
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100242 {
243 _memory_group.manage(&_input_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100244 _slice_input_tensor.configure(&_output_lowp, &_input_gate_input, {0, 0}, {output_size, batch_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100245 _memory_group.manage(&_forget_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100246 _slice_forget_tensor.configure(&_output_lowp, &_forget_gate_input, {output_size, 0},
247 {2 * output_size, batch_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100248 _memory_group.manage(&_input_modulation_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100249 _slice_cell_tensor.configure(&_output_lowp, &_input_modulation_gate_input, {2 * output_size, 0},
250 {3 * output_size, batch_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100251 _memory_group.manage(&_output_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100252 _slice_output_tensor.configure(&_output_lowp, &_output_gate_input, {3 * output_size, 0},
253 {4 * output_size, batch_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100254 _output_lowp.allocator()->allocate();
255 }
256 else
257 {
258 _memory_group.manage(&_input_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100259 _slice_input_tensor.configure(&_output_lowp, &_input_gate_input, {0}, {output_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100260 _memory_group.manage(&_forget_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100261 _slice_forget_tensor.configure(&_output_lowp, &_forget_gate_input, {output_size}, {2 * output_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100262 _memory_group.manage(&_input_modulation_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100263 _slice_cell_tensor.configure(&_output_lowp, &_input_modulation_gate_input, {2 * output_size},
264 {3 * output_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100265 _memory_group.manage(&_output_gate_input);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100266 _slice_output_tensor.configure(&_output_lowp, &_output_gate_input, {3 * output_size}, {4 * output_size});
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100267 _output_lowp.allocator()->allocate();
268 }
Michalis Spyrouba27e442019-05-28 10:04:57 +0100269
270 // Forget gate
271 _memory_group.manage(&_forget_gate_output);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100272 _forget_gate_output.allocator()->init(
273 TensorInfo(_forget_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
274 _sigmoid_forget_gate.configure(&_forget_gate_input, &_forget_gate_output,
275 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100276 _forget_gate_input.allocator()->allocate();
277
278 // Input gate
279 _memory_group.manage(&_input_gate_output);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100280 _input_gate_output.allocator()->init(
281 TensorInfo(_input_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
282 _sigmoid_input_gate.configure(&_input_gate_input, &_input_gate_output,
283 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100284 _input_gate_input.allocator()->allocate();
285
286 // Input modulation gate equation
287 _memory_group.manage(&_input_modulation_gate_output);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100288 _input_modulation_gate_output.allocator()->init(
289 TensorInfo(_input_modulation_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
290 _tanh_modulation_gate.configure(&_input_modulation_gate_input, &_input_modulation_gate_output,
291 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100292 _input_modulation_gate_input.allocator()->allocate();
293
294 // Output gate
295 _memory_group.manage(&_output_gate_output);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100296 _output_gate_output.allocator()->init(
297 TensorInfo(_output_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
298 _sigmoid_output_gate.configure(&_output_gate_input, &_output_gate_output,
299 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100300 _output_gate_input.allocator()->allocate();
301
302 // Long term memory
303 _memory_group.manage(&_cell_state1);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100304 _cell_state1.allocator()->init(
305 TensorInfo(_forget_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_4));
306 _mul1.configure(&_forget_gate_output, cell_state_in, &_cell_state1, 1, ConvertPolicy::SATURATE,
307 RoundingPolicy::TO_ZERO);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100308 _forget_gate_output.allocator()->allocate();
309
310 _memory_group.manage(&_cell_state2);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100311 _cell_state2.allocator()->init(
312 TensorInfo(_input_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_4));
313 _mul2.configure(&_input_gate_output, &_input_modulation_gate_output, &_cell_state2, 1, ConvertPolicy::SATURATE,
314 RoundingPolicy::TO_ZERO);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100315 _input_modulation_gate_output.allocator()->allocate();
316 _input_gate_output.allocator()->allocate();
317
318 _add1.configure(&_cell_state1, &_cell_state2, cell_state_out, ConvertPolicy::SATURATE);
319 _cell_state1.allocator()->allocate();
320 _cell_state2.allocator()->allocate();
321
322 // Short term memory
323 _memory_group.manage(&_output_state_tmp);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100324 _output_state_tmp.allocator()->init(
325 TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
326 _tanh_output_state.configure(cell_state_out, &_output_state_tmp,
327 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100328
329 _memory_group.manage(&_output_state_out_symm);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100330 _output_state_out_symm.allocator()->init(
331 TensorInfo(_output_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
332 _mul3.configure(&_output_state_tmp, &_output_gate_output, &_output_state_out_symm, 1, ConvertPolicy::SATURATE,
333 RoundingPolicy::TO_ZERO);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100334 _output_gate_output.allocator()->allocate();
335 _output_state_tmp.allocator()->allocate();
336
337 // Requantize the output state from QSYMM16 to QASYMM8
338 _memory_group.manage(&_output_state_out_f32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100339 _output_state_out_f32.allocator()->init(
340 TensorInfo(_output_state_out_symm.info()->tensor_shape(), 1, DataType::F32));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100341 _dequantize.configure(&_output_state_out_symm, &_output_state_out_f32);
342 _output_state_out_symm.allocator()->allocate();
343
344 _quantize.configure(&_output_state_out_f32, output_state_out);
345 _output_state_out_f32.allocator()->allocate();
346}
347
348Status NELSTMLayerQuantized::validate(const ITensorInfo *input,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100349 const ITensorInfo *input_to_input_weights,
350 const ITensorInfo *input_to_forget_weights,
351 const ITensorInfo *input_to_cell_weights,
352 const ITensorInfo *input_to_output_weights,
353 const ITensorInfo *recurrent_to_input_weights,
354 const ITensorInfo *recurrent_to_forget_weights,
355 const ITensorInfo *recurrent_to_cell_weights,
356 const ITensorInfo *recurrent_to_output_weights,
357 const ITensorInfo *input_gate_bias,
358 const ITensorInfo *forget_gate_bias,
359 const ITensorInfo *cell_bias,
360 const ITensorInfo *output_gate_bias,
361 const ITensorInfo *cell_state_in,
362 const ITensorInfo *output_state_in,
363 const ITensorInfo *cell_state_out,
364 const ITensorInfo *output_state_out)
Michalis Spyrouba27e442019-05-28 10:04:57 +0100365{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100366 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(
367 input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
368 recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
369 input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out,
370 output_state_out);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100371
372 const int input_size = input->dimension(0);
373 const int batch_size = input->dimension(1);
374 const int output_size = input_to_input_weights->dimension(1);
375
376 // Dimensionality checks
377 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
378 ARM_COMPUTE_RETURN_ERROR_ON(input_to_input_weights->num_dimensions() > 2);
379 ARM_COMPUTE_RETURN_ERROR_ON(input_gate_bias->num_dimensions() > 1);
380 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
381
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100382 TensorInfo input_weights_info(input_to_input_weights->clone()
383 ->set_tensor_shape(TensorShape(input_size, output_size))
384 .set_data_type(DataType::QASYMM8));
385 TensorInfo recurrent_weights_info(input_to_input_weights->clone()
386 ->set_tensor_shape(TensorShape(output_size, output_size))
387 .set_data_type(DataType::QASYMM8));
388 TensorInfo bias_info(
389 input_gate_bias->clone()->set_tensor_shape(TensorShape(output_size)).set_data_type(DataType::S32));
390 TensorInfo output_state_info(cell_state_in->clone()
391 ->set_tensor_shape(TensorShape(output_size, batch_size))
392 .set_data_type(DataType::QASYMM8)
393 .set_quantization_info(qasymm));
394 TensorInfo cell_state_info(cell_state_in->clone()
395 ->set_tensor_shape(TensorShape(output_size, batch_size))
396 .set_data_type(DataType::QSYMM16)
397 .set_quantization_info(qsymm_4));
Michalis Spyrouba27e442019-05-28 10:04:57 +0100398
399 // Shape checks
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100400 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input_weights_info, input_to_input_weights, input_to_forget_weights,
401 input_to_cell_weights, input_to_output_weights);
402 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&recurrent_weights_info, recurrent_to_input_weights,
403 recurrent_to_forget_weights, recurrent_to_cell_weights,
404 recurrent_to_output_weights);
405 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&bias_info, input_gate_bias, forget_gate_bias, cell_bias,
406 output_gate_bias);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100407 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&cell_state_info, cell_state_in);
408 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&output_state_info, output_state_in);
409
410 // Data type checks
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100411 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input_weights_info, input, input_to_input_weights,
412 input_to_forget_weights, input_to_cell_weights,
413 input_to_output_weights);
414 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_input_weights, recurrent_to_forget_weights,
415 recurrent_to_cell_weights, recurrent_to_output_weights);
416 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&bias_info, input_gate_bias, forget_gate_bias, cell_bias,
417 output_gate_bias);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100418 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&cell_state_info, cell_state_in);
419 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&output_state_info, output_state_in);
420
421 // Quantization checks
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100422 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input_weights_info, input_to_input_weights,
423 input_to_forget_weights, input_to_cell_weights,
424 input_to_output_weights);
425 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(recurrent_to_input_weights, recurrent_to_forget_weights,
426 recurrent_to_cell_weights, recurrent_to_output_weights);
Michalis Spyrouba27e442019-05-28 10:04:57 +0100427 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&cell_state_info, cell_state_in);
428 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&output_state_info, output_state_in);
429
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100430 // Validate internal functions
431 // _concat_input_weights
432 std::vector<const ITensorInfo *> inputs_weights_vector;
433 inputs_weights_vector.emplace_back(input_to_input_weights);
434 inputs_weights_vector.emplace_back(input_to_forget_weights);
435 inputs_weights_vector.emplace_back(input_to_cell_weights);
436 inputs_weights_vector.emplace_back(input_to_output_weights);
437 const QuantizationInfo qweights = input_to_input_weights->quantization_info(); // Weights quantization
438 const TensorInfo input_weights(TensorShape(input_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
439 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_weights_vector, &input_weights, Window::DimY));
440
441 // _concat_recurrent_weights
442 std::vector<const ITensorInfo *> recurrent_weights_vector;
443 recurrent_weights_vector.emplace_back(recurrent_to_input_weights);
444 recurrent_weights_vector.emplace_back(recurrent_to_forget_weights);
445 recurrent_weights_vector.emplace_back(recurrent_to_cell_weights);
446 recurrent_weights_vector.emplace_back(recurrent_to_output_weights);
447 const TensorInfo recurrent_weights(TensorShape(output_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100448 ARM_COMPUTE_RETURN_ON_ERROR(
449 NEConcatenateLayer::validate(recurrent_weights_vector, &recurrent_weights, Window::DimY));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100450
451 // _concat_weights
452 std::vector<const ITensorInfo *> weights_vector;
453 weights_vector.emplace_back(&recurrent_weights);
454 weights_vector.emplace_back(&input_weights);
455 const TensorInfo weights(TensorShape(input_size + output_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
456 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(weights_vector, &weights, Window::DimX));
457 // _transpose_weights
458 const TensorShape weights_transposed_shape(weights.tensor_shape()[1], weights.tensor_shape()[0]);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100459 TensorInfo weights_transposed = weights.clone()->set_is_resizable(true).set_tensor_shape(weights_transposed_shape);
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100460 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(&weights, &weights_transposed));
461
462 // _concat_inputs
463 std::vector<const ITensorInfo *> input_vector;
464 input_vector.emplace_back(input);
465 input_vector.emplace_back(output_state_in);
466 TensorInfo input_concatenated(TensorShape(output_size + input_size, batch_size), 1, DataType::QASYMM8, qasymm);
467 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(input_vector, &input_concatenated, Window::DimX));
468
469 // _concat_bias
470 std::vector<const ITensorInfo *> bias_vector;
471 bias_vector.emplace_back(input_gate_bias);
472 bias_vector.emplace_back(forget_gate_bias);
473 bias_vector.emplace_back(cell_bias);
474 bias_vector.emplace_back(output_gate_bias);
475
476 const TensorInfo bias_concatenated(TensorShape(4 * output_size), 1, DataType::S32);
477 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(bias_vector, &bias_concatenated, Window::DimX));
478
479 // Invert the offset for gemmlowp
480 input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, -qasymm.uniform().offset));
481 weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, -qweights.uniform().offset));
482
483 // _gemmlowp
484 const TensorInfo output_highp(TensorShape(4 * output_size, batch_size), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100485 ARM_COMPUTE_RETURN_ON_ERROR(
486 NEGEMMLowpMatrixMultiplyCore::validate(&input_concatenated, &weights_transposed, nullptr, &output_highp));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100487
488 // Set the offset back
489 input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
490 weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
491
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100492 const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
493
Manuel Bottini07263982019-10-17 18:37:26 +0100494 const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
Michalis Spyroue7be8a02019-12-12 16:16:09 +0000495 int32_t output_multiplier = 0;
496 int32_t output_shift = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100497 ARM_COMPUTE_RETURN_ON_ERROR(
498 quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
Manuel Bottini07263982019-10-17 18:37:26 +0100499
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100500 // _output_stage
Manuel Bottiniae58bdf2021-06-17 17:18:45 +0100501 GEMMLowpOutputStageInfo info;
502 info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
503 info.gemmlowp_multiplier = output_multiplier;
504 info.gemmlowp_shift = output_shift;
505 info.output_data_type = DataType::QSYMM16;
506 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&output_highp, &bias_concatenated, &output_lowp, info));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100507
508 TensorInfo input_gate_input;
509 TensorInfo forget_gate_input;
510 TensorInfo input_modulation_gate_input;
511 TensorInfo output_gate_input;
512
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100513 if (batch_size > 1)
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100514 {
515 // _slice_input_tensor
516 input_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100517 ARM_COMPUTE_RETURN_ON_ERROR(
518 NESlice::validate(&output_lowp, &input_gate_input, {0, 0}, {output_size, batch_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100519 // _slice_forget_tensor
520 forget_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100521 ARM_COMPUTE_RETURN_ON_ERROR(
522 NESlice::validate(&output_lowp, &forget_gate_input, {output_size, 0}, {2 * output_size, batch_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100523 // _slice_cell_tensor
524 input_modulation_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100525 ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &input_modulation_gate_input, {2 * output_size, 0},
526 {3 * output_size, batch_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100527 // _slice_output_tensor
528 output_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100529 ARM_COMPUTE_RETURN_ON_ERROR(
530 NESlice::validate(&output_lowp, &output_gate_input, {3 * output_size, 0}, {4 * output_size, batch_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100531 }
532 else
533 {
534 // _slice_input_tensor
535 input_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100536 ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &input_gate_input, {0}, {output_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100537 // _slice_forget_tensor
538 forget_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100539 ARM_COMPUTE_RETURN_ON_ERROR(
540 NESlice::validate(&output_lowp, &forget_gate_input, {output_size}, {2 * output_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100541 // _slice_cell_tensor
542 input_modulation_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100543 ARM_COMPUTE_RETURN_ON_ERROR(
544 NESlice::validate(&output_lowp, &input_modulation_gate_input, {2 * output_size}, {3 * output_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100545 // _slice_output_tensor
546 output_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100547 ARM_COMPUTE_RETURN_ON_ERROR(
548 NESlice::validate(&output_lowp, &output_gate_input, {3 * output_size}, {4 * output_size}));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100549 }
550
551 // _sigmoid_forget_gate
552 const TensorInfo forget_gate_output(forget_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100553 ARM_COMPUTE_RETURN_ON_ERROR(
554 NEActivationLayer::validate(&forget_gate_input, &forget_gate_output,
555 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100556 // _sigmoid_input_gate
557 const TensorInfo input_gate_output(input_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100558 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(
559 &input_gate_input, &input_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100560 // _tanh_modulation_gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100561 const TensorInfo input_modulation_gate_output(input_modulation_gate_input.tensor_shape(), 1, DataType::QSYMM16,
562 qsymm_0);
563 ARM_COMPUTE_RETURN_ON_ERROR(
564 NEActivationLayer::validate(&input_modulation_gate_input, &input_modulation_gate_output,
565 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f)));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100566 // _sigmoid_output_gate
567 const TensorInfo output_gate_output(output_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100568 ARM_COMPUTE_RETURN_ON_ERROR(
569 NEActivationLayer::validate(&output_gate_input, &output_gate_output,
570 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100571
572 // _mul_forget_gate_cell_state
573 const TensorInfo cell_state_tmp1(forget_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100574 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(
575 &forget_gate_output, cell_state_in, &cell_state_tmp1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100576
577 // _mul_input_gate_input_mod_gate
578 const TensorInfo cell_state_tmp2(input_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100579 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate_output, &input_modulation_gate_output,
580 &cell_state_tmp2, 1, ConvertPolicy::SATURATE,
581 RoundingPolicy::TO_ZERO));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100582
583 // _add_cell_state_tmps
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100584 ARM_COMPUTE_RETURN_ON_ERROR(
585 NEArithmeticAddition::validate(&cell_state_tmp1, &cell_state_tmp2, cell_state_out, ConvertPolicy::SATURATE));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100586
587 // _tanh_modulation_gate
588 const TensorInfo output_state_tmp(cell_state_out->tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100589 ARM_COMPUTE_RETURN_ON_ERROR(
590 NEActivationLayer::validate(cell_state_out, &output_state_tmp,
591 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f)));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100592
593 // _mul_output_state_tmp_output_gate
594 const TensorInfo output_state_out_symm(output_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100595 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_state_tmp, &output_gate_output,
596 &output_state_out_symm, 1, ConvertPolicy::SATURATE,
597 RoundingPolicy::TO_ZERO));
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100598
599 // _dequantize
600 const TensorInfo output_state_out_f32(output_state_out_symm.tensor_shape(), 1, DataType::F32);
601 ARM_COMPUTE_RETURN_ON_ERROR(NEDequantizationLayer::validate(&output_state_out_symm, &output_state_out_f32));
602
603 // _quantize
604 ARM_COMPUTE_RETURN_ON_ERROR(NEQuantizationLayer::validate(&output_state_out_f32, output_state_out));
605
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100606 if (cell_state_out->total_size() != 0)
Michalis Spyrouba27e442019-05-28 10:04:57 +0100607 {
608 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&cell_state_info, cell_state_out);
609 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&cell_state_info, cell_state_out);
610 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&cell_state_info, cell_state_out);
611 }
612
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100613 if (output_state_out->total_size() != 0)
Michalis Spyrouba27e442019-05-28 10:04:57 +0100614 {
615 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&output_state_info, output_state_out);
616 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&output_state_info, output_state_out);
617 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&output_state_info, output_state_out);
618 }
619
620 return Status{};
621}
622
623void NELSTMLayerQuantized::run()
624{
625 prepare();
626
627 // Acquire all the temporaries
628 MemoryGroupResourceScope scope_mg(_memory_group);
629
630 // Concat and transpose the input
631 _concat_inputs.run();
632
633 // Run gemmlowp
634 _gemmlowp.run();
635 _output_stage.run();
636
637 // Slice the results
638 _slice_input_tensor.run();
639 _slice_forget_tensor.run();
640 _slice_cell_tensor.run();
641 _slice_output_tensor.run();
642
643 // Gates
644 // Forget gate
645 _sigmoid_forget_gate.run();
646
647 // Input gate
648 _sigmoid_input_gate.run();
649
650 // Input modulation gate
651 _tanh_modulation_gate.run();
652
653 // Output gate
654 _sigmoid_output_gate.run();
655
656 // Cell state (long term memory)
657 _mul1.run();
658 _mul2.run();
659 _add1.run();
660
661 // Output state (short term memory)
662 _tanh_output_state.run();
663 _mul3.run();
664
Michele Di Giorgio35ea9a72019-08-23 12:02:06 +0100665 // Requantize output state from QSYMM16 to QASYMM8
Michalis Spyrouba27e442019-05-28 10:04:57 +0100666 _dequantize.run();
667 _quantize.run();
668}
669
670void NELSTMLayerQuantized::prepare()
671{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100672 if (!_is_prepared)
Michalis Spyrouba27e442019-05-28 10:04:57 +0100673 {
674 _input_weights.allocator()->allocate();
675 _concat_input_weights.run();
676
677 _input_to_input_weights->mark_as_unused();
678 _input_to_forget_weights->mark_as_unused();
679 _input_to_cell_weights->mark_as_unused();
680 _input_to_output_weights->mark_as_unused();
681
682 _recurrent_weights.allocator()->allocate();
683 _concat_recurrent_weights.run();
684 _recurrent_to_input_weights->mark_as_unused();
685 _recurrent_to_forget_weights->mark_as_unused();
686 _recurrent_to_cell_weights->mark_as_unused();
687 _recurrent_to_output_weights->mark_as_unused();
688
689 _weights.allocator()->allocate();
690 _concat_weights.run();
691
692 _input_weights.mark_as_unused();
693 _input_weights.allocator()->free();
694 _recurrent_weights.mark_as_unused();
695 _recurrent_weights.allocator()->free();
696
697 _weights_transposed.allocator()->allocate();
698 _transpose_weights.run();
699
700 _weights.mark_as_unused();
701 _weights.allocator()->free();
702
703 _bias.allocator()->allocate();
704 _concat_bias.run();
705 _input_gate_bias->mark_as_unused();
706 _forget_gate_bias->mark_as_unused();
707 _cell_bias->mark_as_unused();
708 _output_gate_bias->mark_as_unused();
709
710 _is_prepared = true;
711 }
712}
713
714} // namespace arm_compute