blob: 12f6f8929034c0bd26092ec200b66a6eb59689a1 [file] [log] [blame]
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001/*
Pablo Marquez Tello8bab5882022-08-17 16:34:35 +01002 * Copyright (c) 2020-2022 Arm Limited.
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h"
25
26#include "arm_compute/core/KernelDescriptors.h"
27#include "arm_compute/core/QuantizationInfo.h"
28#include "arm_compute/core/Utils.h"
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010029#include "arm_compute/core/utils/misc/InfoHelpers.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010031#include "arm_compute/core/Validate.h"
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010032#include "arm_compute/runtime/CL/CLScheduler.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010033
34#include "src/common/utils/Log.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010035#include "src/core/CL/kernels/CLFillBorderKernel.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010036#include "src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010037#include "src/core/helpers/WindowHelpers.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010038#include "src/gpu/cl/kernels/ClGemmLowpReductionKernel.h"
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010039
40namespace arm_compute
41{
42using namespace arm_compute::utils::info_helpers;
Georgios Pinitas4a578b92021-06-25 12:13:49 +010043using namespace arm_compute::opencl::kernels;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010044namespace
45{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010046Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info,
47 const ITensorInfo *mm_input,
48 const ITensorInfo *mm_weights,
49 const ITensorInfo *bias,
50 float gemmlowp_scale,
51 const TensorInfo *mm_res_info,
52 const TensorInfo *outstage_tensor_info)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010053{
54 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010055 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
56 gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
57 ARM_COMPUTE_RETURN_ON_ERROR(
58 CLGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010059 return Status{};
60}
61} // namespace
62
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +010063Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
64{
65 ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
66 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
67 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
68 ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
69 return Status{};
70}
71
72void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst)
73{
74 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
75 _src = &src;
76 _dst = &dst;
77 _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
78 _window = calculate_max_window(*_src->info(), Steps());
79}
80
81void CLQLSTMLayer::TensorCopyKernel::run()
82{
83 auto &q = CLScheduler::get().queue();
84
85 _src->map(q, true);
86 _dst->map(q, true);
87
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010088 Iterator input_iter{_src, _window};
89 Iterator output_iter{_dst, _window};
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +010090
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010091 execute_window_loop(
92 _window, [&](const Coordinates &) { memcpy(output_iter.ptr(), input_iter.ptr(), _row_size); }, input_iter,
93 output_iter);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +010094
95 _src->unmap(q);
96 _dst->unmap(q);
97}
98
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +010099CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Georgios Pinitas4a578b92021-06-25 12:13:49 +0100100 : _input_to_input_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
101 _recurrent_to_input_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
102 _input_to_forget_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
103 _recurrent_to_forget_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
104 _input_to_cell_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
105 _recurrent_to_cell_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
106 _input_to_output_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
107 _recurrent_to_output_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
108 _projection_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100109 _layer_norms(),
Sheri Zhang7e20e292021-02-02 11:49:34 +0000110 _copy_output()
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100111{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100112 for (auto &norm : _layer_norms)
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100113 {
Georgios Pinitas40f51a62020-11-21 03:04:18 +0000114 norm = std::make_unique<CLQLSTMLayerNormalizationKernel>();
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100115 }
116
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100117 _memory_group = MemoryGroup(std::move(memory_manager));
118}
119
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100120CLQLSTMLayer::~CLQLSTMLayer() = default;
121
122void CLQLSTMLayer::configure_layer_norm(LayerNormGate g, const ICLTensor *in)
123{
124 ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
125
126 CLTensor *out = &get_layer_norm_output(g);
127 _memory_group.manage(out);
128 out->allocator()->init(*(in->info()));
129
130 get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g));
131}
132
133Status CLQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
134{
135 // Output quantization scale will be different, but ignored here
136 // since it will be configured at configure() stage.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100137 const TensorInfo out{in};
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100138 return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
139}
140
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100141void CLQLSTMLayer::configure_mm(const CLCompileContext &compile_context,
142 CLGEMMLowpMatrixMultiplyCore &mm,
143 CLGEMMLowpOutputStage &outstage,
144 GEMMLowpOutputStageInfo &gemmlowp_info,
145 const ICLTensor *mm_input,
146 const ICLTensor *mm_weights,
147 const ICLTensor *bias,
148 CLTensor *mm_res,
149 CLTensor *outstage_res,
150 float gemmlowp_scale,
151 const TensorInfo &mm_res_info,
152 const TensorInfo &outstage_tensor_info)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100153{
154 _memory_group.manage(mm_res);
155 _memory_group.manage(outstage_res);
156
157 mm_res->allocator()->init(mm_res_info);
158 outstage_res->allocator()->init(outstage_tensor_info);
159
160 // Configure matrix-multiplication
Manuel Bottini2b84be52020-04-08 10:15:51 +0100161 mm.configure(compile_context, mm_input, mm_weights, nullptr, mm_res);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100162
163 // Configure output stage
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100164 quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier,
165 &gemmlowp_info.gemmlowp_shift);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100166 outstage.configure(compile_context, mm_res, bias, outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100167 mm_res->allocator()->allocate();
168}
169
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100170void CLQLSTMLayer::configure(const ICLTensor *input,
171 const ICLTensor *input_to_forget_weights,
172 const ICLTensor *input_to_cell_weights,
173 const ICLTensor *input_to_output_weights,
174 const ICLTensor *recurrent_to_forget_weights,
175 const ICLTensor *recurrent_to_cell_weights,
176 const ICLTensor *recurrent_to_output_weights,
177 const ICLTensor *forget_gate_bias,
178 const ICLTensor *cell_bias,
179 const ICLTensor *output_gate_bias,
180 ICLTensor *cell_state_in,
181 ICLTensor *output_state_in,
182 ICLTensor *cell_state_out,
183 ICLTensor *output_state_out,
184 ICLTensor *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100185 const LSTMParams<ICLTensor> &lstm_params)
186{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100187 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights,
188 input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
189 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in,
190 output_state_in, cell_state_out, output_state_out, output, lstm_params);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100191}
192
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100193void CLQLSTMLayer::configure(const CLCompileContext &compile_context,
194 const ICLTensor *input,
195 const ICLTensor *input_to_forget_weights,
196 const ICLTensor *input_to_cell_weights,
197 const ICLTensor *input_to_output_weights,
198 const ICLTensor *recurrent_to_forget_weights,
199 const ICLTensor *recurrent_to_cell_weights,
200 const ICLTensor *recurrent_to_output_weights,
201 const ICLTensor *forget_gate_bias,
202 const ICLTensor *cell_bias,
203 const ICLTensor *output_gate_bias,
204 ICLTensor *cell_state_in,
205 ICLTensor *output_state_in,
206 ICLTensor *cell_state_out,
207 ICLTensor *output_state_out,
208 ICLTensor *output,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100209 const LSTMParams<ICLTensor> &lstm_params)
210{
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100211 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
212 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100213 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
214 cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100215
ramelg016d891572021-09-29 10:05:09 +0100216 ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
217 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
218 forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
219 cell_state_out, output_state_out, output, lstm_params);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100220 // Set lstm parameters
221 LSTMParams<ITensorInfo> lstm_params_info{};
222 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
223
224 // Validate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100225 ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(
226 input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
227 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
228 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), cell_state_in->info(),
229 output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(), lstm_params_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100230
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100231 const int batch_size = input->info()->dimension(1);
232 const int num_units = input_to_output_weights->info()->dimension(1);
233 const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100234
235 const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
236 const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
237 const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
238
239 _projection_bias = lstm_params.projection_bias();
240 _input_to_forget_weights = input_to_forget_weights;
241 _input_to_cell_weights = input_to_cell_weights;
242 _input_to_output_weights = input_to_output_weights;
243 _recurrent_to_forget_weights = recurrent_to_forget_weights;
244 _recurrent_to_cell_weights = recurrent_to_cell_weights;
245 _recurrent_to_output_weights = recurrent_to_output_weights;
246 _projection_weights = lstm_params.projection_weights();
247
Sheri Zhang3a353982020-04-21 13:10:24 +0100248 // Layer normalization
249 _has_layer_norm = lstm_params.use_layer_norm();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100250 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100251 {
252 set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
253 set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
254 set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
255 set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
256
257 set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
258 set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
259 set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
260 set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
261 }
262
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100263 _has_cifg = lstm_params.has_cifg_opt();
264 _has_projection = lstm_params.has_projection();
265 _has_peephole = lstm_params.has_peephole_opt();
266
267 // Calculate and decompose effective scales for optimizing matmul calculation
268 const int32_t cell_shift = log2(qcell_state_in.scale);
269
270 // Calculate quantized parameters for clipping.
271 int16_t quantized_cell_clip = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100272 if (lstm_params.cell_clip() > 0.0f)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100273 {
274 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
275 }
276 _has_cell_clipping = quantized_cell_clip > 0;
277
278 // Precompute effective bias for optimizing the matmul computations.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100279 if (!_has_cifg)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100280 {
281 _input_to_input_weights = lstm_params.input_to_input_weights();
282 _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
283
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100284 _input_to_input_reduction->configure(compile_context, _input_to_input_weights->info(),
285 _input_to_input_eff_bias.info(),
286 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
287 _recurrent_to_input_reduction->configure(
288 compile_context, _recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(),
289 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100290 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100291 _input_to_forget_reduction->configure(compile_context, input_to_forget_weights->info(),
292 _input_to_forget_eff_bias.info(),
293 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
294 _recurrent_to_forget_reduction->configure(
295 compile_context, recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(),
296 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
297 _input_to_cell_reduction->configure(compile_context, input_to_cell_weights->info(), _input_to_cell_eff_bias.info(),
298 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
299 _recurrent_to_cell_reduction->configure(
300 compile_context, recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(),
301 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
302 _input_to_output_reduction->configure(compile_context, input_to_output_weights->info(),
303 _input_to_output_eff_bias.info(),
304 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
305 _recurrent_to_output_reduction->configure(
306 compile_context, recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(),
307 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
308 if (_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100309 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100310 _projection_reduction->configure(
311 compile_context, _projection_weights->info(), _projection_eff_bias.info(),
312 GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
313 if (_projection_bias != nullptr)
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100314 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100315 _projection_bias_add.configure(compile_context, _projection_bias, &_projection_eff_bias,
316 &_projection_eff_bias, ConvertPolicy::SATURATE);
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100317 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100318 }
319
320 // Pre-transpose weights to be used in GEMM.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100321 _transpose_input_to_forget_weights.configure(compile_context, input_to_forget_weights,
322 &_input_to_forget_weights_transposed);
323 _transpose_input_to_cell_weights.configure(compile_context, input_to_cell_weights,
324 &_input_to_cell_weights_transposed);
325 _transpose_input_to_output_weights.configure(compile_context, input_to_output_weights,
326 &_input_to_output_weights_transposed);
327 _transpose_recurrent_to_forget_weights.configure(compile_context, recurrent_to_forget_weights,
328 &_recurrent_to_forget_weights_transposed);
329 _transpose_recurrent_to_cell_weights.configure(compile_context, recurrent_to_cell_weights,
330 &_recurrent_to_cell_weights_transposed);
331 _transpose_recurrent_to_output_weights.configure(compile_context, recurrent_to_output_weights,
332 &_recurrent_to_output_weights_transposed);
333 if (!_has_cifg)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100334 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100335 _transpose_input_to_input_weights.configure(compile_context, lstm_params.input_to_input_weights(),
336 &_input_to_input_weights_transposed);
337 _transpose_recurrent_to_input_weights.configure(compile_context, lstm_params.recurrent_to_input_weights(),
338 &_recurrent_to_input_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100339 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100340 if (_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100341 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100342 _transpose_projection_weights.configure(compile_context, _projection_weights, &_projection_weights_transposed);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100343 }
344
345 GEMMLowpOutputStageInfo gemmlowp_info;
346 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
347 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
348 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
349 gemmlowp_info.output_data_type = DataType::QSYMM16;
350
351 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
352 // Forget gate.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100353 const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16,
354 QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
355 const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale *
356 qinput.scale / lstm_params.forget_intermediate_scale();
357 configure_mm(compile_context, _mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info, input,
358 &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias, &_mm_input_to_forget_res,
359 &_input_to_forget_outstage_res, input_to_forget_scale, mm_out_info, forget_gate_outstage_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100360
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100361 const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale *
362 qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100363 configure_mm(compile_context, _mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100364 output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
365 &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
366 mm_out_info, forget_gate_outstage_info);
367
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100368 _accumulate_input_recurrent_forget.configure(compile_context, &_input_to_forget_outstage_res,
369 &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100370 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100371 _input_to_forget_outstage_res.allocator()->allocate();
372
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100373 if (_has_peephole)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100374 {
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100375 _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100376 _memory_group.manage(&_mul_cell_to_forget_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100377 _pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(),
378 &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE,
379 RoundingPolicy::TO_ZERO);
380 _cell_to_forget_outstage_res.allocator()->init(
381 TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16,
382 QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100383 _memory_group.manage(&_cell_to_forget_outstage_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100384 const float cell_to_forget_scale =
385 std::pow(2, cell_shift) *
386 lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale /
387 lstm_params.forget_intermediate_scale();
388 quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier,
389 &gemmlowp_info.gemmlowp_shift);
390 _cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr,
391 &_cell_to_forget_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100392 _mul_cell_to_forget_res.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100393 _accumulate_cell_forget.configure(compile_context, &_recurrent_to_forget_outstage_res,
394 &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100395 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100396 _cell_to_forget_outstage_res.allocator()->allocate();
397 }
398
Sheri Zhang3a353982020-04-21 13:10:24 +0100399 CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
400
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100401 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100402 {
403 configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
404 _recurrent_to_forget_outstage_res.allocator()->allocate();
405 forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
406 }
407
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100408 // Output quantization info of Sigmoid and Tanh activations
409 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
410
411 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
412 _memory_group.manage(&_forget_gate);
413 _forget_gate.allocator()->init(forget_gate_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100414 _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate,
415 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sheri Zhang3a353982020-04-21 13:10:24 +0100416 forget_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100417
418 // Modulation gate.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100419 const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16,
420 QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
421 const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale *
422 qinput.scale / lstm_params.cell_intermediate_scale();
423 configure_mm(compile_context, _mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info, input,
424 &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias, &_mm_input_to_cell_res,
425 &_input_to_cell_outstage_res, input_to_cell_scale, mm_out_info, cell_outstage_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100426
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100427 const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale *
428 qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
429 configure_mm(compile_context, _mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info, output_state_in,
430 &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias, &_mm_recurrent_to_cell_res,
431 &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale, mm_out_info, cell_outstage_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100432
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100433 _accumulate_input_recurrent_modulation.configure(compile_context, &_input_to_cell_outstage_res,
434 &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100435 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100436 _input_to_cell_outstage_res.allocator()->allocate();
437
Sheri Zhang3a353982020-04-21 13:10:24 +0100438 CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
439
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100440 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100441 {
442 configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
443 _recurrent_to_cell_outstage_res.allocator()->allocate();
444 cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
445 }
446
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100447 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
448 _memory_group.manage(&_cell_gate);
449 _cell_gate.allocator()->init(cell_gate_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100450 _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate,
451 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Sheri Zhang3a353982020-04-21 13:10:24 +0100452 cell_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100453
454 // Input gate.
455 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
456 _input_gate.allocator()->init(input_gate_info);
457 _memory_group.manage(&_input_gate);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100458 if (_has_cifg)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100459 {
460 _ones.allocator()->init(*_forget_gate.info());
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100461 _input_gate_sub.configure(compile_context, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100462 _ones.allocator()->allocate();
463 }
464 else
465 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100466 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
467 QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
468 const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale *
469 qinput.scale / lstm_params.input_intermediate_scale();
470 configure_mm(compile_context, _mm_input_to_input, _input_to_input_outstage, gemmlowp_info, input,
471 &_input_to_input_weights_transposed, &_input_to_input_eff_bias, &_mm_input_to_input_res,
472 &_input_to_input_outstage_res, input_to_input_scale, mm_out_info, input_outstage_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100473
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100474 const float recurrent_to_input_scale =
475 _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale /
476 lstm_params.input_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100477 configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100478 output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100479 &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
480 mm_out_info, input_outstage_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100481 _accumulate_input_recurrent_input.configure(compile_context, &_input_to_input_outstage_res,
482 &_recurrent_to_input_outstage_res,
483 &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100484 _input_to_input_outstage_res.allocator()->allocate();
485
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100486 if (_has_peephole)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100487 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100488 _mul_cell_to_input_res.allocator()->init(
489 TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100490 _memory_group.manage(&_mul_cell_to_input_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100491 _pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(),
492 &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE,
493 RoundingPolicy::TO_ZERO);
494 const float cell_to_input_scale =
495 std::pow(2, cell_shift) *
496 lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale /
497 lstm_params.input_intermediate_scale();
498 quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier,
499 &gemmlowp_info.gemmlowp_shift);
500 _cell_to_input_outstage_res.allocator()->init(
501 TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16,
502 QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100503 _memory_group.manage(&_cell_to_input_outstage_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100504 _cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr,
505 &_cell_to_input_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100506 _mul_cell_to_input_res.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100507 _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res,
508 &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100509 _cell_to_input_outstage_res.allocator()->allocate();
510 }
511
Sheri Zhang3a353982020-04-21 13:10:24 +0100512 CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
513
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100514 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100515 {
516 configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
517 _recurrent_to_input_outstage_res.allocator()->allocate();
518 input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
519 }
520
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100521 _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate,
522 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sheri Zhang3a353982020-04-21 13:10:24 +0100523 input_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100524 }
525 // Cell.
Michalis Spyrou1009e872020-07-27 12:48:34 +0100526 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100527 _pixelwise_mul_forget_cell.configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f,
528 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100529 const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
530 const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100531 const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
532 QuantizationInfo(mul_input_cell_scale, 0));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100533 _memory_group.manage(&_mul_input_cell_res);
534 _mul_input_cell_res.allocator()->init(mul_input_cell_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100535 _pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f,
536 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100537 _cell_gate.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100538 _add_forget_cell.configure(compile_context, &_forget_gate, &_mul_input_cell_res, cell_state_out,
539 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100540 _mul_input_cell_res.allocator()->allocate();
541 _forget_gate.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100542 if (_has_cell_clipping)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100543 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100544 _cell_clip.configure(compile_context, cell_state_out, nullptr,
545 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
546 -quantized_cell_clip, quantized_cell_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100547 }
548 // Output gate.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100549 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
550 QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
551 const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale *
552 qinput.scale / lstm_params.output_intermediate_scale();
553 configure_mm(compile_context, _mm_input_to_output, _input_to_output_outstage, gemmlowp_info, input,
554 &_input_to_output_weights_transposed, &_input_to_output_eff_bias, &_mm_input_to_output_res,
555 &_input_to_output_outstage_res, input_to_output_scale, mm_out_info, output_outstage_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100556
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100557 const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale *
558 qoutput_state_in.scale / lstm_params.output_intermediate_scale();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100559 configure_mm(compile_context, _mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100560 output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
561 &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
562 mm_out_info, output_outstage_info);
563
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100564 _accumulate_input_recurrent_output.configure(compile_context, &_recurrent_to_output_outstage_res,
565 &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100566 ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100567 _input_to_output_outstage_res.allocator()->allocate();
568
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100569 if (_has_peephole)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100570 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100571 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100572 // Here we are not using the output stage because all operations are done in float
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100573 _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100574 _memory_group.manage(&_mul_cell_to_output_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100575 _pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(),
576 &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE,
577 RoundingPolicy::TO_ZERO);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100578
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100579 const float cell_to_output_scale =
580 std::pow(2, cell_shift) *
581 lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale /
582 lstm_params.output_intermediate_scale();
583 quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier,
584 &gemmlowp_info.gemmlowp_shift);
585 _cell_to_output_outstage_res.allocator()->init(
586 TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16,
587 QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100588 _memory_group.manage(&_cell_to_output_outstage_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100589 _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr,
590 &_cell_to_output_outstage_res, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100591 _mul_cell_to_output_res.allocator()->allocate();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100592
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100593 _accumulate_cell_to_output.configure(compile_context, &_recurrent_to_output_outstage_res,
594 &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100595 ConvertPolicy::SATURATE);
596 _cell_to_output_outstage_res.allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100597 }
598
Sheri Zhang3a353982020-04-21 13:10:24 +0100599 CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
600
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100601 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100602 {
603 configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
604 _recurrent_to_output_outstage_res.allocator()->allocate();
605 output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
606 }
607
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100608 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
609 _memory_group.manage(&_output_gate);
610 _output_gate.allocator()->init(output_gate_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100611 _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate,
612 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Sheri Zhang3a353982020-04-21 13:10:24 +0100613 output_activation_input->allocator()->allocate();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100614
615 // Hidden.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100616 _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate,
617 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100618 // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100619 _memory_group.manage(&_hidden_mul_res);
620 const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
621 _hidden_mul_res.allocator()->init(hidden_mul_res);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100622 _pixelwise_mul_hidden.configure(compile_context, &_output_gate, &_input_gate, &_hidden_mul_res, 1.f,
623 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100624 _output_gate.allocator()->allocate();
625 _input_gate.allocator()->allocate();
626 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100627 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier,
628 &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100629 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
630 gemmlowp_info.output_data_type = output_state_in->info()->data_type();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100631
632 _projection_tensor_copy_required = (num_units != output_size);
633 ICLTensor *hidden_gate_result = output_state_out;
634
635 _memory_group.manage(&_hidden_gate);
636
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100637 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100638 {
639 _hidden_gate.allocator()->init(*output_state_out->info());
640 _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
641 hidden_gate_result = &_hidden_gate;
642 }
643
644 _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100645 _hidden_mul_res.allocator()->allocate();
646
647 // Projection.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100648 if (_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100649 {
650 const TensorInfo projection_outstage_info(*output_state_out->info());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100651 const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
652 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
653 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
654 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
655 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
656 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100657
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100658 TensorInfo projection_mm_out_info{mm_out_info};
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100659 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100660
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100661 configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info, hidden_gate_result,
662 &_projection_weights_transposed, &_projection_eff_bias, &_mm_projection_res,
663 &_projection_outstage_res, projection_scale, projection_mm_out_info, projection_outstage_info);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100664
665 ICLTensor *accumulate_destination = output_state_out;
666
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100667 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100668 {
669 _hidden_gate.allocator()->allocate();
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100670 _projection_accumulate_res.allocator()->init(*output_state_in->info());
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100671 _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
Sang-Hoon Park840a72c2020-09-23 13:24:13 +0100672 _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100673 accumulate_destination = &_projection_accumulate_res;
674 }
675
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100676 _accumulate_projection.configure(compile_context, &_projection_outstage_res, accumulate_destination,
677 accumulate_destination, ConvertPolicy::SATURATE);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100678 _projection_outstage_res.allocator()->allocate();
679
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100680 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100681 {
682 _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
683 _projection_accumulate_res.allocator()->allocate();
684 }
685
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100686 int8_t quantized_projection_clip{0};
687 if (lstm_params.projection_clip() > 0.0f)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100688 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100689 quantized_projection_clip =
690 utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100691 }
692
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100693 if (quantized_projection_clip > 0)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100694 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100695 _projection_clip.configure(compile_context, output_state_out, nullptr,
696 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
697 -quantized_projection_clip, quantized_projection_clip));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100698 _has_projection_clipping = true;
699 }
700 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100701 else
702 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100703 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100704 {
705 _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
706 _hidden_gate.allocator()->allocate();
707 }
708 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +0100709
710 // Copy output_state_out to output
Sheri Zhang7e20e292021-02-02 11:49:34 +0000711 _copy_output.configure(compile_context, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100712}
713
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100714Status CLQLSTMLayer::validate(const ITensorInfo *input,
715 const ITensorInfo *input_to_forget_weights,
716 const ITensorInfo *input_to_cell_weights,
717 const ITensorInfo *input_to_output_weights,
718 const ITensorInfo *recurrent_to_forget_weights,
719 const ITensorInfo *recurrent_to_cell_weights,
720 const ITensorInfo *recurrent_to_output_weights,
721 const ITensorInfo *forget_gate_bias,
722 const ITensorInfo *cell_bias,
723 const ITensorInfo *output_gate_bias,
724 const ITensorInfo *cell_state_in,
725 const ITensorInfo *output_state_in,
726 const ITensorInfo *cell_state_out,
727 const ITensorInfo *output_state_out,
728 const ITensorInfo *output,
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100729 const LSTMParams<ITensorInfo> &lstm_params)
730{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100731 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
732 recurrent_to_forget_weights, recurrent_to_cell_weights,
733 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
734 cell_state_in, output_state_in, cell_state_out, output_state_out, output);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100735
736 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
737 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
738
739 const unsigned int input_size = input->dimension(0);
740 const unsigned int batch_size = input->dimension(1);
741 const unsigned int num_units = input_to_output_weights->dimension(1);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100742 const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100743
744 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
745 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100746 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights,
747 input_to_cell_weights);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100748 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
749 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100750 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights,
751 recurrent_to_cell_weights);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100752 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100753 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights,
754 input_to_output_weights, recurrent_to_forget_weights,
755 recurrent_to_cell_weights, recurrent_to_output_weights);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100756
757 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
758 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
759 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
760 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
761 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
762
763 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
764 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
765 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
766 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
767
768 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
769 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
770 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
771 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
772
773 // Check whether peephole weights are all there or none
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100774 if (lstm_params.has_peephole_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100775 {
776 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100777 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1,
778 DataType::QSYMM16);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100779 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
780 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100781 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(),
782 lstm_params.cell_to_output_weights());
783 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(),
784 lstm_params.cell_to_output_weights());
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100785
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100786 if (!lstm_params.has_cifg_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100787 {
788 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100789 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(),
790 lstm_params.cell_to_input_weights());
791 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(),
792 lstm_params.cell_to_input_weights());
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100793 }
794 }
795
796 const UniformQuantizationInfo qinput = input->quantization_info().uniform();
797 const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
798 const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
799
800 // Calculate and decompose effective scales for optimizing matmul calculation
801 const int32_t cell_shift = log2(qcell_state_in.scale);
802 ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
803
804 // Calculate quantized parameters for clipping.
805 int16_t quantized_cell_clip = 0;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100806 if (lstm_params.cell_clip() > 0.0f)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100807 {
808 quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
809 }
810
811 // Precompute effective bias for optimizing the matmul computations.
812 const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +0100813 const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100814 if (!lstm_params.has_cifg_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100815 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100816 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
817 lstm_params.input_to_input_weights(), &eff_bias_info,
818 GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
819 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
820 lstm_params.recurrent_to_input_weights(), &eff_bias_info,
821 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100822 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100823 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
824 input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
825 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
826 recurrent_to_forget_weights, &eff_bias_info,
827 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
828 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
829 input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
830 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
831 recurrent_to_cell_weights, &eff_bias_info,
832 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
833 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
834 input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
835 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
836 recurrent_to_output_weights, &eff_bias_info,
837 GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
838 if (lstm_params.has_projection())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100839 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100840 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(
841 lstm_params.projection_weights(), &projection_eff_bias_info,
842 GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true)));
843 if (lstm_params.projection_bias() != nullptr)
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100844 {
845 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100846 ARM_COMPUTE_RETURN_ON_ERROR(
847 CLArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info,
848 &projection_eff_bias_info, ConvertPolicy::SATURATE));
Michele Di Giorgio11c562c2020-06-10 16:34:50 +0100849 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100850 }
851
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100852 const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1,
853 input_to_forget_weights->data_type(),
854 input_to_forget_weights->quantization_info());
855 const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1,
856 recurrent_to_forget_weights->data_type(),
857 recurrent_to_forget_weights->quantization_info());
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100858
859 // Validate weights transpose
860 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_forget_weights, &input_weights_transposed));
861 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_cell_weights, &input_weights_transposed));
862 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_output_weights, &input_weights_transposed));
863 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
864 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
865 ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100866 if (!lstm_params.has_cifg_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100867 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100868 ARM_COMPUTE_RETURN_ON_ERROR(
869 CLTranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
870 ARM_COMPUTE_RETURN_ON_ERROR(
871 CLTranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100872 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100873 if (lstm_params.has_projection())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100874 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100875 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1,
876 lstm_params.projection_weights()->data_type(),
877 lstm_params.projection_weights()->quantization_info());
878 ARM_COMPUTE_RETURN_ON_ERROR(
879 CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100880 }
881
882 GEMMLowpOutputStageInfo gemmlowp_info;
883 gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
884 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
885 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
886 gemmlowp_info.output_data_type = DataType::QSYMM16;
887
Sheri Zhang3a353982020-04-21 13:10:24 +0100888 const bool has_layer_norm = lstm_params.use_layer_norm();
889
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100890 // Forget gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100891 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100892 const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
893 QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100894 const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100895 const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale /
896 lstm_params.forget_intermediate_scale();
897 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
898 input_to_forget_scale, &mm_out_info, &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100899
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100900 const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale *
901 qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
902 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
903 &eff_bias_info, recurrent_to_forget_scale, &mm_out_info,
904 &forget_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100905
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100906 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info,
907 &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100908
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100909 if (lstm_params.has_peephole_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100910 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100911 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1,
912 DataType::QSYMM16);
913 ARM_COMPUTE_RETURN_ON_ERROR(
914 CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f,
915 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
916 const float cell_to_forget_scale = std::pow(2, cell_shift) *
917 lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale /
918 lstm_params.forget_intermediate_scale();
919 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
920 cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
921 ARM_COMPUTE_RETURN_ON_ERROR(
922 CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
923 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info,
924 &forget_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100925 }
926
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100927 if (has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100928 {
929 const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
930 const ITensorInfo *b_info = forget_gate_bias;
931 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
932 }
933
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100934 // Output quantization info of Sigmoid and Tanh activations
935 const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
936
937 const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100938 ARM_COMPUTE_RETURN_ON_ERROR(
939 CLActivationLayer::validate(&forget_outstage_info, &forget_gate_info,
940 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100941
942 // Modulation gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100943 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100944 const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
945 QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
946 const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale /
947 lstm_params.cell_intermediate_scale();
948 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
949 input_to_cell_scale, &mm_out_info, &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100950
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100951 const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale *
952 qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
953 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
954 &eff_bias_info, recurrent_to_cell_scale, &mm_out_info,
955 &cell_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100956
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100957 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info,
958 &cell_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100959
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100960 if (has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +0100961 {
962 const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
963 const ITensorInfo *b_info = cell_bias;
964 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
965 }
966
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100967 const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100968 ARM_COMPUTE_RETURN_ON_ERROR(
969 CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info,
970 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100971
972 // Input gate.
973 const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100974 if (lstm_params.has_cifg_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100975 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100976 ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr,
977 "Input gate bias must not be present when CIFG is used");
978 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info,
979 &forget_gate_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100980 }
981 else
982 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100983 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
984 lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
985 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(
986 input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100987 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100988 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights,
989 lstm_params.recurrent_to_input_weights());
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +0100990 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
991 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
992
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +0100993 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100994 const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
995 QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
996 const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale *
997 qinput.scale / lstm_params.input_intermediate_scale();
998 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
999 input_to_input_scale, &mm_out_info, &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001000
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001001 const float recurrent_to_input_scale =
1002 lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale /
1003 lstm_params.input_intermediate_scale();
1004 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
1005 &eff_bias_info, recurrent_to_input_scale, &mm_out_info,
1006 &input_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001007
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001008 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info,
1009 &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001010
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001011 if (lstm_params.has_peephole_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001012 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001013 ARM_COMPUTE_RETURN_ON_ERROR(
1014 CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info,
1015 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1016 const float cell_to_input_scale = std::pow(2, cell_shift) *
1017 lstm_params.cell_to_input_weights()->quantization_info().uniform().scale /
1018 lstm_params.input_intermediate_scale();
1019 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
1020 cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
1021 ARM_COMPUTE_RETURN_ON_ERROR(
1022 CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
1023 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info,
1024 &input_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001025 }
1026
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001027 if (has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +01001028 {
1029 const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
1030 const ITensorInfo *b_info = lstm_params.input_gate_bias();
1031 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
1032 }
1033
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001034 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(
1035 &input_outstage_info, &input_gate_info,
1036 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001037 }
1038 // Cell.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001039 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
1040 &forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1041 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
1042 &input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
1043 ARM_COMPUTE_RETURN_ON_ERROR(
1044 CLArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
1045 if (quantized_cell_clip > 0)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001046 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001047 ARM_COMPUTE_RETURN_ON_ERROR(
1048 CLActivationLayer::validate(cell_state_out, nullptr,
1049 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1050 -quantized_cell_clip, quantized_cell_clip)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001051 }
1052 // Output gate.
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001053 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001054 const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16,
1055 QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
1056 const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale /
1057 lstm_params.output_intermediate_scale();
1058 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info,
1059 input_to_output_scale, &mm_out_info, &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001060
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001061 const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale *
1062 qoutput_state_in.scale / lstm_params.output_intermediate_scale();
1063 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed,
1064 &eff_bias_info, recurrent_to_output_scale, &mm_out_info,
1065 &output_outstage_info));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001066
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001067 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info,
1068 &output_outstage_info, ConvertPolicy::SATURATE));
1069 if (lstm_params.has_peephole_opt())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001070 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001071 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1,
1072 DataType::QSYMM16);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001073 // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
1074 // Here we are not using the output stage because all operations are done in float
1075 // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
1076 // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001077 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
1078 cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
1079 RoundingPolicy::TO_ZERO));
1080 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info,
1081 &output_outstage_info, ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001082 }
1083
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001084 if (has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +01001085 {
1086 const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
1087 const ITensorInfo *b_info = output_gate_bias;
1088 ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
1089 }
1090
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001091 const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001092 ARM_COMPUTE_RETURN_ON_ERROR(
1093 CLActivationLayer::validate(&output_outstage_info, &output_gate_info,
1094 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001095
1096 // Hidden.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001097 ARM_COMPUTE_RETURN_ON_ERROR(
1098 CLActivationLayer::validate(cell_state_out, &input_gate_info,
1099 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001100 const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001101 const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
1102
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001103 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001104 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
1105 &output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001106 const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001107 ARM_COMPUTE_RETURN_ON_ERROR(
1108 quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier,
1109 &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
Sang-Hoon Park9f893752020-10-20 15:33:31 +01001110 gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
1111 gemmlowp_info.output_data_type = hidden_out_info.data_type();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001112 ARM_COMPUTE_RETURN_ON_ERROR(
1113 CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001114
1115 const bool projection_tensor_copy_required = num_units != output_size;
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001116
1117 // Projection.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001118 if (lstm_params.has_projection())
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001119 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001120 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights,
1121 lstm_params.projection_weights());
Sang-Hoon Parkee4833d2020-05-20 09:13:32 +01001122 ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001123
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001124 const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
1125 const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
1126 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(
1127 projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001128 gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
1129 gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
1130 gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
1131 gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
1132
1133 const TensorInfo projection_outstage_info(*output_state_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001134 const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1,
1135 lstm_params.projection_weights()->data_type(),
1136 lstm_params.projection_weights()->quantization_info());
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001137
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001138 TensorInfo projection_mm_out_info{mm_out_info};
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001139 projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
1140
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001141 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed,
1142 &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001143 &projection_outstage_info));
1144
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001145 if (projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001146 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001147 ARM_COMPUTE_RETURN_ON_ERROR(
1148 CLQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001149 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001150
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001151 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(output_state_out, output_state_out, output_state_out,
1152 ConvertPolicy::SATURATE));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001153
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001154 if (projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001155 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001156 ARM_COMPUTE_RETURN_ON_ERROR(
1157 CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001158 }
1159
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001160 int8_t quantized_projection_clip{0};
1161 if (lstm_params.projection_clip() > 0.0f)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001162 {
1163 quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
1164 }
1165
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001166 if (quantized_projection_clip > 0)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001167 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001168 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(
1169 output_state_out, nullptr,
1170 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1171 -quantized_projection_clip, quantized_projection_clip)));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001172 }
1173 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001174 else
1175 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001176 if (projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001177 {
1178 ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
1179 }
1180 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001181
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001182 if (cell_state_out->total_size() > 0)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001183 {
1184 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
1185 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
1186 }
1187
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001188 if (output_state_out->total_size() > 0)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001189 {
1190 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
1191 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
1192 }
1193
Sheri Zhang7e20e292021-02-02 11:49:34 +00001194 ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(output_state_out, output));
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001195 return Status{};
1196}
1197
1198void CLQLSTMLayer::run()
1199{
1200 prepare();
1201
1202 // Acquire all the temporaries
1203 MemoryGroupResourceScope scope_mg(_memory_group);
1204
1205 // Forget gate.
1206 _mm_input_to_forget.run();
1207 _input_to_forget_outstage.run();
1208
1209 _mm_recurrent_to_forget.run();
1210 _recurrent_to_forget_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001211 _accumulate_input_recurrent_forget.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001212
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001213 if (_has_peephole)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001214 {
Michalis Spyrou1009e872020-07-27 12:48:34 +01001215 _pixelwise_mul_cell_to_forget.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001216 _cell_to_forget_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001217 _accumulate_cell_forget.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001218 }
1219
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001220 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +01001221 {
1222 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
1223 }
1224
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001225 _forget_gate_sigmoid.run();
1226
1227 // Modulation gate.
1228 _mm_input_to_cell.run();
1229 _input_to_cell_outstage.run();
1230
1231 _mm_recurrent_to_cell.run();
1232 _recurrent_to_cell_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001233 _accumulate_input_recurrent_modulation.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001234
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001235 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +01001236 {
1237 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
1238 }
1239
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001240 _cell_gate_tanh.run();
1241
1242 // Input gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001243 if (_has_cifg)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001244 {
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001245 _input_gate_sub.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001246 }
1247 else
1248 {
1249 _mm_input_to_input.run();
1250 _input_to_input_outstage.run();
1251 _mm_recurrent_to_input.run();
1252 _recurrent_to_input_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001253 _accumulate_input_recurrent_input.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001254
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001255 if (_has_peephole)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001256 {
Michalis Spyrou1009e872020-07-27 12:48:34 +01001257 _pixelwise_mul_cell_to_input.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001258 _cell_to_input_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001259 _accumulate_cell_input.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001260 }
1261
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001262 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +01001263 {
1264 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
1265 }
1266
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001267 _input_gate_sigmoid.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001268 }
1269
1270 // Cell.
Michalis Spyrou1009e872020-07-27 12:48:34 +01001271 _pixelwise_mul_forget_cell.run();
1272 _pixelwise_mul_input_cell.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001273 _add_forget_cell.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001274 if (_has_cell_clipping)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001275 {
1276 _cell_clip.run();
1277 }
1278
1279 // Output gate.
1280 _mm_input_to_output.run();
1281 _input_to_output_outstage.run();
1282 _mm_recurrent_to_output.run();
1283 _recurrent_to_output_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001284 _accumulate_input_recurrent_output.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001285 if (_has_peephole)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001286 {
Michalis Spyrou1009e872020-07-27 12:48:34 +01001287 _pixelwise_mul_cell_to_output.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001288 _cell_to_output_outstage.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001289 _accumulate_cell_to_output.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001290 }
1291
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001292 if (_has_layer_norm)
Sheri Zhang3a353982020-04-21 13:10:24 +01001293 {
1294 CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
1295 }
1296
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001297 _output_gate_sigmoid.run();
1298
1299 // Hidden.
1300 _hidden_tanh.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +01001301 _pixelwise_mul_hidden.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001302 _hidden_outstage.run();
1303
1304 // Projection.
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001305 if (_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001306 {
1307 _mm_projection.run();
1308 _projection_outstage.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001309
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001310 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001311 {
1312 _projection_output_to_accumulate_copy.run();
1313 }
1314
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001315 _accumulate_projection.run();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001316
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001317 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001318 {
1319 _projection_accumulate_to_output_copy.run();
1320 }
1321
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001322 if (_has_projection_clipping)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001323 {
1324 _projection_clip.run();
1325 }
1326 }
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001327 else
1328 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001329 if (_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001330 {
1331 _hidden_to_output_copy.run();
1332 }
1333 }
Michele Di Giorgiobeb2d452020-05-11 16:17:51 +01001334
1335 // Copy output_state_out to output
Sheri Zhang7e20e292021-02-02 11:49:34 +00001336 _copy_output.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001337}
1338
1339void CLQLSTMLayer::prepare()
1340{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001341 if (!_is_prepared)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001342 {
1343 // Pre-transpose weights to be used in GEMM.
1344 _input_to_forget_weights_transposed.allocator()->allocate();
1345 _input_to_cell_weights_transposed.allocator()->allocate();
1346 _input_to_output_weights_transposed.allocator()->allocate();
1347 _recurrent_to_forget_weights_transposed.allocator()->allocate();
1348 _recurrent_to_cell_weights_transposed.allocator()->allocate();
1349 _recurrent_to_output_weights_transposed.allocator()->allocate();
1350 _transpose_input_to_forget_weights.run();
1351 _transpose_input_to_cell_weights.run();
1352 _transpose_input_to_output_weights.run();
1353 _transpose_recurrent_to_forget_weights.run();
1354 _transpose_recurrent_to_cell_weights.run();
1355 _transpose_recurrent_to_output_weights.run();
1356
1357 // Precompute effective biases
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001358 if (_has_cifg)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001359 {
1360 _ones.map(true);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001361 std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()),
1362 _ones.info()->total_size() / _ones.info()->element_size(), 32767);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001363 _ones.unmap();
1364 }
1365 else
1366 {
1367 _input_to_input_eff_bias.allocator()->allocate();
1368 _recurrent_to_input_eff_bias.allocator()->allocate();
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001369
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001370 ITensorPack input_to_input_red_pack = {{ACL_SRC, _input_to_input_weights},
1371 {ACL_DST, &_input_to_input_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001372 CLScheduler::get().enqueue_op(*_input_to_input_reduction, input_to_input_red_pack, false);
1373
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001374 ITensorPack rec_to_input_red_pack = {{ACL_SRC, _recurrent_to_input_weights},
1375 {ACL_DST, &_recurrent_to_input_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001376 CLScheduler::get().enqueue_op(*_recurrent_to_input_reduction, rec_to_input_red_pack, false);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001377
1378 _input_to_input_weights_transposed.allocator()->allocate();
1379 _recurrent_to_input_weights_transposed.allocator()->allocate();
1380 _transpose_input_to_input_weights.run();
1381 _transpose_recurrent_to_input_weights.run();
1382 _input_to_input_weights->mark_as_unused();
1383 _recurrent_to_input_weights->mark_as_unused();
1384 }
1385 _input_to_forget_eff_bias.allocator()->allocate();
1386 _recurrent_to_forget_eff_bias.allocator()->allocate();
1387 _input_to_cell_eff_bias.allocator()->allocate();
1388 _recurrent_to_cell_eff_bias.allocator()->allocate();
1389 _input_to_output_eff_bias.allocator()->allocate();
1390 _recurrent_to_output_eff_bias.allocator()->allocate();
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001391
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001392 ITensorPack input_to_forget_red_pack = {{ACL_SRC, _input_to_forget_weights},
1393 {ACL_DST, &_input_to_forget_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001394 CLScheduler::get().enqueue_op(*_input_to_forget_reduction, input_to_forget_red_pack, false);
1395
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001396 ITensorPack rec_to_forget_red_pack = {{ACL_SRC, _recurrent_to_forget_weights},
1397 {ACL_DST, &_recurrent_to_forget_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001398 CLScheduler::get().enqueue_op(*_recurrent_to_forget_reduction, rec_to_forget_red_pack, false);
1399
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001400 ITensorPack input_to_cell_red_pack = {{ACL_SRC, _input_to_cell_weights}, {ACL_DST, &_input_to_cell_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001401 CLScheduler::get().enqueue_op(*_input_to_cell_reduction, input_to_cell_red_pack, false);
1402
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001403 ITensorPack rec_to_cell_red_pack = {{ACL_SRC, _recurrent_to_cell_weights},
1404 {ACL_DST, &_recurrent_to_cell_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001405 CLScheduler::get().enqueue_op(*_recurrent_to_cell_reduction, rec_to_cell_red_pack, false);
1406
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001407 ITensorPack input_to_output_red_pack = {{ACL_SRC, _input_to_output_weights},
1408 {ACL_DST, &_input_to_output_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001409 CLScheduler::get().enqueue_op(*_input_to_output_reduction, input_to_output_red_pack, false);
1410
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001411 ITensorPack rec_to_output_red_pack = {{ACL_SRC, _recurrent_to_output_weights},
1412 {ACL_DST, &_recurrent_to_output_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001413 CLScheduler::get().enqueue_op(*_recurrent_to_output_reduction, rec_to_output_red_pack, false);
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001414
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001415 if (_has_projection)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001416 {
Michele Di Giorgio11c562c2020-06-10 16:34:50 +01001417 _projection_eff_bias.allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001418 ITensorPack proj_red_pack{{ACL_SRC, _projection_weights}, {ACL_DST, &_projection_eff_bias}};
Georgios Pinitas4a578b92021-06-25 12:13:49 +01001419 CLScheduler::get().enqueue_op(*_projection_reduction, proj_red_pack, false);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001420 if (_projection_bias != nullptr)
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001421 {
Michalis Spyrouad7515d2020-07-24 00:02:23 +01001422 _projection_bias_add.run();
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001423 _projection_bias->mark_as_unused();
1424 }
1425
1426 _projection_weights_transposed.allocator()->allocate();
1427 _transpose_projection_weights.run();
1428 _projection_weights->mark_as_unused();
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001429
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001430 if (!_projection_tensor_copy_required)
Sang-Hoon Parka7431ae2020-05-12 11:13:30 +01001431 {
1432 _hidden_gate.mark_as_unused();
1433 _projection_accumulate_res.mark_as_unused();
1434 }
Michele Di Giorgio1c1b3aa2020-04-02 17:35:42 +01001435 }
1436
1437 // Mark weights as unused
1438 _input_to_forget_weights->mark_as_unused();
1439 _input_to_cell_weights->mark_as_unused();
1440 _input_to_output_weights->mark_as_unused();
1441 _recurrent_to_forget_weights->mark_as_unused();
1442 _recurrent_to_cell_weights->mark_as_unused();
1443 _recurrent_to_output_weights->mark_as_unused();
1444
1445 CLScheduler::get().queue().finish();
1446 _is_prepared = true;
1447 }
1448}
1449
1450} // namespace arm_compute