blob: 6ddf555b5c584056231e46761cc4ba5dabb2e55f [file] [log] [blame]
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001/*
Sheri Zhang7e20e292021-02-02 11:49:34 +00002 * Copyright (c) 2020-2021 Arm Limited.
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h"
25
26#include "arm_compute/core/KernelDescriptors.h"
27#include "arm_compute/core/QuantizationInfo.h"
28#include "arm_compute/core/Utils.h"
29#include "arm_compute/core/Validate.h"
30#include "arm_compute/core/utils/misc/InfoHelpers.h"
31#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
32#include "arm_compute/runtime/CL/CLScheduler.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010033#include "src/core/CL/kernels/CLFillBorderKernel.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010034#include "src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010035#include "src/core/helpers/WindowHelpers.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010036#include "src/gpu/cl/kernels/ClGemmLowpReductionKernel.h"
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010037
38namespace arm_compute
39{
40using namespace arm_compute::utils::info_helpers;
Georgios Pinitas4a578b92021-06-25 12:13:49 +010041using namespace arm_compute::opencl::kernels;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010042namespace
43{
44Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
45 float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
46{
47 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
48 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
49 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
50 return Status{};
51}
52} // namespace
53
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +010054Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
55{
56 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
57 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
58 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
59 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
60 return Status{};
61}
62
63void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst)
64{
65 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
66 _src = &src;
67 _dst = &dst;
68 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
69 _window = calculate_max_window(*_src->info(), Steps());
70}
71
72void CLQLSTMLayer::TensorCopyKernel::run()
73{
74 auto &q = CLScheduler::get().queue();
75
76 _src->map(q, true);
77 _dst->map(q, true);
78
79 Iterator input_iter{ _src, _window };
80 Iterator output_iter{ _dst, _window };
81
82 execute_window_loop(_window, [&](const Coordinates &)
83 {
84 memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
85 },
86 input_iter, output_iter);
87
88 _src->unmap(q);
89 _dst->unmap(q);
90}
91
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010092CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Georgios Pinitas4a578b92021-06-25 12:13:49 +010093 : _input_to_input_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
94 _recurrent_to_input_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
95 _input_to_forget_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
96 _recurrent_to_forget_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
97 _input_to_cell_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
98 _recurrent_to_cell_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
99 _input_to_output_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
100 _recurrent_to_output_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
101 _projection_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100102 _layer_norms(),
Sheri Zhang7e20e292021-02-02 11:49:34 +0000103 _copy_output()
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100104{
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100105 for(auto &norm : _layer_norms)
106 {
Georgios Pinitas40f51a62020-11-21 03:04:18 +0000107 norm = std::make_unique<CLQLSTMLayerNormalizationKernel>();
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100108 }
109
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100110 _memory_group = MemoryGroup(std::move(memory_manager));
111}
112
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100113CLQLSTMLayer::~CLQLSTMLayer() = default;
114
115void CLQLSTMLayer::configure_layer_norm(LayerNormGate g, const ICLTensor *in)
116{
117 ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
118
119 CLTensor *out = &get_layer_norm_output(g);
120 _memory_group.manage(out);
121 out->allocator()->init(*(in->info()));
122
123 get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g));
124}
125
126Status CLQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
127{
128 // Output quantization scale will be different, but ignored here
129 // since it will be configured at configure() stage.
130 const TensorInfo out
131 {
132 in
133 };
134 return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
135}
136
Manuel Bottini2b84be52020-04-08 10:15:51 +0100137void CLQLSTMLayer::configure_mm(const CLCompileContext &compile_context, CLGEMMLowpMatrixMultiplyCore &mm, CLGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100138 const ICLTensor *mm_input, const ICLTensor *mm_weights, const ICLTensor *bias,
139 CLTensor *mm_res, CLTensor *outstage_res, float gemmlowp_scale,
140 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
141{
142 _memory_group.manage(mm_res);
143 _memory_group.manage(outstage_res);
144
145 mm_res->allocator()->init(mm_res_info);
146 outstage_res->allocator()->init(outstage_tensor_info);
147
148 // Configure matrix-multiplication
Manuel Bottini2b84be52020-04-08 10:15:51 +0100149 mm.configure(compile_context, mm_input, mm_weights, nullptr, mm_res);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100150
151 // Configure output stage
152 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100153 outstage.configure(compile_context, mm_res, bias, outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100154 mm_res->allocator()->allocate();
155}
156
157void CLQLSTMLayer::configure(const ICLTensor *input,
158 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
159 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
160 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100161 ICLTensor *cell_state_in, ICLTensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100162 ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100163 const LSTMParams<ICLTensor> &lstm_params)
164{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100165 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
166 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
Michalis Spyroue6bd70c2020-05-21 15:10:25 +0100167 cell_state_in, output_state_in, cell_state_out, output_state_out, output, lstm_params);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100168}
169
170void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
171 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
172 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
173 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100174 ICLTensor *cell_state_in, ICLTensor *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100175 ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100176 const LSTMParams<ICLTensor> &lstm_params)
177{
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100178 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
179 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100180 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
181 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100182
183 // Set lstm parameters
184 LSTMParams<ITensorInfo> lstm_params_info{};
185 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
186
187 // Validate
188 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
189 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
190 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100191 cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
192 lstm_params_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100193
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100194 const int batch_size = input->info()->dimension(1);
195 const int num_units = input_to_output_weights->info()->dimension(1);
196 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100197
198 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
199 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
200 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
201
202 _projection_bias = lstm_params.projection_bias();
203 _input_to_forget_weights = input_to_forget_weights;
204 _input_to_cell_weights = input_to_cell_weights;
205 _input_to_output_weights = input_to_output_weights;
206 _recurrent_to_forget_weights = recurrent_to_forget_weights;
207 _recurrent_to_cell_weights = recurrent_to_cell_weights;
208 _recurrent_to_output_weights = recurrent_to_output_weights;
209 _projection_weights = lstm_params.projection_weights();
210
Sheri Zhang3a353982020-04-21 13:10:24 +0100211 // Layer normalization
212 _has_layer_norm = lstm_params.use_layer_norm();
213 if(_has_layer_norm)
214 {
215 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
216 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
217 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
218 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
219
220 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
221 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
222 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
223 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
224 }
225
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100226 _has_cifg = lstm_params.has_cifg_opt();
227 _has_projection = lstm_params.has_projection();
228 _has_peephole = lstm_params.has_peephole_opt();
229
230 // Calculate and decompose effective scales for optimizing matmul calculation
231 const int32_t cell_shift = log2(qcell_state_in.scale);
232
233 // Calculate quantized parameters for clipping.
234 int16_t quantized_cell_clip = 0;
235 if(lstm_params.cell_clip() > 0.0f)
236 {
237 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
238 }
239 _has_cell_clipping = quantized_cell_clip > 0;
240
241 // Precompute effective bias for optimizing the matmul computations.
242 if(!_has_cifg)
243 {
244 _input_to_input_weights = lstm_params.input_to_input_weights();
245 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
246
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100247 _input_to_input_reduction->configure(compile_context, _input_to_input_weights->info(), _input_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
248 _recurrent_to_input_reduction->configure(compile_context, _recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false,
249 -qoutput_state_in.offset, true));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100250 }
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100251 _input_to_forget_reduction->configure(compile_context, input_to_forget_weights->info(), _input_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
252 _recurrent_to_forget_reduction->configure(compile_context, recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false,
253 -qoutput_state_in.offset, true));
254 _input_to_cell_reduction->configure(compile_context, input_to_cell_weights->info(), _input_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
255 _recurrent_to_cell_reduction->configure(compile_context, recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
256 true));
257 _input_to_output_reduction->configure(compile_context, input_to_output_weights->info(), _input_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
258 _recurrent_to_output_reduction->configure(compile_context, recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false,
259 -qoutput_state_in.offset, true));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100260 if(_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100261 {
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100262 _projection_reduction->configure(compile_context, _projection_weights->info(), _projection_eff_bias.info(), GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100263 if(_projection_bias != nullptr)
264 {
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100265 _projection_bias_add.configure(compile_context, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100266 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100267 }
268
269 // Pre-transpose weights to be used in GEMM.
Manuel Bottini2b84be52020-04-08 10:15:51 +0100270 _transpose_input_to_forget_weights.configure(compile_context, input_to_forget_weights, &_input_to_forget_weights_transposed);
271 _transpose_input_to_cell_weights.configure(compile_context, input_to_cell_weights, &_input_to_cell_weights_transposed);
272 _transpose_input_to_output_weights.configure(compile_context, input_to_output_weights, &_input_to_output_weights_transposed);
273 _transpose_recurrent_to_forget_weights.configure(compile_context, recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
274 _transpose_recurrent_to_cell_weights.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
275 _transpose_recurrent_to_output_weights.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100276 if(!_has_cifg)
277 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100278 _transpose_input_to_input_weights.configure(compile_context, lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
279 _transpose_recurrent_to_input_weights.configure(compile_context, lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100280 }
281 if(_has_projection)
282 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100283 _transpose_projection_weights.configure(compile_context, _projection_weights, &_projection_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100284 }
285
286 GEMMLowpOutputStageInfo gemmlowp_info;
287 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
288 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
289 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
290 gemmlowp_info.output_data_type = DataType::QSYMM16;
291
292 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
293 // Forget gate.
294 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
295 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100296 configure_mm(compile_context, _mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100297 input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
298 &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
299 mm_out_info, forget_gate_outstage_info);
300
301 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100302 configure_mm(compile_context, _mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100303 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
304 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
305 mm_out_info, forget_gate_outstage_info);
306
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100307 _accumulate_input_recurrent_forget.configure(compile_context, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100308 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100309 _input_to_forget_outstage_res.allocator()->allocate();
310
311 if(_has_peephole)
312 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100313 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100314 _memory_group.manage(&_mul_cell_to_forget_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100315 _pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100316 _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
317 _memory_group.manage(&_cell_to_forget_outstage_res);
318 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
319 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100320 _cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100321 _mul_cell_to_forget_res.allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100322 _accumulate_cell_forget.configure(compile_context, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100323 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100324 _cell_to_forget_outstage_res.allocator()->allocate();
325 }
326
Sheri Zhang3a353982020-04-21 13:10:24 +0100327 CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
328
329 if(_has_layer_norm)
330 {
331 configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
332 _recurrent_to_forget_outstage_res.allocator()->allocate();
333 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
334 }
335
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100336 // Output quantization info of Sigmoid and Tanh activations
337 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
338
339 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
340 _memory_group.manage(&_forget_gate);
341 _forget_gate.allocator()->init(forget_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100342 _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
343 forget_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100344
345 // Modulation gate.
346 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
347 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100348 configure_mm(compile_context, _mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100349 input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
350 &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
351 mm_out_info, cell_outstage_info);
352
353 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100354 configure_mm(compile_context, _mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100355 output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
356 &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
357 mm_out_info, cell_outstage_info);
358
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100359 _accumulate_input_recurrent_modulation.configure(compile_context, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100360 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100361 _input_to_cell_outstage_res.allocator()->allocate();
362
Sheri Zhang3a353982020-04-21 13:10:24 +0100363 CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
364
365 if(_has_layer_norm)
366 {
367 configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
368 _recurrent_to_cell_outstage_res.allocator()->allocate();
369 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
370 }
371
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100372 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
373 _memory_group.manage(&_cell_gate);
374 _cell_gate.allocator()->init(cell_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100375 _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
376 cell_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100377
378 // Input gate.
379 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
380 _input_gate.allocator()->init(input_gate_info);
381 _memory_group.manage(&_input_gate);
382 if(_has_cifg)
383 {
384 _ones.allocator()->init(*_forget_gate.info());
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100385 _input_gate_sub.configure(compile_context, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100386 _ones.allocator()->allocate();
387 }
388 else
389 {
390 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
391 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100392 configure_mm(compile_context, _mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100393 input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
394 &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
395 mm_out_info, input_outstage_info);
396
397 const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100398 configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100399 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100400 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
401 mm_out_info, input_outstage_info);
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100402 _accumulate_input_recurrent_input.configure(compile_context, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100403 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100404 _input_to_input_outstage_res.allocator()->allocate();
405
406 if(_has_peephole)
407 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100408 _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100409 _memory_group.manage(&_mul_cell_to_input_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100410 _pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100411 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
412 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
413 _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
414 _memory_group.manage(&_cell_to_input_outstage_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100415 _cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100416 _mul_cell_to_input_res.allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100417 _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100418 _cell_to_input_outstage_res.allocator()->allocate();
419 }
420
Sheri Zhang3a353982020-04-21 13:10:24 +0100421 CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
422
423 if(_has_layer_norm)
424 {
425 configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
426 _recurrent_to_input_outstage_res.allocator()->allocate();
427 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
428 }
429
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100430 _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sheri Zhang3a353982020-04-21 13:10:24 +0100431 input_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100432 }
433 // Cell.
Michalis Spyrou1009e872020-07-27 12:48:34 +0100434 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
Manuel Bottini2b84be52020-04-08 10:15:51 +0100435 _pixelwise_mul_forget_cell.configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100436 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
437 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
438 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
439 _memory_group.manage(&_mul_input_cell_res);
440 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100441 _pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100442 _cell_gate.allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100443 _add_forget_cell.configure(compile_context, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100444 _mul_input_cell_res.allocator()->allocate();
445 _forget_gate.allocator()->allocate();
446 if(_has_cell_clipping)
447 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100448 _cell_clip.configure(compile_context, cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100449 }
450 // Output gate.
451 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
452 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100453 configure_mm(compile_context, _mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100454 input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
455 &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
456 mm_out_info, output_outstage_info);
457
458 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100459 configure_mm(compile_context, _mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100460 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
461 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
462 mm_out_info, output_outstage_info);
463
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100464 _accumulate_input_recurrent_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100465 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100466 _input_to_output_outstage_res.allocator()->allocate();
467
468 if(_has_peephole)
469 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100470 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100471 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100472 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100473 _memory_group.manage(&_mul_cell_to_output_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100474 _pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100475
476 const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
477 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
478 _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
479 _memory_group.manage(&_cell_to_output_outstage_res);
480 _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100481 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100482
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100483 _accumulate_cell_to_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100484 ConvertPolicy::SATURATE);
485 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100486 }
487
Sheri Zhang3a353982020-04-21 13:10:24 +0100488 CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
489
490 if(_has_layer_norm)
491 {
492 configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
493 _recurrent_to_output_outstage_res.allocator()->allocate();
494 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
495 }
496
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100497 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
498 _memory_group.manage(&_output_gate);
499 _output_gate.allocator()->init(output_gate_info);
Sheri Zhang3a353982020-04-21 13:10:24 +0100500 _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
501 output_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100502
503 // Hidden.
Manuel Bottini2b84be52020-04-08 10:15:51 +0100504 _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100505 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100506 _memory_group.manage(&_hidden_mul_res);
507 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
508 _hidden_mul_res.allocator()->init(hidden_mul_res);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100509 _pixelwise_mul_hidden.configure(compile_context, &_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100510 _output_gate.allocator()->allocate();
511 _input_gate.allocator()->allocate();
512 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
513 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
514 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
515 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100516
517 _projection_tensor_copy_required = (num_units != output_size);
518 ICLTensor *hidden_gate_result = output_state_out;
519
520 _memory_group.manage(&_hidden_gate);
521
522 if(_projection_tensor_copy_required)
523 {
524 _hidden_gate.allocator()->init(*output_state_out->info());
525 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
526 hidden_gate_result = &_hidden_gate;
527 }
528
529 _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100530 _hidden_mul_res.allocator()->allocate();
531
532 // Projection.
533 if(_has_projection)
534 {
535 const TensorInfo projection_outstage_info(*output_state_out->info());
536 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
537 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
538 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
539 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
540 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
541 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
542
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100543 TensorInfo projection_mm_out_info{ mm_out_info };
544 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100545
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100546 configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info,
547 hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
548 &_mm_projection_res, &_projection_outstage_res, projection_scale,
549 projection_mm_out_info, projection_outstage_info);
550
551 ICLTensor *accumulate_destination = output_state_out;
552
553 if(_projection_tensor_copy_required)
554 {
555 _hidden_gate.allocator()->allocate();
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100556 _projection_accumulate_res.allocator()->init(*output_state_in->info());
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100557 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100558 _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100559 accumulate_destination = &_projection_accumulate_res;
560 }
561
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100562 _accumulate_projection.configure(compile_context, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100563 _projection_outstage_res.allocator()->allocate();
564
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100565 if(_projection_tensor_copy_required)
566 {
567 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
568 _projection_accumulate_res.allocator()->allocate();
569 }
570
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100571 int8_t quantized_projection_clip{ 0 };
572 if(lstm_params.projection_clip() > 0.0f)
573 {
574 quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
575 }
576
577 if(quantized_projection_clip > 0)
578 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100579 _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
580 quantized_projection_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100581 _has_projection_clipping = true;
582 }
583 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100584 else
585 {
586 if(_projection_tensor_copy_required)
587 {
588 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
589 _hidden_gate.allocator()->allocate();
590 }
591 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100592
593 // Copy output_state_out to output
Sheri Zhang7e20e292021-02-02 11:49:34 +0000594 _copy_output.configure(compile_context, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100595}
596
597Status CLQLSTMLayer::validate(const ITensorInfo *input,
598 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
599 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
600 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
601 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100602 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100603 const LSTMParams<ITensorInfo> &lstm_params)
604{
605 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100606 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
607 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100608
609 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
610 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
611
612 const unsigned int input_size = input->dimension(0);
613 const unsigned int batch_size = input->dimension(1);
614 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100615 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100616
617 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
618 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
619 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
620 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
621 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
622 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
623 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
624 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
625 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
626
627 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
628 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
629 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
630 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
631 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
632
633 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
634 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
635 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
636 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
637
638 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
639 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
640 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
641 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
642
643 // Check whether peephole weights are all there or none
644 if(lstm_params.has_peephole_opt())
645 {
646 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
647 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
648 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
649 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
650 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
651 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
652
653 if(!lstm_params.has_cifg_opt())
654 {
655 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
656 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
657 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
658 }
659 }
660
661 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
662 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
663 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
664
665 // Calculate and decompose effective scales for optimizing matmul calculation
666 const int32_t cell_shift = log2(qcell_state_in.scale);
667 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
668
669 // Calculate quantized parameters for clipping.
670 int16_t quantized_cell_clip = 0;
671 if(lstm_params.cell_clip() > 0.0f)
672 {
673 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
674 }
675
676 // Precompute effective bias for optimizing the matmul computations.
677 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100678 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100679 if(!lstm_params.has_cifg_opt())
680 {
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100681 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
682 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100683 true)));
684 }
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100685 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
686 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
687 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
688 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
689 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
690 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100691 if(lstm_params.has_projection())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100692 {
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100693 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100694 lstm_params.hidden_state_zero(),
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100695 true)));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100696 if(lstm_params.projection_bias() != nullptr)
697 {
698 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100699 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info,
700 &projection_eff_bias_info, ConvertPolicy::SATURATE));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100701 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100702 }
703
704 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
705 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
706
707 // Validate weights transpose
708 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_forget_weights, &input_weights_transposed));
709 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_cell_weights, &input_weights_transposed));
710 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_output_weights, &input_weights_transposed));
711 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
712 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
713 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
714 if(!lstm_params.has_cifg_opt())
715 {
716 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
717 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
718 }
719 if(lstm_params.has_projection())
720 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100721 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
722 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100723 }
724
725 GEMMLowpOutputStageInfo gemmlowp_info;
726 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
727 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
728 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
729 gemmlowp_info.output_data_type = DataType::QSYMM16;
730
Sheri Zhang3a353982020-04-21 13:10:24 +0100731 const bool has_layer_norm = lstm_params.use_layer_norm();
732
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100733 // Forget gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100734 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100735 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
736 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
737 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100738 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100739
740 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100741 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100742
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100743 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100744
745 if(lstm_params.has_peephole_opt())
746 {
747 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
Michalis Spyrou1009e872020-07-27 12:48:34 +0100748 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
749 RoundingPolicy::TO_ZERO));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100750 const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
751 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
752 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100753 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100754 }
755
Sheri Zhang3a353982020-04-21 13:10:24 +0100756 if(has_layer_norm)
757 {
758 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
759 const ITensorInfo *b_info = forget_gate_bias;
760 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
761 }
762
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100763 // Output quantization info of Sigmoid and Tanh activations
764 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
765
766 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
767 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
768
769 // Modulation gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100770 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100771 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
772 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100773 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100774
775 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100776 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &input_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100777
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100778 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100779
Sheri Zhang3a353982020-04-21 13:10:24 +0100780 if(has_layer_norm)
781 {
782 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
783 const ITensorInfo *b_info = cell_bias;
784 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
785 }
786
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100787 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
788 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
789
790 // Input gate.
791 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
792 if(lstm_params.has_cifg_opt())
793 {
794 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100795 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100796 }
797 else
798 {
799 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
800 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
801 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
802 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
803 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
804 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
805
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100806 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100807 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
808 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100809 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100810
811 const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100812 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100813
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100814 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100815
816 if(lstm_params.has_peephole_opt())
817 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100818 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
819 RoundingPolicy::TO_ZERO));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100820 const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
821 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100822 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100823 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100824 }
825
Sheri Zhang3a353982020-04-21 13:10:24 +0100826 if(has_layer_norm)
827 {
828 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
829 const ITensorInfo *b_info = lstm_params.input_gate_bias();
830 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
831 }
832
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100833 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100834 }
835 // Cell.
Michalis Spyrou1009e872020-07-27 12:48:34 +0100836 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
837 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100838 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100839 if(quantized_cell_clip > 0)
840 {
841 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
842 quantized_cell_clip)));
843 }
844 // Output gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100845 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100846 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
847 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100848 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100849
850 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100851 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100852
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100853 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100854 if(lstm_params.has_peephole_opt())
855 {
856 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
857 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
858 // Here we are not using the output stage because all operations are done in float
859 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
860 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100861 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
862 RoundingPolicy::TO_ZERO));
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100863 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100864 }
865
Sheri Zhang3a353982020-04-21 13:10:24 +0100866 if(has_layer_norm)
867 {
868 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
869 const ITensorInfo *b_info = output_gate_bias;
870 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
871 }
872
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100873 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
874 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
875
876 // Hidden.
877 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
878 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100879 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
880
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100881 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
Michalis Spyrou1009e872020-07-27 12:48:34 +0100882 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100883 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
884 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
Sang-Hoon Park9f893752020-10-20 15:33:31 +0100885 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
886 gemmlowp_info.output_data_type = hidden_out_info.data_type();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100887 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
888
889 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100890
891 // Projection.
892 if(lstm_params.has_projection())
893 {
894 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100895 ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100896
897 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
898 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
899 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
900 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
901 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
902 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
903 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
904
905 const TensorInfo projection_outstage_info(*output_state_out);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100906 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
907
908 TensorInfo projection_mm_out_info{ mm_out_info };
909 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
910
911 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
912 &projection_outstage_info));
913
914 if(projection_tensor_copy_required)
915 {
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100916 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100917 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100918
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100919 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100920
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100921 if(projection_tensor_copy_required)
922 {
923 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
924 }
925
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100926 int8_t quantized_projection_clip{ 0 };
927 if(lstm_params.projection_clip() > 0.0f)
928 {
929 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
930 }
931
932 if(quantized_projection_clip > 0)
933 {
934 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
935 quantized_projection_clip)));
936 }
937 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100938 else
939 {
940 if(projection_tensor_copy_required)
941 {
942 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
943 }
944 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100945
946 if(cell_state_out->total_size() > 0)
947 {
948 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
949 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
950 }
951
952 if(output_state_out->total_size() > 0)
953 {
954 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
955 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
956 }
957
Sheri Zhang7e20e292021-02-02 11:49:34 +0000958 ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(output_state_out, output));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100959 return Status{};
960}
961
962void CLQLSTMLayer::run()
963{
964 prepare();
965
966 // Acquire all the temporaries
967 MemoryGroupResourceScope scope_mg(_memory_group);
968
969 // Forget gate.
970 _mm_input_to_forget.run();
971 _input_to_forget_outstage.run();
972
973 _mm_recurrent_to_forget.run();
974 _recurrent_to_forget_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100975 _accumulate_input_recurrent_forget.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100976
977 if(_has_peephole)
978 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100979 _pixelwise_mul_cell_to_forget.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100980 _cell_to_forget_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100981 _accumulate_cell_forget.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100982 }
983
Sheri Zhang3a353982020-04-21 13:10:24 +0100984 if(_has_layer_norm)
985 {
986 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
987 }
988
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100989 _forget_gate_sigmoid.run();
990
991 // Modulation gate.
992 _mm_input_to_cell.run();
993 _input_to_cell_outstage.run();
994
995 _mm_recurrent_to_cell.run();
996 _recurrent_to_cell_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100997 _accumulate_input_recurrent_modulation.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100998
Sheri Zhang3a353982020-04-21 13:10:24 +0100999 if(_has_layer_norm)
1000 {
1001 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
1002 }
1003
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001004 _cell_gate_tanh.run();
1005
1006 // Input gate
1007 if(_has_cifg)
1008 {
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001009 _input_gate_sub.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001010 }
1011 else
1012 {
1013 _mm_input_to_input.run();
1014 _input_to_input_outstage.run();
1015 _mm_recurrent_to_input.run();
1016 _recurrent_to_input_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001017 _accumulate_input_recurrent_input.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001018
1019 if(_has_peephole)
1020 {
Michalis Spyrou1009e872020-07-27 12:48:34 +01001021 _pixelwise_mul_cell_to_input.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001022 _cell_to_input_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001023 _accumulate_cell_input.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001024 }
1025
Sheri Zhang3a353982020-04-21 13:10:24 +01001026 if(_has_layer_norm)
1027 {
1028 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
1029 }
1030
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001031 _input_gate_sigmoid.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001032 }
1033
1034 // Cell.
Michalis Spyrou1009e872020-07-27 12:48:34 +01001035 _pixelwise_mul_forget_cell.run();
1036 _pixelwise_mul_input_cell.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001037 _add_forget_cell.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001038 if(_has_cell_clipping)
1039 {
1040 _cell_clip.run();
1041 }
1042
1043 // Output gate.
1044 _mm_input_to_output.run();
1045 _input_to_output_outstage.run();
1046 _mm_recurrent_to_output.run();
1047 _recurrent_to_output_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001048 _accumulate_input_recurrent_output.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001049 if(_has_peephole)
1050 {
Michalis Spyrou1009e872020-07-27 12:48:34 +01001051 _pixelwise_mul_cell_to_output.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001052 _cell_to_output_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001053 _accumulate_cell_to_output.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001054 }
1055
Sheri Zhang3a353982020-04-21 13:10:24 +01001056 if(_has_layer_norm)
1057 {
1058 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
1059 }
1060
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001061 _output_gate_sigmoid.run();
1062
1063 // Hidden.
1064 _hidden_tanh.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +01001065 _pixelwise_mul_hidden.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001066 _hidden_outstage.run();
1067
1068 // Projection.
1069 if(_has_projection)
1070 {
1071 _mm_projection.run();
1072 _projection_outstage.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001073
1074 if(_projection_tensor_copy_required)
1075 {
1076 _projection_output_to_accumulate_copy.run();
1077 }
1078
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001079 _accumulate_projection.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001080
1081 if(_projection_tensor_copy_required)
1082 {
1083 _projection_accumulate_to_output_copy.run();
1084 }
1085
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001086 if(_has_projection_clipping)
1087 {
1088 _projection_clip.run();
1089 }
1090 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001091 else
1092 {
1093 if(_projection_tensor_copy_required)
1094 {
1095 _hidden_to_output_copy.run();
1096 }
1097 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001098
1099 // Copy output_state_out to output
Sheri Zhang7e20e292021-02-02 11:49:34 +00001100 _copy_output.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001101}
1102
1103void CLQLSTMLayer::prepare()
1104{
1105 if(!_is_prepared)
1106 {
1107 // Pre-transpose weights to be used in GEMM.
1108 _input_to_forget_weights_transposed.allocator()->allocate();
1109 _input_to_cell_weights_transposed.allocator()->allocate();
1110 _input_to_output_weights_transposed.allocator()->allocate();
1111 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1112 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1113 _recurrent_to_output_weights_transposed.allocator()->allocate();
1114 _transpose_input_to_forget_weights.run();
1115 _transpose_input_to_cell_weights.run();
1116 _transpose_input_to_output_weights.run();
1117 _transpose_recurrent_to_forget_weights.run();
1118 _transpose_recurrent_to_cell_weights.run();
1119 _transpose_recurrent_to_output_weights.run();
1120
1121 // Precompute effective biases
1122 if(_has_cifg)
1123 {
1124 _ones.map(true);
1125 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1126 _ones.unmap();
1127 }
1128 else
1129 {
1130 _input_to_input_eff_bias.allocator()->allocate();
1131 _recurrent_to_input_eff_bias.allocator()->allocate();
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001132
1133 ITensorPack input_to_input_red_pack = { { ACL_SRC, _input_to_input_weights }, { ACL_DST, &_input_to_input_eff_bias } };
1134 CLScheduler::get().enqueue_op(*_input_to_input_reduction, input_to_input_red_pack, false);
1135
1136 ITensorPack rec_to_input_red_pack = { { ACL_SRC, _recurrent_to_input_weights }, { ACL_DST, &_recurrent_to_input_eff_bias } };
1137 CLScheduler::get().enqueue_op(*_recurrent_to_input_reduction, rec_to_input_red_pack, false);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001138
1139 _input_to_input_weights_transposed.allocator()->allocate();
1140 _recurrent_to_input_weights_transposed.allocator()->allocate();
1141 _transpose_input_to_input_weights.run();
1142 _transpose_recurrent_to_input_weights.run();
1143 _input_to_input_weights->mark_as_unused();
1144 _recurrent_to_input_weights->mark_as_unused();
1145 }
1146 _input_to_forget_eff_bias.allocator()->allocate();
1147 _recurrent_to_forget_eff_bias.allocator()->allocate();
1148 _input_to_cell_eff_bias.allocator()->allocate();
1149 _recurrent_to_cell_eff_bias.allocator()->allocate();
1150 _input_to_output_eff_bias.allocator()->allocate();
1151 _recurrent_to_output_eff_bias.allocator()->allocate();
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001152
1153 ITensorPack input_to_forget_red_pack = { { ACL_SRC, _input_to_forget_weights }, { ACL_DST, &_input_to_forget_eff_bias } };
1154 CLScheduler::get().enqueue_op(*_input_to_forget_reduction, input_to_forget_red_pack, false);
1155
1156 ITensorPack rec_to_forget_red_pack = { { ACL_SRC, _recurrent_to_forget_weights }, { ACL_DST, &_recurrent_to_forget_eff_bias } };
1157 CLScheduler::get().enqueue_op(*_recurrent_to_forget_reduction, rec_to_forget_red_pack, false);
1158
1159 ITensorPack input_to_cell_red_pack = { { ACL_SRC, _input_to_cell_weights }, { ACL_DST, &_input_to_cell_eff_bias } };
1160 CLScheduler::get().enqueue_op(*_input_to_cell_reduction, input_to_cell_red_pack, false);
1161
1162 ITensorPack rec_to_cell_red_pack = { { ACL_SRC, _recurrent_to_cell_weights }, { ACL_DST, &_recurrent_to_cell_eff_bias } };
1163 CLScheduler::get().enqueue_op(*_recurrent_to_cell_reduction, rec_to_cell_red_pack, false);
1164
1165 ITensorPack input_to_output_red_pack = { { ACL_SRC, _input_to_output_weights }, { ACL_DST, &_input_to_output_eff_bias } };
1166 CLScheduler::get().enqueue_op(*_input_to_output_reduction, input_to_output_red_pack, false);
1167
1168 ITensorPack rec_to_output_red_pack = { { ACL_SRC, _recurrent_to_output_weights }, { ACL_DST, &_recurrent_to_output_eff_bias } };
1169 CLScheduler::get().enqueue_op(*_recurrent_to_output_reduction, rec_to_output_red_pack, false);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001170
1171 if(_has_projection)
1172 {
Michele Di Giorgio11c562c2020-06-10 16:34:50 +01001173 _projection_eff_bias.allocator()->allocate();
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001174 ITensorPack proj_red_pack{ { ACL_SRC, _projection_weights }, { ACL_DST, &_projection_eff_bias } };
1175 CLScheduler::get().enqueue_op(*_projection_reduction, proj_red_pack, false);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001176 if(_projection_bias != nullptr)
1177 {
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001178 _projection_bias_add.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001179 _projection_bias->mark_as_unused();
1180 }
1181
1182 _projection_weights_transposed.allocator()->allocate();
1183 _transpose_projection_weights.run();
1184 _projection_weights->mark_as_unused();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001185
1186 if(!_projection_tensor_copy_required)
1187 {
1188 _hidden_gate.mark_as_unused();
1189 _projection_accumulate_res.mark_as_unused();
1190 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001191 }
1192
1193 // Mark weights as unused
1194 _input_to_forget_weights->mark_as_unused();
1195 _input_to_cell_weights->mark_as_unused();
1196 _input_to_output_weights->mark_as_unused();
1197 _recurrent_to_forget_weights->mark_as_unused();
1198 _recurrent_to_cell_weights->mark_as_unused();
1199 _recurrent_to_output_weights->mark_as_unused();
1200
1201 CLScheduler::get().queue().finish();
1202 _is_prepared = true;
1203 }
1204}
1205
1206} // namespace arm_compute