blob: 1a08cdeb06e28945912af19964f2f7022541589f [file] [log] [blame]
Michalis Spyrou25f45a42018-08-08 12:53:05 +01001/*
Pablo Marquez Tello9454cf72022-02-16 11:15:58 +00002 * Copyright (c) 2018-2022 Arm Limited.
Michalis Spyrou25f45a42018-08-08 12:53:05 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
25
Michalis Spyrou25f45a42018-08-08 12:53:05 +010026#include "arm_compute/core/Utils.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000027#include "arm_compute/core/utils/misc/InfoHelpers.h"
Michalis Spyrou25f45a42018-08-08 12:53:05 +010028#include "arm_compute/core/utils/misc/ShapeCalculator.h"
29#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010030#include "arm_compute/core/Validate.h"
Michalis Spyrou25f45a42018-08-08 12:53:05 +010031#include "arm_compute/runtime/common/LSTMParams.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010032
ramelg01cbbb0382021-09-17 17:36:57 +010033#include "src/common/utils/Log.h"
Michalis Spyrou25f45a42018-08-08 12:53:05 +010034
Michele Di Giorgio47a89902020-03-09 19:32:33 +000035namespace arm_compute
36{
Michalis Spyrou25f45a42018-08-08 12:53:05 +010037using namespace arm_compute::misc::shape_calculator;
Michele Di Giorgio47a89902020-03-09 19:32:33 +000038using namespace arm_compute::utils::info_helpers;
Michalis Spyrou25f45a42018-08-08 12:53:05 +010039
Michalis Spyrouebcebf12020-10-21 00:04:14 +010040NELSTMLayer::~NELSTMLayer() = default;
41
Michalis Spyrou25f45a42018-08-08 12:53:05 +010042NELSTMLayer::NELSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010043 : _memory_group(std::move(memory_manager)),
44 _fully_connected_input_gate(),
45 _accum_input_gate1(),
46 _subtract_input_gate(),
47 _pixelwise_mul_input_gate(),
48 _activation_input_gate(),
49 _fully_connected_forget_gate(),
50 _accum_forget_gate1(),
51 _pixelwise_mul_forget_gate(),
52 _activation_forget_gate(),
53 _fully_connected_cell_state(),
54 _gemm_cell_state1(),
55 _transpose_cell_state(),
56 _accum_cell_state1(),
57 _accum_cell_state2(),
58 _pixelwise_mul_cell_state1(),
59 _activation_cell_state(),
60 _cell_clip(),
61 _pixelwise_mul_cell_state2(),
62 _fully_connected_output(),
63 _pixelwise_mul_output_state1(),
64 _accum_output1(),
65 _activation_output(),
66 _activation_output_state(),
67 _pixelwise_mul_output_state2(),
68 _fully_connected_output_state(),
69 _projection_clip(),
70 _copy_cell_state(),
71 _copy_output(),
72 _concat_scratch_buffer(),
73 _concat_inputs_forget_gate(),
74 _concat_weights_forget_gate(),
75 _concat_weights_input_gate(),
76 _concat_weights_output(),
77 _mean_std_norm_input_gate(),
78 _pixelwise_mul_input_gate_coeff(),
79 _accum_input_gate_bias(),
80 _mean_std_norm_forget_gate(),
81 _pixelwise_mul_forget_gate_coeff(),
82 _accum_forget_gate_bias(),
83 _mean_std_norm_cell_gate(),
84 _pixelwise_mul_cell_gate_coeff(),
85 _accum_cell_gate_bias(),
86 _mean_std_norm_output_gate(),
87 _pixelwise_mul_output_gate_coeff(),
88 _accum_output_gate_bias(),
89 _input_gate_out1(),
90 _input_gate_out2(),
91 _input_gate_out3(),
92 _input_gate_out4(),
93 _forget_gate_out1(),
94 _forget_gate_out2(),
95 _forget_gate_out3(),
96 _forget_gate_out4(),
97 _forget_gate_out5(),
98 _forget_gate_out6(),
99 _cell_state_out1(),
100 _cell_state_out2(),
101 _cell_state_out3(),
102 _cell_state_out4(),
103 _cell_state_out5(),
104 _output1(),
105 _output2(),
106 _output3(),
107 _output4(),
108 _cell_state_activation(),
109 _output_state1(),
110 _ones(),
111 _input_layer_norm_out1(),
112 _input_layer_norm_out2(),
113 _forget_layer_norm_out1(),
114 _forget_layer_norm_out2(),
115 _cell_layer_norm_out1(),
116 _cell_layer_norm_out2(),
117 _output_layer_norm_out1(),
118 _output_layer_norm_out2(),
119 _run_peephole_opt(false),
120 _run_cifg_opt(false),
121 _perform_cell_clipping(false),
122 _has_projection_weights(false),
123 _perform_projection_clipping(false),
124 _is_prepared(false),
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100125 _is_layer_norm_lstm(false)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100126{
127}
128
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100129void NELSTMLayer::configure(const ITensor *input,
130 const ITensor *input_to_forget_weights,
131 const ITensor *input_to_cell_weights,
132 const ITensor *input_to_output_weights,
133 const ITensor *recurrent_to_forget_weights,
134 const ITensor *recurrent_to_cell_weights,
135 const ITensor *recurrent_to_output_weights,
136 const ITensor *forget_gate_bias,
137 const ITensor *cell_bias,
138 const ITensor *output_gate_bias,
139 const ITensor *output_state_in,
140 const ITensor *cell_state_in,
141 ITensor *scratch_buffer,
142 ITensor *output_state_out,
143 ITensor *cell_state_out,
144 ITensor *output,
145 const LSTMParams<ITensor> &lstm_params,
146 const ActivationLayerInfo &activation_info,
147 float cell_threshold,
148 float projection_threshold)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100149{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100150 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100151 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100152 forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in,
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100153 scratch_buffer, output_state_out, cell_state_out, output);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100154 ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
ramelg01cbbb0382021-09-17 17:36:57 +0100155 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100156 forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in,
157 scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
158 cell_threshold, projection_threshold);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100159
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100160 _is_layer_norm_lstm = lstm_params.use_layer_norm();
161
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100162 // Set lstm parameters
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000163 LSTMParams<ITensorInfo> lstm_params_info{};
164 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100165
166 // Validate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100167 ARM_COMPUTE_ERROR_THROW_ON(NELSTMLayer::validate(
168 input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
169 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
170 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), output_state_in->info(),
171 cell_state_in->info(), scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
172 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100173
Georgios Pinitasda953f22019-04-02 17:27:03 +0100174 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100175
176 // Configure block that calculates the forget gate
177 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
John Kesapides917959c2019-02-04 12:37:29 +0000178 // We optimize this as follows:
179 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100180 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100181 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
182 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
183
John Kesapides917959c2019-02-04 12:37:29 +0000184 std::vector<const ITensor *> inputs_vector;
185 inputs_vector.emplace_back(input);
186 inputs_vector.emplace_back(output_state_in);
187
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100188 _memory_group.manage(&_forget_gate_out2);
Georgios Pinitas09f24972019-05-17 18:14:40 +0100189 _concat_inputs_forget_gate.configure(inputs_vector, &_forget_gate_out2, Window::DimX);
John Kesapides917959c2019-02-04 12:37:29 +0000190
191 std::vector<const ITensor *> weights_vector;
192
193 weights_vector.emplace_back(input_to_forget_weights);
194 weights_vector.emplace_back(recurrent_to_forget_weights);
195
Georgios Pinitas09f24972019-05-17 18:14:40 +0100196 _concat_weights_forget_gate.configure(weights_vector, &_forget_gate_out6, Window::DimX);
John Kesapides917959c2019-02-04 12:37:29 +0000197
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100198 _memory_group.manage(&_forget_gate_out5);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100199 _fully_connected_forget_gate.configure(&_forget_gate_out2, &_forget_gate_out6,
200 (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
John Kesapides917959c2019-02-04 12:37:29 +0000201 _memory_group.manage(&_forget_gate_out1);
202 _memory_group.manage(&_forget_gate_out3);
203 _forget_gate_out6.allocator()->allocate();
204
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100205 Tensor *forget_gate_out = &_forget_gate_out5;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100206 if (lstm_params.has_peephole_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100207 {
208 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
209
210 _run_peephole_opt = true;
211 _memory_group.manage(&_forget_gate_out4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100212 _pixelwise_mul_forget_gate.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1,
213 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
214 _accum_forget_gate1.configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3,
215 ConvertPolicy::SATURATE);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100216 _forget_gate_out4.allocator()->allocate();
217 _forget_gate_out5.allocator()->allocate();
218 forget_gate_out = &_forget_gate_out3;
219 }
220 else
221 {
222 _forget_gate_out3.allocator()->allocate();
223 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100224 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100225 {
226 _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
227 _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
228 _memory_group.manage(&_forget_layer_norm_out1);
229 _memory_group.manage(&_forget_layer_norm_out2);
230 _mean_std_norm_forget_gate.configure(forget_gate_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100231 _pixelwise_mul_forget_gate_coeff.configure(forget_gate_out, lstm_params.forget_layer_norm_weights(),
232 &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE,
233 RoundingPolicy::TO_ZERO);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100234 // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
235 forget_gate_out->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100236 _accum_forget_gate_bias.configure(&_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2,
237 ConvertPolicy::SATURATE);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100238 _forget_layer_norm_out1.allocator()->allocate();
239 forget_gate_out = &_forget_layer_norm_out2;
240 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100241 _activation_forget_gate.configure(forget_gate_out, nullptr,
242 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100243
244 // Configure block that calculates the input gate
245 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
246 // input_gate = 1 - forget_gate, with CIFG
John Kesapides917959c2019-02-04 12:37:29 +0000247 // We optimize this as follows:
248 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100249 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000250 Tensor *input_gate_out = &_input_gate_out1;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100251 if (lstm_params.has_cifg_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100252 {
253 _memory_group.manage(&_input_gate_out1);
254 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000255 _subtract_input_gate.configure(&_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100256 _ones.allocator()->allocate();
257 _run_cifg_opt = true;
258 }
259 else
260 {
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100261 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
262 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapides917959c2019-02-04 12:37:29 +0000263
264 std::vector<const ITensor *> lstm_weights;
265 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
266 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
267
Georgios Pinitas09f24972019-05-17 18:14:40 +0100268 _concat_weights_input_gate.configure(lstm_weights, &_input_gate_out2, Window::DimX);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100269
270 _memory_group.manage(&_input_gate_out1);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100271 _memory_group.manage(&_input_gate_out4);
John Kesapides917959c2019-02-04 12:37:29 +0000272
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100273 _fully_connected_input_gate.configure(&_forget_gate_out2, &_input_gate_out2,
274 (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(),
275 &_input_gate_out3);
John Kesapides917959c2019-02-04 12:37:29 +0000276 _input_gate_out2.allocator()->allocate();
277 input_gate_out = &_input_gate_out3;
278
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100279 if (_run_peephole_opt)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100280 {
John Kesapides917959c2019-02-04 12:37:29 +0000281 _memory_group.manage(&_input_gate_out4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100282 _pixelwise_mul_input_gate.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4,
283 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
284 _accum_input_gate1.configure(&_input_gate_out3, &_input_gate_out4, &_input_gate_out1,
285 ConvertPolicy::SATURATE);
John Kesapides917959c2019-02-04 12:37:29 +0000286 _input_gate_out3.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000287 _input_gate_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000288 input_gate_out = &_input_gate_out1;
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100289 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000290 else
291 {
292 _input_gate_out1.allocator()->allocate();
293 }
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100294
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100295 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100296 {
297 _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
298 _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
299 _memory_group.manage(&_input_layer_norm_out1);
300 _memory_group.manage(&_input_layer_norm_out2);
301 _mean_std_norm_input_gate.configure(input_gate_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100302 _pixelwise_mul_input_gate_coeff.configure(input_gate_out, lstm_params.input_layer_norm_weights(),
303 &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE,
304 RoundingPolicy::TO_ZERO);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100305 // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
306 input_gate_out->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100307 _accum_input_gate_bias.configure(&_input_layer_norm_out1, lstm_params.input_gate_bias(),
308 &_input_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100309 _input_layer_norm_out1.allocator()->allocate();
310 input_gate_out = &_input_layer_norm_out2;
311 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100312 _activation_input_gate.configure(input_gate_out, nullptr,
313 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100314 }
315
316 // Configure block that calculates the cell state
317 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
318 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
319 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
320 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
321 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
322 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
323 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
324
325 _memory_group.manage(&_cell_state_out1);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100326 _fully_connected_cell_state.configure(input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias,
327 &_cell_state_out1);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100328 _memory_group.manage(&_cell_state_out2);
329 _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
330 _memory_group.manage(&_cell_state_out3);
331 _gemm_cell_state1.configure(output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
332 _cell_state_out2.allocator()->allocate();
333 _memory_group.manage(&_cell_state_out4);
334 _accum_cell_state1.configure(&_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100335 Tensor *cell_state_out_ptr = &_cell_state_out4;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100336 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100337 {
338 _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
339 _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
340 _memory_group.manage(&_cell_layer_norm_out1);
341 _memory_group.manage(&_cell_layer_norm_out2);
342 _mean_std_norm_cell_gate.configure(cell_state_out_ptr);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100343 _pixelwise_mul_cell_gate_coeff.configure(cell_state_out_ptr, lstm_params.cell_layer_norm_weights(),
344 &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE,
345 RoundingPolicy::TO_ZERO);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100346 // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
347 cell_state_out_ptr->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100348 _accum_cell_gate_bias.configure(&_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2,
349 ConvertPolicy::SATURATE);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100350 _cell_layer_norm_out1.allocator()->allocate();
351 cell_state_out_ptr = &_cell_layer_norm_out2;
352 }
353 _activation_cell_state.configure(cell_state_out_ptr, nullptr, activation_info);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100354 _memory_group.manage(&_cell_state_out5);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100355 _pixelwise_mul_cell_state1.configure(cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1,
356 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100357 cell_state_out_ptr->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100358 _pixelwise_mul_cell_state2.configure(forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE,
359 RoundingPolicy::TO_ZERO);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100360 _accum_cell_state2.configure(&_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
361 _cell_state_out3.allocator()->allocate();
362 _cell_state_out5.allocator()->allocate();
363 // Perform clipping
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100364 if (cell_threshold != 0.f)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100365 {
366 _perform_cell_clipping = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100367 _cell_clip.configure(&_cell_state_out1, nullptr,
368 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
369 cell_threshold, -cell_threshold));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100370 }
371
372 // Configure block that calculates the output
373 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
John Kesapides917959c2019-02-04 12:37:29 +0000374 // We optimize this as follows:
375 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100376 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapides917959c2019-02-04 12:37:29 +0000377 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100378
John Kesapides917959c2019-02-04 12:37:29 +0000379 std::vector<const ITensor *> in_out_weights;
380 in_out_weights.emplace_back(input_to_output_weights);
381 in_out_weights.emplace_back(recurrent_to_output_weights);
382
Georgios Pinitas09f24972019-05-17 18:14:40 +0100383 _concat_weights_output.configure(in_out_weights, &_output2, Window::DimX);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100384 _memory_group.manage(&_output1);
John Kesapides917959c2019-02-04 12:37:29 +0000385 _memory_group.manage(&_output4);
386
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100387 _fully_connected_output.configure(&_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias,
388 &_output4);
John Kesapides917959c2019-02-04 12:37:29 +0000389
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100390 _output2.allocator()->allocate();
John Kesapides917959c2019-02-04 12:37:29 +0000391 _forget_gate_out2.allocator()->allocate();
392
393 Tensor *output_gate_out = &_output4;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100394 if (lstm_params.has_peephole_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100395 {
John Kesapides917959c2019-02-04 12:37:29 +0000396 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100397
John Kesapides917959c2019-02-04 12:37:29 +0000398 _memory_group.manage(&_output3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100399 _pixelwise_mul_output_state1.configure(&_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1,
400 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100401 _accum_output1.configure(&_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
John Kesapides917959c2019-02-04 12:37:29 +0000402 _output4.allocator()->allocate();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100403 output_gate_out = &_output1;
404
405 // Allocate intermediate buffers
John Kesapides917959c2019-02-04 12:37:29 +0000406 _output3.allocator()->allocate();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100407 }
408 else
409 {
410 _output1.allocator()->allocate();
411 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100412 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100413 {
414 _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
415 _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
416 _memory_group.manage(&_output_layer_norm_out1);
417 _memory_group.manage(&_output_layer_norm_out2);
418 _mean_std_norm_output_gate.configure(output_gate_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100419 _pixelwise_mul_output_gate_coeff.configure(output_gate_out, lstm_params.output_layer_norm_weights(),
420 &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE,
421 RoundingPolicy::TO_ZERO);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100422 // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
423 output_gate_out->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100424 _accum_output_gate_bias.configure(&_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2,
425 ConvertPolicy::SATURATE);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100426 _output_layer_norm_out1.allocator()->allocate();
427 output_gate_out = &_output_layer_norm_out2;
428 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100429 _activation_output.configure(output_gate_out, nullptr,
430 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100431
432 // Configure block that calculates the output state
433 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
434 *
435 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
436 * /
437 * output_state = --
438 * \
439 * -- lstm_res , otherwise
440 */
441 ITensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
442 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
443 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
444
445 _memory_group.manage(&_cell_state_activation);
446 _activation_output_state.configure(&_cell_state_out1, &_cell_state_activation, activation_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100447 _pixelwise_mul_output_state2.configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1,
448 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100449 _cell_state_activation.allocator()->allocate();
Georgios Pinitas13a20802019-01-16 18:21:08 +0000450 output_gate_out->allocator()->allocate();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100451
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100452 if (lstm_params.has_projection())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100453 {
454 _has_projection_weights = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100455 _fully_connected_output_state.configure(output_state_out_tmp, lstm_params.projection_weights(),
456 lstm_params.projection_bias(), output_state_out);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100457 _output_state1.allocator()->allocate();
458 // Perform clipping
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100459 if (projection_threshold != 0.f)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100460 {
461 _perform_projection_clipping = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100462 _projection_clip.configure(output_state_out, nullptr,
463 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
464 -projection_threshold, projection_threshold));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100465 }
466 }
467
468 // Copy cell state and output
469 _copy_cell_state.configure(&_cell_state_out1, cell_state_out);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100470 _copy_output.configure(output_state_out, output);
471
472 // Vector for holding the tensors to store in scratch buffer
Georgios Pinitas4667ddd2020-07-13 21:21:33 +0100473 std::vector<const ITensor *> scratch_inputs;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100474 if (!lstm_params.has_cifg_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100475 {
Georgios Pinitas4f859822019-02-06 18:08:04 +0000476 scratch_inputs.emplace_back(input_gate_out);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100477 }
478 scratch_inputs.emplace_back(&_cell_state_out1);
479 scratch_inputs.emplace_back(forget_gate_out);
480 scratch_inputs.emplace_back(output_gate_out);
Georgios Pinitas09f24972019-05-17 18:14:40 +0100481 _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer, Window::DimX);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000482 input_gate_out->allocator()->allocate();
483 _cell_state_out1.allocator()->allocate();
484 forget_gate_out->allocator()->allocate();
485 output_gate_out->allocator()->allocate();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100486}
487
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100488Status NELSTMLayer::validate(const ITensorInfo *input,
489 const ITensorInfo *input_to_forget_weights,
490 const ITensorInfo *input_to_cell_weights,
491 const ITensorInfo *input_to_output_weights,
492 const ITensorInfo *recurrent_to_forget_weights,
493 const ITensorInfo *recurrent_to_cell_weights,
494 const ITensorInfo *recurrent_to_output_weights,
495 const ITensorInfo *forget_gate_bias,
496 const ITensorInfo *cell_bias,
497 const ITensorInfo *output_gate_bias,
498 const ITensorInfo *output_state_in,
499 const ITensorInfo *cell_state_in,
500 const ITensorInfo *scratch_buffer,
501 const ITensorInfo *output_state_out,
502 const ITensorInfo *cell_state_out,
503 const ITensorInfo *output,
504 const LSTMParams<ITensorInfo> &lstm_params,
505 const ActivationLayerInfo &activation_info,
506 float cell_threshold,
507 float projection_threshold)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100508{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100509 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(
510 input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights,
511 recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
512 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100513
514 // Check data types
515 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100516 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(
517 input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights,
518 recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
519 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100520
521 // Check dimensions
522 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
523 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
524 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
525 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
526 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
527 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
528 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
529 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
530 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
531 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
532 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
533 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
534 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
535 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
536 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
537 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100538 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) &&
539 cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100540
541 const unsigned int num_batches = input->dimension(1);
542 const unsigned int num_cells = input_to_output_weights->dimension(1);
543
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100544 if (lstm_params.use_layer_norm())
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100545 {
546 // If CIFG is used, input layer normalization weights tensor is omitted
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100547 if (lstm_params.has_cifg_opt())
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100548 {
549 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
550 }
551 else
552 {
553 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
554 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100555 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100556 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
557 }
558
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100559 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(),
560 lstm_params.cell_layer_norm_weights(),
561 lstm_params.output_layer_norm_weights());
562 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(),
563 lstm_params.cell_layer_norm_weights(),
564 lstm_params.output_layer_norm_weights());
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100565 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
566 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
567 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100568 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
569 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
570 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100571 }
572
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100573 // Check peephole optimization
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100574 if (lstm_params.has_peephole_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100575 {
576 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
577 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
578 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
579 }
580
581 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
582 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
583 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
584 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
585
586 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
587 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
588 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
589 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
590
John Kesapides917959c2019-02-04 12:37:29 +0000591 std::vector<const ITensorInfo *> inputs_vector;
592 inputs_vector.emplace_back(input);
593 inputs_vector.emplace_back(output_state_in);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100594 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
Georgios Pinitas09f24972019-05-17 18:14:40 +0100595 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
596 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
John Kesapides917959c2019-02-04 12:37:29 +0000597
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100598 // Validate forget gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100599 ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(
600 input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
John Kesapides917959c2019-02-04 12:37:29 +0000601
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100602 if (lstm_params.has_peephole_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100603 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100604 ARM_COMPUTE_RETURN_ON_ERROR(
605 NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1,
606 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
607 ARM_COMPUTE_RETURN_ON_ERROR(
608 NEArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100609 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100610 if (lstm_params.use_layer_norm())
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100611 {
612 ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&forget_gate));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100613 ARM_COMPUTE_RETURN_ON_ERROR(
614 NEPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1,
615 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
616 ARM_COMPUTE_RETURN_ON_ERROR(
617 NEArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100618 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100619 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(
620 &forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100621
622 // Validate input gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100623 if (!lstm_params.has_cifg_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100624 {
625 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100626 lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100627 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
628 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
629 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
630
John Kesapides917959c2019-02-04 12:37:29 +0000631 std::vector<const ITensorInfo *> lstm_weights;
632 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
633 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100634 TensorShape lstm_weights_concat_shape =
635 arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
636 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
Georgios Pinitas09f24972019-05-17 18:14:40 +0100637 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100638 ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(
639 input, lstm_params.input_to_input_weights(),
640 (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
John Kesapides917959c2019-02-04 12:37:29 +0000641
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100642 if (lstm_params.has_peephole_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100643 {
644 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
645 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100646 ARM_COMPUTE_RETURN_ON_ERROR(
647 NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1,
648 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
649 ARM_COMPUTE_RETURN_ON_ERROR(
650 NEArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100651 }
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100652
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100653 if (lstm_params.use_layer_norm())
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100654 {
655 ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&input_gate));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100656 ARM_COMPUTE_RETURN_ON_ERROR(
657 NEPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1,
658 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
659 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(),
660 &input_gate, ConvertPolicy::SATURATE));
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100661 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100662 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(
663 &input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100664 }
665 else
666 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100667 ARM_COMPUTE_RETURN_ON_ERROR(
668 NEArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100669 }
670
671 // Validate cell state
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100672 ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(
673 input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
674 ARM_COMPUTE_RETURN_ON_ERROR(
675 NEGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
676 ARM_COMPUTE_RETURN_ON_ERROR(
677 NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
678 if (lstm_params.use_layer_norm())
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100679 {
680 ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100681 ARM_COMPUTE_RETURN_ON_ERROR(
682 NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp,
683 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
684 ARM_COMPUTE_RETURN_ON_ERROR(
685 NEArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100686 }
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100687 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100688 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1,
689 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
690 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1,
691 ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
692 ARM_COMPUTE_RETURN_ON_ERROR(
693 NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
694 if (cell_threshold != 0.f)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100695 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100696 ARM_COMPUTE_RETURN_ON_ERROR(
697 NEActivationLayer::validate(&cell_state_tmp, nullptr,
698 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
699 cell_threshold, -cell_threshold)));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100700 }
701
702 // Validate output gate tmp
John Kesapides917959c2019-02-04 12:37:29 +0000703 std::vector<const ITensorInfo *> in_out_weights;
704 in_out_weights.emplace_back(input_to_output_weights);
705 in_out_weights.emplace_back(recurrent_to_output_weights);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100706 TensorShape in_out_weights_concat_shape =
707 arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
708 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
Georgios Pinitas09f24972019-05-17 18:14:40 +0100709 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
John Kesapides917959c2019-02-04 12:37:29 +0000710
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100711 ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(
712 input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
John Kesapides917959c2019-02-04 12:37:29 +0000713
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100714 if (lstm_params.has_peephole_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100715 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100716 ARM_COMPUTE_RETURN_ON_ERROR(
717 NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp,
718 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
719 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp,
720 ConvertPolicy::SATURATE));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100721 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100722 if (lstm_params.use_layer_norm())
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100723 {
724 ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100725 ARM_COMPUTE_RETURN_ON_ERROR(
726 NEPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(),
727 &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
728 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp,
729 ConvertPolicy::SATURATE));
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100730 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100731 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(
732 &output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100733
734 // Validate output state
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100735 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100736 ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(
737 &cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
738 if (lstm_params.has_projection())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100739 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100740 ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(),
741 lstm_params.projection_bias(), output_state_out));
742 if (projection_threshold != 0.f)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100743 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100744 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(
745 output_state_out, output_state_out,
746 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold,
747 projection_threshold)));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100748 }
749 }
750
751 // Validate copy kernel
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100752 ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(&cell_state_tmp, cell_state_out));
753 ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100754
755 // Validate scratch concatenation
Georgios Pinitas4667ddd2020-07-13 21:21:33 +0100756 std::vector<const ITensorInfo *> inputs_vector_info_raw;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100757 if (!lstm_params.has_cifg_opt())
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100758 {
759 inputs_vector_info_raw.push_back(&input_gate);
760 }
761 inputs_vector_info_raw.push_back(&cell_state_tmp);
762 inputs_vector_info_raw.push_back(&forget_gate);
763 inputs_vector_info_raw.push_back(&output_gate_tmp);
764
Georgios Pinitas09f24972019-05-17 18:14:40 +0100765 ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100766 return Status{};
767}
768
769void NELSTMLayer::run()
770{
John Kesapides917959c2019-02-04 12:37:29 +0000771 prepare();
772
Georgios Pinitasda953f22019-04-02 17:27:03 +0100773 MemoryGroupResourceScope scope_mg(_memory_group);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100774
Michalis Spyrou2761c2f2019-03-22 13:06:08 +0000775 _concat_inputs_forget_gate.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100776 _fully_connected_forget_gate.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100777
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100778 if (_run_peephole_opt)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100779 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100780 _pixelwise_mul_forget_gate.run();
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100781 _accum_forget_gate1.run();
782 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100783 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100784 {
785 _mean_std_norm_forget_gate.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100786 _pixelwise_mul_forget_gate_coeff.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100787 _accum_forget_gate_bias.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100788 }
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100789 _activation_forget_gate.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100790
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100791 if (_run_cifg_opt)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100792 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100793 if (_ones.info()->data_type() == DataType::F16)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100794 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100795 std::fill_n(reinterpret_cast<half *>(_ones.buffer()),
796 _ones.info()->total_size() / _ones.info()->element_size(), 1);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100797 }
798 else
799 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100800 std::fill_n(reinterpret_cast<float *>(_ones.buffer()),
801 _ones.info()->total_size() / _ones.info()->element_size(), 1);
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100802 }
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100803 _subtract_input_gate.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100804 }
805 else
806 {
807 _fully_connected_input_gate.run();
John Kesapides917959c2019-02-04 12:37:29 +0000808
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100809 if (_run_peephole_opt)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100810 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100811 _pixelwise_mul_input_gate.run();
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100812 _accum_input_gate1.run();
813 }
814
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100815 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100816 {
817 _mean_std_norm_input_gate.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100818 _pixelwise_mul_input_gate_coeff.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100819 _accum_input_gate_bias.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100820 }
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100821 _activation_input_gate.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100822 }
823
824 _fully_connected_cell_state.run();
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100825 _transpose_cell_state.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100826 _gemm_cell_state1.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100827 _accum_cell_state1.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100828 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100829 {
830 _mean_std_norm_cell_gate.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100831 _pixelwise_mul_cell_gate_coeff.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100832 _accum_cell_gate_bias.run();
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100833 }
Pablo Marquez Tello9454cf72022-02-16 11:15:58 +0000834
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100835 _activation_cell_state.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100836 _pixelwise_mul_cell_state1.run();
837 _pixelwise_mul_cell_state2.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100838 _accum_cell_state2.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100839
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100840 if (_perform_cell_clipping)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100841 {
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100842 _cell_clip.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100843 }
844
845 _fully_connected_output.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100846 if (_run_peephole_opt)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100847 {
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100848 _pixelwise_mul_output_state1.run();
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100849 _accum_output1.run();
850 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100851 if (_is_layer_norm_lstm)
Michele Di Giorgio0cbfda62019-06-13 17:01:29 +0100852 {
853 _mean_std_norm_output_gate.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100854 _pixelwise_mul_output_gate_coeff.run();
Michalis Spyrou173ba9b2020-06-23 17:25:43 +0100855 _accum_output_gate_bias.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100856 }
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100857 _activation_output.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100858
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100859 _activation_output_state.run();
Michalis Spyrou6eb73452020-07-02 17:39:25 +0100860 _pixelwise_mul_output_state2.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100861
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100862 if (_has_projection_weights)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100863 {
864 _fully_connected_output_state.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100865 if (_perform_projection_clipping)
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100866 {
Georgios Pinitas1fd2c802020-06-16 17:44:46 +0100867 _projection_clip.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100868 }
869 }
870
Michalis Spyrouebcebf12020-10-21 00:04:14 +0100871 _copy_cell_state.run();
872 _copy_output.run();
Michalis Spyrou25f45a42018-08-08 12:53:05 +0100873
874 _concat_scratch_buffer.run();
John Kesapides917959c2019-02-04 12:37:29 +0000875}
876
877void NELSTMLayer::prepare()
878{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100879 if (!_is_prepared)
John Kesapides917959c2019-02-04 12:37:29 +0000880 {
John Kesapides917959c2019-02-04 12:37:29 +0000881 _concat_weights_forget_gate.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100882 if (!_run_cifg_opt)
John Kesapides917959c2019-02-04 12:37:29 +0000883 {
884 _concat_weights_input_gate.run();
885 }
886 _concat_weights_output.run();
887 _is_prepared = true;
888 }
889}
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000890} // namespace arm_compute