blob: 946791a1045019ee6fb63d11d25c924738af698e [file] [log] [blame]
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001/*
Michele Di Giorgio93b75e02021-06-21 12:00:43 +01002 * Copyright (c) 2020-2021 Arm Limited.
Michele Di Giorgio47a89902020-03-09 19:32:33 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
25
Manuel Bottinicfac51c2021-06-18 15:47:28 +010026#include "arm_compute/core/ITensorPack.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000027#include "arm_compute/core/KernelDescriptors.h"
28#include "arm_compute/core/QuantizationInfo.h"
29#include "arm_compute/core/Utils.h"
30#include "arm_compute/core/Validate.h"
31#include "arm_compute/core/utils/misc/InfoHelpers.h"
32#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
33#include "arm_compute/runtime/NEON/NEScheduler.h"
Michalis Spyrouebcebf12020-10-21 00:04:14 +010034#include "src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
Manuel Bottinicfac51c2021-06-18 15:47:28 +010035#include "src/core/cpu/kernels/CpuGemmLowpMatrixReductionKernel.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010036#include "src/core/helpers/WindowHelpers.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000037
38namespace arm_compute
39{
40using namespace arm_compute::utils::info_helpers;
41namespace
42{
43Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
44 float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
45{
46 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
47 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
48 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
49 return Status{};
50}
51} // namespace
52
Michalis Spyrouebcebf12020-10-21 00:04:14 +010053Status NEQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
54{
55 // Output quantization scale will be different, but ignored here
56 // since it will be configured at configure() stage.
57 const TensorInfo out
58 {
59 in
60 };
61 return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
62}
63
64void NEQLSTMLayer::configure_layer_norm(NEQLSTMLayer::LayerNormGate g, const ITensor *in)
65{
66 ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
67
68 Tensor &out = get_layer_norm_output(g);
69 _memory_group.manage(&out);
70 out.allocator()->init(*(in->info()));
71
Georgios Pinitas40f51a62020-11-21 03:04:18 +000072 get_layer_norm(g) = std::make_unique<NEQLSTMLayerNormalizationKernel>();
Michalis Spyrouebcebf12020-10-21 00:04:14 +010073 get_layer_norm(g)->configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
74}
75
76NEQLSTMLayer::TensorCopyKernel::~TensorCopyKernel() = default;
77
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +010078Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
79{
80 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
81 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
82 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
83 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
84 return Status{};
85}
86
87void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
88{
89 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
90 _src = &src;
91 _dst = &dst;
92 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
93 _window = calculate_max_window(*_src->info(), Steps());
94}
95
96void NEQLSTMLayer::TensorCopyKernel::run()
97{
98 Iterator input_iter{ _src, _window };
99 Iterator output_iter{ _dst, _window };
100
101 execute_window_loop(_window, [&](const Coordinates &)
102 {
103 memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
104 },
105 input_iter, output_iter);
106}
107
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100108NEQLSTMLayer::~NEQLSTMLayer() = default;
109
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000110NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100111 : _memory_group(), _transpose_input_to_forget_weights(), _transpose_input_to_cell_weights(), _transpose_input_to_output_weights(), _transpose_input_to_input_weights(),
112 _transpose_recurrent_to_forget_weights(), _transpose_recurrent_to_cell_weights(), _transpose_recurrent_to_output_weights(), _transpose_recurrent_to_input_weights(), _transpose_projection_weights(),
113 _input_to_input_reduction(), _recurrent_to_input_reduction(), _input_to_forget_reduction(), _recurrent_to_forget_reduction(), _input_to_cell_reduction(), _recurrent_to_cell_reduction(),
114 _input_to_output_reduction(), _recurrent_to_output_reduction(), _projection_reduction(), _projection_bias_add(), _mm_input_to_forget(), _mm_recurrent_to_forget(), _pixelwise_mul_cell_to_forget(),
115 _input_to_forget_outstage(), _recurrent_to_forget_outstage(), _cell_to_forget_outstage(), _accumulate_input_recurrent_forget(), _accumulate_cell_forget(), _forget_gate_sigmoid(), _mm_input_to_cell(),
116 _input_to_cell_outstage(), _mm_recurrent_to_cell(), _recurrent_to_cell_outstage(), _accumulate_input_recurrent_modulation(), _cell_gate_tanh(), _input_gate_sub(), _mm_input_to_input(),
117 _input_to_input_outstage(), _mm_recurrent_to_input(), _recurrent_to_input_outstage(), _accumulate_input_recurrent_input(), _pixelwise_mul_cell_to_input(), _cell_to_input_outstage(),
118 _accumulate_cell_input(), _input_gate_sigmoid(), _pixelwise_mul_forget_cell(), _pixelwise_mul_input_cell(), _add_forget_cell(), _cell_clip(), _mm_input_to_output(), _input_to_output_outstage(),
119 _mm_recurrent_to_output(), _recurrent_to_output_outstage(), _accumulate_input_recurrent_output(), _pixelwise_mul_cell_to_output(), _cell_to_output_outstage(), _accumulate_cell_to_output(),
120 _output_gate_sigmoid(), _hidden_tanh(), _pixelwise_mul_hidden(), _hidden_outstage(), _mm_projection(), _projection_outstage(), _accumulate_projection(), _projection_clip(), _projection_bias_copy(),
121 _projection_output_to_accumulate_copy(), _projection_accumulate_to_output_copy(), _hidden_to_output_copy(), _layer_norms(), _copy_output(), _layer_norm_weights(), _layer_norm_bias(),
122 _layer_norm_output()
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000123{
124 _memory_group = MemoryGroup(std::move(memory_manager));
125}
126
127void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
128 const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias,
129 Tensor *mm_res, Tensor *outstage_res, float gemmlowp_scale,
130 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
131{
132 _memory_group.manage(mm_res);
133 _memory_group.manage(outstage_res);
134
135 mm_res->allocator()->init(mm_res_info);
136 outstage_res->allocator()->init(outstage_tensor_info);
137
138 // Configure matrix-multiplication
139 mm.configure(mm_input, mm_weights, nullptr, mm_res);
140
141 // Configure output stage
142 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
143 outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
144 mm_res->allocator()->allocate();
145}
146
147void NEQLSTMLayer::configure(const ITensor *input,
148 const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
149 const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
150 const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100151 const ITensor *cell_state_in, ITensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100152 ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000153 const LSTMParams<ITensor> &lstm_params)
154{
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000155 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
156 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
157 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
158
159 // Set lstm parameters
160 LSTMParams<ITensorInfo> lstm_params_info{};
161 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
162
163 // Validate
164 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
165 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
166 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100167 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
168 lstm_params_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000169
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100170 const int batch_size = input->info()->dimension(1);
171 const int num_units = input_to_output_weights->info()->dimension(1);
172 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000173
174 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
175 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
176 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
177
178 _projection_bias = lstm_params.projection_bias();
179 _input_to_forget_weights = input_to_forget_weights;
180 _input_to_cell_weights = input_to_cell_weights;
181 _input_to_output_weights = input_to_output_weights;
182 _recurrent_to_forget_weights = recurrent_to_forget_weights;
183 _recurrent_to_cell_weights = recurrent_to_cell_weights;
184 _recurrent_to_output_weights = recurrent_to_output_weights;
185 _projection_weights = lstm_params.projection_weights();
186
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100187 // Layer normalization
188 _has_layer_norm = lstm_params.use_layer_norm();
189 if(_has_layer_norm)
190 {
191 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
192 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
193 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
194 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
195
196 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
197 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
198 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
199 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
200 }
201
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000202 _has_cifg = lstm_params.has_cifg_opt();
203 _has_projection = lstm_params.has_projection();
204 _has_peephole = lstm_params.has_peephole_opt();
205
206 // Calculate and decompose effective scales for optimizing matmul calculation
207 const int32_t cell_shift = log2(qcell_state_in.scale);
208
209 // Calculate quantized parameters for clipping.
210 int16_t quantized_cell_clip = 0;
211 if(lstm_params.cell_clip() > 0.0f)
212 {
213 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
214 }
215 _has_cell_clipping = quantized_cell_clip > 0;
216
217 // Precompute effective bias for optimizing the matmul computations.
218 if(!_has_cifg)
219 {
220 _input_to_input_weights = lstm_params.input_to_input_weights();
221 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
222
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100223 _input_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
224 _recurrent_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
225 _input_to_input_reduction->configure(_input_to_input_weights->info(), _input_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
226 _recurrent_to_input_reduction->configure(_recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000227 }
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100228
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100229 _input_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
230 _recurrent_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
231 _input_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
232 _recurrent_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
233 _input_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
234 _recurrent_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100235
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100236 _input_to_forget_reduction->configure(input_to_forget_weights->info(), _input_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
237 _recurrent_to_forget_reduction->configure(recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
238 _input_to_cell_reduction->configure(input_to_cell_weights->info(), _input_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
239 _recurrent_to_cell_reduction->configure(recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
240 _input_to_output_reduction->configure(input_to_output_weights->info(), _input_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
241 _recurrent_to_output_reduction->configure(recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100242 if(_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000243 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100244 _projection_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
245 _projection_reduction->configure(_projection_weights->info(), _projection_eff_bias.info(), GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100246 if(_projection_bias != nullptr)
247 {
Michele Di Giorgio19023832020-06-17 16:08:10 +0000248 _projection_bias_add.configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100249 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000250 }
251
252 // Pre-transpose weights to be used in GEMM.
253 _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
254 _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
255 _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
256 _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
257 _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
258 _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
259 if(!_has_cifg)
260 {
261 _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
262 _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
263 }
264 if(_has_projection)
265 {
266 _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
267 }
268
269 GEMMLowpOutputStageInfo gemmlowp_info;
270 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
271 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
272 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
273 gemmlowp_info.output_data_type = DataType::QSYMM16;
274
275 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
276 // Forget gate.
277 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
278 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
279 configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
280 input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
281 &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
282 mm_out_info, forget_gate_outstage_info);
283
284 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
285 configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
286 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
287 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
288 mm_out_info, forget_gate_outstage_info);
289
290 _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
291 _input_to_forget_outstage_res.allocator()->allocate();
292
293 if(_has_peephole)
294 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100295 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000296 _memory_group.manage(&_mul_cell_to_forget_res);
297 _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
298 _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
299 _memory_group.manage(&_cell_to_forget_outstage_res);
300 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
301 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
302 _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
303 _mul_cell_to_forget_res.allocator()->allocate();
304 _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
305 _cell_to_forget_outstage_res.allocator()->allocate();
306 }
307
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100308 Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
309
310 if(_has_layer_norm)
311 {
312 configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
313 forget_activation_input->allocator()->allocate();
314 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
315 }
316
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000317 // Output quantization info of Sigmoid and Tanh activations
318 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100319 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000320
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000321 _memory_group.manage(&_forget_gate);
322 _forget_gate.allocator()->init(forget_gate_info);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100323 _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
324 forget_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000325
326 // Modulation gate.
327 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
328 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
329 configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
330 input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
331 &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
332 mm_out_info, cell_outstage_info);
333
334 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
335 configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
336 output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
337 &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
338 mm_out_info, cell_outstage_info);
339
340 _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
341 _input_to_cell_outstage_res.allocator()->allocate();
342
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100343 Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
344
345 if(_has_layer_norm)
346 {
347 configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
348 cell_activation_input->allocator()->allocate();
349 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
350 }
351
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000352 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100353
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000354 _memory_group.manage(&_cell_gate);
355 _cell_gate.allocator()->init(cell_gate_info);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100356 _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
357 cell_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000358
359 // Input gate.
360 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
361 _input_gate.allocator()->init(input_gate_info);
362 _memory_group.manage(&_input_gate);
363 if(_has_cifg)
364 {
365 _ones.allocator()->init(*_forget_gate.info());
366 _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
367 _ones.allocator()->allocate();
368 }
369 else
370 {
371 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
372 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
373 configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
374 input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
375 &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
376 mm_out_info, input_outstage_info);
377
378 const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
379 configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100380 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000381 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
382 mm_out_info, input_outstage_info);
383 _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
384 _input_to_input_outstage_res.allocator()->allocate();
385
386 if(_has_peephole)
387 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100388 _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000389 _memory_group.manage(&_mul_cell_to_input_res);
390 _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
391 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
392 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
393 _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
394 _memory_group.manage(&_cell_to_input_outstage_res);
395 _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
396 _mul_cell_to_input_res.allocator()->allocate();
397 _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
398 _cell_to_input_outstage_res.allocator()->allocate();
399 }
400
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100401 Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
402
403 if(_has_layer_norm)
404 {
405 configure_layer_norm(LayerNormGate::Input, input_activation_input);
406 input_activation_input->allocator()->allocate();
407 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
408 }
409
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100410 _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100411 input_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000412 }
413 // Cell.
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100414 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000415 _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
416 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
417 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
418 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
419 _memory_group.manage(&_mul_input_cell_res);
420 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
421 _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
422 _cell_gate.allocator()->allocate();
423 _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
424 _mul_input_cell_res.allocator()->allocate();
425 _forget_gate.allocator()->allocate();
426 if(_has_cell_clipping)
427 {
428 _cell_clip.configure(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
429 }
430 // Output gate.
431 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
432 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
433 configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
434 input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
435 &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
436 mm_out_info, output_outstage_info);
437
438 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
439 configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
440 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
441 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
442 mm_out_info, output_outstage_info);
443
444 _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
445 _input_to_output_outstage_res.allocator()->allocate();
446
447 if(_has_peephole)
448 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100449 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000450 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100451 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000452 _memory_group.manage(&_mul_cell_to_output_res);
453 _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100454
455 const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
456 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
457 _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
458 _memory_group.manage(&_cell_to_output_outstage_res);
459 _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000460 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100461
462 _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
463 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000464 }
465
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100466 Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
467
468 if(_has_layer_norm)
469 {
470 configure_layer_norm(LayerNormGate::Output, output_activation_input);
471 output_activation_input->allocator()->allocate();
472 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
473 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000474 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100475
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000476 _memory_group.manage(&_output_gate);
477 _output_gate.allocator()->init(output_gate_info);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100478 _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
479 output_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000480
481 // Hidden.
482 _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100483 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000484 _memory_group.manage(&_hidden_mul_res);
485 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
486 _hidden_mul_res.allocator()->init(hidden_mul_res);
487 _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
488 _output_gate.allocator()->allocate();
489 _input_gate.allocator()->allocate();
490 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Sang-Hoon Park30b46a62020-04-18 01:40:57 +0100491 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000492 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
493 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100494
495 _projection_tensor_copy_required = (num_units != output_size);
496 ITensor *hidden_gate_result = output_state_out;
497
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100498 _memory_group.manage(&_hidden_gate);
499
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100500 if(_projection_tensor_copy_required)
501 {
502 _hidden_gate.allocator()->init(*output_state_out->info());
503 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
504 hidden_gate_result = &_hidden_gate;
505 }
506
507 _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000508 _hidden_mul_res.allocator()->allocate();
509
510 // Projection.
511 if(_has_projection)
512 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100513 const TensorInfo projection_outstage_info(*output_state_out->info());
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000514 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
515 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
516 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
517 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
518 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
519 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
520
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100521 TensorInfo projection_mm_out_info{ mm_out_info };
522 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100523
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000524 configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100525 hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000526 &_mm_projection_res, &_projection_outstage_res, projection_scale,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100527 projection_mm_out_info, projection_outstage_info);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100528
529 ITensor *accumulate_destination = output_state_out;
530
531 if(_projection_tensor_copy_required)
532 {
533 _hidden_gate.allocator()->allocate();
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100534 _projection_accumulate_res.allocator()->init(*output_state_in->info());
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100535 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100536 _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100537 accumulate_destination = &_projection_accumulate_res;
538 }
539
540 _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000541 _projection_outstage_res.allocator()->allocate();
542
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100543 if(_projection_tensor_copy_required)
544 {
545 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
546 _projection_accumulate_res.allocator()->allocate();
547 }
548
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000549 int8_t quantized_projection_clip{ 0 };
550 if(lstm_params.projection_clip() > 0.0f)
551 {
552 quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
553 }
554
555 if(quantized_projection_clip > 0)
556 {
557 _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip, quantized_projection_clip));
558 _has_projection_clipping = true;
559 }
560 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100561 else
562 {
563 if(_projection_tensor_copy_required)
564 {
565 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
566 _hidden_gate.allocator()->allocate();
567 }
568 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100569
570 // Copy output_state_out to output
571 _copy_output.configure(output_state_out, output);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000572}
573
574Status NEQLSTMLayer::validate(const ITensorInfo *input,
575 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
576 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
577 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
578 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100579 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000580 const LSTMParams<ITensorInfo> &lstm_params)
581{
582 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100583 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
584 cell_state_out, output_state_out, output);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000585
586 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
587 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
588
589 const unsigned int input_size = input->dimension(0);
590 const unsigned int batch_size = input->dimension(1);
591 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100592 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000593
594 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
595 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
596 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
597 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
598 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
599 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
600 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
601 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
602 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
603
604 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
605 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
606 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
607 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
608 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
609
610 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
611 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
612 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
613 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
614
615 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
616 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
617 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
618 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
619
620 // Check whether peephole weights are all there or none
621 if(lstm_params.has_peephole_opt())
622 {
623 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
624 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
625 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
626 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
627 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
628 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
629
630 if(!lstm_params.has_cifg_opt())
631 {
632 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
633 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
634 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
635 }
636 }
637
638 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
639 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
640 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
641
642 // Calculate and decompose effective scales for optimizing matmul calculation
643 const int32_t cell_shift = log2(qcell_state_in.scale);
644 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
645
646 // Calculate quantized parameters for clipping.
647 int16_t quantized_cell_clip = 0;
648 if(lstm_params.cell_clip() > 0.0f)
649 {
650 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
651 }
652
653 // Precompute effective bias for optimizing the matmul computations.
654 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100655 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000656 if(!lstm_params.has_cifg_opt())
657 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100658 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
659 -qinput.offset, true)));
660 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
661 -qoutput_state_in.offset,
662 true)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000663 }
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100664 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
665 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
666 -qoutput_state_in.offset, true)));
667 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
668 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
669 true)));
670 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
671 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
672 -qoutput_state_in.offset, true)));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100673 if(lstm_params.has_projection())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000674 {
Manuel Bottinicfac51c2021-06-18 15:47:28 +0100675 ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
676 lstm_params.hidden_state_zero(),
677 true)));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100678 if(lstm_params.projection_bias() != nullptr)
679 {
680 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100681 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info, &projection_eff_bias_info, ConvertPolicy::SATURATE));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100682 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000683 }
684
685 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
686 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
687
688 // Validate weights transpose
689 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_forget_weights, &input_weights_transposed));
690 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed));
691 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_weights_transposed));
692 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
693 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
694 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
695 if(!lstm_params.has_cifg_opt())
696 {
697 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
698 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
699 }
700 if(lstm_params.has_projection())
701 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100702 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
703 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000704 }
705
706 GEMMLowpOutputStageInfo gemmlowp_info;
707 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
708 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
709 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
710 gemmlowp_info.output_data_type = DataType::QSYMM16;
711
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100712 const bool has_layer_norm = lstm_params.use_layer_norm();
713
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000714 // Forget gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100715 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000716 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
717 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
718 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100719 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000720
721 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100722 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000723
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100724 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000725
726 if(lstm_params.has_peephole_opt())
727 {
728 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100729 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
730 RoundingPolicy::TO_ZERO));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000731 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
732 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
733 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100734 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000735 }
736
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100737 if(has_layer_norm)
738 {
739 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
740 const ITensorInfo *b_info = forget_gate_bias;
741 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
742 }
743
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000744 // Output quantization info of Sigmoid and Tanh activations
745 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100746 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000747
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000748 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
749
750 // Modulation gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100751 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000752 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
753 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100754 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000755
756 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100757 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000758
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100759 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000760
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100761 if(has_layer_norm)
762 {
763 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
764 const ITensorInfo *b_info = cell_bias;
765 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
766 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000767 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100768
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000769 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
770
771 // Input gate.
772 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
773 if(lstm_params.has_cifg_opt())
774 {
775 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100776 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000777 }
778 else
779 {
780 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
781 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
782 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
783 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
784 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
785 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
786
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100787 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000788 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
789 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100790 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000791
792 const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100793 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000794
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100795 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000796
797 if(lstm_params.has_peephole_opt())
798 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100799 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
800 RoundingPolicy::TO_ZERO));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000801 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
802 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100803 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100804 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000805 }
806
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100807 if(has_layer_norm)
808 {
809 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
810 const ITensorInfo *b_info = lstm_params.input_gate_bias();
811 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
812 }
813
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100814 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000815 }
816 // Cell.
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100817 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
818 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100819 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000820 if(quantized_cell_clip > 0)
821 {
822 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
823 quantized_cell_clip)));
824 }
825 // Output gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100826 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000827 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
828 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100829 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000830
831 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100832 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000833
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100834 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000835 if(lstm_params.has_peephole_opt())
836 {
837 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100838 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000839 // Here we are not using the output stage because all operations are done in float
840 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
841 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100842 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
843 RoundingPolicy::TO_ZERO));
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100844 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000845 }
846
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100847 if(has_layer_norm)
848 {
849 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
850 const ITensorInfo *b_info = output_gate_bias;
851 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
852 }
853
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000854 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
855 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
856
857 // Hidden.
858 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
859 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100860 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100861 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100862
863 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000864 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Sang-Hoon Park30b46a62020-04-18 01:40:57 +0100865 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
Sang-Hoon Park9f893752020-10-20 15:33:31 +0100866 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
867 gemmlowp_info.output_data_type = hidden_out_info.data_type();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100868 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
869
870 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000871
872 // Projection.
873 if(lstm_params.has_projection())
874 {
875 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100876 ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000877
878 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
879 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
880 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
881 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
882 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
883 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
884 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
885
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100886 const TensorInfo projection_outstage_info(*output_state_out);
887 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100888
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100889 TensorInfo projection_mm_out_info{ mm_out_info };
890 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100891
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100892 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
893 &projection_outstage_info));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100894
895 if(projection_tensor_copy_required)
896 {
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100897 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100898 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000899
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100900 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000901
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100902 if(projection_tensor_copy_required)
903 {
904 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
905 }
906
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000907 int8_t quantized_projection_clip{ 0 };
908 if(lstm_params.projection_clip() > 0.0f)
909 {
910 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
911 }
912
913 if(quantized_projection_clip > 0)
914 {
915 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
916 quantized_projection_clip)));
917 }
918 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100919 else
920 {
921 if(projection_tensor_copy_required)
922 {
923 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
924 }
925 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000926
927 if(cell_state_out->total_size() > 0)
928 {
929 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
930 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
931 }
932
933 if(output_state_out->total_size() > 0)
934 {
935 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
936 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
937 }
938
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100939 ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000940 return Status{};
941}
942
943void NEQLSTMLayer::run()
944{
945 prepare();
946
947 // Acquire all the temporaries
948 MemoryGroupResourceScope scope_mg(_memory_group);
949
950 // Forget gate.
951 _mm_input_to_forget.run();
952 _input_to_forget_outstage.run();
953
954 _mm_recurrent_to_forget.run();
955 _recurrent_to_forget_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100956 _accumulate_input_recurrent_forget.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000957
958 if(_has_peephole)
959 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100960 _pixelwise_mul_cell_to_forget.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000961 _cell_to_forget_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100962 _accumulate_cell_forget.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000963 }
964
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100965 if(_has_layer_norm)
966 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100967 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Forget).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100968 }
969
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000970 _forget_gate_sigmoid.run();
971
972 // Modulation gate.
973 _mm_input_to_cell.run();
974 _input_to_cell_outstage.run();
975
976 _mm_recurrent_to_cell.run();
977 _recurrent_to_cell_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100978 _accumulate_input_recurrent_modulation.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000979
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100980 if(_has_layer_norm)
981 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100982 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Cell).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100983 }
984
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000985 _cell_gate_tanh.run();
986
987 // Input gate
988 if(_has_cifg)
989 {
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100990 _input_gate_sub.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000991 }
992 else
993 {
994 _mm_input_to_input.run();
995 _input_to_input_outstage.run();
996 _mm_recurrent_to_input.run();
997 _recurrent_to_input_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100998 _accumulate_input_recurrent_input.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000999
1000 if(_has_peephole)
1001 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001002 _pixelwise_mul_cell_to_input.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001003 _cell_to_input_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001004 _accumulate_cell_input.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001005 }
1006
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001007 if(_has_layer_norm)
1008 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001009 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Input).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001010 }
1011
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001012 _input_gate_sigmoid.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001013 }
1014
1015 // Cell.
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001016 _pixelwise_mul_forget_cell.run();
1017 _pixelwise_mul_input_cell.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001018 _add_forget_cell.run();
1019
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001020 if(_has_cell_clipping)
1021 {
1022 _cell_clip.run();
1023 }
1024
1025 // Output gate.
1026 _mm_input_to_output.run();
1027 _input_to_output_outstage.run();
1028 _mm_recurrent_to_output.run();
1029 _recurrent_to_output_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001030 _accumulate_input_recurrent_output.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001031 if(_has_peephole)
1032 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001033 _pixelwise_mul_cell_to_output.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001034 _cell_to_output_outstage.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001035 _accumulate_cell_to_output.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001036 }
1037
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001038 if(_has_layer_norm)
1039 {
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001040 NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Output).get(), Window::DimY);
Sang-Hoon Park9230e272020-04-18 00:46:34 +01001041 }
1042
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001043 _output_gate_sigmoid.run();
1044
1045 // Hidden.
1046 _hidden_tanh.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +01001047 _pixelwise_mul_hidden.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001048 _hidden_outstage.run();
1049
1050 // Projection.
1051 if(_has_projection)
1052 {
1053 _mm_projection.run();
1054 _projection_outstage.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001055
1056 if(_projection_tensor_copy_required)
1057 {
1058 _projection_output_to_accumulate_copy.run();
1059 }
1060
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001061 _accumulate_projection.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001062
1063 if(_projection_tensor_copy_required)
1064 {
1065 _projection_accumulate_to_output_copy.run();
1066 }
1067
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001068 if(_has_projection_clipping)
1069 {
1070 _projection_clip.run();
1071 }
1072 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001073 else
1074 {
1075 if(_projection_tensor_copy_required)
1076 {
1077 _hidden_to_output_copy.run();
1078 }
1079 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001080
1081 // Copy output_state_out to output
Michalis Spyrouebcebf12020-10-21 00:04:14 +01001082 _copy_output.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001083}
1084
1085void NEQLSTMLayer::prepare()
1086{
1087 if(!_is_prepared)
1088 {
1089 // Pre-transpose weights to be used in GEMM.
1090 _input_to_forget_weights_transposed.allocator()->allocate();
1091 _input_to_cell_weights_transposed.allocator()->allocate();
1092 _input_to_output_weights_transposed.allocator()->allocate();
1093 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1094 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1095 _recurrent_to_output_weights_transposed.allocator()->allocate();
1096 _transpose_input_to_forget_weights.run();
1097 _transpose_input_to_cell_weights.run();
1098 _transpose_input_to_output_weights.run();
1099 _transpose_recurrent_to_forget_weights.run();
1100 _transpose_recurrent_to_cell_weights.run();
1101 _transpose_recurrent_to_output_weights.run();
1102
1103 // Precompute effective biases
1104 if(_has_cifg)
1105 {
1106 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1107 }
1108 else
1109 {
1110 _input_to_input_eff_bias.allocator()->allocate();
1111 _recurrent_to_input_eff_bias.allocator()->allocate();
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001112
1113 ITensorPack packII =
1114 {
1115 { TensorType::ACL_SRC, _input_to_input_weights },
1116 { TensorType::ACL_DST, &_input_to_input_eff_bias }
1117 };
1118 NEScheduler::get().schedule_op(_input_to_input_reduction.get(), Window::DimY, _input_to_input_reduction->window(), packII);
1119
1120 ITensorPack packRI =
1121 {
1122 { TensorType::ACL_SRC, _recurrent_to_input_weights },
1123 { TensorType::ACL_DST, &_recurrent_to_input_eff_bias }
1124 };
1125 NEScheduler::get().schedule_op(_recurrent_to_input_reduction.get(), Window::DimY, _recurrent_to_input_reduction->window(), packRI);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001126
1127 _input_to_input_weights_transposed.allocator()->allocate();
1128 _recurrent_to_input_weights_transposed.allocator()->allocate();
1129 _transpose_input_to_input_weights.run();
1130 _transpose_recurrent_to_input_weights.run();
1131 _input_to_input_weights->mark_as_unused();
1132 _recurrent_to_input_weights->mark_as_unused();
1133 }
1134 _input_to_forget_eff_bias.allocator()->allocate();
1135 _recurrent_to_forget_eff_bias.allocator()->allocate();
1136 _input_to_cell_eff_bias.allocator()->allocate();
1137 _recurrent_to_cell_eff_bias.allocator()->allocate();
1138 _input_to_output_eff_bias.allocator()->allocate();
1139 _recurrent_to_output_eff_bias.allocator()->allocate();
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001140
1141 ITensorPack packIF =
1142 {
1143 { TensorType::ACL_SRC, _input_to_forget_weights },
1144 { TensorType::ACL_DST, &_input_to_forget_eff_bias }
1145 };
1146 NEScheduler::get().schedule_op(_input_to_forget_reduction.get(), Window::DimY, _input_to_forget_reduction->window(), packIF);
1147
1148 ITensorPack packRF =
1149 {
1150 { TensorType::ACL_SRC, _recurrent_to_forget_weights },
1151 { TensorType::ACL_DST, &_recurrent_to_forget_eff_bias }
1152 };
1153 NEScheduler::get().schedule_op(_recurrent_to_forget_reduction.get(), Window::DimY, _recurrent_to_forget_reduction->window(), packRF);
1154
1155 ITensorPack packIC =
1156 {
1157 { TensorType::ACL_SRC, _input_to_cell_weights },
1158 { TensorType::ACL_DST, &_input_to_cell_eff_bias }
1159 };
1160 NEScheduler::get().schedule_op(_input_to_cell_reduction.get(), Window::DimY, _input_to_cell_reduction->window(), packIC);
1161
1162 ITensorPack packRC =
1163 {
1164 { TensorType::ACL_SRC, _recurrent_to_cell_weights },
1165 { TensorType::ACL_DST, &_recurrent_to_cell_eff_bias }
1166 };
1167 NEScheduler::get().schedule_op(_recurrent_to_cell_reduction.get(), Window::DimY, _recurrent_to_cell_reduction->window(), packRC);
1168
1169 ITensorPack packIO =
1170 {
1171 { TensorType::ACL_SRC, _input_to_output_weights },
1172 { TensorType::ACL_DST, &_input_to_output_eff_bias }
1173 };
1174 NEScheduler::get().schedule_op(_input_to_output_reduction.get(), Window::DimY, _input_to_output_reduction->window(), packIO);
1175
1176 ITensorPack packRO =
1177 {
1178 { TensorType::ACL_SRC, _recurrent_to_output_weights },
1179 { TensorType::ACL_DST, &_recurrent_to_output_eff_bias }
1180 };
1181 NEScheduler::get().schedule_op(_recurrent_to_output_reduction.get(), Window::DimY, _recurrent_to_output_reduction->window(), packRO);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001182
1183 if(_has_projection)
1184 {
Michele Di Giorgio11c562c2020-06-10 16:34:50 +01001185 _projection_eff_bias.allocator()->allocate();
Manuel Bottinicfac51c2021-06-18 15:47:28 +01001186 ITensorPack pack =
1187 {
1188 { TensorType::ACL_SRC, _projection_weights },
1189 { TensorType::ACL_DST, &_projection_eff_bias }
1190 };
1191 NEScheduler::get().schedule_op(_projection_reduction.get(), Window::DimY, _projection_reduction->window(), pack);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001192 if(_projection_bias != nullptr)
1193 {
Michalis Spyrou173ba9b2020-06-23 17:25:43 +01001194 _projection_bias_add.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001195 _projection_bias->mark_as_unused();
1196 }
1197
1198 _projection_weights_transposed.allocator()->allocate();
1199 _transpose_projection_weights.run();
1200 _projection_weights->mark_as_unused();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001201
1202 if(!_projection_tensor_copy_required)
1203 {
1204 _hidden_gate.mark_as_unused();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001205 _projection_accumulate_res.mark_as_unused();
1206 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001207 }
1208
1209 // Mark weights as unused
1210 _input_to_forget_weights->mark_as_unused();
1211 _input_to_cell_weights->mark_as_unused();
1212 _input_to_output_weights->mark_as_unused();
1213 _recurrent_to_forget_weights->mark_as_unused();
1214 _recurrent_to_cell_weights->mark_as_unused();
1215 _recurrent_to_output_weights->mark_as_unused();
1216
1217 _is_prepared = true;
1218 }
1219}
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001220} // namespace arm_compute