blob: 3b50234c77cdca52b75244a967d31740c404b81a [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
Sheri Zhang7e20e292021-02-02 11:49:34 +00002 * Copyright (c) 2018-2021 Arm Limited.
Michalis Spyroubcedf512018-03-22 14:55:08 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25
Michalis Spyroubcedf512018-03-22 14:55:08 +000026#include "arm_compute/core/Utils.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000027#include "arm_compute/core/utils/misc/InfoHelpers.h"
Michalis Spyroubcedf512018-03-22 14:55:08 +000028#include "arm_compute/core/utils/misc/ShapeCalculator.h"
29#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010030#include "arm_compute/core/Validate.h"
Michalis Spyroubcedf512018-03-22 14:55:08 +000031#include "arm_compute/runtime/CL/CLScheduler.h"
32
ramelg016d891572021-09-29 10:05:09 +010033#include "src/common/utils/Log.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010034#include "src/core/CL/kernels/CLFillBorderKernel.h"
35#include "src/gpu/cl/kernels/ClTransposeKernel.h"
ramelg016d891572021-09-29 10:05:09 +010036
Michele Di Giorgio47a89902020-03-09 19:32:33 +000037namespace arm_compute
38{
Michalis Spyroubcedf512018-03-22 14:55:08 +000039using namespace arm_compute::misc::shape_calculator;
Michele Di Giorgio47a89902020-03-09 19:32:33 +000040using namespace arm_compute::utils::info_helpers;
Michalis Spyroubcedf512018-03-22 14:55:08 +000041
42CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010043 : _memory_group(std::move(memory_manager)),
44 _fully_connected_input_gate(),
45 _accum_input_gate1(),
46 _subtract_input_gate(),
47 _pixelwise_mul_input_gate(),
48 _activation_input_gate(),
49 _fully_connected_forget_gate(),
50 _accum_forget_gate1(),
51 _pixelwise_mul_forget_gate(),
52 _activation_forget_gate(),
53 _fully_connected_cell_state(),
54 _gemm_cell_state1(),
55 _transpose_cell_state(std::make_unique<opencl::kernels::ClTransposeKernel>()),
56 _accum_cell_state1(),
57 _accum_cell_state2(),
58 _pixelwise_mul_cell_state1(),
59 _activation_cell_state(),
60 _cell_clip(),
61 _pixelwise_mul_cell_state2(),
62 _fully_connected_output(),
63 _pixelwise_mul_output_state1(),
64 _accum_output1(),
65 _activation_output(),
66 _activation_output_state(),
67 _pixelwise_mul_output_state2(),
68 _fully_connected_output_state(),
69 _projection_clip(),
70 _copy_cell_state(),
71 _copy_output(),
72 _concat_scratch_buffer(),
73 _concat_inputs_forget_gate(),
74 _concat_weights_forget_gate(),
75 _concat_weights_input_gate(),
76 _concat_weights_output(),
77 _ones_fill(),
78 _mean_std_norm_input_gate(),
79 _pixelwise_mul_input_gate_coeff(),
80 _accum_input_gate_bias(),
81 _mean_std_norm_forget_gate(),
82 _pixelwise_mul_forget_gate_coeff(),
83 _accum_forget_gate_bias(),
84 _mean_std_norm_cell_gate(),
85 _pixelwise_mul_cell_gate_coeff(),
86 _accum_cell_gate_bias(),
87 _mean_std_norm_output_gate(),
88 _pixelwise_mul_output_gate_coeff(),
89 _accum_output_gate_bias(),
90 _input_gate_out1(),
91 _input_gate_out2(),
92 _input_gate_out3(),
93 _input_gate_out4(),
94 _forget_gate_out1(),
95 _forget_gate_out2(),
96 _forget_gate_out3(),
97 _forget_gate_out4(),
98 _forget_gate_out5(),
99 _forget_gate_out6(),
100 _cell_state_out1(),
101 _cell_state_out2(),
102 _cell_state_out3(),
103 _cell_state_out4(),
104 _cell_state_out5(),
105 _output1(),
106 _output2(),
107 _output3(),
108 _output4(),
109 _cell_state_activation(),
110 _output_state1(),
111 _ones(),
112 _input_layer_norm_out1(),
113 _input_layer_norm_out2(),
114 _forget_layer_norm_out1(),
115 _forget_layer_norm_out2(),
116 _cell_layer_norm_out1(),
117 _cell_layer_norm_out2(),
118 _output_layer_norm_out1(),
119 _output_layer_norm_out2(),
120 _run_peephole_opt(false),
121 _run_cifg_opt(false),
122 _perform_cell_clipping(false),
123 _has_projection_weights(false),
124 _perform_projection_clipping(false),
125 _is_prepared(false),
126 _is_layer_norm_lstm(false)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000127{
128}
129
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100130CLLSTMLayer::~CLLSTMLayer() = default;
131
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100132void CLLSTMLayer::configure(const ICLTensor *input,
133 const ICLTensor *input_to_forget_weights,
134 const ICLTensor *input_to_cell_weights,
135 const ICLTensor *input_to_output_weights,
136 const ICLTensor *recurrent_to_forget_weights,
137 const ICLTensor *recurrent_to_cell_weights,
138 const ICLTensor *recurrent_to_output_weights,
139 const ICLTensor *forget_gate_bias,
140 const ICLTensor *cell_bias,
141 const ICLTensor *output_gate_bias,
142 const ICLTensor *output_state_in,
143 ICLTensor *cell_state_in,
144 ICLTensor *scratch_buffer,
145 ICLTensor *output_state_out,
146 ICLTensor *cell_state_out,
147 ICLTensor *output,
148 const LSTMParams<ICLTensor> &lstm_params,
149 const ActivationLayerInfo &activation_info,
150 float cell_threshold,
151 float projection_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000152{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100153 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights,
154 input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
155 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state_in,
156 cell_state_in, scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
Manuel Bottini2b84be52020-04-08 10:15:51 +0100157 cell_threshold, projection_threshold);
158}
159
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100160void CLLSTMLayer::configure(const CLCompileContext &compile_context,
161 const ICLTensor *input,
162 const ICLTensor *input_to_forget_weights,
163 const ICLTensor *input_to_cell_weights,
164 const ICLTensor *input_to_output_weights,
165 const ICLTensor *recurrent_to_forget_weights,
166 const ICLTensor *recurrent_to_cell_weights,
167 const ICLTensor *recurrent_to_output_weights,
168 const ICLTensor *forget_gate_bias,
169 const ICLTensor *cell_bias,
170 const ICLTensor *output_gate_bias,
171 const ICLTensor *output_state_in,
172 ICLTensor *cell_state_in,
173 ICLTensor *scratch_buffer,
174 ICLTensor *output_state_out,
175 ICLTensor *cell_state_out,
176 ICLTensor *output,
177 const LSTMParams<ICLTensor> &lstm_params,
178 const ActivationLayerInfo &activation_info,
179 float cell_threshold,
180 float projection_threshold)
Manuel Bottini2b84be52020-04-08 10:15:51 +0100181{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100182 ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100183 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100184 forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100185 scratch_buffer, output_state_out, cell_state_out, output);
186
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100187 ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
188 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
189 forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in,
190 scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
191 cell_threshold, projection_threshold);
ramelg016d891572021-09-29 10:05:09 +0100192
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100193 _is_layer_norm_lstm = lstm_params.use_layer_norm();
194
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100195 // Set lstm parameters
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000196 LSTMParams<ITensorInfo> lstm_params_info{};
197 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100198
199 // Validate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100200 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(
201 input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
202 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
203 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), output_state_in->info(),
204 cell_state_in->info(), scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
205 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000206
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100207 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100208 // Configure block that calculates the forget gate
209 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000210 // We optimize this as follows:
211 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
Michalis Spyroubcedf512018-03-22 14:55:08 +0000212 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000213 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100214 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000215
John Kesapidescafec8f2019-02-19 15:53:59 +0000216 std::vector<const ICLTensor *> inputs_vector;
217 inputs_vector.emplace_back(input);
218 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100219 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000220 _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
221
Michalis Spyroubcedf512018-03-22 14:55:08 +0000222 _memory_group.manage(&_forget_gate_out2);
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100223 _concat_inputs_forget_gate.configure(compile_context, inputs_vector, &_forget_gate_out2, Window::DimX);
John Kesapidescafec8f2019-02-19 15:53:59 +0000224
225 std::vector<const ICLTensor *> weights_vector;
226
227 weights_vector.emplace_back(input_to_forget_weights);
228 weights_vector.emplace_back(recurrent_to_forget_weights);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100229 const TensorShape weights_concat_shape =
230 arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000231 _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
232
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100233 _concat_weights_forget_gate.configure(compile_context, weights_vector, &_forget_gate_out6, Window::DimX);
John Kesapidescafec8f2019-02-19 15:53:59 +0000234
Georgios Pinitas42a31722018-07-09 14:35:32 +0100235 _memory_group.manage(&_forget_gate_out5);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100236 _fully_connected_forget_gate.configure(compile_context, &_forget_gate_out2, &_forget_gate_out6,
237 (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
John Kesapidescafec8f2019-02-19 15:53:59 +0000238 _memory_group.manage(&_forget_gate_out1);
239 _memory_group.manage(&_forget_gate_out3);
240 _forget_gate_out6.allocator()->allocate();
241
Georgios Pinitas42a31722018-07-09 14:35:32 +0100242 CLTensor *forget_gate_out = &_forget_gate_out5;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100243 if (lstm_params.has_peephole_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000244 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100245 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000246
247 _run_peephole_opt = true;
248 _memory_group.manage(&_forget_gate_out4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100249 _pixelwise_mul_forget_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(),
250 &_forget_gate_out4, 1, ConvertPolicy::SATURATE,
251 RoundingPolicy::TO_NEAREST_EVEN);
252 _accum_forget_gate1.configure(compile_context, &_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3,
253 ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000254 _forget_gate_out4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000255 _forget_gate_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000256 forget_gate_out = &_forget_gate_out3;
257 }
258 else
259 {
260 _forget_gate_out3.allocator()->allocate();
261 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100262 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100263 {
264 _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
265 _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
266 _memory_group.manage(&_forget_layer_norm_out1);
267 _memory_group.manage(&_forget_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100268 _mean_std_norm_forget_gate.configure(compile_context, forget_gate_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100269 _pixelwise_mul_forget_gate_coeff.configure(compile_context, forget_gate_out,
270 lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1,
271 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100272 // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
273 forget_gate_out->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100274 _accum_forget_gate_bias.configure(compile_context, &_forget_layer_norm_out1, forget_gate_bias,
275 &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100276 _forget_layer_norm_out1.allocator()->allocate();
277 forget_gate_out = &_forget_layer_norm_out2;
278 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100279 _activation_forget_gate.configure(compile_context, forget_gate_out, nullptr,
280 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000281
Michalis Spyroubcedf512018-03-22 14:55:08 +0000282 // Configure block that calculates the input gate
Georgios Pinitas42a31722018-07-09 14:35:32 +0100283 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Michalis Spyroubcedf512018-03-22 14:55:08 +0000284 // input_gate = 1 - forget_gate, with CIFG
John Kesapidescafec8f2019-02-19 15:53:59 +0000285 // We optimize this as follows:
286 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100287 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000288 CLTensor *input_gate_out = &_input_gate_out1;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100289 if (lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000290 {
291 _memory_group.manage(&_input_gate_out1);
292 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Sheri Zhang7e20e292021-02-02 11:49:34 +0000293 _ones_fill.configure(compile_context, &_ones, PixelValue(1, _ones.info()->data_type()));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100294 _subtract_input_gate.configure(compile_context, &_ones, forget_gate_out, &_input_gate_out1,
295 ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000296 _ones.allocator()->allocate();
297 _run_cifg_opt = true;
298 }
299 else
300 {
Michalis Spyroubcedf512018-03-22 14:55:08 +0000301 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100302 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000303
304 std::vector<const ICLTensor *> lstm_weights;
305 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
306 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100307 TensorShape lstm_weights_concat_shape =
308 arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000309 _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
310
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100311 _concat_weights_input_gate.configure(compile_context, lstm_weights, &_input_gate_out2, Window::DimX);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000312
313 _memory_group.manage(&_input_gate_out1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000314
Michalis Spyroubcedf512018-03-22 14:55:08 +0000315 _memory_group.manage(&_input_gate_out3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100316 _fully_connected_input_gate.configure(compile_context, &_forget_gate_out2, &_input_gate_out2,
317 (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(),
318 &_input_gate_out3);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000319 _input_gate_out2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000320
321 input_gate_out = &_input_gate_out3;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100322 if (_run_peephole_opt)
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100323 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000324 _memory_group.manage(&_input_gate_out4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100325 _pixelwise_mul_input_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(),
326 &_input_gate_out4, 1, ConvertPolicy::SATURATE,
327 RoundingPolicy::TO_NEAREST_EVEN);
328 _accum_input_gate1.configure(compile_context, &_input_gate_out3, &_input_gate_out4, &_input_gate_out1,
329 ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000330 _input_gate_out3.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000331 _input_gate_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000332 input_gate_out = &_input_gate_out1;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100333 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000334 else
335 {
336 _input_gate_out1.allocator()->allocate();
337 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100338
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100339 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100340 {
341 _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
342 _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
343 _memory_group.manage(&_input_layer_norm_out1);
344 _memory_group.manage(&_input_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100345 _mean_std_norm_input_gate.configure(compile_context, input_gate_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100346 _pixelwise_mul_input_gate_coeff.configure(compile_context, input_gate_out,
347 lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1,
348 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100349 // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
350 input_gate_out->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100351 _accum_input_gate_bias.configure(compile_context, &_input_layer_norm_out1, lstm_params.input_gate_bias(),
352 &_input_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100353 _input_layer_norm_out1.allocator()->allocate();
354 input_gate_out = &_input_layer_norm_out2;
355 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100356 _activation_input_gate.configure(compile_context, input_gate_out, nullptr,
357 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000358 }
359
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100360 // Configure block that calculates the cell state
361 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000362 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
363 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
364 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
365 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
366 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
367 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
368
Michalis Spyroubcedf512018-03-22 14:55:08 +0000369 _memory_group.manage(&_cell_state_out1);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100370 _fully_connected_cell_state.configure(compile_context, input, input_to_cell_weights,
371 (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000372 _memory_group.manage(&_cell_state_out2);
Teresa Charlin27886092021-02-25 20:15:01 +0000373 _transpose_cell_state->configure(compile_context, recurrent_to_cell_weights->info(), _cell_state_out2.info());
374 _recurrent_to_cell_weights = recurrent_to_cell_weights;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000375 _memory_group.manage(&_cell_state_out3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100376 _gemm_cell_state1.configure(compile_context, output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f,
377 0.f);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000378 _cell_state_out2.allocator()->allocate();
379 _memory_group.manage(&_cell_state_out4);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100380 _accum_cell_state1.configure(compile_context, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4,
381 ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100382 CLTensor *cell_state_out_ptr = &_cell_state_out4;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100383 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100384 {
385 _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
386 _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
387 _memory_group.manage(&_cell_layer_norm_out1);
388 _memory_group.manage(&_cell_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100389 _mean_std_norm_cell_gate.configure(compile_context, cell_state_out_ptr);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100390 _pixelwise_mul_cell_gate_coeff.configure(compile_context, cell_state_out_ptr,
391 lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1,
392 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100393 // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
394 cell_state_out_ptr->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100395 _accum_cell_gate_bias.configure(compile_context, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2,
396 ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100397 _cell_layer_norm_out1.allocator()->allocate();
398 cell_state_out_ptr = &_cell_layer_norm_out2;
399 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100400 _activation_cell_state.configure(compile_context, cell_state_out_ptr, nullptr, activation_info);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000401 _memory_group.manage(&_cell_state_out5);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100402 _pixelwise_mul_cell_state1.configure(compile_context, cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1,
403 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100404 cell_state_out_ptr->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100405 _pixelwise_mul_cell_state2.configure(compile_context, forget_gate_out, cell_state_in, &_cell_state_out3, 1,
406 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
407 _accum_cell_state2.configure(compile_context, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1,
408 ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000409 _cell_state_out3.allocator()->allocate();
410 _cell_state_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000411 // Perform clipping
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100412 if (cell_threshold != 0.f)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000413 {
414 _perform_cell_clipping = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100415 _cell_clip.configure(compile_context, &_cell_state_out1, nullptr,
416 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
417 cell_threshold, -cell_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000418 }
419
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100420 // Configure block that calculates the output
421 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000422 // We optimize this as follows:
423 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000424 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000425 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
426 std::vector<const ICLTensor *> in_out_weights;
427 in_out_weights.emplace_back(input_to_output_weights);
428 in_out_weights.emplace_back(recurrent_to_output_weights);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100429 TensorShape in_out_weights_concat_shape =
430 arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000431 _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
432
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100433 _concat_weights_output.configure(compile_context, in_out_weights, &_output2, Window::DimX);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000434
Michalis Spyroubcedf512018-03-22 14:55:08 +0000435 _memory_group.manage(&_output1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000436 _memory_group.manage(&_output4);
437
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100438 _fully_connected_output.configure(compile_context, &_forget_gate_out2, &_output2,
439 (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
John Kesapidescafec8f2019-02-19 15:53:59 +0000440
Michalis Spyroubcedf512018-03-22 14:55:08 +0000441 _output2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000442 _forget_gate_out2.allocator()->allocate();
443
444 CLTensor *output_gate_out = &_output4;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100445 if (lstm_params.has_peephole_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000446 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000447 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000448
John Kesapidescafec8f2019-02-19 15:53:59 +0000449 _memory_group.manage(&_output3);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100450 _pixelwise_mul_output_state1.configure(compile_context, &_cell_state_out1, lstm_params.cell_to_output_weights(),
451 &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100452 _accum_output1.configure(compile_context, &_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000453 _output4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000454 output_gate_out = &_output1;
455
456 // Allocate intermediate buffers
John Kesapidescafec8f2019-02-19 15:53:59 +0000457 _output3.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000458 }
459 else
460 {
461 _output1.allocator()->allocate();
462 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100463 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100464 {
465 _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
466 _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
467 _memory_group.manage(&_output_layer_norm_out1);
468 _memory_group.manage(&_output_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100469 _mean_std_norm_output_gate.configure(compile_context, output_gate_out);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100470 _pixelwise_mul_output_gate_coeff.configure(compile_context, output_gate_out,
471 lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1,
472 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100473 // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
474 output_gate_out->allocator()->allocate();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100475 _accum_output_gate_bias.configure(compile_context, &_output_layer_norm_out1, output_gate_bias,
476 &_output_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100477 _output_layer_norm_out1.allocator()->allocate();
478 output_gate_out = &_output_layer_norm_out2;
479 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100480 _activation_output.configure(compile_context, output_gate_out, nullptr,
481 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000482
Michalis Spyroubcedf512018-03-22 14:55:08 +0000483 // Configure block that calculates the output state
484 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
485 *
486 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
487 * /
488 * output_state = --
489 * \
490 * -- lstm_res , otherwise
491 */
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100492 ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
493 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
494 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
495
Michalis Spyroubcedf512018-03-22 14:55:08 +0000496 _memory_group.manage(&_cell_state_activation);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100497 _activation_output_state.configure(compile_context, &_cell_state_out1, &_cell_state_activation, activation_info);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100498 _pixelwise_mul_output_state2.configure(compile_context, &_cell_state_activation, output_gate_out,
499 output_state_out_tmp, 1, ConvertPolicy::SATURATE,
500 RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000501 _cell_state_activation.allocator()->allocate();
502
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100503 if (lstm_params.has_projection())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000504 {
505 _has_projection_weights = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100506 _fully_connected_output_state.configure(compile_context, output_state_out_tmp, lstm_params.projection_weights(),
507 lstm_params.projection_bias(), output_state_out);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100508 _output_state1.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000509 // Perform clipping
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100510 if (projection_threshold != 0.f)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000511 {
512 _perform_projection_clipping = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100513 _projection_clip.configure(compile_context, output_state_out, nullptr,
514 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
515 -projection_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000516 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000517 }
518
519 // Copy cell state and output
Sheri Zhang7e20e292021-02-02 11:49:34 +0000520 _copy_cell_state.configure(compile_context, &_cell_state_out1, cell_state_out);
521 _copy_output.configure(compile_context, output_state_out, output);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000522
523 // Vector for holding the tensors to store in scratch buffer
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100524 std::vector<const ICLTensor *> scratch_inputs;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100525 if (!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000526 {
Georgios Pinitas4f859822019-02-06 18:08:04 +0000527 scratch_inputs.emplace_back(input_gate_out);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000528 }
529 scratch_inputs.emplace_back(&_cell_state_out1);
530 scratch_inputs.emplace_back(forget_gate_out);
531 scratch_inputs.emplace_back(output_gate_out);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100532 _concat_scratch_buffer.configure(compile_context, scratch_inputs, scratch_buffer, Window::DimX);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000533 input_gate_out->allocator()->allocate();
Michele Di Giorgiodd2619a2018-11-05 16:46:09 +0000534 _cell_state_out1.allocator()->allocate();
535 forget_gate_out->allocator()->allocate();
536 output_gate_out->allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000537}
538
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100539Status CLLSTMLayer::validate(const ITensorInfo *input,
540 const ITensorInfo *input_to_forget_weights,
541 const ITensorInfo *input_to_cell_weights,
542 const ITensorInfo *input_to_output_weights,
543 const ITensorInfo *recurrent_to_forget_weights,
544 const ITensorInfo *recurrent_to_cell_weights,
545 const ITensorInfo *recurrent_to_output_weights,
546 const ITensorInfo *forget_gate_bias,
547 const ITensorInfo *cell_bias,
548 const ITensorInfo *output_gate_bias,
549 const ITensorInfo *output_state_in,
550 const ITensorInfo *cell_state_in,
551 const ITensorInfo *scratch_buffer,
552 const ITensorInfo *output_state_out,
553 const ITensorInfo *cell_state_out,
554 const ITensorInfo *output,
555 const LSTMParams<ITensorInfo> &lstm_params,
556 const ActivationLayerInfo &activation_info,
557 float cell_threshold,
558 float projection_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000559{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100560 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(
561 input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights,
562 recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
563 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100564
565 // Check data types
Michalis Spyroubcedf512018-03-22 14:55:08 +0000566 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100567 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(
568 input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights,
569 recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
570 output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100571
572 // Check dimensions
Georgios Pinitas42447c12018-07-16 17:01:20 +0100573 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
574 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
575 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
576 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
577 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
578 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
579 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
580 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
581 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
582 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100583 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
584 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100585 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100586 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
587 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100588 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100589 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) &&
590 cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000591
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100592 const unsigned int num_batches = input->dimension(1);
593 const unsigned int num_cells = input_to_output_weights->dimension(1);
594
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100595 if (lstm_params.use_layer_norm())
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100596 {
597 // If CIFG is used, input layer normalization weights tensor is omitted
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100598 if (lstm_params.has_cifg_opt())
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100599 {
600 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
601 }
602 else
603 {
604 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
605 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100606 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100607 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
608 }
609
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100610 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(),
611 lstm_params.cell_layer_norm_weights(),
612 lstm_params.output_layer_norm_weights());
613 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(),
614 lstm_params.cell_layer_norm_weights(),
615 lstm_params.output_layer_norm_weights());
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100616 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
617 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
618 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100619 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
620 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
621 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100622 }
623
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100624 // Check peephole optimization
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100625 if (lstm_params.has_peephole_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000626 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +0100627 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100628 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
629 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000630 }
631
632 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000633 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
634 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000635 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
636
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100637 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
638 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
639 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
640 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
641
Michalis Spyroubcedf512018-03-22 14:55:08 +0000642 // Validate forget gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100643 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(
644 input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000645
646 std::vector<const ITensorInfo *> inputs_vector;
647 inputs_vector.emplace_back(input);
648 inputs_vector.emplace_back(output_state_in);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100649 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000650 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
651
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100652 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
John Kesapidescafec8f2019-02-19 15:53:59 +0000653
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100654 if (lstm_params.has_peephole_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000655 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100656 ARM_COMPUTE_RETURN_ON_ERROR(
657 CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1,
658 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
659 ARM_COMPUTE_RETURN_ON_ERROR(
660 CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000661 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100662 if (lstm_params.use_layer_norm())
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100663 {
664 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100665 ARM_COMPUTE_RETURN_ON_ERROR(
666 CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1,
667 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
668 ARM_COMPUTE_RETURN_ON_ERROR(
669 CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100670 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100671 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(
672 &forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000673
674 // Validate input gate
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100675 if (!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000676 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100677 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100678 lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100679 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
680 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100681 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100682
John Kesapidescafec8f2019-02-19 15:53:59 +0000683 std::vector<const ITensorInfo *> lstm_weights;
684 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
685 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100686 TensorShape lstm_weights_concat_shape =
687 arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
688 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100689 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
John Kesapidescafec8f2019-02-19 15:53:59 +0000690
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100691 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(
692 input, lstm_params.input_to_input_weights(),
693 (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000694
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100695 if (lstm_params.has_peephole_opt())
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100696 {
697 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
698 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100699 ARM_COMPUTE_RETURN_ON_ERROR(
700 CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1,
701 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
702 ARM_COMPUTE_RETURN_ON_ERROR(
703 CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100704 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100705
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100706 if (lstm_params.use_layer_norm())
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100707 {
708 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100709 ARM_COMPUTE_RETURN_ON_ERROR(
710 CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1,
711 ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
712 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(),
713 &input_gate, ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100714 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100715 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(
716 &input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000717 }
718 else
719 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100720 ARM_COMPUTE_RETURN_ON_ERROR(
721 CLArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000722 }
723
724 // Validate cell state
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100725 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(
726 input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
727 ARM_COMPUTE_RETURN_ON_ERROR(
728 CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
729 ARM_COMPUTE_RETURN_ON_ERROR(
730 CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
731 if (lstm_params.use_layer_norm())
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100732 {
733 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100734 ARM_COMPUTE_RETURN_ON_ERROR(
735 CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp,
736 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
737 ARM_COMPUTE_RETURN_ON_ERROR(
738 CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100739 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100740 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100741 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
742 &cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
743 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
744 &cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
745 ARM_COMPUTE_RETURN_ON_ERROR(
746 CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
747 if (cell_threshold != 0.f)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000748 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100749 ARM_COMPUTE_RETURN_ON_ERROR(
750 CLActivationLayer::validate(&cell_state_tmp, nullptr,
751 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
752 cell_threshold, -cell_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000753 }
754
John Kesapidescafec8f2019-02-19 15:53:59 +0000755 std::vector<const ITensorInfo *> in_out_weights;
756 in_out_weights.emplace_back(input_to_output_weights);
757 in_out_weights.emplace_back(recurrent_to_output_weights);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100758 TensorShape in_out_weights_concat_shape =
759 arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
760 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100761 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100762 // Validate output gate tmp
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100763 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(
764 input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
John Kesapidescafec8f2019-02-19 15:53:59 +0000765
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100766 if (lstm_params.has_peephole_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000767 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100768 ARM_COMPUTE_RETURN_ON_ERROR(
769 CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp,
770 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
771 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp,
772 ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000773 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100774 if (lstm_params.use_layer_norm())
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100775 {
776 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100777 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(
778 &output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
779 RoundingPolicy::TO_NEAREST_EVEN));
780 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp,
781 ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100782 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100783 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(
784 &output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000785
786 // Validate output state
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100787 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100788 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp,
789 1, ConvertPolicy::SATURATE,
790 RoundingPolicy::TO_NEAREST_EVEN));
791 if (lstm_params.has_projection())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000792 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100793 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(),
794 lstm_params.projection_bias(), output_state_out));
795 if (projection_threshold != 0.f)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000796 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100797 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(
798 output_state_out, output_state_out,
799 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold,
800 projection_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000801 }
802 }
803
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100804 // Validate copy kernel
Sheri Zhang7e20e292021-02-02 11:49:34 +0000805 ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(&cell_state_tmp, cell_state_out));
806 ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(output_state_out, output));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100807
808 // Validate scratch concatenation
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100809 std::vector<const ITensorInfo *> inputs_vector_info_raw;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100810 if (!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000811 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100812 inputs_vector_info_raw.push_back(&input_gate);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000813 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100814 inputs_vector_info_raw.push_back(&cell_state_tmp);
815 inputs_vector_info_raw.push_back(&forget_gate);
816 inputs_vector_info_raw.push_back(&output_gate_tmp);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000817
Georgios Pinitas09f24972019-05-17 18:14:40 +0100818 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000819 return Status{};
820}
821
822void CLLSTMLayer::run()
823{
John Kesapidescafec8f2019-02-19 15:53:59 +0000824 prepare();
825
Georgios Pinitasda953f22019-04-02 17:27:03 +0100826 MemoryGroupResourceScope scope_mg(_memory_group);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000827
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100828 _concat_inputs_forget_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000829
Michalis Spyroubcedf512018-03-22 14:55:08 +0000830 _fully_connected_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000831
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100832 if (_run_peephole_opt)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000833 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100834 _pixelwise_mul_forget_gate.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100835 _accum_forget_gate1.run();
836 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100837 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100838 {
839 _mean_std_norm_forget_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100840 _pixelwise_mul_forget_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100841 _accum_forget_gate_bias.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000842 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100843 _activation_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000844
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100845 if (_run_cifg_opt)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000846 {
Sheri Zhang7e20e292021-02-02 11:49:34 +0000847 _ones_fill.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100848 _subtract_input_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000849 }
850 else
851 {
852 _fully_connected_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000853
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100854 if (_run_peephole_opt)
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100855 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100856 _pixelwise_mul_input_gate.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100857 _accum_input_gate1.run();
858 }
859
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100860 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100861 {
862 _mean_std_norm_input_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100863 _pixelwise_mul_input_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100864 _accum_input_gate_bias.run();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100865 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100866 _activation_input_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000867 }
868
869 _fully_connected_cell_state.run();
Teresa Charlin27886092021-02-25 20:15:01 +0000870 ITensorPack pack;
871 pack.add_tensor(TensorType::ACL_SRC, _recurrent_to_cell_weights);
872 pack.add_tensor(TensorType::ACL_DST, &_cell_state_out2);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100873 CLScheduler::get().enqueue_op(*_transpose_cell_state, pack, false);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000874 _gemm_cell_state1.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100875 _accum_cell_state1.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100876 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100877 {
878 _mean_std_norm_cell_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100879 _pixelwise_mul_cell_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100880 _accum_cell_gate_bias.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100881 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100882 _activation_cell_state.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100883 _pixelwise_mul_cell_state1.run();
884 _pixelwise_mul_cell_state2.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100885 _accum_cell_state2.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000886
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100887 if (_perform_cell_clipping)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000888 {
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100889 _cell_clip.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000890 }
891
892 _fully_connected_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000893
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100894 if (_run_peephole_opt)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000895 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100896 _pixelwise_mul_output_state1.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100897 _accum_output1.run();
898 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100899 if (_is_layer_norm_lstm)
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100900 {
901 _mean_std_norm_output_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100902 _pixelwise_mul_output_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100903 _accum_output_gate_bias.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000904 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100905 _activation_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000906
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100907 _activation_output_state.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100908 _pixelwise_mul_output_state2.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000909
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100910 if (_has_projection_weights)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000911 {
912 _fully_connected_output_state.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100913 if (_perform_projection_clipping)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000914 {
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100915 _projection_clip.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000916 }
917 }
918
Sheri Zhang7e20e292021-02-02 11:49:34 +0000919 _copy_cell_state.run();
920 _copy_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000921
922 _concat_scratch_buffer.run();
giuros01164a2722018-11-20 18:34:46 +0000923}
John Kesapidescafec8f2019-02-19 15:53:59 +0000924
925void CLLSTMLayer::prepare()
926{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100927 if (!_is_prepared)
John Kesapidescafec8f2019-02-19 15:53:59 +0000928 {
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100929 _concat_weights_forget_gate.run();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100930 if (!_run_cifg_opt)
John Kesapidescafec8f2019-02-19 15:53:59 +0000931 {
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100932 _concat_weights_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000933 }
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100934 _concat_weights_output.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000935 _is_prepared = true;
936 }
937}
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000938} // namespace arm_compute