blob: 3a3917784b50946263fe7f8c082bb6b339289723 [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
Michele Di Giorgio47a89902020-03-09 19:32:33 +00002 * Copyright (c) 2018-2020 ARM Limited.
Michalis Spyroubcedf512018-03-22 14:55:08 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25
Michalis Spyroubcedf512018-03-22 14:55:08 +000026#include "arm_compute/core/Utils.h"
27#include "arm_compute/core/Validate.h"
Michele Di Giorgio47a89902020-03-09 19:32:33 +000028#include "arm_compute/core/utils/misc/InfoHelpers.h"
Michalis Spyroubcedf512018-03-22 14:55:08 +000029#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31#include "arm_compute/runtime/CL/CLScheduler.h"
32
Michele Di Giorgio47a89902020-03-09 19:32:33 +000033namespace arm_compute
34{
Michalis Spyroubcedf512018-03-22 14:55:08 +000035using namespace arm_compute::misc::shape_calculator;
Michele Di Giorgio47a89902020-03-09 19:32:33 +000036using namespace arm_compute::utils::info_helpers;
Michalis Spyroubcedf512018-03-22 14:55:08 +000037
38CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Michele Di Giorgio39438b42019-06-04 12:41:45 +010039 : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
40 _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _transpose_cell_state(),
41 _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(),
42 _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _projection_clip(),
43 _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(),
44 _ones_memset_kernel(), _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(), _pixelwise_mul_forget_gate_coeff(),
45 _accum_forget_gate_bias(), _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(), _pixelwise_mul_output_gate_coeff(),
46 _accum_output_gate_bias(), _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(),
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +010047 _forget_gate_out5(), _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(),
Michele Di Giorgio39438b42019-06-04 12:41:45 +010048 _cell_state_activation(), _output_state1(), _ones(), _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(), _cell_layer_norm_out1(),
49 _cell_layer_norm_out2(), _output_layer_norm_out1(), _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false),
50 _perform_projection_clipping(false), _is_prepared(false), _is_layer_norm_lstm(false)
Michalis Spyroubcedf512018-03-22 14:55:08 +000051{
52}
53
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010054void CLLSTMLayer::configure(const ICLTensor *input,
55 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +000056 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
57 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010058 const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
59 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
60 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +000061{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010062 ARM_COMPUTE_ERROR_ON_NULLPTR(input,
63 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
64 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
65 forget_gate_bias, cell_bias, output_gate_bias,
66 output_state_in, cell_state_in,
67 scratch_buffer, output_state_out, cell_state_out, output);
68
Michele Di Giorgio39438b42019-06-04 12:41:45 +010069 _is_layer_norm_lstm = lstm_params.use_layer_norm();
70
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010071 // Set lstm parameters
Michele Di Giorgio47a89902020-03-09 19:32:33 +000072 LSTMParams<ITensorInfo> lstm_params_info{};
73 build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010074
75 // Validate
Michalis Spyroubcedf512018-03-22 14:55:08 +000076 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
77 input_to_cell_weights->info(), input_to_output_weights->info(),
78 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
79 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010080 output_state_in->info(), cell_state_in->info(),
81 scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
82 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +000083
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010084 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010085 // Configure block that calculates the forget gate
86 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +000087 // We optimize this as follows:
88 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
Michalis Spyroubcedf512018-03-22 14:55:08 +000089 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +000090 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +010091 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +000092
John Kesapidescafec8f2019-02-19 15:53:59 +000093 std::vector<const ICLTensor *> inputs_vector;
94 inputs_vector.emplace_back(input);
95 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +010096 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +000097 _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
98
Michalis Spyroubcedf512018-03-22 14:55:08 +000099 _memory_group.manage(&_forget_gate_out2);
John Kesapidescafec8f2019-02-19 15:53:59 +0000100 _concat_inputs_forget_gate.configure(input, output_state_in, &_forget_gate_out2);
101
102 std::vector<const ICLTensor *> weights_vector;
103
104 weights_vector.emplace_back(input_to_forget_weights);
105 weights_vector.emplace_back(recurrent_to_forget_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100106 const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000107 _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
108
109 _concat_weights_forget_gate.configure(input_to_forget_weights, recurrent_to_forget_weights, &_forget_gate_out6);
110
Georgios Pinitas42a31722018-07-09 14:35:32 +0100111 _memory_group.manage(&_forget_gate_out5);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100112 _fully_connected_forget_gate.configure(&_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
John Kesapidescafec8f2019-02-19 15:53:59 +0000113 _memory_group.manage(&_forget_gate_out1);
114 _memory_group.manage(&_forget_gate_out3);
115 _forget_gate_out6.allocator()->allocate();
116
Georgios Pinitas42a31722018-07-09 14:35:32 +0100117 CLTensor *forget_gate_out = &_forget_gate_out5;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000118 if(lstm_params.has_peephole_opt())
119 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100120 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000121
122 _run_peephole_opt = true;
123 _memory_group.manage(&_forget_gate_out4);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100124 _pixelwise_mul_forget_gate.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100125 _accum_forget_gate1.configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000126 _forget_gate_out4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000127 _forget_gate_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000128 forget_gate_out = &_forget_gate_out3;
129 }
130 else
131 {
132 _forget_gate_out3.allocator()->allocate();
133 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100134 if(_is_layer_norm_lstm)
135 {
136 _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
137 _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
138 _memory_group.manage(&_forget_layer_norm_out1);
139 _memory_group.manage(&_forget_layer_norm_out2);
140 _mean_std_norm_forget_gate.configure(forget_gate_out);
141 _pixelwise_mul_forget_gate_coeff.configure(forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
142 // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
143 forget_gate_out->allocator()->allocate();
144 _accum_forget_gate_bias.configure(ArithmeticOperation::ADD, &_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
145 _forget_layer_norm_out1.allocator()->allocate();
146 forget_gate_out = &_forget_layer_norm_out2;
147 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000148 _activation_forget_gate.configure(forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000149
Michalis Spyroubcedf512018-03-22 14:55:08 +0000150 // Configure block that calculates the input gate
Georgios Pinitas42a31722018-07-09 14:35:32 +0100151 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Michalis Spyroubcedf512018-03-22 14:55:08 +0000152 // input_gate = 1 - forget_gate, with CIFG
John Kesapidescafec8f2019-02-19 15:53:59 +0000153 // We optimize this as follows:
154 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100155 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000156 CLTensor *input_gate_out = &_input_gate_out1;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000157 if(lstm_params.has_cifg_opt())
158 {
159 _memory_group.manage(&_input_gate_out1);
160 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100161 _ones_memset_kernel.configure(&_ones, PixelValue(1, _ones.info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000162 _subtract_input_gate.configure(ArithmeticOperation::SUB, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000163 _ones.allocator()->allocate();
164 _run_cifg_opt = true;
165 }
166 else
167 {
Michalis Spyroubcedf512018-03-22 14:55:08 +0000168 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100169 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000170
171 std::vector<const ICLTensor *> lstm_weights;
172 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
173 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100174 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000175 _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
176
177 _concat_weights_input_gate.configure(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), &_input_gate_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000178
179 _memory_group.manage(&_input_gate_out1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000180
Michalis Spyroubcedf512018-03-22 14:55:08 +0000181 _memory_group.manage(&_input_gate_out3);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100182 _fully_connected_input_gate.configure(&_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000183 _input_gate_out2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000184
185 input_gate_out = &_input_gate_out3;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100186 if(_run_peephole_opt)
187 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000188 _memory_group.manage(&_input_gate_out4);
189 _pixelwise_mul_input_gate.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100190 _accum_input_gate1.configure(&_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000191 _input_gate_out3.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000192 _input_gate_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000193 input_gate_out = &_input_gate_out1;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100194 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000195 else
196 {
197 _input_gate_out1.allocator()->allocate();
198 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100199
200 if(_is_layer_norm_lstm)
201 {
202 _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
203 _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
204 _memory_group.manage(&_input_layer_norm_out1);
205 _memory_group.manage(&_input_layer_norm_out2);
206 _mean_std_norm_input_gate.configure(input_gate_out);
207 _pixelwise_mul_input_gate_coeff.configure(input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
208 // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
209 input_gate_out->allocator()->allocate();
210 _accum_input_gate_bias.configure(ArithmeticOperation::ADD, &_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
211 _input_layer_norm_out1.allocator()->allocate();
212 input_gate_out = &_input_layer_norm_out2;
213 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000214 _activation_input_gate.configure(input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000215 }
216
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100217 // Configure block that calculates the cell state
218 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000219 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
220 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
221 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
222 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
223 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
224 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
225
Michalis Spyroubcedf512018-03-22 14:55:08 +0000226 _memory_group.manage(&_cell_state_out1);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100227 _fully_connected_cell_state.configure(input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000228 _memory_group.manage(&_cell_state_out2);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100229 _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000230 _memory_group.manage(&_cell_state_out3);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100231 _gemm_cell_state1.configure(output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000232 _cell_state_out2.allocator()->allocate();
233 _memory_group.manage(&_cell_state_out4);
giuros01164a2722018-11-20 18:34:46 +0000234 _accum_cell_state1.configure(ArithmeticOperation::ADD, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100235 CLTensor *cell_state_out_ptr = &_cell_state_out4;
236 if(_is_layer_norm_lstm)
237 {
238 _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
239 _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
240 _memory_group.manage(&_cell_layer_norm_out1);
241 _memory_group.manage(&_cell_layer_norm_out2);
242 _mean_std_norm_cell_gate.configure(cell_state_out_ptr);
243 _pixelwise_mul_cell_gate_coeff.configure(cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
244 // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
245 cell_state_out_ptr->allocator()->allocate();
246 _accum_cell_gate_bias.configure(ArithmeticOperation::ADD, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
247 _cell_layer_norm_out1.allocator()->allocate();
248 cell_state_out_ptr = &_cell_layer_norm_out2;
249 }
250 _activation_cell_state.configure(cell_state_out_ptr, nullptr, activation_info);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000251 _memory_group.manage(&_cell_state_out5);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100252 _pixelwise_mul_cell_state1.configure(cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
253 cell_state_out_ptr->allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000254 _pixelwise_mul_cell_state2.configure(forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
giuros01164a2722018-11-20 18:34:46 +0000255 _accum_cell_state2.configure(ArithmeticOperation::ADD, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000256 _cell_state_out3.allocator()->allocate();
257 _cell_state_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000258 // Perform clipping
259 if(cell_threshold != 0.f)
260 {
261 _perform_cell_clipping = true;
262 _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
263 }
264
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100265 // Configure block that calculates the output
266 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000267 // We optimize this as follows:
268 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000269 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000270 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
271 std::vector<const ICLTensor *> in_out_weights;
272 in_out_weights.emplace_back(input_to_output_weights);
273 in_out_weights.emplace_back(recurrent_to_output_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100274 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000275 _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
276
277 _concat_weights_output.configure(input_to_output_weights, recurrent_to_output_weights, &_output2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000278
Michalis Spyroubcedf512018-03-22 14:55:08 +0000279 _memory_group.manage(&_output1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000280 _memory_group.manage(&_output4);
281
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100282 _fully_connected_output.configure(&_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
John Kesapidescafec8f2019-02-19 15:53:59 +0000283
Michalis Spyroubcedf512018-03-22 14:55:08 +0000284 _output2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000285 _forget_gate_out2.allocator()->allocate();
286
287 CLTensor *output_gate_out = &_output4;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000288 if(lstm_params.has_peephole_opt())
289 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000290 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000291
John Kesapidescafec8f2019-02-19 15:53:59 +0000292 _memory_group.manage(&_output3);
293 _pixelwise_mul_output_state1.configure(&_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100294 _accum_output1.configure(&_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000295 _output4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000296 output_gate_out = &_output1;
297
298 // Allocate intermediate buffers
John Kesapidescafec8f2019-02-19 15:53:59 +0000299 _output3.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000300 }
301 else
302 {
303 _output1.allocator()->allocate();
304 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100305 if(_is_layer_norm_lstm)
306 {
307 _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
308 _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
309 _memory_group.manage(&_output_layer_norm_out1);
310 _memory_group.manage(&_output_layer_norm_out2);
311 _mean_std_norm_output_gate.configure(output_gate_out);
312 _pixelwise_mul_output_gate_coeff.configure(output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
313 // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
314 output_gate_out->allocator()->allocate();
315 _accum_output_gate_bias.configure(ArithmeticOperation::ADD, &_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
316 _output_layer_norm_out1.allocator()->allocate();
317 output_gate_out = &_output_layer_norm_out2;
318 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100319 _activation_output.configure(output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000320
Michalis Spyroubcedf512018-03-22 14:55:08 +0000321 // Configure block that calculates the output state
322 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
323 *
324 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
325 * /
326 * output_state = --
327 * \
328 * -- lstm_res , otherwise
329 */
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100330 ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
331 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
332 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
333
Michalis Spyroubcedf512018-03-22 14:55:08 +0000334 _memory_group.manage(&_cell_state_activation);
335 _activation_output_state.configure(&_cell_state_out1, &_cell_state_activation, activation_info);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100336 _pixelwise_mul_output_state2.configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000337 _cell_state_activation.allocator()->allocate();
338
339 if(lstm_params.has_projection())
340 {
341 _has_projection_weights = true;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100342 _fully_connected_output_state.configure(output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
343 _output_state1.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000344 // Perform clipping
345 if(projection_threshold != 0.f)
346 {
347 _perform_projection_clipping = true;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100348 _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000349 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000350 }
351
352 // Copy cell state and output
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100353 _copy_cell_state.configure(&_cell_state_out1, cell_state_out);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100354 _copy_output.configure(output_state_out, output);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000355
356 // Vector for holding the tensors to store in scratch buffer
357 std::vector<ICLTensor *> scratch_inputs;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000358 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000359 {
Georgios Pinitas4f859822019-02-06 18:08:04 +0000360 scratch_inputs.emplace_back(input_gate_out);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000361 }
362 scratch_inputs.emplace_back(&_cell_state_out1);
363 scratch_inputs.emplace_back(forget_gate_out);
364 scratch_inputs.emplace_back(output_gate_out);
Georgios Pinitas09f24972019-05-17 18:14:40 +0100365 _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer, Window::DimX);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000366 input_gate_out->allocator()->allocate();
Michele Di Giorgiodd2619a2018-11-05 16:46:09 +0000367 _cell_state_out1.allocator()->allocate();
368 forget_gate_out->allocator()->allocate();
369 output_gate_out->allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000370}
371
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100372Status CLLSTMLayer::validate(const ITensorInfo *input,
373 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000374 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
375 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100376 const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
377 const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000378 const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
379{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100380 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
381 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
382 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
383 forget_gate_bias, cell_bias, output_gate_bias,
384 output_state_in, cell_state_in,
385 scratch_buffer, output_state_out, cell_state_out, output);
386
387 // Check data types
Michalis Spyroubcedf512018-03-22 14:55:08 +0000388 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100389 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
390 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
391 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
392 forget_gate_bias, cell_bias, output_gate_bias,
393 output_state_in, cell_state_in,
394 scratch_buffer, output_state_out, cell_state_out, output);
395
396 // Check dimensions
Georgios Pinitas42447c12018-07-16 17:01:20 +0100397 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
398 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
399 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
400 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
401 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
402 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
403 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
404 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
405 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
406 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100407 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
408 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100409 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100410 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
411 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100412 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100413 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
414 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000415
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100416 const unsigned int num_batches = input->dimension(1);
417 const unsigned int num_cells = input_to_output_weights->dimension(1);
418
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100419 if(lstm_params.use_layer_norm())
420 {
421 // If CIFG is used, input layer normalization weights tensor is omitted
422 if(lstm_params.has_cifg_opt())
423 {
424 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
425 }
426 else
427 {
428 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
429 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
430 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_batches);
431 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
432 }
433
434 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
435 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
436 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
437 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
438 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
439 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_batches);
440 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_batches);
441 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_batches);
442 }
443
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100444 // Check peephole optimization
Michalis Spyroubcedf512018-03-22 14:55:08 +0000445 if(lstm_params.has_peephole_opt())
446 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +0100447 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100448 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
449 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000450 }
451
452 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000453 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
454 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000455 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
456
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100457 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
458 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
459 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
460 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
461
Michalis Spyroubcedf512018-03-22 14:55:08 +0000462 // Validate forget gate
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100463 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000464
465 std::vector<const ITensorInfo *> inputs_vector;
466 inputs_vector.emplace_back(input);
467 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100468 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000469 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
470
471 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(input, output_state_in, &forget_gate_concat));
472
Michalis Spyroubcedf512018-03-22 14:55:08 +0000473 if(lstm_params.has_peephole_opt())
474 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100475 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
476 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000477 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100478 if(lstm_params.use_layer_norm())
479 {
480 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
481 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
482 RoundingPolicy::TO_NEAREST_EVEN));
483 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
484 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100485 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000486
487 // Validate input gate
488 if(!lstm_params.has_cifg_opt())
489 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100490 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
491 lstm_params.recurrent_to_input_weights(),
492 lstm_params.input_gate_bias());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100493 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
494 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100495 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100496
John Kesapidescafec8f2019-02-19 15:53:59 +0000497 std::vector<const ITensorInfo *> lstm_weights;
498 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
499 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100500 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000501 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
502 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), &lstm_gate_concat));
503
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100504 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000505
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100506 if(lstm_params.has_peephole_opt())
507 {
508 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
509 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
510 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
511 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
512 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100513
514 if(lstm_params.use_layer_norm())
515 {
516 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
517 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
518 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
519 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100520 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000521 }
522 else
523 {
giuros01164a2722018-11-20 18:34:46 +0000524 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000525 }
526
527 // Validate cell state
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100528 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100529 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
530 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100531 if(lstm_params.use_layer_norm())
532 {
533 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
534 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
535 RoundingPolicy::TO_NEAREST_EVEN));
536 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
537 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100538 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info));
539 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
540 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
541 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000542 if(cell_threshold != 0.f)
543 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100544 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
545 cell_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000546 }
547
John Kesapidescafec8f2019-02-19 15:53:59 +0000548 std::vector<const ITensorInfo *> in_out_weights;
549 in_out_weights.emplace_back(input_to_output_weights);
550 in_out_weights.emplace_back(recurrent_to_output_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100551 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000552 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
553 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(input_to_output_weights, recurrent_to_output_weights, &in_out_gate_concat));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100554 // Validate output gate tmp
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100555 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
John Kesapidescafec8f2019-02-19 15:53:59 +0000556
Michalis Spyroubcedf512018-03-22 14:55:08 +0000557 if(lstm_params.has_peephole_opt())
558 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100559 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
560 RoundingPolicy::TO_NEAREST_EVEN));
561 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000562 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100563 if(lstm_params.use_layer_norm())
564 {
565 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
566 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
567 RoundingPolicy::TO_NEAREST_EVEN));
568 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
569 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100570 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000571
572 // Validate output state
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100573 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
574 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000575 if(lstm_params.has_projection())
576 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100577 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000578 if(projection_threshold != 0.f)
579 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100580 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(output_state_out, output_state_out,
581 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000582 }
583 }
584
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100585 // Validate copy kernel
586 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
587 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
588
589 // Validate scratch concatenation
590 std::vector<ITensorInfo *> inputs_vector_info_raw;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000591 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000592 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100593 inputs_vector_info_raw.push_back(&input_gate);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000594 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100595 inputs_vector_info_raw.push_back(&cell_state_tmp);
596 inputs_vector_info_raw.push_back(&forget_gate);
597 inputs_vector_info_raw.push_back(&output_gate_tmp);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000598
Georgios Pinitas09f24972019-05-17 18:14:40 +0100599 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000600 return Status{};
601}
602
603void CLLSTMLayer::run()
604{
John Kesapidescafec8f2019-02-19 15:53:59 +0000605 prepare();
606
Georgios Pinitasda953f22019-04-02 17:27:03 +0100607 MemoryGroupResourceScope scope_mg(_memory_group);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000608
John Kesapidescafec8f2019-02-19 15:53:59 +0000609 CLScheduler::get().enqueue(_concat_inputs_forget_gate);
610
Michalis Spyroubcedf512018-03-22 14:55:08 +0000611 _fully_connected_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000612
613 if(_run_peephole_opt)
614 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100615 CLScheduler::get().enqueue(_pixelwise_mul_forget_gate);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100616 _accum_forget_gate1.run();
617 }
618 if(_is_layer_norm_lstm)
619 {
620 _mean_std_norm_forget_gate.run();
621 CLScheduler::get().enqueue(_pixelwise_mul_forget_gate_coeff);
622 CLScheduler::get().enqueue(_accum_forget_gate_bias);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000623 }
624 CLScheduler::get().enqueue(_activation_forget_gate);
625
626 if(_run_cifg_opt)
627 {
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100628 CLScheduler::get().enqueue(_ones_memset_kernel);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000629 CLScheduler::get().enqueue(_subtract_input_gate);
630 }
631 else
632 {
633 _fully_connected_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000634
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100635 if(_run_peephole_opt)
636 {
637 CLScheduler::get().enqueue(_pixelwise_mul_input_gate);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100638 _accum_input_gate1.run();
639 }
640
641 if(_is_layer_norm_lstm)
642 {
643 _mean_std_norm_input_gate.run();
644 CLScheduler::get().enqueue(_pixelwise_mul_input_gate_coeff);
645 CLScheduler::get().enqueue(_accum_input_gate_bias);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100646 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000647 CLScheduler::get().enqueue(_activation_input_gate);
648 }
649
650 _fully_connected_cell_state.run();
Georgios Pinitas42a31722018-07-09 14:35:32 +0100651 CLScheduler::get().enqueue(_transpose_cell_state);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000652 _gemm_cell_state1.run();
653 CLScheduler::get().enqueue(_accum_cell_state1);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100654 if(_is_layer_norm_lstm)
655 {
656 _mean_std_norm_cell_gate.run();
657 CLScheduler::get().enqueue(_pixelwise_mul_cell_gate_coeff);
658 CLScheduler::get().enqueue(_accum_cell_gate_bias);
659 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000660 CLScheduler::get().enqueue(_activation_cell_state);
661 CLScheduler::get().enqueue(_pixelwise_mul_cell_state1);
662 CLScheduler::get().enqueue(_pixelwise_mul_cell_state2);
663 CLScheduler::get().enqueue(_accum_cell_state2);
664
665 if(_perform_cell_clipping)
666 {
667 CLScheduler::get().enqueue(_cell_clip);
668 }
669
670 _fully_connected_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000671
672 if(_run_peephole_opt)
673 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100674 CLScheduler::get().enqueue(_pixelwise_mul_output_state1);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100675 _accum_output1.run();
676 }
677 if(_is_layer_norm_lstm)
678 {
679 _mean_std_norm_output_gate.run();
680 CLScheduler::get().enqueue(_pixelwise_mul_output_gate_coeff);
681 CLScheduler::get().enqueue(_accum_output_gate_bias);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000682 }
683 CLScheduler::get().enqueue(_activation_output);
684
685 CLScheduler::get().enqueue(_activation_output_state);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100686 CLScheduler::get().enqueue(_pixelwise_mul_output_state2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000687
688 if(_has_projection_weights)
689 {
690 _fully_connected_output_state.run();
691 if(_perform_projection_clipping)
692 {
693 CLScheduler::get().enqueue(_projection_clip);
694 }
695 }
696
697 CLScheduler::get().enqueue(_copy_cell_state);
698 CLScheduler::get().enqueue(_copy_output);
699
700 _concat_scratch_buffer.run();
giuros01164a2722018-11-20 18:34:46 +0000701}
John Kesapidescafec8f2019-02-19 15:53:59 +0000702
703void CLLSTMLayer::prepare()
704{
705 if(!_is_prepared)
706 {
707 CLScheduler::get().enqueue(_concat_weights_forget_gate);
708 if(!_run_cifg_opt)
709 {
710 CLScheduler::get().enqueue(_concat_weights_input_gate);
711 }
712 CLScheduler::get().enqueue(_concat_weights_output);
713 _is_prepared = true;
714 }
715}
Michele Di Giorgio47a89902020-03-09 19:32:33 +0000716} // namespace arm_compute