blob: 058b6027c2259945f134a5a213e659ba91695ea5 [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2020 Arm Limited.
Michalis Spyroubcedf512018-03-22 14:55:08 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25
Michalis Spyroubcedf512018-03-22 14:55:08 +000026#include "arm_compute/core/Utils.h"
27#include "arm_compute/core/Validate.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000028#include "arm_compute/core/utils/misc/InfoHelpers.h"
Michalis Spyroubcedf512018-03-22 14:55:08 +000029#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31#include "arm_compute/runtime/CL/CLScheduler.h"
32
Michele Di Giorgio47a89902020-03-09 19:32:33 +000033namespace arm_compute
34{
Michalis Spyroubcedf512018-03-22 14:55:08 +000035using namespace arm_compute::misc::shape_calculator;
Michele Di Giorgio47a89902020-03-09 19:32:33 +000036using namespace arm_compute::utils::info_helpers;
Michalis Spyroubcedf512018-03-22 14:55:08 +000037
38CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Michele Di Giorgio39438b42019-06-04 12:41:45 +010039 : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
40 _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _transpose_cell_state(),
41 _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(),
42 _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _projection_clip(),
43 _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(),
44 _ones_memset_kernel(), _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(), _pixelwise_mul_forget_gate_coeff(),
45 _accum_forget_gate_bias(), _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(), _pixelwise_mul_output_gate_coeff(),
46 _accum_output_gate_bias(), _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(),
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +010047 _forget_gate_out5(), _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(),
Michele Di Giorgio39438b42019-06-04 12:41:45 +010048 _cell_state_activation(), _output_state1(), _ones(), _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(), _cell_layer_norm_out1(),
49 _cell_layer_norm_out2(), _output_layer_norm_out1(), _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false),
50 _perform_projection_clipping(false), _is_prepared(false), _is_layer_norm_lstm(false)
Michalis Spyroubcedf512018-03-22 14:55:08 +000051{
52}
53
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010054void CLLSTMLayer::configure(const ICLTensor *input,
55 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +000056 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
57 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Michalis Spyrou1009e872020-07-27 12:48:34 +010058 const ICLTensor *output_state_in, ICLTensor *cell_state_in,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010059 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
60 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +000061{
Manuel Bottini2b84be52020-04-08 10:15:51 +010062 configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
63 recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
64 cell_threshold, projection_threshold);
65}
66
67void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
68 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
69 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
70 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Michalis Spyrou1009e872020-07-27 12:48:34 +010071 const ICLTensor *output_state_in, ICLTensor *cell_state_in,
Manuel Bottini2b84be52020-04-08 10:15:51 +010072 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
73 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
74{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010075 ARM_COMPUTE_ERROR_ON_NULLPTR(input,
76 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
77 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
78 forget_gate_bias, cell_bias, output_gate_bias,
79 output_state_in, cell_state_in,
80 scratch_buffer, output_state_out, cell_state_out, output);
81
Michele Di Giorgio39438b42019-06-04 12:41:45 +010082 _is_layer_norm_lstm = lstm_params.use_layer_norm();
83
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010084 // Set lstm parameters
Michele Di Giorgio47a89902020-03-09 19:32:33 +000085 LSTMParams<ITensorInfo> lstm_params_info{};
86 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010087
88 // Validate
Michalis Spyroubcedf512018-03-22 14:55:08 +000089 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
90 input_to_cell_weights->info(), input_to_output_weights->info(),
91 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
92 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010093 output_state_in->info(), cell_state_in->info(),
94 scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
95 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +000096
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010097 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010098 // Configure block that calculates the forget gate
99 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000100 // We optimize this as follows:
101 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
Michalis Spyroubcedf512018-03-22 14:55:08 +0000102 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000103 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100104 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000105
John Kesapidescafec8f2019-02-19 15:53:59 +0000106 std::vector<const ICLTensor *> inputs_vector;
107 inputs_vector.emplace_back(input);
108 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100109 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000110 _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
111
Michalis Spyroubcedf512018-03-22 14:55:08 +0000112 _memory_group.manage(&_forget_gate_out2);
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100113 _concat_inputs_forget_gate.configure(compile_context, inputs_vector, &_forget_gate_out2, Window::DimX);
John Kesapidescafec8f2019-02-19 15:53:59 +0000114
115 std::vector<const ICLTensor *> weights_vector;
116
117 weights_vector.emplace_back(input_to_forget_weights);
118 weights_vector.emplace_back(recurrent_to_forget_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100119 const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000120 _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
121
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100122 _concat_weights_forget_gate.configure(compile_context, weights_vector, &_forget_gate_out6, Window::DimX);
John Kesapidescafec8f2019-02-19 15:53:59 +0000123
Georgios Pinitas42a31722018-07-09 14:35:32 +0100124 _memory_group.manage(&_forget_gate_out5);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100125 _fully_connected_forget_gate.configure(compile_context, &_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
John Kesapidescafec8f2019-02-19 15:53:59 +0000126 _memory_group.manage(&_forget_gate_out1);
127 _memory_group.manage(&_forget_gate_out3);
128 _forget_gate_out6.allocator()->allocate();
129
Georgios Pinitas42a31722018-07-09 14:35:32 +0100130 CLTensor *forget_gate_out = &_forget_gate_out5;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000131 if(lstm_params.has_peephole_opt())
132 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100133 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000134
135 _run_peephole_opt = true;
136 _memory_group.manage(&_forget_gate_out4);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100137 _pixelwise_mul_forget_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
138 _accum_forget_gate1.configure(compile_context, &_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000139 _forget_gate_out4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000140 _forget_gate_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000141 forget_gate_out = &_forget_gate_out3;
142 }
143 else
144 {
145 _forget_gate_out3.allocator()->allocate();
146 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100147 if(_is_layer_norm_lstm)
148 {
149 _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
150 _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
151 _memory_group.manage(&_forget_layer_norm_out1);
152 _memory_group.manage(&_forget_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100153 _mean_std_norm_forget_gate.configure(compile_context, forget_gate_out);
154 _pixelwise_mul_forget_gate_coeff.configure(compile_context, forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE,
155 RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100156 // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
157 forget_gate_out->allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100158 _accum_forget_gate_bias.configure(compile_context, &_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100159 _forget_layer_norm_out1.allocator()->allocate();
160 forget_gate_out = &_forget_layer_norm_out2;
161 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100162 _activation_forget_gate.configure(compile_context, forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000163
Michalis Spyroubcedf512018-03-22 14:55:08 +0000164 // Configure block that calculates the input gate
Georgios Pinitas42a31722018-07-09 14:35:32 +0100165 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Michalis Spyroubcedf512018-03-22 14:55:08 +0000166 // input_gate = 1 - forget_gate, with CIFG
John Kesapidescafec8f2019-02-19 15:53:59 +0000167 // We optimize this as follows:
168 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100169 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000170 CLTensor *input_gate_out = &_input_gate_out1;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000171 if(lstm_params.has_cifg_opt())
172 {
173 _memory_group.manage(&_input_gate_out1);
174 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Manuel Bottini2b84be52020-04-08 10:15:51 +0100175 _ones_memset_kernel.configure(compile_context, &_ones, PixelValue(1, _ones.info()->data_type()));
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100176 _subtract_input_gate.configure(compile_context, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000177 _ones.allocator()->allocate();
178 _run_cifg_opt = true;
179 }
180 else
181 {
Michalis Spyroubcedf512018-03-22 14:55:08 +0000182 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100183 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000184
185 std::vector<const ICLTensor *> lstm_weights;
186 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
187 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100188 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000189 _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
190
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100191 _concat_weights_input_gate.configure(compile_context, lstm_weights, &_input_gate_out2, Window::DimX);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000192
193 _memory_group.manage(&_input_gate_out1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000194
Michalis Spyroubcedf512018-03-22 14:55:08 +0000195 _memory_group.manage(&_input_gate_out3);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100196 _fully_connected_input_gate.configure(compile_context, &_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000197 _input_gate_out2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000198
199 input_gate_out = &_input_gate_out3;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100200 if(_run_peephole_opt)
201 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000202 _memory_group.manage(&_input_gate_out4);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100203 _pixelwise_mul_input_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
204 _accum_input_gate1.configure(compile_context, &_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000205 _input_gate_out3.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000206 _input_gate_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000207 input_gate_out = &_input_gate_out1;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100208 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000209 else
210 {
211 _input_gate_out1.allocator()->allocate();
212 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100213
214 if(_is_layer_norm_lstm)
215 {
216 _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
217 _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
218 _memory_group.manage(&_input_layer_norm_out1);
219 _memory_group.manage(&_input_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100220 _mean_std_norm_input_gate.configure(compile_context, input_gate_out);
221 _pixelwise_mul_input_gate_coeff.configure(compile_context, input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE,
222 RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100223 // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
224 input_gate_out->allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100225 _accum_input_gate_bias.configure(compile_context, &_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100226 _input_layer_norm_out1.allocator()->allocate();
227 input_gate_out = &_input_layer_norm_out2;
228 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100229 _activation_input_gate.configure(compile_context, input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000230 }
231
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100232 // Configure block that calculates the cell state
233 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000234 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
235 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
236 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
237 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
238 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
239 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
240
Michalis Spyroubcedf512018-03-22 14:55:08 +0000241 _memory_group.manage(&_cell_state_out1);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100242 _fully_connected_cell_state.configure(compile_context, input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000243 _memory_group.manage(&_cell_state_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100244 _transpose_cell_state.configure(compile_context, recurrent_to_cell_weights, &_cell_state_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000245 _memory_group.manage(&_cell_state_out3);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100246 _gemm_cell_state1.configure(compile_context, output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000247 _cell_state_out2.allocator()->allocate();
248 _memory_group.manage(&_cell_state_out4);
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100249 _accum_cell_state1.configure(compile_context, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100250 CLTensor *cell_state_out_ptr = &_cell_state_out4;
251 if(_is_layer_norm_lstm)
252 {
253 _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
254 _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
255 _memory_group.manage(&_cell_layer_norm_out1);
256 _memory_group.manage(&_cell_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100257 _mean_std_norm_cell_gate.configure(compile_context, cell_state_out_ptr);
258 _pixelwise_mul_cell_gate_coeff.configure(compile_context, cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE,
259 RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100260 // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
261 cell_state_out_ptr->allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100262 _accum_cell_gate_bias.configure(compile_context, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100263 _cell_layer_norm_out1.allocator()->allocate();
264 cell_state_out_ptr = &_cell_layer_norm_out2;
265 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100266 _activation_cell_state.configure(compile_context, cell_state_out_ptr, nullptr, activation_info);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000267 _memory_group.manage(&_cell_state_out5);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100268 _pixelwise_mul_cell_state1.configure(compile_context, cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100269 cell_state_out_ptr->allocator()->allocate();
Manuel Bottini2b84be52020-04-08 10:15:51 +0100270 _pixelwise_mul_cell_state2.configure(compile_context, forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100271 _accum_cell_state2.configure(compile_context, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000272 _cell_state_out3.allocator()->allocate();
273 _cell_state_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000274 // Perform clipping
275 if(cell_threshold != 0.f)
276 {
277 _perform_cell_clipping = true;
Manuel Bottini2b84be52020-04-08 10:15:51 +0100278 _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000279 }
280
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100281 // Configure block that calculates the output
282 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000283 // We optimize this as follows:
284 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000285 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000286 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
287 std::vector<const ICLTensor *> in_out_weights;
288 in_out_weights.emplace_back(input_to_output_weights);
289 in_out_weights.emplace_back(recurrent_to_output_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100290 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000291 _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
292
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100293 _concat_weights_output.configure(compile_context, in_out_weights, &_output2, Window::DimX);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000294
Michalis Spyroubcedf512018-03-22 14:55:08 +0000295 _memory_group.manage(&_output1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000296 _memory_group.manage(&_output4);
297
Manuel Bottini2b84be52020-04-08 10:15:51 +0100298 _fully_connected_output.configure(compile_context, &_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
John Kesapidescafec8f2019-02-19 15:53:59 +0000299
Michalis Spyroubcedf512018-03-22 14:55:08 +0000300 _output2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000301 _forget_gate_out2.allocator()->allocate();
302
303 CLTensor *output_gate_out = &_output4;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000304 if(lstm_params.has_peephole_opt())
305 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000306 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000307
John Kesapidescafec8f2019-02-19 15:53:59 +0000308 _memory_group.manage(&_output3);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100309 _pixelwise_mul_output_state1.configure(compile_context, &_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
310 _accum_output1.configure(compile_context, &_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000311 _output4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000312 output_gate_out = &_output1;
313
314 // Allocate intermediate buffers
John Kesapidescafec8f2019-02-19 15:53:59 +0000315 _output3.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000316 }
317 else
318 {
319 _output1.allocator()->allocate();
320 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100321 if(_is_layer_norm_lstm)
322 {
323 _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
324 _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
325 _memory_group.manage(&_output_layer_norm_out1);
326 _memory_group.manage(&_output_layer_norm_out2);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100327 _mean_std_norm_output_gate.configure(compile_context, output_gate_out);
328 _pixelwise_mul_output_gate_coeff.configure(compile_context, output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE,
329 RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100330 // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
331 output_gate_out->allocator()->allocate();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100332 _accum_output_gate_bias.configure(compile_context, &_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100333 _output_layer_norm_out1.allocator()->allocate();
334 output_gate_out = &_output_layer_norm_out2;
335 }
Manuel Bottini2b84be52020-04-08 10:15:51 +0100336 _activation_output.configure(compile_context, output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000337
Michalis Spyroubcedf512018-03-22 14:55:08 +0000338 // Configure block that calculates the output state
339 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
340 *
341 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
342 * /
343 * output_state = --
344 * \
345 * -- lstm_res , otherwise
346 */
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100347 ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
348 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
349 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
350
Michalis Spyroubcedf512018-03-22 14:55:08 +0000351 _memory_group.manage(&_cell_state_activation);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100352 _activation_output_state.configure(compile_context, &_cell_state_out1, &_cell_state_activation, activation_info);
353 _pixelwise_mul_output_state2.configure(compile_context, &_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000354 _cell_state_activation.allocator()->allocate();
355
356 if(lstm_params.has_projection())
357 {
358 _has_projection_weights = true;
Manuel Bottini2b84be52020-04-08 10:15:51 +0100359 _fully_connected_output_state.configure(compile_context, output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100360 _output_state1.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000361 // Perform clipping
362 if(projection_threshold != 0.f)
363 {
364 _perform_projection_clipping = true;
Manuel Bottini2b84be52020-04-08 10:15:51 +0100365 _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000366 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000367 }
368
369 // Copy cell state and output
Manuel Bottini2b84be52020-04-08 10:15:51 +0100370 _copy_cell_state.configure(compile_context, &_cell_state_out1, cell_state_out);
371 _copy_output.configure(compile_context, output_state_out, output);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000372
373 // Vector for holding the tensors to store in scratch buffer
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100374 std::vector<const ICLTensor *> scratch_inputs;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000375 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000376 {
Georgios Pinitas4f859822019-02-06 18:08:04 +0000377 scratch_inputs.emplace_back(input_gate_out);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000378 }
379 scratch_inputs.emplace_back(&_cell_state_out1);
380 scratch_inputs.emplace_back(forget_gate_out);
381 scratch_inputs.emplace_back(output_gate_out);
Manuel Bottini2b84be52020-04-08 10:15:51 +0100382 _concat_scratch_buffer.configure(compile_context, scratch_inputs, scratch_buffer, Window::DimX);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000383 input_gate_out->allocator()->allocate();
Michele Di Giorgiodd2619a2018-11-05 16:46:09 +0000384 _cell_state_out1.allocator()->allocate();
385 forget_gate_out->allocator()->allocate();
386 output_gate_out->allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000387}
388
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100389Status CLLSTMLayer::validate(const ITensorInfo *input,
390 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000391 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
392 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100393 const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
394 const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000395 const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
396{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100397 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
398 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
399 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
400 forget_gate_bias, cell_bias, output_gate_bias,
401 output_state_in, cell_state_in,
402 scratch_buffer, output_state_out, cell_state_out, output);
403
404 // Check data types
Michalis Spyroubcedf512018-03-22 14:55:08 +0000405 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100406 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
407 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
408 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
409 forget_gate_bias, cell_bias, output_gate_bias,
410 output_state_in, cell_state_in,
411 scratch_buffer, output_state_out, cell_state_out, output);
412
413 // Check dimensions
Georgios Pinitas42447c12018-07-16 17:01:20 +0100414 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
415 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
416 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
417 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
418 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
419 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
420 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
421 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
422 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
423 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100424 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
425 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100426 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100427 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
428 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100429 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100430 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
431 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000432
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100433 const unsigned int num_batches = input->dimension(1);
434 const unsigned int num_cells = input_to_output_weights->dimension(1);
435
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100436 if(lstm_params.use_layer_norm())
437 {
438 // If CIFG is used, input layer normalization weights tensor is omitted
439 if(lstm_params.has_cifg_opt())
440 {
441 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
442 }
443 else
444 {
445 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
446 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100447 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100448 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
449 }
450
451 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
452 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
453 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
454 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
455 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
Sang-Hoon Parkd5c020a2020-05-06 21:01:19 +0100456 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
457 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
458 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100459 }
460
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100461 // Check peephole optimization
Michalis Spyroubcedf512018-03-22 14:55:08 +0000462 if(lstm_params.has_peephole_opt())
463 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +0100464 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100465 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
466 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000467 }
468
469 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000470 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
471 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000472 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
473
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100474 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
475 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
476 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
477 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
478
Michalis Spyroubcedf512018-03-22 14:55:08 +0000479 // Validate forget gate
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100480 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000481
482 std::vector<const ITensorInfo *> inputs_vector;
483 inputs_vector.emplace_back(input);
484 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100485 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000486 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
487
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100488 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
John Kesapidescafec8f2019-02-19 15:53:59 +0000489
Michalis Spyroubcedf512018-03-22 14:55:08 +0000490 if(lstm_params.has_peephole_opt())
491 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100492 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100493 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000494 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100495 if(lstm_params.use_layer_norm())
496 {
497 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100498 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
499 RoundingPolicy::TO_NEAREST_EVEN));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100500 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
501 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100502 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000503
504 // Validate input gate
505 if(!lstm_params.has_cifg_opt())
506 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100507 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
508 lstm_params.recurrent_to_input_weights(),
509 lstm_params.input_gate_bias());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100510 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
511 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100512 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100513
John Kesapidescafec8f2019-02-19 15:53:59 +0000514 std::vector<const ITensorInfo *> lstm_weights;
515 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
516 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100517 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000518 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100519 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
John Kesapidescafec8f2019-02-19 15:53:59 +0000520
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100521 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000522
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100523 if(lstm_params.has_peephole_opt())
524 {
525 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
526 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
Michalis Spyrou1009e872020-07-27 12:48:34 +0100527 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100528 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
529 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100530
531 if(lstm_params.use_layer_norm())
532 {
533 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100534 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100535 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
536 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100537 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000538 }
539 else
540 {
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100541 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000542 }
543
544 // Validate cell state
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100545 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100546 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
547 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100548 if(lstm_params.use_layer_norm())
549 {
550 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100551 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
552 RoundingPolicy::TO_NEAREST_EVEN));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100553 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
554 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100555 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100556 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
557 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100558 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000559 if(cell_threshold != 0.f)
560 {
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100561 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
562 cell_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000563 }
564
John Kesapidescafec8f2019-02-19 15:53:59 +0000565 std::vector<const ITensorInfo *> in_out_weights;
566 in_out_weights.emplace_back(input_to_output_weights);
567 in_out_weights.emplace_back(recurrent_to_output_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100568 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000569 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100570 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100571 // Validate output gate tmp
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100572 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
John Kesapidescafec8f2019-02-19 15:53:59 +0000573
Michalis Spyroubcedf512018-03-22 14:55:08 +0000574 if(lstm_params.has_peephole_opt())
575 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100576 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
577 RoundingPolicy::TO_NEAREST_EVEN));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100578 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000579 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100580 if(lstm_params.use_layer_norm())
581 {
582 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100583 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
584 RoundingPolicy::TO_NEAREST_EVEN));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100585 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
586 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100587 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000588
589 // Validate output state
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100590 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
Michalis Spyrou1009e872020-07-27 12:48:34 +0100591 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000592 if(lstm_params.has_projection())
593 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100594 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000595 if(projection_threshold != 0.f)
596 {
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100597 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, output_state_out,
598 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000599 }
600 }
601
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100602 // Validate copy kernel
603 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
604 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
605
606 // Validate scratch concatenation
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100607 std::vector<const ITensorInfo *> inputs_vector_info_raw;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000608 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000609 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100610 inputs_vector_info_raw.push_back(&input_gate);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000611 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100612 inputs_vector_info_raw.push_back(&cell_state_tmp);
613 inputs_vector_info_raw.push_back(&forget_gate);
614 inputs_vector_info_raw.push_back(&output_gate_tmp);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000615
Georgios Pinitas09f24972019-05-17 18:14:40 +0100616 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000617 return Status{};
618}
619
620void CLLSTMLayer::run()
621{
John Kesapidescafec8f2019-02-19 15:53:59 +0000622 prepare();
623
Georgios Pinitasda953f22019-04-02 17:27:03 +0100624 MemoryGroupResourceScope scope_mg(_memory_group);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000625
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100626 _concat_inputs_forget_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000627
Michalis Spyroubcedf512018-03-22 14:55:08 +0000628 _fully_connected_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000629
630 if(_run_peephole_opt)
631 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100632 _pixelwise_mul_forget_gate.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100633 _accum_forget_gate1.run();
634 }
635 if(_is_layer_norm_lstm)
636 {
637 _mean_std_norm_forget_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100638 _pixelwise_mul_forget_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100639 _accum_forget_gate_bias.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000640 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100641 _activation_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000642
643 if(_run_cifg_opt)
644 {
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100645 CLScheduler::get().enqueue(_ones_memset_kernel);
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100646 _subtract_input_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000647 }
648 else
649 {
650 _fully_connected_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000651
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100652 if(_run_peephole_opt)
653 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100654 _pixelwise_mul_input_gate.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100655 _accum_input_gate1.run();
656 }
657
658 if(_is_layer_norm_lstm)
659 {
660 _mean_std_norm_input_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100661 _pixelwise_mul_input_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100662 _accum_input_gate_bias.run();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100663 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100664 _activation_input_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000665 }
666
667 _fully_connected_cell_state.run();
Georgios Pinitas42a31722018-07-09 14:35:32 +0100668 CLScheduler::get().enqueue(_transpose_cell_state);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000669 _gemm_cell_state1.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100670 _accum_cell_state1.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100671 if(_is_layer_norm_lstm)
672 {
673 _mean_std_norm_cell_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100674 _pixelwise_mul_cell_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100675 _accum_cell_gate_bias.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100676 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100677 _activation_cell_state.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100678 _pixelwise_mul_cell_state1.run();
679 _pixelwise_mul_cell_state2.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100680 _accum_cell_state2.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000681
682 if(_perform_cell_clipping)
683 {
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100684 _cell_clip.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000685 }
686
687 _fully_connected_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000688
689 if(_run_peephole_opt)
690 {
Michalis Spyrou1009e872020-07-27 12:48:34 +0100691 _pixelwise_mul_output_state1.run();
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100692 _accum_output1.run();
693 }
694 if(_is_layer_norm_lstm)
695 {
696 _mean_std_norm_output_gate.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100697 _pixelwise_mul_output_gate_coeff.run();
Michalis Spyrouad7515d2020-07-24 00:02:23 +0100698 _accum_output_gate_bias.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000699 }
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100700 _activation_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000701
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100702 _activation_output_state.run();
Michalis Spyrou1009e872020-07-27 12:48:34 +0100703 _pixelwise_mul_output_state2.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000704
705 if(_has_projection_weights)
706 {
707 _fully_connected_output_state.run();
708 if(_perform_projection_clipping)
709 {
Georgios Pinitasab23dd02020-07-06 14:57:36 +0100710 _projection_clip.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000711 }
712 }
713
714 CLScheduler::get().enqueue(_copy_cell_state);
715 CLScheduler::get().enqueue(_copy_output);
716
717 _concat_scratch_buffer.run();
giuros01164a2722018-11-20 18:34:46 +0000718}
John Kesapidescafec8f2019-02-19 15:53:59 +0000719
720void CLLSTMLayer::prepare()
721{
722 if(!_is_prepared)
723 {
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100724 _concat_weights_forget_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000725 if(!_run_cifg_opt)
726 {
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100727 _concat_weights_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000728 }
Michele Di Giorgiof932d2c2020-07-06 11:27:21 +0100729 _concat_weights_output.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000730 _is_prepared = true;
731 }
732}
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000733} // namespace arm_compute