blob: beb180fda554d749f352ce75c8f0579f0515cee3 [file] [log] [blame]
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001/*
2 * Copyright (c) 2020 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
25
26#include "arm_compute/core/KernelDescriptors.h"
27#include "arm_compute/core/QuantizationInfo.h"
28#include "arm_compute/core/Utils.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/utils/misc/InfoHelpers.h"
31#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
32#include "arm_compute/runtime/NEON/NEScheduler.h"
33
34namespace arm_compute
35{
36using namespace arm_compute::utils::info_helpers;
37namespace
38{
39Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
40 float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
41{
42 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
43 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
44 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
45 return Status{};
46}
47} // namespace
48
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +010049Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
50{
51 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
52 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
53 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
54 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
55 return Status{};
56}
57
58void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
59{
60 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
61 _src = &src;
62 _dst = &dst;
63 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
64 _window = calculate_max_window(*_src->info(), Steps());
65}
66
67void NEQLSTMLayer::TensorCopyKernel::run()
68{
69 Iterator input_iter{ _src, _window };
70 Iterator output_iter{ _dst, _window };
71
72 execute_window_loop(_window, [&](const Coordinates &)
73 {
74 memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
75 },
76 input_iter, output_iter);
77}
78
Michele Di Giorgio47a89902020-03-09 19:32:33 +000079NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
80{
81 _memory_group = MemoryGroup(std::move(memory_manager));
82}
83
84void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
85 const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias,
86 Tensor *mm_res, Tensor *outstage_res, float gemmlowp_scale,
87 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
88{
89 _memory_group.manage(mm_res);
90 _memory_group.manage(outstage_res);
91
92 mm_res->allocator()->init(mm_res_info);
93 outstage_res->allocator()->init(outstage_tensor_info);
94
95 // Configure matrix-multiplication
96 mm.configure(mm_input, mm_weights, nullptr, mm_res);
97
98 // Configure output stage
99 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
100 outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
101 mm_res->allocator()->allocate();
102}
103
104void NEQLSTMLayer::configure(const ITensor *input,
105 const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
106 const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
107 const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
108 const ITensor *cell_state_in, const ITensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100109 ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000110 const LSTMParams<ITensor> &lstm_params)
111{
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000112 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
113 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
114 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
115
116 // Set lstm parameters
117 LSTMParams<ITensorInfo> lstm_params_info{};
118 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
119
120 // Validate
121 ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
122 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
123 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100124 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
125 lstm_params_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000126
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100127 const int batch_size = input->info()->dimension(1);
128 const int num_units = input_to_output_weights->info()->dimension(1);
129 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000130
131 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
132 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
133 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
134
135 _projection_bias = lstm_params.projection_bias();
136 _input_to_forget_weights = input_to_forget_weights;
137 _input_to_cell_weights = input_to_cell_weights;
138 _input_to_output_weights = input_to_output_weights;
139 _recurrent_to_forget_weights = recurrent_to_forget_weights;
140 _recurrent_to_cell_weights = recurrent_to_cell_weights;
141 _recurrent_to_output_weights = recurrent_to_output_weights;
142 _projection_weights = lstm_params.projection_weights();
143
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100144 // Layer normalization
145 _has_layer_norm = lstm_params.use_layer_norm();
146 if(_has_layer_norm)
147 {
148 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
149 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
150 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
151 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
152
153 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
154 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
155 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
156 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
157 }
158
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000159 _has_cifg = lstm_params.has_cifg_opt();
160 _has_projection = lstm_params.has_projection();
161 _has_peephole = lstm_params.has_peephole_opt();
162
163 // Calculate and decompose effective scales for optimizing matmul calculation
164 const int32_t cell_shift = log2(qcell_state_in.scale);
165
166 // Calculate quantized parameters for clipping.
167 int16_t quantized_cell_clip = 0;
168 if(lstm_params.cell_clip() > 0.0f)
169 {
170 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
171 }
172 _has_cell_clipping = quantized_cell_clip > 0;
173
174 // Precompute effective bias for optimizing the matmul computations.
175 if(!_has_cifg)
176 {
177 _input_to_input_weights = lstm_params.input_to_input_weights();
178 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
179
180 _input_to_input_reduction.configure(_input_to_input_weights, &_input_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
181 _recurrent_to_input_reduction.configure(_recurrent_to_input_weights, &_recurrent_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
182 }
183 _input_to_forget_reduction.configure(input_to_forget_weights, &_input_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
184 _recurrent_to_forget_reduction.configure(recurrent_to_forget_weights, &_recurrent_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
185 _input_to_cell_reduction.configure(input_to_cell_weights, &_input_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
186 _recurrent_to_cell_reduction.configure(recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
187 _input_to_output_reduction.configure(input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
188 _recurrent_to_output_reduction.configure(recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100189 if(_has_projection)
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000190 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100191 _projection_reduction.configure(_projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000192 }
193
194 // Pre-transpose weights to be used in GEMM.
195 _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
196 _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
197 _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
198 _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
199 _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
200 _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
201 if(!_has_cifg)
202 {
203 _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
204 _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
205 }
206 if(_has_projection)
207 {
208 _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
209 }
210
211 GEMMLowpOutputStageInfo gemmlowp_info;
212 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
213 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
214 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
215 gemmlowp_info.output_data_type = DataType::QSYMM16;
216
217 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
218 // Forget gate.
219 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
220 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
221 configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
222 input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
223 &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
224 mm_out_info, forget_gate_outstage_info);
225
226 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
227 configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
228 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
229 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
230 mm_out_info, forget_gate_outstage_info);
231
232 _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
233 _input_to_forget_outstage_res.allocator()->allocate();
234
235 if(_has_peephole)
236 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100237 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000238 _memory_group.manage(&_mul_cell_to_forget_res);
239 _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
240 _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
241 _memory_group.manage(&_cell_to_forget_outstage_res);
242 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
243 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
244 _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
245 _mul_cell_to_forget_res.allocator()->allocate();
246 _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
247 _cell_to_forget_outstage_res.allocator()->allocate();
248 }
249
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100250 Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
251
252 if(_has_layer_norm)
253 {
254 configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
255 forget_activation_input->allocator()->allocate();
256 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
257 }
258
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000259 // Output quantization info of Sigmoid and Tanh activations
260 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100261 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000262
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000263 _memory_group.manage(&_forget_gate);
264 _forget_gate.allocator()->init(forget_gate_info);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100265 _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
266 forget_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000267
268 // Modulation gate.
269 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
270 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
271 configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
272 input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
273 &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
274 mm_out_info, cell_outstage_info);
275
276 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
277 configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
278 output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
279 &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
280 mm_out_info, cell_outstage_info);
281
282 _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
283 _input_to_cell_outstage_res.allocator()->allocate();
284
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100285 Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
286
287 if(_has_layer_norm)
288 {
289 configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
290 cell_activation_input->allocator()->allocate();
291 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
292 }
293
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000294 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100295
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000296 _memory_group.manage(&_cell_gate);
297 _cell_gate.allocator()->init(cell_gate_info);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100298 _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
299 cell_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000300
301 // Input gate.
302 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
303 _input_gate.allocator()->init(input_gate_info);
304 _memory_group.manage(&_input_gate);
305 if(_has_cifg)
306 {
307 _ones.allocator()->init(*_forget_gate.info());
308 _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
309 _ones.allocator()->allocate();
310 }
311 else
312 {
313 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
314 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
315 configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
316 input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
317 &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
318 mm_out_info, input_outstage_info);
319
320 const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
321 configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100322 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000323 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
324 mm_out_info, input_outstage_info);
325 _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
326 _input_to_input_outstage_res.allocator()->allocate();
327
328 if(_has_peephole)
329 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100330 _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000331 _memory_group.manage(&_mul_cell_to_input_res);
332 _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
333 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
334 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
335 _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
336 _memory_group.manage(&_cell_to_input_outstage_res);
337 _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
338 _mul_cell_to_input_res.allocator()->allocate();
339 _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
340 _cell_to_input_outstage_res.allocator()->allocate();
341 }
342
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100343 Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
344
345 if(_has_layer_norm)
346 {
347 configure_layer_norm(LayerNormGate::Input, input_activation_input);
348 input_activation_input->allocator()->allocate();
349 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
350 }
351
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100352 _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100353 input_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000354 }
355 // Cell.
356 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
357 _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
358 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
359 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
360 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
361 _memory_group.manage(&_mul_input_cell_res);
362 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
363 _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
364 _cell_gate.allocator()->allocate();
365 _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
366 _mul_input_cell_res.allocator()->allocate();
367 _forget_gate.allocator()->allocate();
368 if(_has_cell_clipping)
369 {
370 _cell_clip.configure(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
371 }
372 // Output gate.
373 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
374 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
375 configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
376 input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
377 &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
378 mm_out_info, output_outstage_info);
379
380 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
381 configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
382 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
383 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
384 mm_out_info, output_outstage_info);
385
386 _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
387 _input_to_output_outstage_res.allocator()->allocate();
388
389 if(_has_peephole)
390 {
391 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
392 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100393 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000394 _memory_group.manage(&_mul_cell_to_output_res);
395 _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100396
397 const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
398 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
399 _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
400 _memory_group.manage(&_cell_to_output_outstage_res);
401 _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000402 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100403
404 _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
405 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000406 }
407
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100408 Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
409
410 if(_has_layer_norm)
411 {
412 configure_layer_norm(LayerNormGate::Output, output_activation_input);
413 output_activation_input->allocator()->allocate();
414 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
415 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000416 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100417
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000418 _memory_group.manage(&_output_gate);
419 _output_gate.allocator()->init(output_gate_info);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100420 _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
421 output_activation_input->allocator()->allocate();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000422
423 // Hidden.
424 _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
425 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
426 _memory_group.manage(&_hidden_mul_res);
427 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
428 _hidden_mul_res.allocator()->init(hidden_mul_res);
429 _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
430 _output_gate.allocator()->allocate();
431 _input_gate.allocator()->allocate();
432 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Sang-Hoon Park30b46a62020-04-18 01:40:57 +0100433 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000434 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
435 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100436
437 _projection_tensor_copy_required = (num_units != output_size);
438 ITensor *hidden_gate_result = output_state_out;
439
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100440 _memory_group.manage(&_hidden_gate);
441
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100442 if(_projection_tensor_copy_required)
443 {
444 _hidden_gate.allocator()->init(*output_state_out->info());
445 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
446 hidden_gate_result = &_hidden_gate;
447 }
448
449 _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000450 _hidden_mul_res.allocator()->allocate();
451
452 // Projection.
453 if(_has_projection)
454 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100455 const TensorInfo projection_outstage_info(*output_state_out->info());
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000456 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
457 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
458 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
459 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
460 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
461 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
462
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100463 TensorInfo projection_mm_out_info{ mm_out_info };
464 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100465
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000466 configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100467 hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000468 &_mm_projection_res, &_projection_outstage_res, projection_scale,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100469 projection_mm_out_info, projection_outstage_info);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100470
471 ITensor *accumulate_destination = output_state_out;
472
473 if(_projection_tensor_copy_required)
474 {
475 _hidden_gate.allocator()->allocate();
476 _projection_accumulate_res.allocator()->init(*output_state_out->info());
477 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
478 _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res);
479 accumulate_destination = &_projection_accumulate_res;
480 }
481
482 _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000483 _projection_outstage_res.allocator()->allocate();
484
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100485 if(_projection_tensor_copy_required)
486 {
487 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
488 _projection_accumulate_res.allocator()->allocate();
489 }
490
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000491 int8_t quantized_projection_clip{ 0 };
492 if(lstm_params.projection_clip() > 0.0f)
493 {
494 quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
495 }
496
497 if(quantized_projection_clip > 0)
498 {
499 _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip, quantized_projection_clip));
500 _has_projection_clipping = true;
501 }
502 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100503 else
504 {
505 if(_projection_tensor_copy_required)
506 {
507 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
508 _hidden_gate.allocator()->allocate();
509 }
510 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100511
512 // Copy output_state_out to output
513 _copy_output.configure(output_state_out, output);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000514}
515
516Status NEQLSTMLayer::validate(const ITensorInfo *input,
517 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
518 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
519 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
520 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100521 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000522 const LSTMParams<ITensorInfo> &lstm_params)
523{
524 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100525 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
526 cell_state_out, output_state_out, output);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000527
528 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
529 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
530
531 const unsigned int input_size = input->dimension(0);
532 const unsigned int batch_size = input->dimension(1);
533 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100534 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000535
536 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
537 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
538 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
539 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
540 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
541 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
542 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
543 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
544 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
545
546 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
547 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
548 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
549 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
550 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
551
552 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
553 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
554 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
555 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
556
557 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
558 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
559 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
560 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
561
562 // Check whether peephole weights are all there or none
563 if(lstm_params.has_peephole_opt())
564 {
565 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
566 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
567 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
568 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
569 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
570 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
571
572 if(!lstm_params.has_cifg_opt())
573 {
574 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
575 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
576 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
577 }
578 }
579
580 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
581 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
582 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
583
584 // Calculate and decompose effective scales for optimizing matmul calculation
585 const int32_t cell_shift = log2(qcell_state_in.scale);
586 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
587
588 // Calculate quantized parameters for clipping.
589 int16_t quantized_cell_clip = 0;
590 if(lstm_params.cell_clip() > 0.0f)
591 {
592 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
593 }
594
595 // Precompute effective bias for optimizing the matmul computations.
596 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100597 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000598 if(!lstm_params.has_cifg_opt())
599 {
600 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
601 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
602 true)));
603 }
604 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
605 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
606 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
607 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
608 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
609 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100610 if(lstm_params.has_projection())
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000611 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100612 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
613 lstm_params.hidden_state_zero(),
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000614 true)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000615 }
616
617 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
618 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
619
620 // Validate weights transpose
621 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_forget_weights, &input_weights_transposed));
622 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed));
623 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_weights_transposed));
624 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
625 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
626 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
627 if(!lstm_params.has_cifg_opt())
628 {
629 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
630 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
631 }
632 if(lstm_params.has_projection())
633 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100634 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
635 ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000636 }
637
638 GEMMLowpOutputStageInfo gemmlowp_info;
639 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
640 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
641 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
642 gemmlowp_info.output_data_type = DataType::QSYMM16;
643
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100644 const bool has_layer_norm = lstm_params.use_layer_norm();
645
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000646 // Forget gate.
647 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
648 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
649 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100650 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000651
652 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100653 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000654
655 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
656
657 if(lstm_params.has_peephole_opt())
658 {
659 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
660 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
661 RoundingPolicy::TO_ZERO));
662 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
663 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
664 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
665 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
666 }
667
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100668 if(has_layer_norm)
669 {
670 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
671 const ITensorInfo *b_info = forget_gate_bias;
672 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
673 }
674
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000675 // Output quantization info of Sigmoid and Tanh activations
676 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100677 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000678
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000679 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
680
681 // Modulation gate.
682 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
683 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100684 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000685
686 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100687 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000688
689 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
690
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100691 if(has_layer_norm)
692 {
693 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
694 const ITensorInfo *b_info = cell_bias;
695 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
696 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000697 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100698
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000699 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
700
701 // Input gate.
702 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
703 if(lstm_params.has_cifg_opt())
704 {
705 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
706 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtractionKernel::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
707 }
708 else
709 {
710 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
711 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
712 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
713 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
714 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
715 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
716
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000717 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
718 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100719 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000720
721 const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100722 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000723
724 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
725
726 if(lstm_params.has_peephole_opt())
727 {
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100728 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000729 RoundingPolicy::TO_ZERO));
730 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
731 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100732 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000733 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
734 }
735
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100736 if(has_layer_norm)
737 {
738 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
739 const ITensorInfo *b_info = lstm_params.input_gate_bias();
740 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
741 }
742
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100743 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000744 }
745 // Cell.
746 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
747 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
748 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
749 if(quantized_cell_clip > 0)
750 {
751 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
752 quantized_cell_clip)));
753 }
754 // Output gate.
755 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
756 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100757 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000758
759 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100760 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000761
762 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
763 if(lstm_params.has_peephole_opt())
764 {
765 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
766 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
767 // Here we are not using the output stage because all operations are done in float
768 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
769 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
770 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
771 RoundingPolicy::TO_ZERO));
772 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
773 }
774
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100775 if(has_layer_norm)
776 {
777 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
778 const ITensorInfo *b_info = output_gate_bias;
779 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
780 }
781
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000782 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
783 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
784
785 // Hidden.
786 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
787 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100788 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000789 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
790 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Sang-Hoon Park30b46a62020-04-18 01:40:57 +0100791 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000792 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100793 ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
794
795 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000796
797 // Projection.
798 if(lstm_params.has_projection())
799 {
800 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
801 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.projection_bias());
802
803 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
804 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
805 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
806 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
807 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
808 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
809 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
810
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100811 const TensorInfo projection_outstage_info(*output_state_out);
812 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100813
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100814 TensorInfo projection_mm_out_info{ mm_out_info };
815 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100816
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100817 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
818 &projection_outstage_info));
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100819
820 if(projection_tensor_copy_required)
821 {
822 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
823 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000824
825 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
826
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100827 if(projection_tensor_copy_required)
828 {
829 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
830 }
831
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000832 int8_t quantized_projection_clip{ 0 };
833 if(lstm_params.projection_clip() > 0.0f)
834 {
835 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
836 }
837
838 if(quantized_projection_clip > 0)
839 {
840 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
841 quantized_projection_clip)));
842 }
843 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100844 else
845 {
846 if(projection_tensor_copy_required)
847 {
848 ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
849 }
850 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000851
852 if(cell_state_out->total_size() > 0)
853 {
854 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
855 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
856 }
857
858 if(output_state_out->total_size() > 0)
859 {
860 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
861 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
862 }
863
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100864 ARM_COMPUTE_RETURN_ON_ERROR(NECopyKernel::validate(output_state_out, output));
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000865 return Status{};
866}
867
868void NEQLSTMLayer::run()
869{
870 prepare();
871
872 // Acquire all the temporaries
873 MemoryGroupResourceScope scope_mg(_memory_group);
874
875 // Forget gate.
876 _mm_input_to_forget.run();
877 _input_to_forget_outstage.run();
878
879 _mm_recurrent_to_forget.run();
880 _recurrent_to_forget_outstage.run();
881 NEScheduler::get().schedule(&_accumulate_input_recurrent_forget, Window::DimY);
882
883 if(_has_peephole)
884 {
885 NEScheduler::get().schedule(&_pixelwise_mul_cell_to_forget, Window::DimY);
886 _cell_to_forget_outstage.run();
887 NEScheduler::get().schedule(&_accumulate_cell_forget, Window::DimY);
888 }
889
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100890 if(_has_layer_norm)
891 {
892 NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Forget), Window::DimY);
893 }
894
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000895 _forget_gate_sigmoid.run();
896
897 // Modulation gate.
898 _mm_input_to_cell.run();
899 _input_to_cell_outstage.run();
900
901 _mm_recurrent_to_cell.run();
902 _recurrent_to_cell_outstage.run();
903 NEScheduler::get().schedule(&_accumulate_input_recurrent_modulation, Window::DimY);
904
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100905 if(_has_layer_norm)
906 {
907 NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Cell), Window::DimY);
908 }
909
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000910 _cell_gate_tanh.run();
911
912 // Input gate
913 if(_has_cifg)
914 {
915 NEScheduler::get().schedule(&_input_gate_sub, Window::DimY);
916 }
917 else
918 {
919 _mm_input_to_input.run();
920 _input_to_input_outstage.run();
921 _mm_recurrent_to_input.run();
922 _recurrent_to_input_outstage.run();
923 NEScheduler::get().schedule(&_accumulate_input_recurrent_input, Window::DimY);
924
925 if(_has_peephole)
926 {
927 NEScheduler::get().schedule(&_pixelwise_mul_cell_to_input, Window::DimY);
928 _cell_to_input_outstage.run();
929 NEScheduler::get().schedule(&_accumulate_cell_input, Window::DimY);
930 }
931
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100932 if(_has_layer_norm)
933 {
934 NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Input), Window::DimY);
935 }
936
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100937 _input_gate_sigmoid.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000938 }
939
940 // Cell.
941 NEScheduler::get().schedule(&_pixelwise_mul_forget_cell, Window::DimY);
942 NEScheduler::get().schedule(&_pixelwise_mul_input_cell, Window::DimY);
943 NEScheduler::get().schedule(&_add_forget_cell, Window::DimY);
944 if(_has_cell_clipping)
945 {
946 _cell_clip.run();
947 }
948
949 // Output gate.
950 _mm_input_to_output.run();
951 _input_to_output_outstage.run();
952 _mm_recurrent_to_output.run();
953 _recurrent_to_output_outstage.run();
954 NEScheduler::get().schedule(&_accumulate_input_recurrent_output, Window::DimY);
955 if(_has_peephole)
956 {
957 NEScheduler::get().schedule(&_pixelwise_mul_cell_to_output, Window::DimY);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100958 _cell_to_output_outstage.run();
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000959 NEScheduler::get().schedule(&_accumulate_cell_to_output, Window::DimY);
960 }
961
Sang-Hoon Park9230e272020-04-18 00:46:34 +0100962 if(_has_layer_norm)
963 {
964 NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Output), Window::DimY);
965 }
966
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000967 _output_gate_sigmoid.run();
968
969 // Hidden.
970 _hidden_tanh.run();
971 NEScheduler::get().schedule(&_pixelwise_mul_hidden, Window::DimY);
972 _hidden_outstage.run();
973
974 // Projection.
975 if(_has_projection)
976 {
977 _mm_projection.run();
978 _projection_outstage.run();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100979
980 if(_projection_tensor_copy_required)
981 {
982 _projection_output_to_accumulate_copy.run();
983 }
984
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000985 NEScheduler::get().schedule(&_accumulate_projection, Window::DimY);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100986
987 if(_projection_tensor_copy_required)
988 {
989 _projection_accumulate_to_output_copy.run();
990 }
991
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000992 if(_has_projection_clipping)
993 {
994 _projection_clip.run();
995 }
996 }
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100997 else
998 {
999 if(_projection_tensor_copy_required)
1000 {
1001 _hidden_to_output_copy.run();
1002 }
1003 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001004
1005 // Copy output_state_out to output
1006 NEScheduler::get().schedule(&_copy_output, Window::DimY);
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001007}
1008
1009void NEQLSTMLayer::prepare()
1010{
1011 if(!_is_prepared)
1012 {
1013 // Pre-transpose weights to be used in GEMM.
1014 _input_to_forget_weights_transposed.allocator()->allocate();
1015 _input_to_cell_weights_transposed.allocator()->allocate();
1016 _input_to_output_weights_transposed.allocator()->allocate();
1017 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1018 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1019 _recurrent_to_output_weights_transposed.allocator()->allocate();
1020 _transpose_input_to_forget_weights.run();
1021 _transpose_input_to_cell_weights.run();
1022 _transpose_input_to_output_weights.run();
1023 _transpose_recurrent_to_forget_weights.run();
1024 _transpose_recurrent_to_cell_weights.run();
1025 _transpose_recurrent_to_output_weights.run();
1026
1027 // Precompute effective biases
1028 if(_has_cifg)
1029 {
1030 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1031 }
1032 else
1033 {
1034 _input_to_input_eff_bias.allocator()->allocate();
1035 _recurrent_to_input_eff_bias.allocator()->allocate();
1036 NEScheduler::get().schedule(&_input_to_input_reduction, Window::DimY);
1037 NEScheduler::get().schedule(&_recurrent_to_input_reduction, Window::DimY);
1038
1039 _input_to_input_weights_transposed.allocator()->allocate();
1040 _recurrent_to_input_weights_transposed.allocator()->allocate();
1041 _transpose_input_to_input_weights.run();
1042 _transpose_recurrent_to_input_weights.run();
1043 _input_to_input_weights->mark_as_unused();
1044 _recurrent_to_input_weights->mark_as_unused();
1045 }
1046 _input_to_forget_eff_bias.allocator()->allocate();
1047 _recurrent_to_forget_eff_bias.allocator()->allocate();
1048 _input_to_cell_eff_bias.allocator()->allocate();
1049 _recurrent_to_cell_eff_bias.allocator()->allocate();
1050 _input_to_output_eff_bias.allocator()->allocate();
1051 _recurrent_to_output_eff_bias.allocator()->allocate();
1052 NEScheduler::get().schedule(&_input_to_forget_reduction, Window::DimY);
1053 NEScheduler::get().schedule(&_recurrent_to_forget_reduction, Window::DimY);
1054 NEScheduler::get().schedule(&_input_to_cell_reduction, Window::DimY);
1055 NEScheduler::get().schedule(&_recurrent_to_cell_reduction, Window::DimY);
1056 NEScheduler::get().schedule(&_input_to_output_reduction, Window::DimY);
1057 NEScheduler::get().schedule(&_recurrent_to_output_reduction, Window::DimY);
1058
1059 if(_has_projection)
1060 {
1061 if(_projection_bias != nullptr)
1062 {
1063 _projection_eff_bias.allocator()->allocate();
1064 NEScheduler::get().schedule(&_projection_reduction, Window::DimY);
1065 _projection_bias->mark_as_unused();
1066 }
1067
1068 _projection_weights_transposed.allocator()->allocate();
1069 _transpose_projection_weights.run();
1070 _projection_weights->mark_as_unused();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001071
1072 if(!_projection_tensor_copy_required)
1073 {
1074 _hidden_gate.mark_as_unused();
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +01001075 _projection_accumulate_res.mark_as_unused();
1076 }
Michele Di Giorgio47a89902020-03-09 19:32:33 +00001077 }
1078
1079 // Mark weights as unused
1080 _input_to_forget_weights->mark_as_unused();
1081 _input_to_cell_weights->mark_as_unused();
1082 _input_to_output_weights->mark_as_unused();
1083 _recurrent_to_forget_weights->mark_as_unused();
1084 _recurrent_to_cell_weights->mark_as_unused();
1085 _recurrent_to_output_weights->mark_as_unused();
1086
1087 _is_prepared = true;
1088 }
1089}
1090
1091} // namespace arm_compute