blob: f063410972997bdc3be26818179bff5d7a763c0e [file] [log] [blame]
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001/*
2 * Copyright (c) 2020 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h"
25
26#include "arm_compute/core/KernelDescriptors.h"
27#include "arm_compute/core/QuantizationInfo.h"
28#include "arm_compute/core/Utils.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/utils/misc/InfoHelpers.h"
31#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
32#include "arm_compute/runtime/CL/CLScheduler.h"
33
34namespace arm_compute
35{
36using namespace arm_compute::utils::info_helpers;
37namespace
38{
39Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
40 float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
41{
42 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
43 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
44 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
45 return Status{};
46}
47} // namespace
48
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +010049Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
50{
51 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
52 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
53 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
54 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
55 return Status{};
56}
57
58void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst)
59{
60 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
61 _src = &src;
62 _dst = &dst;
63 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
64 _window = calculate_max_window(*_src->info(), Steps());
65}
66
67void CLQLSTMLayer::TensorCopyKernel::run()
68{
69 auto &q = CLScheduler::get().queue();
70
71 _src->map(q, true);
72 _dst->map(q, true);
73
74 Iterator input_iter{ _src, _window };
75 Iterator output_iter{ _dst, _window };
76
77 execute_window_loop(_window, [&](const Coordinates &)
78 {
79 memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
80 },
81 input_iter, output_iter);
82
83 _src->unmap(q);
84 _dst->unmap(q);
85}
86
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010087CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
88{
89 _memory_group = MemoryGroup(std::move(memory_manager));
90}
91
Manuel Bottini2b84be52020-04-08 10:15:51 +010092void CLQLSTMLayer::configure_mm(const CLCompileContext &compile_context, CLGEMMLowpMatrixMultiplyCore &mm, CLGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010093 const ICLTensor *mm_input, const ICLTensor *mm_weights, const ICLTensor *bias,
94 CLTensor *mm_res, CLTensor *outstage_res, float gemmlowp_scale,
95 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
96{
97 _memory_group.manage(mm_res);
98 _memory_group.manage(outstage_res);
99
100 mm_res->allocator()->init(mm_res_info);
101 outstage_res->allocator()->init(outstage_tensor_info);
102
103 // Configure matrix-multiplication
Manuel Bottini2b84be52020-04-08 10:15:51 +0100104 mm.configure(compile_context, mm_input, mm_weights, nullptr, mm_res);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100105
106 // Configure output stage
107 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100108 outstage.configure(compile_context, mm_res, bias, outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100109 mm_res->allocator()->allocate();
110}
111
112void CLQLSTMLayer::configure(const ICLTensor *input,
113 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
114 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
115 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
116 const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100117 ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100118 const LSTMParams<ICLTensor> &lstm_params)
119{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100120 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
121 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
Michalis Spyroue6bd70c2020-05-21 15:10:25 +0100122 cell_state_in, output_state_in, cell_state_out, output_state_out, output, lstm_params);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100123}
124
125void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
126 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
127 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
128 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
129 const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100130 ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100131 const LSTMParams<ICLTensor> &lstm_params)
132{
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100133 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
134 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100135 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
136 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100137
138 // Set lstm parameters
139 LSTMParams<ITensorInfo> lstm_params_info{};
140 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
141
142 // Validate
143 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
144 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
145 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100146 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
147 lstm_params_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100148
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100149 const int batch_size = input->info()->dimension(1);
150 const int num_units = input_to_output_weights->info()->dimension(1);
151 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100152
153 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
154 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
155 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
156
157 _projection_bias = lstm_params.projection_bias();
158 _input_to_forget_weights = input_to_forget_weights;
159 _input_to_cell_weights = input_to_cell_weights;
160 _input_to_output_weights = input_to_output_weights;
161 _recurrent_to_forget_weights = recurrent_to_forget_weights;
162 _recurrent_to_cell_weights = recurrent_to_cell_weights;
163 _recurrent_to_output_weights = recurrent_to_output_weights;
164 _projection_weights = lstm_params.projection_weights();
165
Sheri Zhang3a353982020-04-21 13:10:24 +0100166 // Layer normalization
167 _has_layer_norm = lstm_params.use_layer_norm();
168 if(_has_layer_norm)
169 {
170 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
171 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
172 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
173 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
174
175 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
176 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
177 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
178 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
179 }
180
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100181 _has_cifg = lstm_params.has_cifg_opt();
182 _has_projection = lstm_params.has_projection();
183 _has_peephole = lstm_params.has_peephole_opt();
184
185 // Calculate and decompose effective scales for optimizing matmul calculation
186 const int32_t cell_shift = log2(qcell_state_in.scale);
187
188 // Calculate quantized parameters for clipping.
189 int16_t quantized_cell_clip = 0;
190 if(lstm_params.cell_clip() > 0.0f)
191 {
192 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
193 }
194 _has_cell_clipping = quantized_cell_clip > 0;
195
196 // Precompute effective bias for optimizing the matmul computations.
197 if(!_has_cifg)
198 {
199 _input_to_input_weights = lstm_params.input_to_input_weights();
200 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
201
Manuel Bottini2b84be52020-04-08 10:15:51 +0100202 _input_to_input_reduction.configure(compile_context, _input_to_input_weights, &_input_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
203 _recurrent_to_input_reduction.configure(compile_context, _recurrent_to_input_weights, &_recurrent_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100204 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100205 _input_to_forget_reduction.configure(compile_context, input_to_forget_weights, &_input_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
206 _recurrent_to_forget_reduction.configure(compile_context, recurrent_to_forget_weights, &_recurrent_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
207 _input_to_cell_reduction.configure(compile_context, input_to_cell_weights, &_input_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
208 _recurrent_to_cell_reduction.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
209 _input_to_output_reduction.configure(compile_context, input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
210 _recurrent_to_output_reduction.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100211 if(_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100212 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100213 _projection_reduction.configure(compile_context, _projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100214 if(_projection_bias != nullptr)
215 {
216 _projection_bias_add.configure(compile_context, ArithmeticOperation::ADD, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
217 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100218 }
219
220 // Pre-transpose weights to be used in GEMM.
Manuel Bottini2b84be52020-04-08 10:15:51 +0100221 _transpose_input_to_forget_weights.configure(compile_context, input_to_forget_weights, &_input_to_forget_weights_transposed);
222 _transpose_input_to_cell_weights.configure(compile_context, input_to_cell_weights, &_input_to_cell_weights_transposed);
223 _transpose_input_to_output_weights.configure(compile_context, input_to_output_weights, &_input_to_output_weights_transposed);
224 _transpose_recurrent_to_forget_weights.configure(compile_context, recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
225 _transpose_recurrent_to_cell_weights.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
226 _transpose_recurrent_to_output_weights.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100227 if(!_has_cifg)
228 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100229 _transpose_input_to_input_weights.configure(compile_context, lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
230 _transpose_recurrent_to_input_weights.configure(compile_context, lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100231 }
232 if(_has_projection)
233 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100234 _transpose_projection_weights.configure(compile_context, _projection_weights, &_projection_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100235 }
236
237 GEMMLowpOutputStageInfo gemmlowp_info;
238 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
239 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
240 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
241 gemmlowp_info.output_data_type = DataType::QSYMM16;
242
243 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
244 // Forget gate.
245 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
246 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100247 configure_mm(compile_context, _mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100248 input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
249 &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
250 mm_out_info, forget_gate_outstage_info);
251
252 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100253 configure_mm(compile_context, _mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100254 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
255 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
256 mm_out_info, forget_gate_outstage_info);
257
Manuel Bottini2b84be52020-04-08 10:15:51 +0100258 _accumulate_input_recurrent_forget.configure(compile_context, ArithmeticOperation::ADD, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
259 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100260 _input_to_forget_outstage_res.allocator()->allocate();
261
262 if(_has_peephole)
263 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100264 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100265 _memory_group.manage(&_mul_cell_to_forget_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100266 _pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100267 _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
268 _memory_group.manage(&_cell_to_forget_outstage_res);
269 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
270 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100271 _cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100272 _mul_cell_to_forget_res.allocator()->allocate();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100273 _accumulate_cell_forget.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
274 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100275 _cell_to_forget_outstage_res.allocator()->allocate();
276 }
277
Sheri Zhang3a353982020-04-21 13:10:24 +0100278 CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
279
280 if(_has_layer_norm)
281 {
282 configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
283 _recurrent_to_forget_outstage_res.allocator()->allocate();
284 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
285 }
286
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100287 // Output quantization info of Sigmoid and Tanh activations
288 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
289
290 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
291 _memory_group.manage(&_forget_gate);
292 _forget_gate.allocator()->init(forget_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100293 _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
294 forget_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100295
296 // Modulation gate.
297 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
298 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100299 configure_mm(compile_context, _mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100300 input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
301 &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
302 mm_out_info, cell_outstage_info);
303
304 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100305 configure_mm(compile_context, _mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100306 output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
307 &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
308 mm_out_info, cell_outstage_info);
309
Manuel Bottini2b84be52020-04-08 10:15:51 +0100310 _accumulate_input_recurrent_modulation.configure(compile_context, ArithmeticOperation::ADD, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
311 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100312 _input_to_cell_outstage_res.allocator()->allocate();
313
Sheri Zhang3a353982020-04-21 13:10:24 +0100314 CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
315
316 if(_has_layer_norm)
317 {
318 configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
319 _recurrent_to_cell_outstage_res.allocator()->allocate();
320 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
321 }
322
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100323 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
324 _memory_group.manage(&_cell_gate);
325 _cell_gate.allocator()->init(cell_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100326 _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
327 cell_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100328
329 // Input gate.
330 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
331 _input_gate.allocator()->init(input_gate_info);
332 _memory_group.manage(&_input_gate);
333 if(_has_cifg)
334 {
335 _ones.allocator()->init(*_forget_gate.info());
Manuel Bottini2b84be52020-04-08 10:15:51 +0100336 _input_gate_sub.configure(compile_context, ArithmeticOperation::SUB, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100337 _ones.allocator()->allocate();
338 }
339 else
340 {
341 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
342 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100343 configure_mm(compile_context, _mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100344 input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
345 &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
346 mm_out_info, input_outstage_info);
347
348 const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100349 configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100350 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100351 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
352 mm_out_info, input_outstage_info);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100353 _accumulate_input_recurrent_input.configure(compile_context, ArithmeticOperation::ADD, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
354 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100355 _input_to_input_outstage_res.allocator()->allocate();
356
357 if(_has_peephole)
358 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100359 _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100360 _memory_group.manage(&_mul_cell_to_input_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100361 _pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100362 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
363 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
364 _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
365 _memory_group.manage(&_cell_to_input_outstage_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100366 _cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100367 _mul_cell_to_input_res.allocator()->allocate();
368 _accumulate_cell_input.configure(ArithmeticOperation::ADD, &_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
369 _cell_to_input_outstage_res.allocator()->allocate();
370 }
371
Sheri Zhang3a353982020-04-21 13:10:24 +0100372 CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
373
374 if(_has_layer_norm)
375 {
376 configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
377 _recurrent_to_input_outstage_res.allocator()->allocate();
378 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
379 }
380
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100381 _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sheri Zhang3a353982020-04-21 13:10:24 +0100382 input_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100383 }
384 // Cell.
385 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100386 _pixelwise_mul_forget_cell.configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100387 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
388 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
389 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
390 _memory_group.manage(&_mul_input_cell_res);
391 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100392 _pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100393 _cell_gate.allocator()->allocate();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100394 _add_forget_cell.configure(compile_context, ArithmeticOperation::ADD, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100395 _mul_input_cell_res.allocator()->allocate();
396 _forget_gate.allocator()->allocate();
397 if(_has_cell_clipping)
398 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100399 _cell_clip.configure(compile_context, cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100400 }
401 // Output gate.
402 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
403 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100404 configure_mm(compile_context, _mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100405 input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
406 &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
407 mm_out_info, output_outstage_info);
408
409 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100410 configure_mm(compile_context, _mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100411 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
412 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
413 mm_out_info, output_outstage_info);
414
Manuel Bottini2b84be52020-04-08 10:15:51 +0100415 _accumulate_input_recurrent_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
416 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100417 _input_to_output_outstage_res.allocator()->allocate();
418
419 if(_has_peephole)
420 {
421 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
422 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100423 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100424 _memory_group.manage(&_mul_cell_to_output_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100425 _pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100426
427 const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
428 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
429 _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
430 _memory_group.manage(&_cell_to_output_outstage_res);
431 _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100432 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100433
434 _accumulate_cell_to_output.configure(compile_context, ArithmeticOperation::ADD, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
435 ConvertPolicy::SATURATE);
436 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100437 }
438
Sheri Zhang3a353982020-04-21 13:10:24 +0100439 CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
440
441 if(_has_layer_norm)
442 {
443 configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
444 _recurrent_to_output_outstage_res.allocator()->allocate();
445 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
446 }
447
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100448 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
449 _memory_group.manage(&_output_gate);
450 _output_gate.allocator()->init(output_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100451 _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
452 output_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100453
454 // Hidden.
Manuel Bottini2b84be52020-04-08 10:15:51 +0100455 _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100456 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
457 _memory_group.manage(&_hidden_mul_res);
458 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
459 _hidden_mul_res.allocator()->init(hidden_mul_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100460 _pixelwise_mul_hidden.configure(compile_context, &_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100461 _output_gate.allocator()->allocate();
462 _input_gate.allocator()->allocate();
463 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
464 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
465 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
466 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100467
468 _projection_tensor_copy_required = (num_units != output_size);
469 ICLTensor *hidden_gate_result = output_state_out;
470
471 _memory_group.manage(&_hidden_gate);
472
473 if(_projection_tensor_copy_required)
474 {
475 _hidden_gate.allocator()->init(*output_state_out->info());
476 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
477 hidden_gate_result = &_hidden_gate;
478 }
479
480 _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100481 _hidden_mul_res.allocator()->allocate();
482
483 // Projection.
484 if(_has_projection)
485 {
486 const TensorInfo projection_outstage_info(*output_state_out->info());
487 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
488 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
489 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
490 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
491 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
492 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
493
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100494 TensorInfo projection_mm_out_info{ mm_out_info };
495 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100496
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100497 configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info,
498 hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
499 &_mm_projection_res, &_projection_outstage_res, projection_scale,
500 projection_mm_out_info, projection_outstage_info);
501
502 ICLTensor *accumulate_destination = output_state_out;
503
504 if(_projection_tensor_copy_required)
505 {
506 _hidden_gate.allocator()->allocate();
507 _projection_accumulate_res.allocator()->init(*output_state_out->info());
508 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
509 _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res);
510 accumulate_destination = &_projection_accumulate_res;
511 }
512
513 _accumulate_projection.configure(compile_context, ArithmeticOperation::ADD, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100514 _projection_outstage_res.allocator()->allocate();
515
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100516 if(_projection_tensor_copy_required)
517 {
518 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
519 _projection_accumulate_res.allocator()->allocate();
520 }
521
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100522 int8_t quantized_projection_clip{ 0 };
523 if(lstm_params.projection_clip() > 0.0f)
524 {
525 quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
526 }
527
528 if(quantized_projection_clip > 0)
529 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100530 _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
531 quantized_projection_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100532 _has_projection_clipping = true;
533 }
534 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100535 else
536 {
537 if(_projection_tensor_copy_required)
538 {
539 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
540 _hidden_gate.allocator()->allocate();
541 }
542 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100543
544 // Copy output_state_out to output
545 _copy_output.configure(compile_context, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100546}
547
548Status CLQLSTMLayer::validate(const ITensorInfo *input,
549 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
550 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
551 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
552 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100553 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100554 const LSTMParams<ITensorInfo> &lstm_params)
555{
556 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100557 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
558 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100559
560 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
561 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
562
563 const unsigned int input_size = input->dimension(0);
564 const unsigned int batch_size = input->dimension(1);
565 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100566 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100567
568 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
569 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
570 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
571 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
572 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
573 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
574 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
575 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
576 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
577
578 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
579 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
580 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
581 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
582 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
583
584 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
585 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
586 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
587 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
588
589 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
590 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
591 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
592 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
593
594 // Check whether peephole weights are all there or none
595 if(lstm_params.has_peephole_opt())
596 {
597 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
598 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
599 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
600 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
601 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
602 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
603
604 if(!lstm_params.has_cifg_opt())
605 {
606 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
607 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
608 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
609 }
610 }
611
612 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
613 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
614 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
615
616 // Calculate and decompose effective scales for optimizing matmul calculation
617 const int32_t cell_shift = log2(qcell_state_in.scale);
618 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
619
620 // Calculate quantized parameters for clipping.
621 int16_t quantized_cell_clip = 0;
622 if(lstm_params.cell_clip() > 0.0f)
623 {
624 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
625 }
626
627 // Precompute effective bias for optimizing the matmul computations.
628 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100629 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100630 if(!lstm_params.has_cifg_opt())
631 {
632 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
633 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
634 true)));
635 }
636 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
637 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
638 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
639 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
640 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
641 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100642 if(lstm_params.has_projection())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100643 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100644 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
645 lstm_params.hidden_state_zero(),
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100646 true)));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100647 if(lstm_params.projection_bias() != nullptr)
648 {
649 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
650 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, lstm_params.projection_bias(), &projection_eff_bias_info,
651 &projection_eff_bias_info, ConvertPolicy::SATURATE));
652 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100653 }
654
655 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
656 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
657
658 // Validate weights transpose
659 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_forget_weights, &input_weights_transposed));
660 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_cell_weights, &input_weights_transposed));
661 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_output_weights, &input_weights_transposed));
662 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
663 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
664 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
665 if(!lstm_params.has_cifg_opt())
666 {
667 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
668 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
669 }
670 if(lstm_params.has_projection())
671 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100672 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
673 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100674 }
675
676 GEMMLowpOutputStageInfo gemmlowp_info;
677 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
678 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
679 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
680 gemmlowp_info.output_data_type = DataType::QSYMM16;
681
Sheri Zhang3a353982020-04-21 13:10:24 +0100682 const bool has_layer_norm = lstm_params.use_layer_norm();
683
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100684 // Forget gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100685 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100686 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
687 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
688 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100689 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100690
691 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100692 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100693
694 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
695
696 if(lstm_params.has_peephole_opt())
697 {
698 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
699 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
700 RoundingPolicy::TO_ZERO));
701 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
702 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
703 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
704 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
705 }
706
Sheri Zhang3a353982020-04-21 13:10:24 +0100707 if(has_layer_norm)
708 {
709 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
710 const ITensorInfo *b_info = forget_gate_bias;
711 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
712 }
713
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100714 // Output quantization info of Sigmoid and Tanh activations
715 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
716
717 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
718 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
719
720 // Modulation gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100721 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100722 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
723 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100724 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100725
726 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100727 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100728
729 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
730
Sheri Zhang3a353982020-04-21 13:10:24 +0100731 if(has_layer_norm)
732 {
733 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
734 const ITensorInfo *b_info = cell_bias;
735 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
736 }
737
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100738 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
739 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
740
741 // Input gate.
742 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
743 if(lstm_params.has_cifg_opt())
744 {
745 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
746 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
747 }
748 else
749 {
750 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
751 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
752 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
753 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
754 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
755 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
756
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100757 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100758 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
759 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100760 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100761
762 const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100763 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100764
765 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
766
767 if(lstm_params.has_peephole_opt())
768 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100769 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100770 RoundingPolicy::TO_ZERO));
771 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
772 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100773 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100774 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
775 }
776
Sheri Zhang3a353982020-04-21 13:10:24 +0100777 if(has_layer_norm)
778 {
779 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
780 const ITensorInfo *b_info = lstm_params.input_gate_bias();
781 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
782 }
783
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100784 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100785 }
786 // Cell.
787 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
788 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
789 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
790 if(quantized_cell_clip > 0)
791 {
792 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
793 quantized_cell_clip)));
794 }
795 // Output gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100796 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100797 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
798 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100799 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100800
801 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100802 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100803
804 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
805 if(lstm_params.has_peephole_opt())
806 {
807 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
808 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
809 // Here we are not using the output stage because all operations are done in float
810 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
811 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
812 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
813 RoundingPolicy::TO_ZERO));
814 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
815 }
816
Sheri Zhang3a353982020-04-21 13:10:24 +0100817 if(has_layer_norm)
818 {
819 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
820 const ITensorInfo *b_info = output_gate_bias;
821 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
822 }
823
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100824 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
825 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
826
827 // Hidden.
828 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
829 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100830 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
831
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100832 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100833 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
834 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
835 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
836 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100837 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
838
839 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100840
841 // Projection.
842 if(lstm_params.has_projection())
843 {
844 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100845 ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100846
847 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
848 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
849 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
850 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
851 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
852 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
853 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
854
855 const TensorInfo projection_outstage_info(*output_state_out);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100856 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
857
858 TensorInfo projection_mm_out_info{ mm_out_info };
859 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
860
861 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
862 &projection_outstage_info));
863
864 if(projection_tensor_copy_required)
865 {
866 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
867 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100868
869 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
870
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100871 if(projection_tensor_copy_required)
872 {
873 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
874 }
875
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100876 int8_t quantized_projection_clip{ 0 };
877 if(lstm_params.projection_clip() > 0.0f)
878 {
879 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
880 }
881
882 if(quantized_projection_clip > 0)
883 {
884 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
885 quantized_projection_clip)));
886 }
887 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100888 else
889 {
890 if(projection_tensor_copy_required)
891 {
892 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
893 }
894 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100895
896 if(cell_state_out->total_size() > 0)
897 {
898 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
899 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
900 }
901
902 if(output_state_out->total_size() > 0)
903 {
904 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
905 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
906 }
907
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100908 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100909 return Status{};
910}
911
912void CLQLSTMLayer::run()
913{
914 prepare();
915
916 // Acquire all the temporaries
917 MemoryGroupResourceScope scope_mg(_memory_group);
918
919 // Forget gate.
920 _mm_input_to_forget.run();
921 _input_to_forget_outstage.run();
922
923 _mm_recurrent_to_forget.run();
924 _recurrent_to_forget_outstage.run();
925 CLScheduler::get().enqueue(_accumulate_input_recurrent_forget);
926
927 if(_has_peephole)
928 {
929 CLScheduler::get().enqueue(_pixelwise_mul_cell_to_forget);
930 _cell_to_forget_outstage.run();
931 CLScheduler::get().enqueue(_accumulate_cell_forget);
932 }
933
Sheri Zhang3a353982020-04-21 13:10:24 +0100934 if(_has_layer_norm)
935 {
936 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
937 }
938
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100939 _forget_gate_sigmoid.run();
940
941 // Modulation gate.
942 _mm_input_to_cell.run();
943 _input_to_cell_outstage.run();
944
945 _mm_recurrent_to_cell.run();
946 _recurrent_to_cell_outstage.run();
947 CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation);
948
Sheri Zhang3a353982020-04-21 13:10:24 +0100949 if(_has_layer_norm)
950 {
951 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
952 }
953
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100954 _cell_gate_tanh.run();
955
956 // Input gate
957 if(_has_cifg)
958 {
959 CLScheduler::get().enqueue(_input_gate_sub);
960 }
961 else
962 {
963 _mm_input_to_input.run();
964 _input_to_input_outstage.run();
965 _mm_recurrent_to_input.run();
966 _recurrent_to_input_outstage.run();
967 CLScheduler::get().enqueue(_accumulate_input_recurrent_input);
968
969 if(_has_peephole)
970 {
971 CLScheduler::get().enqueue(_pixelwise_mul_cell_to_input);
972 _cell_to_input_outstage.run();
973 CLScheduler::get().enqueue(_accumulate_cell_input);
974 }
975
Sheri Zhang3a353982020-04-21 13:10:24 +0100976 if(_has_layer_norm)
977 {
978 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
979 }
980
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100981 _input_gate_sigmoid.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100982 }
983
984 // Cell.
985 CLScheduler::get().enqueue(_pixelwise_mul_forget_cell);
986 CLScheduler::get().enqueue(_pixelwise_mul_input_cell);
987 CLScheduler::get().enqueue(_add_forget_cell);
988 if(_has_cell_clipping)
989 {
990 _cell_clip.run();
991 }
992
993 // Output gate.
994 _mm_input_to_output.run();
995 _input_to_output_outstage.run();
996 _mm_recurrent_to_output.run();
997 _recurrent_to_output_outstage.run();
998 CLScheduler::get().enqueue(_accumulate_input_recurrent_output);
999 if(_has_peephole)
1000 {
1001 CLScheduler::get().enqueue(_pixelwise_mul_cell_to_output);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001002 _cell_to_output_outstage.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001003 CLScheduler::get().enqueue(_accumulate_cell_to_output);
1004 }
1005
Sheri Zhang3a353982020-04-21 13:10:24 +01001006 if(_has_layer_norm)
1007 {
1008 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
1009 }
1010
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001011 _output_gate_sigmoid.run();
1012
1013 // Hidden.
1014 _hidden_tanh.run();
1015 CLScheduler::get().enqueue(_pixelwise_mul_hidden);
1016 _hidden_outstage.run();
1017
1018 // Projection.
1019 if(_has_projection)
1020 {
1021 _mm_projection.run();
1022 _projection_outstage.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001023
1024 if(_projection_tensor_copy_required)
1025 {
1026 _projection_output_to_accumulate_copy.run();
1027 }
1028
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001029 CLScheduler::get().enqueue(_accumulate_projection);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001030
1031 if(_projection_tensor_copy_required)
1032 {
1033 _projection_accumulate_to_output_copy.run();
1034 }
1035
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001036 if(_has_projection_clipping)
1037 {
1038 _projection_clip.run();
1039 }
1040 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001041 else
1042 {
1043 if(_projection_tensor_copy_required)
1044 {
1045 _hidden_to_output_copy.run();
1046 }
1047 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001048
1049 // Copy output_state_out to output
1050 CLScheduler::get().enqueue(_copy_output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001051}
1052
1053void CLQLSTMLayer::prepare()
1054{
1055 if(!_is_prepared)
1056 {
1057 // Pre-transpose weights to be used in GEMM.
1058 _input_to_forget_weights_transposed.allocator()->allocate();
1059 _input_to_cell_weights_transposed.allocator()->allocate();
1060 _input_to_output_weights_transposed.allocator()->allocate();
1061 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1062 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1063 _recurrent_to_output_weights_transposed.allocator()->allocate();
1064 _transpose_input_to_forget_weights.run();
1065 _transpose_input_to_cell_weights.run();
1066 _transpose_input_to_output_weights.run();
1067 _transpose_recurrent_to_forget_weights.run();
1068 _transpose_recurrent_to_cell_weights.run();
1069 _transpose_recurrent_to_output_weights.run();
1070
1071 // Precompute effective biases
1072 if(_has_cifg)
1073 {
1074 _ones.map(true);
1075 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1076 _ones.unmap();
1077 }
1078 else
1079 {
1080 _input_to_input_eff_bias.allocator()->allocate();
1081 _recurrent_to_input_eff_bias.allocator()->allocate();
1082 CLScheduler::get().enqueue(_input_to_input_reduction);
1083 CLScheduler::get().enqueue(_recurrent_to_input_reduction);
1084
1085 _input_to_input_weights_transposed.allocator()->allocate();
1086 _recurrent_to_input_weights_transposed.allocator()->allocate();
1087 _transpose_input_to_input_weights.run();
1088 _transpose_recurrent_to_input_weights.run();
1089 _input_to_input_weights->mark_as_unused();
1090 _recurrent_to_input_weights->mark_as_unused();
1091 }
1092 _input_to_forget_eff_bias.allocator()->allocate();
1093 _recurrent_to_forget_eff_bias.allocator()->allocate();
1094 _input_to_cell_eff_bias.allocator()->allocate();
1095 _recurrent_to_cell_eff_bias.allocator()->allocate();
1096 _input_to_output_eff_bias.allocator()->allocate();
1097 _recurrent_to_output_eff_bias.allocator()->allocate();
1098 CLScheduler::get().enqueue(_input_to_forget_reduction);
1099 CLScheduler::get().enqueue(_recurrent_to_forget_reduction);
1100 CLScheduler::get().enqueue(_input_to_cell_reduction);
1101 CLScheduler::get().enqueue(_recurrent_to_cell_reduction);
1102 CLScheduler::get().enqueue(_input_to_output_reduction);
1103 CLScheduler::get().enqueue(_recurrent_to_output_reduction);
1104
1105 if(_has_projection)
1106 {
Michele Di Giorgio11c562c2020-06-10 16:34:50 +01001107 _projection_eff_bias.allocator()->allocate();
1108 CLScheduler::get().enqueue(_projection_reduction);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001109 if(_projection_bias != nullptr)
1110 {
Michele Di Giorgio11c562c2020-06-10 16:34:50 +01001111 CLScheduler::get().enqueue(_projection_bias_add);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001112 _projection_bias->mark_as_unused();
1113 }
1114
1115 _projection_weights_transposed.allocator()->allocate();
1116 _transpose_projection_weights.run();
1117 _projection_weights->mark_as_unused();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001118
1119 if(!_projection_tensor_copy_required)
1120 {
1121 _hidden_gate.mark_as_unused();
1122 _projection_accumulate_res.mark_as_unused();
1123 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001124 }
1125
1126 // Mark weights as unused
1127 _input_to_forget_weights->mark_as_unused();
1128 _input_to_cell_weights->mark_as_unused();
1129 _input_to_output_weights->mark_as_unused();
1130 _recurrent_to_forget_weights->mark_as_unused();
1131 _recurrent_to_cell_weights->mark_as_unused();
1132 _recurrent_to_output_weights->mark_as_unused();
1133
1134 CLScheduler::get().queue().finish();
1135 _is_prepared = true;
1136 }
1137}
1138
1139} // namespace arm_compute