blob: e30b1dbb867d2927406d5c15795a4946ccec91d4 [file] [log] [blame]
Manuel Bottini10c53f12019-07-17 16:11:53 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2019-2020 Arm Limited.
Manuel Bottini10c53f12019-07-17 16:11:53 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h"
26
27#include "arm_compute/core/Utils.h"
28#include "arm_compute/core/Validate.h"
29#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
30
Manuel Bottini10c53f12019-07-17 16:11:53 +010031#include <memory>
Manuel Bottini10c53f12019-07-17 16:11:53 +010032
33namespace arm_compute
34{
35namespace
36{
37// Quantization info structures used in the LSTMQuantize layer
38const QuantizationInfo qasymm(1.f / 128.f, 128);
39const QuantizationInfo qsymm_3(8.f / 32768.f, 0); // qsymm16 with 3 integer bit
40const QuantizationInfo qsymm_4(16.f / 32768.f, 0); // qsymm16 with 4 integer bit
41const QuantizationInfo qsymm_0(1.f / 32768.f, 0); // qsymm16 with 0 integer bit
42} // namespace
43
44CLLSTMLayerQuantized::CLLSTMLayerQuantized(std::shared_ptr<IMemoryManager> memory_manager)
45 : _memory_group(std::move(memory_manager)), _gemmlowp(), _output_stage(), _transpose_weights(), _concat_input_weights(), _concat_recurrent_weights(), _concat_weights(), _concat_inputs(),
46 _concat_bias(), _sigmoid_forget_gate(), _sigmoid_input_gate(), _sigmoid_output_gate(), _tanh_modulation_gate(), _tanh_output_state(), _add_cell_state_tmps(), _add2(), _mul_forget_gate_cell_state(),
47 _mul_input_gate_input_mod_gate(), _mul_output_state_tmp_output_gate(), _slice_input_tensor(), _slice_forget_tensor(), _slice_cell_tensor(), _slice_output_tensor(), _dequantize(), _quantize(),
48 _input_to_input_weights(nullptr), _input_to_forget_weights(nullptr), _input_to_cell_weights(nullptr), _input_to_output_weights(nullptr), _recurrent_to_input_weights(nullptr),
49 _recurrent_to_forget_weights(nullptr), _recurrent_to_cell_weights(nullptr), _recurrent_to_output_weights(nullptr), _input_gate_bias(nullptr), _forget_gate_bias(nullptr), _cell_bias(nullptr),
50 _output_gate_bias(nullptr), _recurrent_weights(), _input_weights(), _weights(), _input(), _weights_transposed(), _output_highp(), _output_lowp(), _bias(), _forget_gate_input(), _input_gate_input(),
51 _output_gate_input(), _input_modulation_gate_input(), _forget_gate_output(), _input_gate_output(), _output_gate_output(), _input_modulation_gate_output(), _cell_state_tmp1(), _cell_state_tmp2(),
52 _output_state_tmp(), _output_state_out_symm(), _output_state_out_f32(), _is_prepared(false)
53{
54}
55
56void CLLSTMLayerQuantized::configure(const ICLTensor *input,
57 const ICLTensor *input_to_input_weights, const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
58 const ICLTensor *recurrent_to_input_weights, const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
59 const ICLTensor *input_gate_bias, const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
60 ICLTensor *cell_state_in, const ICLTensor *output_state_in,
61 ICLTensor *cell_state_out, ICLTensor *output_state_out)
62{
Manuel Bottini2b84be52020-04-08 10:15:51 +010063 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights,
64 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out,
65 output_state_out);
66}
67
68void CLLSTMLayerQuantized::configure(const CLCompileContext &compile_context, const ICLTensor *input,
69 const ICLTensor *input_to_input_weights, const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
70 const ICLTensor *recurrent_to_input_weights, const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
71 const ICLTensor *input_gate_bias, const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
72 ICLTensor *cell_state_in, const ICLTensor *output_state_in,
73 ICLTensor *cell_state_out, ICLTensor *output_state_out)
74{
Manuel Bottini10c53f12019-07-17 16:11:53 +010075 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
76 recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
77 input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
78
79 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayerQuantized::validate(input->info(), input_to_input_weights->info(), input_to_forget_weights->info(), input_to_cell_weights->info(),
80 input_to_output_weights->info(),
81 recurrent_to_input_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
82 input_gate_bias->info(), forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info()));
83
84 const int input_size = input->info()->dimension(0);
85 const int batch_size = input->info()->dimension(1);
86 const int output_size = input_to_input_weights->info()->dimension(1);
87
88 const QuantizationInfo qweights = input_to_input_weights->info()->quantization_info(); // Weights quantization
89
90 auto_init_if_empty(*cell_state_out->info(), TensorInfo(TensorShape(batch_size, output_size), 1, DataType::QSYMM16, qsymm_4));
91 auto_init_if_empty(*output_state_out->info(), TensorInfo(TensorShape(batch_size, output_size), 1, DataType::QASYMM8, qasymm));
92
93 _input_to_input_weights = input_to_input_weights;
94 _input_to_forget_weights = input_to_forget_weights;
95 _input_to_cell_weights = input_to_cell_weights;
96 _input_to_output_weights = input_to_output_weights;
97 _recurrent_to_input_weights = recurrent_to_input_weights;
98 _recurrent_to_forget_weights = recurrent_to_forget_weights;
99 _recurrent_to_cell_weights = recurrent_to_cell_weights;
100 _recurrent_to_output_weights = recurrent_to_output_weights;
101 _input_gate_bias = input_gate_bias;
102 _forget_gate_bias = forget_gate_bias;
103 _cell_bias = cell_bias;
104 _output_gate_bias = output_gate_bias;
105
106 // Weights concatenation
107 std::vector<const ICLTensor *> inputs_weights_vector;
108 inputs_weights_vector.emplace_back(input_to_input_weights);
109 inputs_weights_vector.emplace_back(input_to_forget_weights);
110 inputs_weights_vector.emplace_back(input_to_cell_weights);
111 inputs_weights_vector.emplace_back(input_to_output_weights);
112
113 std::vector<const ICLTensor *> recurrent_weights_vector;
114 recurrent_weights_vector.emplace_back(recurrent_to_input_weights);
115 recurrent_weights_vector.emplace_back(recurrent_to_forget_weights);
116 recurrent_weights_vector.emplace_back(recurrent_to_cell_weights);
117 recurrent_weights_vector.emplace_back(recurrent_to_output_weights);
118
119 _input_weights.allocator()->init(TensorInfo(TensorShape(input_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100120 _concat_input_weights.configure(compile_context, inputs_weights_vector, &_input_weights, Window::DimY);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100121
122 _recurrent_weights.allocator()->init(TensorInfo(TensorShape(output_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100123 _concat_recurrent_weights.configure(compile_context, recurrent_weights_vector, &_recurrent_weights, Window::DimY);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100124
125 std::vector<const ICLTensor *> weights_vector;
126 weights_vector.emplace_back(&_recurrent_weights);
127 weights_vector.emplace_back(&_input_weights);
128
129 _weights.allocator()->init(TensorInfo(TensorShape(output_size + input_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100130 _concat_weights.configure(compile_context, weights_vector, &_weights, Window::DimX);
131 _transpose_weights.configure(compile_context, &_weights, &_weights_transposed);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100132
133 // Input concatenation
134 std::vector<const ICLTensor *> input_vector;
135 input_vector.emplace_back(input);
136 input_vector.emplace_back(output_state_in);
137
138 _memory_group.manage(&_input);
139 _input.allocator()->init(TensorInfo(TensorShape(output_size + input_size, batch_size), 1, DataType::QASYMM8, qasymm));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100140 _concat_inputs.configure(compile_context, input_vector, &_input, Window::DimX);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100141
142 // Bias concatenation
143 std::vector<const ICLTensor *> bias_vector;
144 bias_vector.emplace_back(input_gate_bias);
145 bias_vector.emplace_back(forget_gate_bias);
146 bias_vector.emplace_back(cell_bias);
147 bias_vector.emplace_back(output_gate_bias);
148
149 _bias.allocator()->init(TensorInfo(TensorShape(4 * output_size), 1, DataType::S32));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100150 _concat_bias.configure(compile_context, bias_vector, &_bias, Window::DimX);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100151
152 // Invert the offset for gemmlowp
153 _input.info()->set_quantization_info(QuantizationInfo(qasymm.uniform().scale, -qasymm.uniform().offset));
154 _weights_transposed.info()->set_quantization_info(QuantizationInfo(qweights.uniform().scale, -qweights.uniform().offset));
155
156 // Run gemmlowp
157 _memory_group.manage(&_output_highp);
158 _output_highp.allocator()->init(TensorInfo(TensorShape(4 * output_size, batch_size), 1, DataType::S32));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100159 _gemmlowp.configure(compile_context, &_input, &_weights_transposed, nullptr, &_output_highp);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100160 _input.allocator()->allocate();
161
162 // Set the offset back
163 _input.info()->set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
164 _weights_transposed.info()->set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
165
166 // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
167 _output_lowp.allocator()->init(TensorInfo(_output_highp.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_3));
168
169 const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
170 int output_multiplier = 0;
171 int output_shift = 0;
Manuel Bottini07263982019-10-17 18:37:26 +0100172 quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100173
174 _memory_group.manage(&_output_lowp);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100175 _output_stage.configure(compile_context, &_output_highp, &_bias, &_output_lowp, output_multiplier, output_shift);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100176 _output_highp.allocator()->allocate();
177 _bias.allocator()->allocate();
178
179 // Get the gate tensors
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100180 if(batch_size > 1)
181 {
182 _memory_group.manage(&_input_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100183 _slice_input_tensor.configure(compile_context, &_output_lowp, &_input_gate_input, { 0, 0 }, { output_size, batch_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100184 _memory_group.manage(&_forget_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100185 _slice_forget_tensor.configure(compile_context, &_output_lowp, &_forget_gate_input, { output_size, 0 }, { 2 * output_size, batch_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100186 _memory_group.manage(&_input_modulation_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100187 _slice_cell_tensor.configure(compile_context, &_output_lowp, &_input_modulation_gate_input, { 2 * output_size, 0 }, { 3 * output_size, batch_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100188 _memory_group.manage(&_output_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100189 _slice_output_tensor.configure(compile_context, &_output_lowp, &_output_gate_input, { 3 * output_size, 0 }, { 4 * output_size, batch_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100190 _output_lowp.allocator()->allocate();
191 }
192 else
193 {
194 _memory_group.manage(&_input_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100195 _slice_input_tensor.configure(compile_context, &_output_lowp, &_input_gate_input, { 0 }, { output_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100196 _memory_group.manage(&_forget_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100197 _slice_forget_tensor.configure(compile_context, &_output_lowp, &_forget_gate_input, { output_size }, { 2 * output_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100198 _memory_group.manage(&_input_modulation_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100199 _slice_cell_tensor.configure(compile_context, &_output_lowp, &_input_modulation_gate_input, { 2 * output_size }, { 3 * output_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100200 _memory_group.manage(&_output_gate_input);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100201 _slice_output_tensor.configure(compile_context, &_output_lowp, &_output_gate_input, { 3 * output_size }, { 4 * output_size });
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100202 _output_lowp.allocator()->allocate();
203 }
Manuel Bottini10c53f12019-07-17 16:11:53 +0100204
205 // Forget gate
206 _memory_group.manage(&_forget_gate_output);
207 _forget_gate_output.allocator()->init(TensorInfo(_forget_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100208 _sigmoid_forget_gate.configure(compile_context, &_forget_gate_input, &_forget_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Manuel Bottini10c53f12019-07-17 16:11:53 +0100209 _forget_gate_input.allocator()->allocate();
210
211 // Input gate
212 _memory_group.manage(&_input_gate_output);
213 _input_gate_output.allocator()->init(TensorInfo(_input_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100214 _sigmoid_input_gate.configure(compile_context, &_input_gate_input, &_input_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Manuel Bottini10c53f12019-07-17 16:11:53 +0100215 _input_gate_input.allocator()->allocate();
216
217 // Input modulation gate equation
218 _memory_group.manage(&_input_modulation_gate_output);
219 _input_modulation_gate_output.allocator()->init(TensorInfo(_input_modulation_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100220 _tanh_modulation_gate.configure(compile_context, &_input_modulation_gate_input, &_input_modulation_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f));
Manuel Bottini10c53f12019-07-17 16:11:53 +0100221 _input_modulation_gate_input.allocator()->allocate();
222
223 // Output gate
224 _memory_group.manage(&_output_gate_output);
225 _output_gate_output.allocator()->init(TensorInfo(_output_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100226 _sigmoid_output_gate.configure(compile_context, &_output_gate_input, &_output_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Manuel Bottini10c53f12019-07-17 16:11:53 +0100227 _output_gate_input.allocator()->allocate();
228
229 // Long term memory
230 _memory_group.manage(&_cell_state_tmp1);
231 _cell_state_tmp1.allocator()->init(TensorInfo(_forget_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_4));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100232 _mul_forget_gate_cell_state.configure(compile_context, &_forget_gate_output, cell_state_in, &_cell_state_tmp1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100233 _forget_gate_output.allocator()->allocate();
234
235 _memory_group.manage(&_cell_state_tmp2);
236 _cell_state_tmp2.allocator()->init(TensorInfo(_input_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_4));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100237 _mul_input_gate_input_mod_gate.configure(compile_context, &_input_gate_output, &_input_modulation_gate_output, &_cell_state_tmp2, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100238 _input_modulation_gate_output.allocator()->allocate();
239 _input_gate_output.allocator()->allocate();
240
Manuel Bottini2b84be52020-04-08 10:15:51 +0100241 _add_cell_state_tmps.configure(compile_context, &_cell_state_tmp1, &_cell_state_tmp2, cell_state_out, ConvertPolicy::SATURATE);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100242 _cell_state_tmp1.allocator()->allocate();
243 _cell_state_tmp2.allocator()->allocate();
244
245 // Short term memory
246 _memory_group.manage(&_output_state_tmp);
247 _output_state_tmp.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100248 _tanh_output_state.configure(compile_context, cell_state_out, &_output_state_tmp, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f));
Manuel Bottini10c53f12019-07-17 16:11:53 +0100249
250 _memory_group.manage(&_output_state_out_symm);
251 _output_state_out_symm.allocator()->init(TensorInfo(_output_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100252 _mul_output_state_tmp_output_gate.configure(compile_context, &_output_state_tmp, &_output_gate_output, &_output_state_out_symm, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100253 _output_gate_output.allocator()->allocate();
254 _output_state_tmp.allocator()->allocate();
255
256 // Requantize the output state from QSYMM16 to QASYMM8
257 _memory_group.manage(&_output_state_out_f32);
258 _output_state_out_f32.allocator()->init(TensorInfo(_output_state_out_symm.info()->tensor_shape(), 1, DataType::F32));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100259 _dequantize.configure(compile_context, &_output_state_out_symm, &_output_state_out_f32);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100260 _output_state_out_symm.allocator()->allocate();
261
Manuel Bottini2b84be52020-04-08 10:15:51 +0100262 _quantize.configure(compile_context, &_output_state_out_f32, output_state_out);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100263 _output_state_out_f32.allocator()->allocate();
264}
265
266Status CLLSTMLayerQuantized::validate(const ITensorInfo *input,
267 const ITensorInfo *input_to_input_weights, const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
268 const ITensorInfo *recurrent_to_input_weights, const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
269 const ITensorInfo *input_gate_bias, const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
270 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
271 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out)
272{
273 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights,
274 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in,
275 output_state_in, cell_state_out, output_state_out);
Michele Di Giorgiof6f78762020-07-06 11:27:21 +0100276 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(input, DataType::QASYMM8);
Manuel Bottini10c53f12019-07-17 16:11:53 +0100277
278 const int input_size = input->dimension(0);
279 const int batch_size = input->dimension(1);
280 const int output_size = input_to_input_weights->dimension(1);
281
282 // Dimensionality checks
283 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
284 ARM_COMPUTE_RETURN_ERROR_ON(input_to_input_weights->num_dimensions() > 2);
285 ARM_COMPUTE_RETURN_ERROR_ON(input_gate_bias->num_dimensions() > 1);
286 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
287
288 TensorInfo input_weights_info(input_to_input_weights->clone()->set_tensor_shape(TensorShape(input_size, output_size)).set_data_type(DataType::QASYMM8));
289 TensorInfo recurrent_weights_info(input_to_input_weights->clone()->set_tensor_shape(TensorShape(output_size, output_size)).set_data_type(DataType::QASYMM8));
290 TensorInfo bias_info(input_gate_bias->clone()->set_tensor_shape(TensorShape(output_size)).set_data_type(DataType::S32));
291 TensorInfo output_state_info(cell_state_in->clone()->set_tensor_shape(TensorShape(output_size, batch_size)).set_data_type(DataType::QASYMM8).set_quantization_info(qasymm));
292 TensorInfo cell_state_info(cell_state_in->clone()->set_tensor_shape(TensorShape(output_size, batch_size)).set_data_type(DataType::QSYMM16).set_quantization_info(qsymm_4));
293
294 // Shape checks
295 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input_weights_info, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights);
296 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&recurrent_weights_info, recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
297 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&bias_info, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias);
298 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&cell_state_info, cell_state_in);
299 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&output_state_info, output_state_in);
300
301 // Data type checks
302 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input_weights_info, input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights);
303 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&recurrent_weights_info, recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
304 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&bias_info, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias);
305 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&cell_state_info, cell_state_in);
306 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&output_state_info, output_state_in);
307
308 // Quantization checks
309 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights);
310 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
311 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&cell_state_info, cell_state_in);
312 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&output_state_info, output_state_in);
313
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100314 // Validate internal functions
315 // _concat_input_weights
316 std::vector<const ITensorInfo *> inputs_weights_vector;
317 inputs_weights_vector.emplace_back(input_to_input_weights);
318 inputs_weights_vector.emplace_back(input_to_forget_weights);
319 inputs_weights_vector.emplace_back(input_to_cell_weights);
320 inputs_weights_vector.emplace_back(input_to_output_weights);
321 const QuantizationInfo qweights = input_to_input_weights->quantization_info(); // Weights quantization
322 const TensorInfo input_weights(TensorShape(input_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
323 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_weights_vector, &input_weights, Window::DimY));
324
325 // _concat_recurrent_weights
326 std::vector<const ITensorInfo *> recurrent_weights_vector;
327 recurrent_weights_vector.emplace_back(recurrent_to_input_weights);
328 recurrent_weights_vector.emplace_back(recurrent_to_forget_weights);
329 recurrent_weights_vector.emplace_back(recurrent_to_cell_weights);
330 recurrent_weights_vector.emplace_back(recurrent_to_output_weights);
331 const TensorInfo recurrent_weights(TensorShape(output_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
332 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(recurrent_weights_vector, &recurrent_weights, Window::DimY));
333
334 // _concat_weights
335 std::vector<const ITensorInfo *> weights_vector;
336 weights_vector.emplace_back(&recurrent_weights);
337 weights_vector.emplace_back(&input_weights);
338 const TensorInfo weights(TensorShape(input_size + output_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
339 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(weights_vector, &weights, Window::DimX));
340 // _transpose_weights
341 const TensorShape weights_transposed_shape(weights.tensor_shape()[1], weights.tensor_shape()[0]);
342 TensorInfo weights_transposed = weights.clone()->set_is_resizable(true).set_tensor_shape(weights_transposed_shape);
343 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(&weights, &weights_transposed));
344
345 // _concat_inputs
346 std::vector<const ITensorInfo *> input_vector;
347 input_vector.emplace_back(input);
348 input_vector.emplace_back(output_state_in);
349 TensorInfo input_concatenated(TensorShape(output_size + input_size, batch_size), 1, DataType::QASYMM8, qasymm);
350 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(input_vector, &input_concatenated, Window::DimX));
351
352 // _concat_bias
353 std::vector<const ITensorInfo *> bias_vector;
354 bias_vector.emplace_back(input_gate_bias);
355 bias_vector.emplace_back(forget_gate_bias);
356 bias_vector.emplace_back(cell_bias);
357 bias_vector.emplace_back(output_gate_bias);
358
359 const TensorInfo bias_concatenated(TensorShape(4 * output_size), 1, DataType::S32);
360 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(bias_vector, &bias_concatenated, Window::DimX));
361
362 // Invert the offset for gemmlowp
363 input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, -qasymm.uniform().offset));
364 weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, -qweights.uniform().offset));
365
366 // _gemmlowp
367 const TensorInfo output_highp(TensorShape(4 * output_size, batch_size), 1, DataType::S32);
368 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(&input_concatenated, &weights_transposed, nullptr, &output_highp));
369
370 // Set the offset back
371 input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
372 weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
373
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100374 const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
375
Manuel Bottini07263982019-10-17 18:37:26 +0100376 const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
377 int output_multiplier = 0;
378 int output_shift = 0;
379 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
380
Michele Di Giorgio601ba3f2019-08-22 16:20:04 +0100381 // _output_stage
382 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(&output_highp, &bias_concatenated, &output_lowp));
383
384 TensorInfo input_gate_input;
385 TensorInfo forget_gate_input;
386 TensorInfo input_modulation_gate_input;
387 TensorInfo output_gate_input;
388
389 if(batch_size > 1)
390 {
391 // _slice_input_tensor
392 input_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
393 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &input_gate_input, { 0, 0 }, { output_size, batch_size }));
394 // _slice_forget_tensor
395 forget_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
396 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &forget_gate_input, { output_size, 0 }, { 2 * output_size, batch_size }));
397 // _slice_cell_tensor
398 input_modulation_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
399 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &input_modulation_gate_input, { 2 * output_size, 0 }, { 3 * output_size, batch_size }));
400 // _slice_output_tensor
401 output_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
402 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &output_gate_input, { 3 * output_size, 0 }, { 4 * output_size, batch_size }));
403 }
404 else
405 {
406 // _slice_input_tensor
407 input_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
408 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &input_gate_input, { 0 }, { output_size }));
409 // _slice_forget_tensor
410 forget_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
411 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &forget_gate_input, { output_size }, { 2 * output_size }));
412 // _slice_cell_tensor
413 input_modulation_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
414 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &input_modulation_gate_input, { 2 * output_size }, { 3 * output_size }));
415 // _slice_output_tensor
416 output_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
417 ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&output_lowp, &output_gate_input, { 3 * output_size }, { 4 * output_size }));
418 }
419
420 // _sigmoid_forget_gate
421 const TensorInfo forget_gate_output(forget_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
422 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate_input, &forget_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
423 // _sigmoid_input_gate
424 const TensorInfo input_gate_output(input_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
425 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate_input, &input_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
426 // _tanh_modulation_gate
427 const TensorInfo input_modulation_gate_output(input_modulation_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
428 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_modulation_gate_input, &input_modulation_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f)));
429 // _sigmoid_output_gate
430 const TensorInfo output_gate_output(output_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
431 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_input, &output_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
432
433 // _mul_forget_gate_cell_state
434 const TensorInfo cell_state_tmp1(forget_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_4);
435 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate_output, cell_state_in, &cell_state_tmp1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
436
437 // _mul_input_gate_input_mod_gate
438 const TensorInfo cell_state_tmp2(input_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_4);
439 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate_output, &input_modulation_gate_output, &cell_state_tmp2, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
440
441 // _add_cell_state_tmps
442 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp1, &cell_state_tmp2, cell_state_out, ConvertPolicy::SATURATE));
443
444 // _tanh_modulation_gate
445 const TensorInfo output_state_tmp(cell_state_out->tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
446 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &output_state_tmp, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f)));
447
448 // _mul_output_state_tmp_output_gate
449 const TensorInfo output_state_out_symm(output_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
450 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_state_tmp, &output_gate_output, &output_state_out_symm, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
451
452 // _dequantize
453 const TensorInfo output_state_out_f32(output_state_out_symm.tensor_shape(), 1, DataType::F32);
454 ARM_COMPUTE_RETURN_ON_ERROR(CLDequantizationLayer::validate(&output_state_out_symm, &output_state_out_f32));
455
456 // _quantize
457 ARM_COMPUTE_RETURN_ON_ERROR(CLQuantizationLayer::validate(&output_state_out_f32, output_state_out));
458
Manuel Bottini10c53f12019-07-17 16:11:53 +0100459 if(cell_state_out->total_size() != 0)
460 {
461 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&cell_state_info, cell_state_out);
462 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&cell_state_info, cell_state_out);
463 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&cell_state_info, cell_state_out);
464 }
465
466 if(output_state_out->total_size() != 0)
467 {
468 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&output_state_info, output_state_out);
469 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&output_state_info, output_state_out);
470 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&output_state_info, output_state_out);
471 }
472
473 return Status{};
474}
475
476void CLLSTMLayerQuantized::run()
477{
478 prepare();
479
480 // Acquire all the temporaries
481 MemoryGroupResourceScope scope_mg(_memory_group);
482
483 // Concat and transpose the input
484 _concat_inputs.run();
485
486 // Run gemmlowp
487 _gemmlowp.run();
488 _output_stage.run();
489
490 // Slice the results
491 _slice_input_tensor.run();
492 _slice_forget_tensor.run();
493 _slice_cell_tensor.run();
494 _slice_output_tensor.run();
495
496 // Gates
497 // Forget gate
498 _sigmoid_forget_gate.run();
499
500 // Input gate
501 _sigmoid_input_gate.run();
502
503 // Input modulation gate
504 _tanh_modulation_gate.run();
505
506 // Output gate
507 _sigmoid_output_gate.run();
508
509 // Cell state (long term memory)
510 _mul_forget_gate_cell_state.run();
511 _mul_input_gate_input_mod_gate.run();
512 _add_cell_state_tmps.run();
513
514 // Output state (short term memory)
515 _tanh_output_state.run();
516 _mul_output_state_tmp_output_gate.run();
517
Michele Di Giorgio35ea9a72019-08-23 12:02:06 +0100518 // Requantize output state from QSYMM16 to QASYMM8
Manuel Bottini10c53f12019-07-17 16:11:53 +0100519 _dequantize.run();
520 _quantize.run();
521}
522
523void CLLSTMLayerQuantized::prepare()
524{
525 if(!_is_prepared)
526 {
527 _input_weights.allocator()->allocate();
528 _concat_input_weights.run();
529
530 _input_to_input_weights->mark_as_unused();
531 _input_to_forget_weights->mark_as_unused();
532 _input_to_cell_weights->mark_as_unused();
533 _input_to_output_weights->mark_as_unused();
534
535 _recurrent_weights.allocator()->allocate();
536 _concat_recurrent_weights.run();
537 _recurrent_to_input_weights->mark_as_unused();
538 _recurrent_to_forget_weights->mark_as_unused();
539 _recurrent_to_cell_weights->mark_as_unused();
540 _recurrent_to_output_weights->mark_as_unused();
541
542 _weights.allocator()->allocate();
543 _concat_weights.run();
544
545 _input_weights.mark_as_unused();
546 _input_weights.allocator()->free();
547 _recurrent_weights.mark_as_unused();
548 _recurrent_weights.allocator()->free();
549
550 _weights_transposed.allocator()->allocate();
551 _transpose_weights.run();
552
553 _weights.mark_as_unused();
554 _weights.allocator()->free();
555
556 _bias.allocator()->allocate();
557 _concat_bias.run();
558 _input_gate_bias->mark_as_unused();
559 _forget_gate_bias->mark_as_unused();
560 _cell_bias->mark_as_unused();
561 _output_gate_bias->mark_as_unused();
562
563 _is_prepared = true;
564 }
565}
566
Michele Di Giorgio35ea9a72019-08-23 12:02:06 +0100567} // namespace arm_compute