blob: dd78d10d164c1a01eb3dc03925dd8e97465798ef [file] [log] [blame]
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001/*
Pablo Marquez Tellobfd10a12022-03-21 14:56:38 +00002 * Copyright (c) 2020-2022 Arm Limited.
Michele Di Giorgio47a89902020-03-09 19:32:33 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
25
Manuel Bottinicfac51c2021-06-18 15:47:28 +010026#include "arm_compute/core/ITensorPack.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000027#include "arm_compute/core/KernelDescriptors.h"
28#include "arm_compute/core/QuantizationInfo.h"
29#include "arm_compute/core/Utils.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000030#include "arm_compute/core/utils/misc/InfoHelpers.h"
31#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010032#include "arm_compute/core/Validate.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000033#include "arm_compute/runtime/NEON/NEScheduler.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010034
ramelg01cbbb0382021-09-17 17:36:57 +010035#include "src/common/utils/Log.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010036#include "src/core/helpers/WindowHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010037#include "src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010038#include "src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000039
40namespace arm_compute
41{
42using namespace arm_compute::utils::info_helpers;
43namespace
44{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010045Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info,
46 const ITensorInfo *mm_input,
47 const ITensorInfo *mm_weights,
48 const ITensorInfo *bias,
49 float gemmlowp_scale,
50 const TensorInfo *mm_res_info,
51 const TensorInfo *outstage_tensor_info)
Michele Di Giorgio47a89902020-03-09 19:32:33 +000052{
53 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010054 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
55 gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
56 ARM_COMPUTE_RETURN_ON_ERROR(
57 NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +000058 return Status{};
59}
60} // namespace
61
Michalis Spyrouebcebf12020-10-21 00:04:14 +010062Status NEQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
63{
64 // Output quantization scale will be different, but ignored here
65 // since it will be configured at configure() stage.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010066 const TensorInfo out{in};
Michalis Spyrouebcebf12020-10-21 00:04:14 +010067 return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
68}
69
70void NEQLSTMLayer::configure_layer_norm(NEQLSTMLayer::LayerNormGate g, const ITensor *in)
71{
72 ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
73
74 Tensor &out = get_layer_norm_output(g);
75 _memory_group.manage(&out);
76 out.allocator()->init(*(in->info()));
77
Georgios Pinitas40f51a62020-11-21 03:04:18 +000078 get_layer_norm(g) = std::make_unique<NEQLSTMLayerNormalizationKernel>();
Michalis Spyrouebcebf12020-10-21 00:04:14 +010079 get_layer_norm(g)->configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
80}
81
82NEQLSTMLayer::TensorCopyKernel::~TensorCopyKernel() = default;
83
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +010084Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
85{
86 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
87 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
88 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
89 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
90 return Status{};
91}
92
93void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
94{
95 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
ramelg01cbbb0382021-09-17 17:36:57 +010096 ARM_COMPUTE_LOG_PARAMS(src, dst);
97
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +010098 _src = &src;
99 _dst = &dst;
100 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
101 _window = calculate_max_window(*_src->info(), Steps());
102}
103
104void NEQLSTMLayer::TensorCopyKernel::run()
105{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100106 Iterator input_iter{_src, _window};
107 Iterator output_iter{_dst, _window};
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100108
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100109 execute_window_loop(
110 _window, [&](const Coordinates &) { memcpy(output_iter.ptr(), input_iter.ptr(), _row_size); }, input_iter,
111 output_iter);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100112}
113
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100114NEQLSTMLayer::~NEQLSTMLayer() = default;
115
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000116NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000117 : _memory_group(),
118 _dequantize_input_to_forget_weights(),
119 _quantize_input_to_forget_weights(),
120 _transpose_input_to_forget_weights(),
121 _transpose_input_to_cell_weights(),
122 _transpose_input_to_output_weights(),
123 _transpose_input_to_input_weights(),
124 _transpose_recurrent_to_forget_weights(),
125 _transpose_recurrent_to_cell_weights(),
126 _transpose_recurrent_to_output_weights(),
127 _transpose_recurrent_to_input_weights(),
128 _transpose_projection_weights(),
129 _input_to_input_reduction(),
130 _recurrent_to_input_reduction(),
131 _input_to_forget_reduction(),
132 _recurrent_to_forget_reduction(),
133 _input_to_cell_reduction(),
134 _recurrent_to_cell_reduction(),
135 _input_to_output_reduction(),
136 _recurrent_to_output_reduction(),
137 _projection_reduction(),
138 _projection_bias_add(),
139 _mm_input_to_forget(),
140 _mm_recurrent_to_forget(),
141 _pixelwise_mul_cell_to_forget(),
142 _input_to_forget_outstage(),
143 _recurrent_to_forget_outstage(),
144 _cell_to_forget_outstage(),
145 _accumulate_input_recurrent_forget(),
146 _accumulate_cell_forget(),
147 _forget_gate_sigmoid(),
148 _mm_input_to_cell(),
149 _input_to_cell_outstage(),
150 _mm_recurrent_to_cell(),
151 _recurrent_to_cell_outstage(),
152 _accumulate_input_recurrent_modulation(),
153 _cell_gate_tanh(),
154 _input_gate_sub(),
155 _mm_input_to_input(),
156 _input_to_input_outstage(),
157 _mm_recurrent_to_input(),
158 _recurrent_to_input_outstage(),
159 _accumulate_input_recurrent_input(),
160 _pixelwise_mul_cell_to_input(),
161 _cell_to_input_outstage(),
162 _accumulate_cell_input(),
163 _input_gate_sigmoid(),
164 _pixelwise_mul_forget_cell(),
165 _pixelwise_mul_input_cell(),
166 _add_forget_cell(),
167 _cell_clip(),
168 _mm_input_to_output(),
169 _input_to_output_outstage(),
170 _mm_recurrent_to_output(),
171 _recurrent_to_output_outstage(),
172 _accumulate_input_recurrent_output(),
173 _pixelwise_mul_cell_to_output(),
174 _cell_to_output_outstage(),
175 _accumulate_cell_to_output(),
176 _output_gate_sigmoid(),
177 _hidden_tanh(),
178 _pixelwise_mul_hidden(),
179 _hidden_outstage(),
180 _mm_projection(),
181 _projection_outstage(),
182 _accumulate_projection(),
183 _projection_clip(),
184 _projection_bias_copy(),
185 _projection_output_to_accumulate_copy(),
186 _projection_accumulate_to_output_copy(),
187 _hidden_to_output_copy(),
188 _layer_norms(),
189 _copy_output(),
190 _layer_norm_weights(),
191 _layer_norm_bias(),
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100192 _layer_norm_output()
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000193{
194 _memory_group = MemoryGroup(std::move(memory_manager));
195}
196
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100197void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm,
198 NEGEMMLowpOutputStage &outstage,
199 GEMMLowpOutputStageInfo &gemmlowp_info,
200 const ITensor *mm_input,
201 const ITensor *mm_weights,
202 const ITensor *bias,
203 Tensor *mm_res,
204 Tensor *outstage_res,
205 float gemmlowp_scale,
206 const TensorInfo &mm_res_info,
207 const TensorInfo &outstage_tensor_info)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000208{
209 _memory_group.manage(mm_res);
210 _memory_group.manage(outstage_res);
211
212 mm_res->allocator()->init(mm_res_info);
213 outstage_res->allocator()->init(outstage_tensor_info);
214
215 // Configure matrix-multiplication
216 mm.configure(mm_input, mm_weights, nullptr, mm_res);
217
218 // Configure output stage
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100219 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier,
220 &gemmlowp_info.gemmlowp_shift);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000221 outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
222 mm_res->allocator()->allocate();
223}
224
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100225void NEQLSTMLayer::configure(const ITensor *input,
226 const ITensor *input_to_forget_weights,
227 const ITensor *input_to_cell_weights,
228 const ITensor *input_to_output_weights,
229 const ITensor *recurrent_to_forget_weights,
230 const ITensor *recurrent_to_cell_weights,
231 const ITensor *recurrent_to_output_weights,
232 const ITensor *forget_gate_bias,
233 const ITensor *cell_bias,
234 const ITensor *output_gate_bias,
235 const ITensor *cell_state_in,
236 ITensor *output_state_in,
237 ITensor *cell_state_out,
238 ITensor *output_state_out,
239 ITensor *output,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000240 const LSTMParams<ITensor> &lstm_params)
241{
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000242 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
243 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100244 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
245 cell_state_out, output_state_out);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000246
ramelg01cbbb0382021-09-17 17:36:57 +0100247 ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
248 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100249 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
250 cell_state_out, output_state_out);
ramelg01cbbb0382021-09-17 17:36:57 +0100251
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000252 // Set lstm parameters
253 LSTMParams<ITensorInfo> lstm_params_info{};
254 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
255
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100256 _input_to_forget_weights_transposed.info()->set_quantization_info(
257 input_to_forget_weights->info()->quantization_info());
Pablo Marquez Tellobfd10a12022-03-21 14:56:38 +0000258 _input_to_cell_weights_transposed.info()->set_quantization_info(input_to_cell_weights->info()->quantization_info());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100259 _input_to_output_weights_transposed.info()->set_quantization_info(
260 input_to_output_weights->info()->quantization_info());
261 _recurrent_to_forget_weights_transposed.info()->set_quantization_info(
262 recurrent_to_forget_weights->info()->quantization_info());
263 _recurrent_to_cell_weights_transposed.info()->set_quantization_info(
264 recurrent_to_cell_weights->info()->quantization_info());
265 _recurrent_to_output_weights_transposed.info()->set_quantization_info(
266 recurrent_to_output_weights->info()->quantization_info());
Pablo Marquez Tellobfd10a12022-03-21 14:56:38 +0000267
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100268 if (input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED)
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000269 {
270 _convert_input_to_forget_weights_to_qsymm8 = true;
271 // Setup dequantize output tensor to go from QASYMM8_SIGNED -> F32
272
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100273 _input_to_forget_weights_f32.allocator()->init(
274 TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::F32)
275 .set_data_layout(input_to_forget_weights->info()->data_layout()));
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000276 // Setup the quantize output tensor to go from F32 -> QSYMM8
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100277 _input_to_forget_weights_symm8.allocator()->init(
278 (TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::QSYMM8)
279 .set_data_layout(input_to_forget_weights->info()->data_layout())
280 .set_quantization_info(input_to_forget_weights->info()->quantization_info())));
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000281
282 _dequantize_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_f32);
283 _quantize_input_to_forget_weights.configure(&_input_to_forget_weights_f32, &_input_to_forget_weights_symm8);
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000284
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100285 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(
286 input->info(), _input_to_forget_weights_symm8.info(), input_to_cell_weights->info(),
287 input_to_output_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(),
288 recurrent_to_output_weights->info(), forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
289 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(),
290 output->info(), lstm_params_info));
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000291 }
292 else
293 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100294 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(
295 input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(),
296 input_to_output_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(),
297 recurrent_to_output_weights->info(), forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
298 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(),
299 output->info(), lstm_params_info));
Pablo Marquez Telloa5d61bf2022-03-17 12:52:02 +0000300 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000301
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100302 const int batch_size = input->info()->dimension(1);
303 const int num_units = input_to_output_weights->info()->dimension(1);
304 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000305
306 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
307 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
308 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
309
310 _projection_bias = lstm_params.projection_bias();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100311 _input_to_forget_weights = (input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED)
312 ? &_input_to_forget_weights_symm8
313 : input_to_forget_weights;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000314 _input_to_cell_weights = input_to_cell_weights;
315 _input_to_output_weights = input_to_output_weights;
316 _recurrent_to_forget_weights = recurrent_to_forget_weights;
317 _recurrent_to_cell_weights = recurrent_to_cell_weights;
318 _recurrent_to_output_weights = recurrent_to_output_weights;
319 _projection_weights = lstm_params.projection_weights();
320
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100321 // Layer normalization
322 _has_layer_norm = lstm_params.use_layer_norm();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100323 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100324 {
325 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
326 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
327 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
328 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
329
330 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
331 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
332 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
333 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
334 }
335
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000336 _has_cifg = lstm_params.has_cifg_opt();
337 _has_projection = lstm_params.has_projection();
338 _has_peephole = lstm_params.has_peephole_opt();
339
340 // Calculate and decompose effective scales for optimizing matmul calculation
341 const int32_t cell_shift = log2(qcell_state_in.scale);
342
343 // Calculate quantized parameters for clipping.
344 int16_t quantized_cell_clip = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100345 if (lstm_params.cell_clip() > 0.0f)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000346 {
347 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
348 }
349 _has_cell_clipping = quantized_cell_clip > 0;
350
351 // Precompute effective bias for optimizing the matmul computations.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100352 if (!_has_cifg)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000353 {
354 _input_to_input_weights = lstm_params.input_to_input_weights();
355 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
356
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100357 _input_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
358 _recurrent_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100359 _input_to_input_reduction->configure(_input_to_input_weights->info(), _input_to_input_eff_bias.info(),
360 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
361 _recurrent_to_input_reduction->configure(
362 _recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(),
363 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000364 }
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100365
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100366 _input_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
367 _recurrent_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
368 _input_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
369 _recurrent_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
370 _input_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
371 _recurrent_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100372
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100373 _input_to_forget_reduction->configure(input_to_forget_weights->info(), _input_to_forget_eff_bias.info(),
374 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
375 _recurrent_to_forget_reduction->configure(
376 recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(),
377 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
378 _input_to_cell_reduction->configure(input_to_cell_weights->info(), _input_to_cell_eff_bias.info(),
379 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
380 _recurrent_to_cell_reduction->configure(
381 recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(),
382 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
383 _input_to_output_reduction->configure(input_to_output_weights->info(), _input_to_output_eff_bias.info(),
384 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
385 _recurrent_to_output_reduction->configure(
386 recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(),
387 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
388 if (_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000389 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100390 _projection_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100391 _projection_reduction->configure(
392 _projection_weights->info(), _projection_eff_bias.info(),
393 GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
394 if (_projection_bias != nullptr)
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100395 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100396 _projection_bias_add.configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias,
397 ConvertPolicy::SATURATE);
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100398 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000399 }
400
401 // Pre-transpose weights to be used in GEMM.
402 _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
403 _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
404 _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100405 _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights,
406 &_recurrent_to_forget_weights_transposed);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000407 _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100408 _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights,
409 &_recurrent_to_output_weights_transposed);
410 if (!_has_cifg)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000411 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100412 _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(),
413 &_input_to_input_weights_transposed);
414 _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(),
415 &_recurrent_to_input_weights_transposed);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000416 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100417 if (_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000418 {
419 _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
420 }
421
422 GEMMLowpOutputStageInfo gemmlowp_info;
423 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
424 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
425 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
426 gemmlowp_info.output_data_type = DataType::QSYMM16;
427
428 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
429 // Forget gate.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100430 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16,
431 QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
432 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale *
433 qinput.scale / lstm_params.forget_intermediate_scale();
434 configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info, input,
435 &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias, &_mm_input_to_forget_res,
436 &_input_to_forget_outstage_res, input_to_forget_scale, mm_out_info, forget_gate_outstage_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000437
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100438 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale *
439 qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
440 configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info, output_state_in,
441 &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias, &_mm_recurrent_to_forget_res,
442 &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale, mm_out_info, forget_gate_outstage_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000443
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100444 _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
445 &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000446 _input_to_forget_outstage_res.allocator()->allocate();
447
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100448 if (_has_peephole)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000449 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100450 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000451 _memory_group.manage(&_mul_cell_to_forget_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100452 _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(),
453 &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE,
454 RoundingPolicy::TO_ZERO);
455 _cell_to_forget_outstage_res.allocator()->init(
456 TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16,
457 QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000458 _memory_group.manage(&_cell_to_forget_outstage_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100459 const float cell_to_forget_scale =
460 std::pow(2, cell_shift) *
461 lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale /
462 lstm_params.forget_intermediate_scale();
463 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier,
464 &gemmlowp_info.gemmlowp_shift);
465 _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res,
466 gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000467 _mul_cell_to_forget_res.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100468 _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res,
469 &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000470 _cell_to_forget_outstage_res.allocator()->allocate();
471 }
472
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100473 Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
474
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100475 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100476 {
477 configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
478 forget_activation_input->allocator()->allocate();
479 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
480 }
481
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000482 // Output quantization info of Sigmoid and Tanh activations
483 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100484 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000485
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000486 _memory_group.manage(&_forget_gate);
487 _forget_gate.allocator()->init(forget_gate_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100488 _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate,
489 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100490 forget_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000491
492 // Modulation gate.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100493 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16,
494 QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
495 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale *
496 qinput.scale / lstm_params.cell_intermediate_scale();
497 configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info, input, &_input_to_cell_weights_transposed,
498 &_input_to_cell_eff_bias, &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000499 mm_out_info, cell_outstage_info);
500
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100501 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale *
502 qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
503 configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info, output_state_in,
504 &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias, &_mm_recurrent_to_cell_res,
505 &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale, mm_out_info, cell_outstage_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000506
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100507 _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
508 &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000509 _input_to_cell_outstage_res.allocator()->allocate();
510
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100511 Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
512
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100513 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100514 {
515 configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
516 cell_activation_input->allocator()->allocate();
517 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
518 }
519
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000520 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100521
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000522 _memory_group.manage(&_cell_gate);
523 _cell_gate.allocator()->init(cell_gate_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100524 _cell_gate_tanh.configure(cell_activation_input, &_cell_gate,
525 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100526 cell_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000527
528 // Input gate.
529 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
530 _input_gate.allocator()->init(input_gate_info);
531 _memory_group.manage(&_input_gate);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100532 if (_has_cifg)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000533 {
534 _ones.allocator()->init(*_forget_gate.info());
535 _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
536 _ones.allocator()->allocate();
537 }
538 else
539 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100540 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
541 QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
542 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale *
543 qinput.scale / lstm_params.input_intermediate_scale();
544 configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info, input,
545 &_input_to_input_weights_transposed, &_input_to_input_eff_bias, &_mm_input_to_input_res,
546 &_input_to_input_outstage_res, input_to_input_scale, mm_out_info, input_outstage_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000547
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100548 const float recurrent_to_input_scale =
549 _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale /
550 lstm_params.input_intermediate_scale();
551 configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info, output_state_in,
552 &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000553 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
554 mm_out_info, input_outstage_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100555 _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res,
556 &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000557 _input_to_input_outstage_res.allocator()->allocate();
558
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100559 if (_has_peephole)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000560 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100561 _mul_cell_to_input_res.allocator()->init(
562 TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000563 _memory_group.manage(&_mul_cell_to_input_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100564 _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(),
565 &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE,
566 RoundingPolicy::TO_ZERO);
567 const float cell_to_input_scale =
568 std::pow(2, cell_shift) *
569 lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale /
570 lstm_params.input_intermediate_scale();
571 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier,
572 &gemmlowp_info.gemmlowp_shift);
573 _cell_to_input_outstage_res.allocator()->init(
574 TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16,
575 QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000576 _memory_group.manage(&_cell_to_input_outstage_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100577 _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res,
578 gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000579 _mul_cell_to_input_res.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100580 _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res,
581 &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000582 _cell_to_input_outstage_res.allocator()->allocate();
583 }
584
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100585 Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
586
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100587 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100588 {
589 configure_layer_norm(LayerNormGate::Input, input_activation_input);
590 input_activation_input->allocator()->allocate();
591 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
592 }
593
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100594 _input_gate_sigmoid.configure(input_activation_input, &_input_gate,
595 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100596 input_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000597 }
598 // Cell.
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100599 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100600 _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE,
601 RoundingPolicy::TO_ZERO);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000602 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
603 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100604 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
605 QuantizationInfo(mul_input_cell_scale, 0));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000606 _memory_group.manage(&_mul_input_cell_res);
607 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100608 _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE,
609 RoundingPolicy::TO_ZERO);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000610 _cell_gate.allocator()->allocate();
611 _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
612 _mul_input_cell_res.allocator()->allocate();
613 _forget_gate.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100614 if (_has_cell_clipping)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000615 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100616 _cell_clip.configure(cell_state_out, nullptr,
617 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
618 -quantized_cell_clip, quantized_cell_clip));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000619 }
620 // Output gate.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100621 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
622 QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
623 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale *
624 qinput.scale / lstm_params.output_intermediate_scale();
625 configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info, input,
626 &_input_to_output_weights_transposed, &_input_to_output_eff_bias, &_mm_input_to_output_res,
627 &_input_to_output_outstage_res, input_to_output_scale, mm_out_info, output_outstage_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000628
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100629 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale *
630 qoutput_state_in.scale / lstm_params.output_intermediate_scale();
631 configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info, output_state_in,
632 &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias, &_mm_recurrent_to_output_res,
633 &_recurrent_to_output_outstage_res, recurrent_to_output_scale, mm_out_info, output_outstage_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000634
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100635 _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res,
636 &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000637 _input_to_output_outstage_res.allocator()->allocate();
638
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100639 if (_has_peephole)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000640 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100641 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000642 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100643 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000644 _memory_group.manage(&_mul_cell_to_output_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100645 _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(),
646 &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE,
647 RoundingPolicy::TO_ZERO);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100648
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100649 const float cell_to_output_scale =
650 std::pow(2, cell_shift) *
651 lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale /
652 lstm_params.output_intermediate_scale();
653 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier,
654 &gemmlowp_info.gemmlowp_shift);
655 _cell_to_output_outstage_res.allocator()->init(
656 TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16,
657 QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100658 _memory_group.manage(&_cell_to_output_outstage_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100659 _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res,
660 gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000661 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100662
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100663 _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res,
664 &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100665 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000666 }
667
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100668 Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
669
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100670 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100671 {
672 configure_layer_norm(LayerNormGate::Output, output_activation_input);
673 output_activation_input->allocator()->allocate();
674 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
675 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000676 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100677
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000678 _memory_group.manage(&_output_gate);
679 _output_gate.allocator()->init(output_gate_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100680 _output_gate_sigmoid.configure(output_activation_input, &_output_gate,
681 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100682 output_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000683
684 // Hidden.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100685 _hidden_tanh.configure(cell_state_out, &_input_gate,
686 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100687 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000688 _memory_group.manage(&_hidden_mul_res);
689 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
690 _hidden_mul_res.allocator()->init(hidden_mul_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100691 _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE,
692 RoundingPolicy::TO_ZERO);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000693 _output_gate.allocator()->allocate();
694 _input_gate.allocator()->allocate();
695 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100696 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier,
697 &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000698 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
699 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100700
701 _projection_tensor_copy_required = (num_units != output_size);
702 ITensor *hidden_gate_result = output_state_out;
703
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100704 _memory_group.manage(&_hidden_gate);
705
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100706 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100707 {
708 _hidden_gate.allocator()->init(*output_state_out->info());
709 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
710 hidden_gate_result = &_hidden_gate;
711 }
712
713 _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000714 _hidden_mul_res.allocator()->allocate();
715
716 // Projection.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100717 if (_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000718 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100719 const TensorInfo projection_outstage_info(*output_state_out->info());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100720 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
721 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
722 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
723 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
724 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
725 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000726
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100727 TensorInfo projection_mm_out_info{mm_out_info};
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100728 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100729
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100730 configure_mm(_mm_projection, _projection_outstage, gemmlowp_info, hidden_gate_result,
731 &_projection_weights_transposed, &_projection_eff_bias, &_mm_projection_res,
732 &_projection_outstage_res, projection_scale, projection_mm_out_info, projection_outstage_info);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100733
734 ITensor *accumulate_destination = output_state_out;
735
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100736 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100737 {
738 _hidden_gate.allocator()->allocate();
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100739 _projection_accumulate_res.allocator()->init(*output_state_in->info());
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100740 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100741 _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100742 accumulate_destination = &_projection_accumulate_res;
743 }
744
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100745 _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination,
746 ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000747 _projection_outstage_res.allocator()->allocate();
748
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100749 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100750 {
751 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
752 _projection_accumulate_res.allocator()->allocate();
753 }
754
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100755 int8_t quantized_projection_clip{0};
756 if (lstm_params.projection_clip() > 0.0f)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000757 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100758 quantized_projection_clip =
759 utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000760 }
761
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100762 if (quantized_projection_clip > 0)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000763 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100764 _projection_clip.configure(output_state_out, nullptr,
765 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
766 -quantized_projection_clip, quantized_projection_clip));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000767 _has_projection_clipping = true;
768 }
769 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100770 else
771 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100772 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100773 {
774 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
775 _hidden_gate.allocator()->allocate();
776 }
777 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100778
779 // Copy output_state_out to output
780 _copy_output.configure(output_state_out, output);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000781}
782
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100783Status NEQLSTMLayer::validate(const ITensorInfo *input,
784 const ITensorInfo *input_to_forget_weights,
785 const ITensorInfo *input_to_cell_weights,
786 const ITensorInfo *input_to_output_weights,
787 const ITensorInfo *recurrent_to_forget_weights,
788 const ITensorInfo *recurrent_to_cell_weights,
789 const ITensorInfo *recurrent_to_output_weights,
790 const ITensorInfo *forget_gate_bias,
791 const ITensorInfo *cell_bias,
792 const ITensorInfo *output_gate_bias,
793 const ITensorInfo *cell_state_in,
794 const ITensorInfo *output_state_in,
795 const ITensorInfo *cell_state_out,
796 const ITensorInfo *output_state_out,
797 const ITensorInfo *output,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000798 const LSTMParams<ITensorInfo> &lstm_params)
799{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100800 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
801 recurrent_to_forget_weights, recurrent_to_cell_weights,
802 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
803 cell_state_in, output_state_in, cell_state_out, output_state_out, output);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000804
805 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
806 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
807
808 const unsigned int input_size = input->dimension(0);
809 const unsigned int batch_size = input->dimension(1);
810 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100811 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000812
813 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
814 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100815 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights,
816 input_to_cell_weights);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000817 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
818 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100819 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights,
820 recurrent_to_cell_weights);
821 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QASYMM8_SIGNED,
822 DataType::QSYMM8);
Mike Kelly58733582022-05-05 20:19:00 +0100823
824 // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
825 if (input_to_forget_weights->data_type() == DataType::QSYMM8)
826 {
827 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_cell_weights, input_to_output_weights,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100828 recurrent_to_forget_weights, recurrent_to_cell_weights,
829 recurrent_to_output_weights);
Mike Kelly58733582022-05-05 20:19:00 +0100830 }
831 else
832 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100833 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights,
834 input_to_output_weights, recurrent_to_forget_weights,
835 recurrent_to_cell_weights, recurrent_to_output_weights);
Mike Kelly58733582022-05-05 20:19:00 +0100836 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000837 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
838 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
839 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
840 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
841 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
842
843 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
844 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
845 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
846 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
847
848 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
849 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
850 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
851 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
852
853 // Check whether peephole weights are all there or none
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100854 if (lstm_params.has_peephole_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000855 {
856 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100857 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1,
858 DataType::QSYMM16);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000859 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
860 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100861 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(),
862 lstm_params.cell_to_output_weights());
863 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(),
864 lstm_params.cell_to_output_weights());
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000865
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100866 if (!lstm_params.has_cifg_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000867 {
868 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100869 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(),
870 lstm_params.cell_to_input_weights());
871 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(),
872 lstm_params.cell_to_input_weights());
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000873 }
874 }
875
876 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
877 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
878 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
879
880 // Calculate and decompose effective scales for optimizing matmul calculation
881 const int32_t cell_shift = log2(qcell_state_in.scale);
882 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
883
884 // Calculate quantized parameters for clipping.
885 int16_t quantized_cell_clip = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100886 if (lstm_params.cell_clip() > 0.0f)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000887 {
888 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
889 }
890
891 // Precompute effective bias for optimizing the matmul computations.
892 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100893 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100894 if (!lstm_params.has_cifg_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000895 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100896 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
897 lstm_params.input_to_input_weights(), &eff_bias_info,
898 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
899 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
900 lstm_params.recurrent_to_input_weights(), &eff_bias_info,
901 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000902 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100903 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
904 input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
905 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
906 recurrent_to_forget_weights, &eff_bias_info,
907 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
908 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
909 input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
910 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
911 recurrent_to_cell_weights, &eff_bias_info,
912 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
913 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
914 input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
915 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
916 recurrent_to_output_weights, &eff_bias_info,
917 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
918 if (lstm_params.has_projection())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000919 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100920 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(
921 lstm_params.projection_weights(), &projection_eff_bias_info,
922 GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true)));
923 if (lstm_params.projection_bias() != nullptr)
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100924 {
925 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100926 ARM_COMPUTE_RETURN_ON_ERROR(
927 NEArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info,
928 &projection_eff_bias_info, ConvertPolicy::SATURATE));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100929 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000930 }
931
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100932 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_cell_weights->data_type(),
933 input_to_cell_weights->quantization_info());
934 const TensorInfo input_to_output_weights_transposed(TensorShape(num_units, input_size), 1,
935 input_to_output_weights->data_type(),
936 input_to_output_weights->quantization_info());
937 const TensorInfo recurrent_to_forget_weights_transposed(TensorShape(num_units, output_size), 1,
938 recurrent_to_forget_weights->data_type(),
939 recurrent_to_forget_weights->quantization_info());
940 const TensorInfo recurrent_to_cell_weights_transposed(TensorShape(num_units, output_size), 1,
941 recurrent_to_cell_weights->data_type(),
942 recurrent_to_cell_weights->quantization_info());
943 const TensorInfo recurrent_to_output_weights_transposed(TensorShape(num_units, output_size), 1,
944 recurrent_to_output_weights->data_type(),
945 recurrent_to_output_weights->quantization_info());
946 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1,
947 recurrent_to_forget_weights->data_type(),
948 recurrent_to_forget_weights->quantization_info());
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000949
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000950 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed));
Pablo Marquez Tellobfd10a12022-03-21 14:56:38 +0000951 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_to_output_weights_transposed));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100952 ARM_COMPUTE_RETURN_ON_ERROR(
953 NETranspose::validate(recurrent_to_forget_weights, &recurrent_to_forget_weights_transposed));
954 ARM_COMPUTE_RETURN_ON_ERROR(
955 NETranspose::validate(recurrent_to_cell_weights, &recurrent_to_cell_weights_transposed));
956 ARM_COMPUTE_RETURN_ON_ERROR(
957 NETranspose::validate(recurrent_to_output_weights, &recurrent_to_output_weights_transposed));
958 if (!lstm_params.has_cifg_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000959 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100960 const TensorInfo recurrent_to_input_weights_transposed(
961 TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(),
962 lstm_params.recurrent_to_input_weights()->quantization_info());
Pablo Marquez Tellobfd10a12022-03-21 14:56:38 +0000963 const TensorInfo input_to_input_weights_transposed(TensorShape(num_units, input_size), 1,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100964 lstm_params.input_to_input_weights()->data_type(),
965 lstm_params.input_to_input_weights()->quantization_info());
966 ARM_COMPUTE_RETURN_ON_ERROR(
967 NETranspose::validate(lstm_params.input_to_input_weights(), &input_to_input_weights_transposed));
968 ARM_COMPUTE_RETURN_ON_ERROR(
969 NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_to_input_weights_transposed));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000970 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100971 if (lstm_params.has_projection())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000972 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100973 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1,
974 lstm_params.projection_weights()->data_type(),
975 lstm_params.projection_weights()->quantization_info());
976 ARM_COMPUTE_RETURN_ON_ERROR(
977 NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000978 }
979
980 GEMMLowpOutputStageInfo gemmlowp_info;
981 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
982 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
983 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
984 gemmlowp_info.output_data_type = DataType::QSYMM16;
985
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100986 const bool has_layer_norm = lstm_params.use_layer_norm();
987
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000988 // Forget gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100989 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100990 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
991 QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000992 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100993 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale /
994 lstm_params.forget_intermediate_scale();
995 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
996 input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000997
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100998 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale *
999 qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
1000 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
1001 &eff_bias_info, recurrent_to_forget_scale, &mm_out_info,
1002 &forget_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001003
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001004 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info,
1005 &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001006
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001007 if (lstm_params.has_peephole_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001008 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001009 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1,
1010 DataType::QSYMM16);
1011 ARM_COMPUTE_RETURN_ON_ERROR(
1012 NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f,
1013 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1014 const float cell_to_forget_scale = std::pow(2, cell_shift) *
1015 lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale /
1016 lstm_params.forget_intermediate_scale();
1017 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
1018 cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
1019 ARM_COMPUTE_RETURN_ON_ERROR(
1020 NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
1021 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info,
1022 &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001023 }
1024
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001025 if (has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001026 {
1027 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
1028 const ITensorInfo *b_info = forget_gate_bias;
1029 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
1030 }
1031
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001032 // Output quantization info of Sigmoid and Tanh activations
1033 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001034 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001035
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001036 ARM_COMPUTE_RETURN_ON_ERROR(
1037 NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info,
1038 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001039
1040 // Modulation gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001041 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001042 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
1043 QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
1044 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale /
1045 lstm_params.cell_intermediate_scale();
1046 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
1047 input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001048
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001049 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale *
1050 qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
1051 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
1052 &eff_bias_info, recurrent_to_cell_scale, &mm_out_info,
1053 &cell_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001054
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001055 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info,
1056 &cell_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001057
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001058 if (has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001059 {
1060 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
1061 const ITensorInfo *b_info = cell_bias;
1062 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
1063 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001064 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001065
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001066 ARM_COMPUTE_RETURN_ON_ERROR(
1067 NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info,
1068 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001069
1070 // Input gate.
1071 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001072 if (lstm_params.has_cifg_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001073 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001074 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr,
1075 "Input gate bias must not be present when CIFG is used");
1076 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info,
1077 &forget_gate_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001078 }
1079 else
1080 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001081 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
1082 lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
Mike Kelly58733582022-05-05 20:19:00 +01001083
1084 // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
1085 if (input_to_forget_weights->data_type() == DataType::QSYMM8)
1086 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001087 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.input_to_input_weights(),
1088 lstm_params.recurrent_to_input_weights());
Mike Kelly58733582022-05-05 20:19:00 +01001089 }
1090 else
1091 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001092 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights,
1093 lstm_params.input_to_input_weights(),
1094 lstm_params.recurrent_to_input_weights());
Mike Kelly58733582022-05-05 20:19:00 +01001095 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001096 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001097 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights,
1098 lstm_params.recurrent_to_input_weights());
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001099 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
1100 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
1101
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001102 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001103 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
1104 QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
1105 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale *
1106 qinput.scale / lstm_params.input_intermediate_scale();
1107 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
1108 input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001109
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001110 const float recurrent_to_input_scale =
1111 lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale /
1112 lstm_params.input_intermediate_scale();
1113 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
1114 &eff_bias_info, recurrent_to_input_scale, &mm_out_info,
1115 &input_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001116
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001117 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info,
1118 &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001119
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001120 if (lstm_params.has_peephole_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001121 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001122 ARM_COMPUTE_RETURN_ON_ERROR(
1123 NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info,
1124 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1125 const float cell_to_input_scale = std::pow(2, cell_shift) *
1126 lstm_params.cell_to_input_weights()->quantization_info().uniform().scale /
1127 lstm_params.input_intermediate_scale();
1128 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
1129 cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
1130 ARM_COMPUTE_RETURN_ON_ERROR(
1131 NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
1132 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info,
1133 &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001134 }
1135
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001136 if (has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001137 {
1138 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
1139 const ITensorInfo *b_info = lstm_params.input_gate_bias();
1140 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
1141 }
1142
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001143 ARM_COMPUTE_RETURN_ON_ERROR(
1144 NEActivationLayer::validate(&input_outstage_info, &input_gate_info,
1145 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001146 }
1147 // Cell.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001148 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(
1149 &forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1150 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(
1151 &input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1152 ARM_COMPUTE_RETURN_ON_ERROR(
1153 NEArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
1154 if (quantized_cell_clip > 0)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001155 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001156 ARM_COMPUTE_RETURN_ON_ERROR(
1157 NEActivationLayer::validate(cell_state_out, nullptr,
1158 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1159 -quantized_cell_clip, quantized_cell_clip)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001160 }
1161 // Output gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001162 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001163 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
1164 QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
1165 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale /
1166 lstm_params.output_intermediate_scale();
1167 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
1168 input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001169
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001170 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale *
1171 qoutput_state_in.scale / lstm_params.output_intermediate_scale();
1172 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
1173 &eff_bias_info, recurrent_to_output_scale, &mm_out_info,
1174 &output_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001175
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001176 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info,
1177 &output_outstage_info, ConvertPolicy::SATURATE));
1178 if (lstm_params.has_peephole_opt())
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001179 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001180 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1,
1181 DataType::QSYMM16);
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001182 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001183 // Here we are not using the output stage because all operations are done in float
1184 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
1185 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001186 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(
1187 cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
1188 RoundingPolicy::TO_ZERO));
1189 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info,
1190 &output_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001191 }
1192
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001193 if (has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001194 {
1195 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
1196 const ITensorInfo *b_info = output_gate_bias;
1197 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
1198 }
1199
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001200 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001201 ARM_COMPUTE_RETURN_ON_ERROR(
1202 NEActivationLayer::validate(&output_outstage_info, &output_gate_info,
1203 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001204
1205 // Hidden.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001206 ARM_COMPUTE_RETURN_ON_ERROR(
1207 NEActivationLayer::validate(cell_state_out, &input_gate_info,
1208 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001209 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001210 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001211 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(
1212 &output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001213
1214 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001215 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001216 ARM_COMPUTE_RETURN_ON_ERROR(
1217 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier,
1218 &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
Sang-Hoon Park9f893752020-10-20 15:33:31 +01001219 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
1220 gemmlowp_info.output_data_type = hidden_out_info.data_type();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001221 ARM_COMPUTE_RETURN_ON_ERROR(
1222 NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001223
1224 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001225
1226 // Projection.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001227 if (lstm_params.has_projection())
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001228 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001229 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights,
1230 lstm_params.projection_weights());
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001231 ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001232
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001233 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
1234 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
1235 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
1236 projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001237 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
1238 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
1239 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
1240 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
1241
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001242 const TensorInfo projection_outstage_info(*output_state_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001243 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1,
1244 lstm_params.projection_weights()->data_type(),
1245 lstm_params.projection_weights()->quantization_info());
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001246
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001247 TensorInfo projection_mm_out_info{mm_out_info};
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001248 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001249
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001250 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed,
1251 &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001252 &projection_outstage_info));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001253
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001254 if (projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001255 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001256 ARM_COMPUTE_RETURN_ON_ERROR(
1257 NEQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001258 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001259
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001260 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out,
1261 ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001262
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001263 if (projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001264 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001265 ARM_COMPUTE_RETURN_ON_ERROR(
1266 NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001267 }
1268
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001269 int8_t quantized_projection_clip{0};
1270 if (lstm_params.projection_clip() > 0.0f)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001271 {
1272 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
1273 }
1274
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001275 if (quantized_projection_clip > 0)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001276 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001277 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(
1278 output_state_out, nullptr,
1279 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1280 -quantized_projection_clip, quantized_projection_clip)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001281 }
1282 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001283 else
1284 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001285 if (projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001286 {
1287 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
1288 }
1289 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001290
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001291 if (cell_state_out->total_size() > 0)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001292 {
1293 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
1294 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
1295 }
1296
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001297 if (output_state_out->total_size() > 0)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001298 {
1299 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
1300 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
1301 }
1302
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001303 ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001304 return Status{};
1305}
1306
1307void NEQLSTMLayer::run()
1308{
1309 prepare();
1310
1311 // Acquire all the temporaries
1312 MemoryGroupResourceScope scope_mg(_memory_group);
1313
1314 // Forget gate.
1315 _mm_input_to_forget.run();
1316 _input_to_forget_outstage.run();
1317
1318 _mm_recurrent_to_forget.run();
1319 _recurrent_to_forget_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001320 _accumulate_input_recurrent_forget.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001321
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001322 if (_has_peephole)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001323 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001324 _pixelwise_mul_cell_to_forget.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001325 _cell_to_forget_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001326 _accumulate_cell_forget.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001327 }
1328
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001329 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001330 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001331 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Forget).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001332 }
1333
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001334 _forget_gate_sigmoid.run();
1335
1336 // Modulation gate.
1337 _mm_input_to_cell.run();
1338 _input_to_cell_outstage.run();
1339
1340 _mm_recurrent_to_cell.run();
1341 _recurrent_to_cell_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001342 _accumulate_input_recurrent_modulation.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001343
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001344 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001345 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001346 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Cell).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001347 }
1348
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001349 _cell_gate_tanh.run();
1350
1351 // Input gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001352 if (_has_cifg)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001353 {
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001354 _input_gate_sub.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001355 }
1356 else
1357 {
1358 _mm_input_to_input.run();
1359 _input_to_input_outstage.run();
1360 _mm_recurrent_to_input.run();
1361 _recurrent_to_input_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001362 _accumulate_input_recurrent_input.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001363
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001364 if (_has_peephole)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001365 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001366 _pixelwise_mul_cell_to_input.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001367 _cell_to_input_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001368 _accumulate_cell_input.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001369 }
1370
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001371 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001372 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001373 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Input).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001374 }
1375
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001376 _input_gate_sigmoid.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001377 }
1378
1379 // Cell.
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001380 _pixelwise_mul_forget_cell.run();
1381 _pixelwise_mul_input_cell.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001382 _add_forget_cell.run();
1383
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001384 if (_has_cell_clipping)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001385 {
1386 _cell_clip.run();
1387 }
1388
1389 // Output gate.
1390 _mm_input_to_output.run();
1391 _input_to_output_outstage.run();
1392 _mm_recurrent_to_output.run();
1393 _recurrent_to_output_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001394 _accumulate_input_recurrent_output.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001395 if (_has_peephole)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001396 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001397 _pixelwise_mul_cell_to_output.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001398 _cell_to_output_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001399 _accumulate_cell_to_output.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001400 }
1401
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001402 if (_has_layer_norm)
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001403 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001404 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Output).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001405 }
1406
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001407 _output_gate_sigmoid.run();
1408
1409 // Hidden.
1410 _hidden_tanh.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001411 _pixelwise_mul_hidden.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001412 _hidden_outstage.run();
1413
1414 // Projection.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001415 if (_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001416 {
1417 _mm_projection.run();
1418 _projection_outstage.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001419
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001420 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001421 {
1422 _projection_output_to_accumulate_copy.run();
1423 }
1424
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001425 _accumulate_projection.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001426
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001427 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001428 {
1429 _projection_accumulate_to_output_copy.run();
1430 }
1431
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001432 if (_has_projection_clipping)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001433 {
1434 _projection_clip.run();
1435 }
1436 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001437 else
1438 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001439 if (_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001440 {
1441 _hidden_to_output_copy.run();
1442 }
1443 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001444
1445 // Copy output_state_out to output
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001446 _copy_output.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001447}
1448
1449void NEQLSTMLayer::prepare()
1450{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001451 if (!_is_prepared)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001452 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001453 if (_convert_input_to_forget_weights_to_qsymm8)
Pablo Marquez Tello1bae0f92022-04-20 18:19:19 +01001454 {
1455 _input_to_forget_weights_f32.allocator()->allocate();
1456 _input_to_forget_weights_symm8.allocator()->allocate();
1457 _dequantize_input_to_forget_weights.run();
1458 _quantize_input_to_forget_weights.run();
1459 }
1460
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001461 // Pre-transpose weights to be used in GEMM.
1462 _input_to_forget_weights_transposed.allocator()->allocate();
1463 _input_to_cell_weights_transposed.allocator()->allocate();
1464 _input_to_output_weights_transposed.allocator()->allocate();
1465 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1466 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1467 _recurrent_to_output_weights_transposed.allocator()->allocate();
1468 _transpose_input_to_forget_weights.run();
1469 _transpose_input_to_cell_weights.run();
1470 _transpose_input_to_output_weights.run();
1471 _transpose_recurrent_to_forget_weights.run();
1472 _transpose_recurrent_to_cell_weights.run();
1473 _transpose_recurrent_to_output_weights.run();
1474
1475 // Precompute effective biases
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001476 if (_has_cifg)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001477 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001478 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()),
1479 _ones.info()->total_size() / _ones.info()->element_size(), 32767);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001480 }
1481 else
1482 {
1483 _input_to_input_eff_bias.allocator()->allocate();
1484 _recurrent_to_input_eff_bias.allocator()->allocate();
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001485
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001486 ITensorPack packII = {{TensorType::ACL_SRC, _input_to_input_weights},
1487 {TensorType::ACL_DST, &_input_to_input_eff_bias}};
1488 NEScheduler::get().schedule_op(_input_to_input_reduction.get(), Window::DimY,
1489 _input_to_input_reduction->window(), packII);
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001490
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001491 ITensorPack packRI = {{TensorType::ACL_SRC, _recurrent_to_input_weights},
1492 {TensorType::ACL_DST, &_recurrent_to_input_eff_bias}};
1493 NEScheduler::get().schedule_op(_recurrent_to_input_reduction.get(), Window::DimY,
1494 _recurrent_to_input_reduction->window(), packRI);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001495
1496 _input_to_input_weights_transposed.allocator()->allocate();
1497 _recurrent_to_input_weights_transposed.allocator()->allocate();
1498 _transpose_input_to_input_weights.run();
1499 _transpose_recurrent_to_input_weights.run();
1500 _input_to_input_weights->mark_as_unused();
1501 _recurrent_to_input_weights->mark_as_unused();
1502 }
1503 _input_to_forget_eff_bias.allocator()->allocate();
1504 _recurrent_to_forget_eff_bias.allocator()->allocate();
1505 _input_to_cell_eff_bias.allocator()->allocate();
1506 _recurrent_to_cell_eff_bias.allocator()->allocate();
1507 _input_to_output_eff_bias.allocator()->allocate();
1508 _recurrent_to_output_eff_bias.allocator()->allocate();
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001509
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001510 ITensorPack packIF = {{TensorType::ACL_SRC, _input_to_forget_weights},
1511 {TensorType::ACL_DST, &_input_to_forget_eff_bias}};
1512 NEScheduler::get().schedule_op(_input_to_forget_reduction.get(), Window::DimY,
1513 _input_to_forget_reduction->window(), packIF);
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001514
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001515 ITensorPack packRF = {{TensorType::ACL_SRC, _recurrent_to_forget_weights},
1516 {TensorType::ACL_DST, &_recurrent_to_forget_eff_bias}};
1517 NEScheduler::get().schedule_op(_recurrent_to_forget_reduction.get(), Window::DimY,
1518 _recurrent_to_forget_reduction->window(), packRF);
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001519
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001520 ITensorPack packIC = {{TensorType::ACL_SRC, _input_to_cell_weights},
1521 {TensorType::ACL_DST, &_input_to_cell_eff_bias}};
1522 NEScheduler::get().schedule_op(_input_to_cell_reduction.get(), Window::DimY, _input_to_cell_reduction->window(),
1523 packIC);
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001524
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001525 ITensorPack packRC = {{TensorType::ACL_SRC, _recurrent_to_cell_weights},
1526 {TensorType::ACL_DST, &_recurrent_to_cell_eff_bias}};
1527 NEScheduler::get().schedule_op(_recurrent_to_cell_reduction.get(), Window::DimY,
1528 _recurrent_to_cell_reduction->window(), packRC);
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001529
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001530 ITensorPack packIO = {{TensorType::ACL_SRC, _input_to_output_weights},
1531 {TensorType::ACL_DST, &_input_to_output_eff_bias}};
1532 NEScheduler::get().schedule_op(_input_to_output_reduction.get(), Window::DimY,
1533 _input_to_output_reduction->window(), packIO);
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001534
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001535 ITensorPack packRO = {{TensorType::ACL_SRC, _recurrent_to_output_weights},
1536 {TensorType::ACL_DST, &_recurrent_to_output_eff_bias}};
1537 NEScheduler::get().schedule_op(_recurrent_to_output_reduction.get(), Window::DimY,
1538 _recurrent_to_output_reduction->window(), packRO);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001539
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001540 if (_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001541 {
Michele Di Giorgio11c562c2020-06-10 16:34:50 +01001542 _projection_eff_bias.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001543 ITensorPack pack = {{TensorType::ACL_SRC, _projection_weights},
1544 {TensorType::ACL_DST, &_projection_eff_bias}};
1545 NEScheduler::get().schedule_op(_projection_reduction.get(), Window::DimY, _projection_reduction->window(),
1546 pack);
1547 if (_projection_bias != nullptr)
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001548 {
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001549 _projection_bias_add.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001550 _projection_bias->mark_as_unused();
1551 }
1552
1553 _projection_weights_transposed.allocator()->allocate();
1554 _transpose_projection_weights.run();
1555 _projection_weights->mark_as_unused();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001556
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001557 if (!_projection_tensor_copy_required)
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001558 {
1559 _hidden_gate.mark_as_unused();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001560 _projection_accumulate_res.mark_as_unused();
1561 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001562 }
1563
1564 // Mark weights as unused
1565 _input_to_forget_weights->mark_as_unused();
1566 _input_to_cell_weights->mark_as_unused();
1567 _input_to_output_weights->mark_as_unused();
1568 _recurrent_to_forget_weights->mark_as_unused();
1569 _recurrent_to_cell_weights->mark_as_unused();
1570 _recurrent_to_output_weights->mark_as_unused();
1571
1572 _is_prepared = true;
1573 }
1574}
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001575} // namespace arm_compute