blob: 60e42a500d170dc5c01d3b6d6b2dc89692babd38 [file] [log] [blame]
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001/*
2 * Copyright (c) 2020 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h"
25
26#include "arm_compute/core/KernelDescriptors.h"
27#include "arm_compute/core/QuantizationInfo.h"
28#include "arm_compute/core/Utils.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/utils/misc/InfoHelpers.h"
31#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
32#include "arm_compute/runtime/CL/CLScheduler.h"
33
34namespace arm_compute
35{
36using namespace arm_compute::utils::info_helpers;
37namespace
38{
39Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
40 float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
41{
42 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
43 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
44 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
45 return Status{};
46}
47} // namespace
48
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +010049Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
50{
51 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
52 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
53 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
54 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
55 return Status{};
56}
57
58void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst)
59{
60 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
61 _src = &src;
62 _dst = &dst;
63 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
64 _window = calculate_max_window(*_src->info(), Steps());
65}
66
67void CLQLSTMLayer::TensorCopyKernel::run()
68{
69 auto &q = CLScheduler::get().queue();
70
71 _src->map(q, true);
72 _dst->map(q, true);
73
74 Iterator input_iter{ _src, _window };
75 Iterator output_iter{ _dst, _window };
76
77 execute_window_loop(_window, [&](const Coordinates &)
78 {
79 memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
80 },
81 input_iter, output_iter);
82
83 _src->unmap(q);
84 _dst->unmap(q);
85}
86
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010087CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
88{
89 _memory_group = MemoryGroup(std::move(memory_manager));
90}
91
Manuel Bottini2b84be52020-04-08 10:15:51 +010092void CLQLSTMLayer::configure_mm(const CLCompileContext &compile_context, CLGEMMLowpMatrixMultiplyCore &mm, CLGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010093 const ICLTensor *mm_input, const ICLTensor *mm_weights, const ICLTensor *bias,
94 CLTensor *mm_res, CLTensor *outstage_res, float gemmlowp_scale,
95 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
96{
97 _memory_group.manage(mm_res);
98 _memory_group.manage(outstage_res);
99
100 mm_res->allocator()->init(mm_res_info);
101 outstage_res->allocator()->init(outstage_tensor_info);
102
103 // Configure matrix-multiplication
Manuel Bottini2b84be52020-04-08 10:15:51 +0100104 mm.configure(compile_context, mm_input, mm_weights, nullptr, mm_res);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100105
106 // Configure output stage
107 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100108 outstage.configure(compile_context, mm_res, bias, outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100109 mm_res->allocator()->allocate();
110}
111
112void CLQLSTMLayer::configure(const ICLTensor *input,
113 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
114 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
115 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
116 const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100117 ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100118 const LSTMParams<ICLTensor> &lstm_params)
119{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100120 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
121 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100122 cell_state_in, output_state_in, cell_state_out, output, output_state_out, lstm_params);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100123}
124
125void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
126 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
127 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
128 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
129 const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100130 ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100131 const LSTMParams<ICLTensor> &lstm_params)
132{
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100133 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
134 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100135 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
136 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100137
138 // Set lstm parameters
139 LSTMParams<ITensorInfo> lstm_params_info{};
140 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
141
142 // Validate
143 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
144 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
145 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100146 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
147 lstm_params_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100148
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100149 const int batch_size = input->info()->dimension(1);
150 const int num_units = input_to_output_weights->info()->dimension(1);
151 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100152
153 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
154 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
155 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
156
157 _projection_bias = lstm_params.projection_bias();
158 _input_to_forget_weights = input_to_forget_weights;
159 _input_to_cell_weights = input_to_cell_weights;
160 _input_to_output_weights = input_to_output_weights;
161 _recurrent_to_forget_weights = recurrent_to_forget_weights;
162 _recurrent_to_cell_weights = recurrent_to_cell_weights;
163 _recurrent_to_output_weights = recurrent_to_output_weights;
164 _projection_weights = lstm_params.projection_weights();
165
Sheri Zhang3a353982020-04-21 13:10:24 +0100166 // Layer normalization
167 _has_layer_norm = lstm_params.use_layer_norm();
168 if(_has_layer_norm)
169 {
170 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
171 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
172 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
173 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
174
175 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
176 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
177 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
178 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
179 }
180
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100181 _has_cifg = lstm_params.has_cifg_opt();
182 _has_projection = lstm_params.has_projection();
183 _has_peephole = lstm_params.has_peephole_opt();
184
185 // Calculate and decompose effective scales for optimizing matmul calculation
186 const int32_t cell_shift = log2(qcell_state_in.scale);
187
188 // Calculate quantized parameters for clipping.
189 int16_t quantized_cell_clip = 0;
190 if(lstm_params.cell_clip() > 0.0f)
191 {
192 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
193 }
194 _has_cell_clipping = quantized_cell_clip > 0;
195
196 // Precompute effective bias for optimizing the matmul computations.
197 if(!_has_cifg)
198 {
199 _input_to_input_weights = lstm_params.input_to_input_weights();
200 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
201
Manuel Bottini2b84be52020-04-08 10:15:51 +0100202 _input_to_input_reduction.configure(compile_context, _input_to_input_weights, &_input_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
203 _recurrent_to_input_reduction.configure(compile_context, _recurrent_to_input_weights, &_recurrent_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100204 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100205 _input_to_forget_reduction.configure(compile_context, input_to_forget_weights, &_input_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
206 _recurrent_to_forget_reduction.configure(compile_context, recurrent_to_forget_weights, &_recurrent_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
207 _input_to_cell_reduction.configure(compile_context, input_to_cell_weights, &_input_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
208 _recurrent_to_cell_reduction.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
209 _input_to_output_reduction.configure(compile_context, input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
210 _recurrent_to_output_reduction.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100211 if(_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100212 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100213 _projection_reduction.configure(compile_context, _projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100214 }
215
216 // Pre-transpose weights to be used in GEMM.
Manuel Bottini2b84be52020-04-08 10:15:51 +0100217 _transpose_input_to_forget_weights.configure(compile_context, input_to_forget_weights, &_input_to_forget_weights_transposed);
218 _transpose_input_to_cell_weights.configure(compile_context, input_to_cell_weights, &_input_to_cell_weights_transposed);
219 _transpose_input_to_output_weights.configure(compile_context, input_to_output_weights, &_input_to_output_weights_transposed);
220 _transpose_recurrent_to_forget_weights.configure(compile_context, recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
221 _transpose_recurrent_to_cell_weights.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
222 _transpose_recurrent_to_output_weights.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100223 if(!_has_cifg)
224 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100225 _transpose_input_to_input_weights.configure(compile_context, lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
226 _transpose_recurrent_to_input_weights.configure(compile_context, lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100227 }
228 if(_has_projection)
229 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100230 _transpose_projection_weights.configure(compile_context, _projection_weights, &_projection_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100231 }
232
233 GEMMLowpOutputStageInfo gemmlowp_info;
234 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
235 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
236 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
237 gemmlowp_info.output_data_type = DataType::QSYMM16;
238
239 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
240 // Forget gate.
241 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
242 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100243 configure_mm(compile_context, _mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100244 input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
245 &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
246 mm_out_info, forget_gate_outstage_info);
247
248 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100249 configure_mm(compile_context, _mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100250 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
251 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
252 mm_out_info, forget_gate_outstage_info);
253
Manuel Bottini2b84be52020-04-08 10:15:51 +0100254 _accumulate_input_recurrent_forget.configure(compile_context, ArithmeticOperation::ADD, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
255 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100256 _input_to_forget_outstage_res.allocator()->allocate();
257
258 if(_has_peephole)
259 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100260 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100261 _memory_group.manage(&_mul_cell_to_forget_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100262 _pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100263 _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
264 _memory_group.manage(&_cell_to_forget_outstage_res);
265 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
266 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100267 _cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100268 _mul_cell_to_forget_res.allocator()->allocate();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100269 _accumulate_cell_forget.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
270 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100271 _cell_to_forget_outstage_res.allocator()->allocate();
272 }
273
Sheri Zhang3a353982020-04-21 13:10:24 +0100274 CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
275
276 if(_has_layer_norm)
277 {
278 configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
279 _recurrent_to_forget_outstage_res.allocator()->allocate();
280 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
281 }
282
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100283 // Output quantization info of Sigmoid and Tanh activations
284 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
285
286 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
287 _memory_group.manage(&_forget_gate);
288 _forget_gate.allocator()->init(forget_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100289 _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
290 forget_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100291
292 // Modulation gate.
293 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
294 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100295 configure_mm(compile_context, _mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100296 input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
297 &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
298 mm_out_info, cell_outstage_info);
299
300 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100301 configure_mm(compile_context, _mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100302 output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
303 &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
304 mm_out_info, cell_outstage_info);
305
Manuel Bottini2b84be52020-04-08 10:15:51 +0100306 _accumulate_input_recurrent_modulation.configure(compile_context, ArithmeticOperation::ADD, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
307 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100308 _input_to_cell_outstage_res.allocator()->allocate();
309
Sheri Zhang3a353982020-04-21 13:10:24 +0100310 CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
311
312 if(_has_layer_norm)
313 {
314 configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
315 _recurrent_to_cell_outstage_res.allocator()->allocate();
316 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
317 }
318
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100319 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
320 _memory_group.manage(&_cell_gate);
321 _cell_gate.allocator()->init(cell_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100322 _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
323 cell_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100324
325 // Input gate.
326 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
327 _input_gate.allocator()->init(input_gate_info);
328 _memory_group.manage(&_input_gate);
329 if(_has_cifg)
330 {
331 _ones.allocator()->init(*_forget_gate.info());
Manuel Bottini2b84be52020-04-08 10:15:51 +0100332 _input_gate_sub.configure(compile_context, ArithmeticOperation::SUB, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100333 _ones.allocator()->allocate();
334 }
335 else
336 {
337 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
338 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100339 configure_mm(compile_context, _mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100340 input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
341 &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
342 mm_out_info, input_outstage_info);
343
344 const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100345 configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100346 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100347 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
348 mm_out_info, input_outstage_info);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100349 _accumulate_input_recurrent_input.configure(compile_context, ArithmeticOperation::ADD, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
350 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100351 _input_to_input_outstage_res.allocator()->allocate();
352
353 if(_has_peephole)
354 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100355 _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100356 _memory_group.manage(&_mul_cell_to_input_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100357 _pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100358 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
359 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
360 _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
361 _memory_group.manage(&_cell_to_input_outstage_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100362 _cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100363 _mul_cell_to_input_res.allocator()->allocate();
364 _accumulate_cell_input.configure(ArithmeticOperation::ADD, &_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
365 _cell_to_input_outstage_res.allocator()->allocate();
366 }
367
Sheri Zhang3a353982020-04-21 13:10:24 +0100368 CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
369
370 if(_has_layer_norm)
371 {
372 configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
373 _recurrent_to_input_outstage_res.allocator()->allocate();
374 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
375 }
376
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100377 _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sheri Zhang3a353982020-04-21 13:10:24 +0100378 input_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100379 }
380 // Cell.
381 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100382 _pixelwise_mul_forget_cell.configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100383 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
384 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
385 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
386 _memory_group.manage(&_mul_input_cell_res);
387 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100388 _pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100389 _cell_gate.allocator()->allocate();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100390 _add_forget_cell.configure(compile_context, ArithmeticOperation::ADD, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100391 _mul_input_cell_res.allocator()->allocate();
392 _forget_gate.allocator()->allocate();
393 if(_has_cell_clipping)
394 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100395 _cell_clip.configure(compile_context, cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100396 }
397 // Output gate.
398 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
399 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100400 configure_mm(compile_context, _mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100401 input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
402 &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
403 mm_out_info, output_outstage_info);
404
405 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100406 configure_mm(compile_context, _mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100407 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
408 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
409 mm_out_info, output_outstage_info);
410
Manuel Bottini2b84be52020-04-08 10:15:51 +0100411 _accumulate_input_recurrent_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
412 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100413 _input_to_output_outstage_res.allocator()->allocate();
414
415 if(_has_peephole)
416 {
417 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
418 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100419 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100420 _memory_group.manage(&_mul_cell_to_output_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100421 _pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100422
423 const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
424 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
425 _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
426 _memory_group.manage(&_cell_to_output_outstage_res);
427 _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100428 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100429
430 _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
431 ConvertPolicy::SATURATE);
432 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100433 }
434
Sheri Zhang3a353982020-04-21 13:10:24 +0100435 CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
436
437 if(_has_layer_norm)
438 {
439 configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
440 _recurrent_to_output_outstage_res.allocator()->allocate();
441 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
442 }
443
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100444 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
445 _memory_group.manage(&_output_gate);
446 _output_gate.allocator()->init(output_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100447 _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
448 output_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100449
450 // Hidden.
Manuel Bottini2b84be52020-04-08 10:15:51 +0100451 _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100452 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
453 _memory_group.manage(&_hidden_mul_res);
454 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
455 _hidden_mul_res.allocator()->init(hidden_mul_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100456 _pixelwise_mul_hidden.configure(compile_context, &_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100457 _output_gate.allocator()->allocate();
458 _input_gate.allocator()->allocate();
459 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
460 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
461 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
462 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100463
464 _projection_tensor_copy_required = (num_units != output_size);
465 ICLTensor *hidden_gate_result = output_state_out;
466
467 _memory_group.manage(&_hidden_gate);
468
469 if(_projection_tensor_copy_required)
470 {
471 _hidden_gate.allocator()->init(*output_state_out->info());
472 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
473 hidden_gate_result = &_hidden_gate;
474 }
475
476 _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100477 _hidden_mul_res.allocator()->allocate();
478
479 // Projection.
480 if(_has_projection)
481 {
482 const TensorInfo projection_outstage_info(*output_state_out->info());
483 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
484 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
485 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
486 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
487 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
488 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
489
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100490 TensorInfo projection_mm_out_info{ mm_out_info };
491 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100492
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100493 configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info,
494 hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
495 &_mm_projection_res, &_projection_outstage_res, projection_scale,
496 projection_mm_out_info, projection_outstage_info);
497
498 ICLTensor *accumulate_destination = output_state_out;
499
500 if(_projection_tensor_copy_required)
501 {
502 _hidden_gate.allocator()->allocate();
503 _projection_accumulate_res.allocator()->init(*output_state_out->info());
504 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
505 _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res);
506 accumulate_destination = &_projection_accumulate_res;
507 }
508
509 _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100510 _projection_outstage_res.allocator()->allocate();
511
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100512 if(_projection_tensor_copy_required)
513 {
514 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
515 _projection_accumulate_res.allocator()->allocate();
516 }
517
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100518 int8_t quantized_projection_clip{ 0 };
519 if(lstm_params.projection_clip() > 0.0f)
520 {
521 quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
522 }
523
524 if(quantized_projection_clip > 0)
525 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100526 _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
527 quantized_projection_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100528 _has_projection_clipping = true;
529 }
530 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100531 else
532 {
533 if(_projection_tensor_copy_required)
534 {
535 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
536 _hidden_gate.allocator()->allocate();
537 }
538 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100539
540 // Copy output_state_out to output
541 _copy_output.configure(compile_context, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100542}
543
544Status CLQLSTMLayer::validate(const ITensorInfo *input,
545 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
546 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
547 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
548 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100549 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100550 const LSTMParams<ITensorInfo> &lstm_params)
551{
552 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100553 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
554 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100555
556 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
557 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
558
559 const unsigned int input_size = input->dimension(0);
560 const unsigned int batch_size = input->dimension(1);
561 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100562 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100563
564 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
565 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
566 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
567 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
568 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
569 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
570 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
571 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
572 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
573
574 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
575 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
576 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
577 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
578 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
579
580 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
581 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
582 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
583 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
584
585 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
586 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
587 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
588 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
589
590 // Check whether peephole weights are all there or none
591 if(lstm_params.has_peephole_opt())
592 {
593 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
594 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
595 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
596 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
597 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
598 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
599
600 if(!lstm_params.has_cifg_opt())
601 {
602 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
603 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
604 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
605 }
606 }
607
608 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
609 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
610 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
611
612 // Calculate and decompose effective scales for optimizing matmul calculation
613 const int32_t cell_shift = log2(qcell_state_in.scale);
614 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
615
616 // Calculate quantized parameters for clipping.
617 int16_t quantized_cell_clip = 0;
618 if(lstm_params.cell_clip() > 0.0f)
619 {
620 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
621 }
622
623 // Precompute effective bias for optimizing the matmul computations.
624 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100625 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100626 if(!lstm_params.has_cifg_opt())
627 {
628 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
629 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
630 true)));
631 }
632 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
633 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
634 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
635 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
636 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
637 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100638 if(lstm_params.has_projection())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100639 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100640 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
641 lstm_params.hidden_state_zero(),
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100642 true)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100643 }
644
645 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
646 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
647
648 // Validate weights transpose
649 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_forget_weights, &input_weights_transposed));
650 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_cell_weights, &input_weights_transposed));
651 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_output_weights, &input_weights_transposed));
652 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
653 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
654 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
655 if(!lstm_params.has_cifg_opt())
656 {
657 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
658 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
659 }
660 if(lstm_params.has_projection())
661 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100662 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
663 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100664 }
665
666 GEMMLowpOutputStageInfo gemmlowp_info;
667 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
668 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
669 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
670 gemmlowp_info.output_data_type = DataType::QSYMM16;
671
Sheri Zhang3a353982020-04-21 13:10:24 +0100672 const bool has_layer_norm = lstm_params.use_layer_norm();
673
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100674 // Forget gate.
675 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
676 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
677 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100678 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100679
680 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100681 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100682
683 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
684
685 if(lstm_params.has_peephole_opt())
686 {
687 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
688 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
689 RoundingPolicy::TO_ZERO));
690 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
691 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
692 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
693 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
694 }
695
Sheri Zhang3a353982020-04-21 13:10:24 +0100696 if(has_layer_norm)
697 {
698 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
699 const ITensorInfo *b_info = forget_gate_bias;
700 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
701 }
702
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100703 // Output quantization info of Sigmoid and Tanh activations
704 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
705
706 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
707 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
708
709 // Modulation gate.
710 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
711 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100712 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100713
714 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100715 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100716
717 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
718
Sheri Zhang3a353982020-04-21 13:10:24 +0100719 if(has_layer_norm)
720 {
721 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
722 const ITensorInfo *b_info = cell_bias;
723 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
724 }
725
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100726 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
727 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
728
729 // Input gate.
730 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
731 if(lstm_params.has_cifg_opt())
732 {
733 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
734 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
735 }
736 else
737 {
738 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
739 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
740 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
741 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
742 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
743 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
744
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100745 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
746 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100747 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100748
749 const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100750 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100751
752 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
753
754 if(lstm_params.has_peephole_opt())
755 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100756 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100757 RoundingPolicy::TO_ZERO));
758 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
759 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100760 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100761 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
762 }
763
Sheri Zhang3a353982020-04-21 13:10:24 +0100764 if(has_layer_norm)
765 {
766 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
767 const ITensorInfo *b_info = lstm_params.input_gate_bias();
768 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
769 }
770
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100771 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100772 }
773 // Cell.
774 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
775 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
776 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
777 if(quantized_cell_clip > 0)
778 {
779 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
780 quantized_cell_clip)));
781 }
782 // Output gate.
783 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
784 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100785 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100786
787 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100788 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100789
790 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
791 if(lstm_params.has_peephole_opt())
792 {
793 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
794 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
795 // Here we are not using the output stage because all operations are done in float
796 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
797 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
798 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
799 RoundingPolicy::TO_ZERO));
800 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
801 }
802
Sheri Zhang3a353982020-04-21 13:10:24 +0100803 if(has_layer_norm)
804 {
805 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
806 const ITensorInfo *b_info = output_gate_bias;
807 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
808 }
809
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100810 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
811 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
812
813 // Hidden.
814 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
815 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100816 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
817
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100818 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
819 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
820 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
821 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100822 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
823
824 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100825
826 // Projection.
827 if(lstm_params.has_projection())
828 {
829 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
830 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.projection_bias());
831
832 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
833 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
834 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
835 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
836 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
837 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
838 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
839
840 const TensorInfo projection_outstage_info(*output_state_out);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100841 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
842
843 TensorInfo projection_mm_out_info{ mm_out_info };
844 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
845
846 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
847 &projection_outstage_info));
848
849 if(projection_tensor_copy_required)
850 {
851 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
852 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100853
854 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
855
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100856 if(projection_tensor_copy_required)
857 {
858 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
859 }
860
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100861 int8_t quantized_projection_clip{ 0 };
862 if(lstm_params.projection_clip() > 0.0f)
863 {
864 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
865 }
866
867 if(quantized_projection_clip > 0)
868 {
869 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
870 quantized_projection_clip)));
871 }
872 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100873 else
874 {
875 if(projection_tensor_copy_required)
876 {
877 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
878 }
879 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100880
881 if(cell_state_out->total_size() > 0)
882 {
883 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
884 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
885 }
886
887 if(output_state_out->total_size() > 0)
888 {
889 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
890 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
891 }
892
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100893 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100894 return Status{};
895}
896
897void CLQLSTMLayer::run()
898{
899 prepare();
900
901 // Acquire all the temporaries
902 MemoryGroupResourceScope scope_mg(_memory_group);
903
904 // Forget gate.
905 _mm_input_to_forget.run();
906 _input_to_forget_outstage.run();
907
908 _mm_recurrent_to_forget.run();
909 _recurrent_to_forget_outstage.run();
910 CLScheduler::get().enqueue(_accumulate_input_recurrent_forget);
911
912 if(_has_peephole)
913 {
914 CLScheduler::get().enqueue(_pixelwise_mul_cell_to_forget);
915 _cell_to_forget_outstage.run();
916 CLScheduler::get().enqueue(_accumulate_cell_forget);
917 }
918
Sheri Zhang3a353982020-04-21 13:10:24 +0100919 if(_has_layer_norm)
920 {
921 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
922 }
923
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100924 _forget_gate_sigmoid.run();
925
926 // Modulation gate.
927 _mm_input_to_cell.run();
928 _input_to_cell_outstage.run();
929
930 _mm_recurrent_to_cell.run();
931 _recurrent_to_cell_outstage.run();
932 CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation);
933
Sheri Zhang3a353982020-04-21 13:10:24 +0100934 if(_has_layer_norm)
935 {
936 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
937 }
938
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100939 _cell_gate_tanh.run();
940
941 // Input gate
942 if(_has_cifg)
943 {
944 CLScheduler::get().enqueue(_input_gate_sub);
945 }
946 else
947 {
948 _mm_input_to_input.run();
949 _input_to_input_outstage.run();
950 _mm_recurrent_to_input.run();
951 _recurrent_to_input_outstage.run();
952 CLScheduler::get().enqueue(_accumulate_input_recurrent_input);
953
954 if(_has_peephole)
955 {
956 CLScheduler::get().enqueue(_pixelwise_mul_cell_to_input);
957 _cell_to_input_outstage.run();
958 CLScheduler::get().enqueue(_accumulate_cell_input);
959 }
960
Sheri Zhang3a353982020-04-21 13:10:24 +0100961 if(_has_layer_norm)
962 {
963 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
964 }
965
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100966 _input_gate_sigmoid.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100967 }
968
969 // Cell.
970 CLScheduler::get().enqueue(_pixelwise_mul_forget_cell);
971 CLScheduler::get().enqueue(_pixelwise_mul_input_cell);
972 CLScheduler::get().enqueue(_add_forget_cell);
973 if(_has_cell_clipping)
974 {
975 _cell_clip.run();
976 }
977
978 // Output gate.
979 _mm_input_to_output.run();
980 _input_to_output_outstage.run();
981 _mm_recurrent_to_output.run();
982 _recurrent_to_output_outstage.run();
983 CLScheduler::get().enqueue(_accumulate_input_recurrent_output);
984 if(_has_peephole)
985 {
986 CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100987 _cell_to_output_outstage.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100988 CLScheduler::get().enqueue(_accumulate_cell_to_output);
989 }
990
Sheri Zhang3a353982020-04-21 13:10:24 +0100991 if(_has_layer_norm)
992 {
993 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
994 }
995
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100996 _output_gate_sigmoid.run();
997
998 // Hidden.
999 _hidden_tanh.run();
1000 CLScheduler::get().enqueue(_pixelwise_mul_hidden);
1001 _hidden_outstage.run();
1002
1003 // Projection.
1004 if(_has_projection)
1005 {
1006 _mm_projection.run();
1007 _projection_outstage.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001008
1009 if(_projection_tensor_copy_required)
1010 {
1011 _projection_output_to_accumulate_copy.run();
1012 }
1013
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001014 CLScheduler::get().enqueue(_accumulate_projection);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001015
1016 if(_projection_tensor_copy_required)
1017 {
1018 _projection_accumulate_to_output_copy.run();
1019 }
1020
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001021 if(_has_projection_clipping)
1022 {
1023 _projection_clip.run();
1024 }
1025 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001026 else
1027 {
1028 if(_projection_tensor_copy_required)
1029 {
1030 _hidden_to_output_copy.run();
1031 }
1032 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001033
1034 // Copy output_state_out to output
1035 CLScheduler::get().enqueue(_copy_output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001036}
1037
1038void CLQLSTMLayer::prepare()
1039{
1040 if(!_is_prepared)
1041 {
1042 // Pre-transpose weights to be used in GEMM.
1043 _input_to_forget_weights_transposed.allocator()->allocate();
1044 _input_to_cell_weights_transposed.allocator()->allocate();
1045 _input_to_output_weights_transposed.allocator()->allocate();
1046 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1047 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1048 _recurrent_to_output_weights_transposed.allocator()->allocate();
1049 _transpose_input_to_forget_weights.run();
1050 _transpose_input_to_cell_weights.run();
1051 _transpose_input_to_output_weights.run();
1052 _transpose_recurrent_to_forget_weights.run();
1053 _transpose_recurrent_to_cell_weights.run();
1054 _transpose_recurrent_to_output_weights.run();
1055
1056 // Precompute effective biases
1057 if(_has_cifg)
1058 {
1059 _ones.map(true);
1060 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1061 _ones.unmap();
1062 }
1063 else
1064 {
1065 _input_to_input_eff_bias.allocator()->allocate();
1066 _recurrent_to_input_eff_bias.allocator()->allocate();
1067 CLScheduler::get().enqueue(_input_to_input_reduction);
1068 CLScheduler::get().enqueue(_recurrent_to_input_reduction);
1069
1070 _input_to_input_weights_transposed.allocator()->allocate();
1071 _recurrent_to_input_weights_transposed.allocator()->allocate();
1072 _transpose_input_to_input_weights.run();
1073 _transpose_recurrent_to_input_weights.run();
1074 _input_to_input_weights->mark_as_unused();
1075 _recurrent_to_input_weights->mark_as_unused();
1076 }
1077 _input_to_forget_eff_bias.allocator()->allocate();
1078 _recurrent_to_forget_eff_bias.allocator()->allocate();
1079 _input_to_cell_eff_bias.allocator()->allocate();
1080 _recurrent_to_cell_eff_bias.allocator()->allocate();
1081 _input_to_output_eff_bias.allocator()->allocate();
1082 _recurrent_to_output_eff_bias.allocator()->allocate();
1083 CLScheduler::get().enqueue(_input_to_forget_reduction);
1084 CLScheduler::get().enqueue(_recurrent_to_forget_reduction);
1085 CLScheduler::get().enqueue(_input_to_cell_reduction);
1086 CLScheduler::get().enqueue(_recurrent_to_cell_reduction);
1087 CLScheduler::get().enqueue(_input_to_output_reduction);
1088 CLScheduler::get().enqueue(_recurrent_to_output_reduction);
1089
1090 if(_has_projection)
1091 {
1092 if(_projection_bias != nullptr)
1093 {
1094 _projection_eff_bias.allocator()->allocate();
1095 CLScheduler::get().enqueue(_projection_reduction);
1096 _projection_bias->mark_as_unused();
1097 }
1098
1099 _projection_weights_transposed.allocator()->allocate();
1100 _transpose_projection_weights.run();
1101 _projection_weights->mark_as_unused();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001102
1103 if(!_projection_tensor_copy_required)
1104 {
1105 _hidden_gate.mark_as_unused();
1106 _projection_accumulate_res.mark_as_unused();
1107 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001108 }
1109
1110 // Mark weights as unused
1111 _input_to_forget_weights->mark_as_unused();
1112 _input_to_cell_weights->mark_as_unused();
1113 _input_to_output_weights->mark_as_unused();
1114 _recurrent_to_forget_weights->mark_as_unused();
1115 _recurrent_to_cell_weights->mark_as_unused();
1116 _recurrent_to_output_weights->mark_as_unused();
1117
1118 CLScheduler::get().queue().finish();
1119 _is_prepared = true;
1120 }
1121}
1122
1123} // namespace arm_compute