blob: 793d5ca1a9d042d7de44970bbb78ec5223121d2d [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
Georgios Pinitas4f859822019-02-06 18:08:04 +00002 * Copyright (c) 2018-2019 ARM Limited.
Michalis Spyroubcedf512018-03-22 14:55:08 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25
26#include "arm_compute/core/PixelValue.h"
27#include "arm_compute/core/Utils.h"
28#include "arm_compute/core/Validate.h"
29#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31#include "arm_compute/runtime/CL/CLScheduler.h"
32
33#include <cmath>
34#include <memory>
35#include <tuple>
36
37using namespace arm_compute;
38using namespace arm_compute::misc::shape_calculator;
39
40CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Michele Di Giorgio39438b42019-06-04 12:41:45 +010041 : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
42 _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _transpose_cell_state(),
43 _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(),
44 _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _projection_clip(),
45 _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(),
46 _ones_memset_kernel(), _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(), _pixelwise_mul_forget_gate_coeff(),
47 _accum_forget_gate_bias(), _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(), _pixelwise_mul_output_gate_coeff(),
48 _accum_output_gate_bias(), _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(),
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +010049 _forget_gate_out5(), _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(),
Michele Di Giorgio39438b42019-06-04 12:41:45 +010050 _cell_state_activation(), _output_state1(), _ones(), _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(), _cell_layer_norm_out1(),
51 _cell_layer_norm_out2(), _output_layer_norm_out1(), _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false),
52 _perform_projection_clipping(false), _is_prepared(false), _is_layer_norm_lstm(false)
Michalis Spyroubcedf512018-03-22 14:55:08 +000053{
54}
55
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010056void CLLSTMLayer::configure(const ICLTensor *input,
57 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +000058 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
59 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010060 const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
61 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
62 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +000063{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010064 ARM_COMPUTE_ERROR_ON_NULLPTR(input,
65 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
66 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
67 forget_gate_bias, cell_bias, output_gate_bias,
68 output_state_in, cell_state_in,
69 scratch_buffer, output_state_out, cell_state_out, output);
70
Michele Di Giorgio39438b42019-06-04 12:41:45 +010071 _is_layer_norm_lstm = lstm_params.use_layer_norm();
72
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010073 // Set lstm parameters
Michalis Spyroubcedf512018-03-22 14:55:08 +000074 LSTMParams<ITensorInfo> lstm_params_info;
75 if(lstm_params.has_peephole_opt())
76 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +010077 lstm_params_info.set_peephole_params(lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info());
Michalis Spyroubcedf512018-03-22 14:55:08 +000078 }
79 if(lstm_params.has_projection())
80 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010081 lstm_params_info.set_projection_params(lstm_params.projection_weights()->info(),
82 lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr);
Michalis Spyroubcedf512018-03-22 14:55:08 +000083 }
84 if(!lstm_params.has_cifg_opt())
85 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010086 const ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr;
Michalis Spyroubcedf512018-03-22 14:55:08 +000087 lstm_params_info.set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(),
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010088 cell_to_input_weights_info, lstm_params.input_gate_bias()->info());
Michalis Spyroubcedf512018-03-22 14:55:08 +000089 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010090
91 // Validate
Michalis Spyroubcedf512018-03-22 14:55:08 +000092 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
93 input_to_cell_weights->info(), input_to_output_weights->info(),
94 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
95 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010096 output_state_in->info(), cell_state_in->info(),
97 scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
98 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +000099
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100100 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100101 // Configure block that calculates the forget gate
102 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000103 // We optimize this as follows:
104 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
Michalis Spyroubcedf512018-03-22 14:55:08 +0000105 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000106 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100107 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000108
John Kesapidescafec8f2019-02-19 15:53:59 +0000109 std::vector<const ICLTensor *> inputs_vector;
110 inputs_vector.emplace_back(input);
111 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100112 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000113 _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
114
Michalis Spyroubcedf512018-03-22 14:55:08 +0000115 _memory_group.manage(&_forget_gate_out2);
John Kesapidescafec8f2019-02-19 15:53:59 +0000116 _concat_inputs_forget_gate.configure(input, output_state_in, &_forget_gate_out2);
117
118 std::vector<const ICLTensor *> weights_vector;
119
120 weights_vector.emplace_back(input_to_forget_weights);
121 weights_vector.emplace_back(recurrent_to_forget_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100122 const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000123 _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
124
125 _concat_weights_forget_gate.configure(input_to_forget_weights, recurrent_to_forget_weights, &_forget_gate_out6);
126
Georgios Pinitas42a31722018-07-09 14:35:32 +0100127 _memory_group.manage(&_forget_gate_out5);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100128 _fully_connected_forget_gate.configure(&_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
John Kesapidescafec8f2019-02-19 15:53:59 +0000129 _memory_group.manage(&_forget_gate_out1);
130 _memory_group.manage(&_forget_gate_out3);
131 _forget_gate_out6.allocator()->allocate();
132
Georgios Pinitas42a31722018-07-09 14:35:32 +0100133 CLTensor *forget_gate_out = &_forget_gate_out5;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000134 if(lstm_params.has_peephole_opt())
135 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100136 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000137
138 _run_peephole_opt = true;
139 _memory_group.manage(&_forget_gate_out4);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100140 _pixelwise_mul_forget_gate.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100141 _accum_forget_gate1.configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000142 _forget_gate_out4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000143 _forget_gate_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000144 forget_gate_out = &_forget_gate_out3;
145 }
146 else
147 {
148 _forget_gate_out3.allocator()->allocate();
149 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100150 if(_is_layer_norm_lstm)
151 {
152 _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
153 _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
154 _memory_group.manage(&_forget_layer_norm_out1);
155 _memory_group.manage(&_forget_layer_norm_out2);
156 _mean_std_norm_forget_gate.configure(forget_gate_out);
157 _pixelwise_mul_forget_gate_coeff.configure(forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
158 // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
159 forget_gate_out->allocator()->allocate();
160 _accum_forget_gate_bias.configure(ArithmeticOperation::ADD, &_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
161 _forget_layer_norm_out1.allocator()->allocate();
162 forget_gate_out = &_forget_layer_norm_out2;
163 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000164 _activation_forget_gate.configure(forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000165
Michalis Spyroubcedf512018-03-22 14:55:08 +0000166 // Configure block that calculates the input gate
Georgios Pinitas42a31722018-07-09 14:35:32 +0100167 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Michalis Spyroubcedf512018-03-22 14:55:08 +0000168 // input_gate = 1 - forget_gate, with CIFG
John Kesapidescafec8f2019-02-19 15:53:59 +0000169 // We optimize this as follows:
170 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100171 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000172 CLTensor *input_gate_out = &_input_gate_out1;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000173 if(lstm_params.has_cifg_opt())
174 {
175 _memory_group.manage(&_input_gate_out1);
176 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100177 _ones_memset_kernel.configure(&_ones, PixelValue(1, _ones.info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000178 _subtract_input_gate.configure(ArithmeticOperation::SUB, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000179 _ones.allocator()->allocate();
180 _run_cifg_opt = true;
181 }
182 else
183 {
Michalis Spyroubcedf512018-03-22 14:55:08 +0000184 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100185 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000186
187 std::vector<const ICLTensor *> lstm_weights;
188 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
189 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100190 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000191 _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
192
193 _concat_weights_input_gate.configure(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), &_input_gate_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000194
195 _memory_group.manage(&_input_gate_out1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000196
Michalis Spyroubcedf512018-03-22 14:55:08 +0000197 _memory_group.manage(&_input_gate_out3);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100198 _fully_connected_input_gate.configure(&_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000199 _input_gate_out2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000200
201 input_gate_out = &_input_gate_out3;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100202 if(_run_peephole_opt)
203 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000204 _memory_group.manage(&_input_gate_out4);
205 _pixelwise_mul_input_gate.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100206 _accum_input_gate1.configure(&_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000207 _input_gate_out3.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000208 _input_gate_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000209 input_gate_out = &_input_gate_out1;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100210 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000211 else
212 {
213 _input_gate_out1.allocator()->allocate();
214 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100215
216 if(_is_layer_norm_lstm)
217 {
218 _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
219 _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
220 _memory_group.manage(&_input_layer_norm_out1);
221 _memory_group.manage(&_input_layer_norm_out2);
222 _mean_std_norm_input_gate.configure(input_gate_out);
223 _pixelwise_mul_input_gate_coeff.configure(input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
224 // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
225 input_gate_out->allocator()->allocate();
226 _accum_input_gate_bias.configure(ArithmeticOperation::ADD, &_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
227 _input_layer_norm_out1.allocator()->allocate();
228 input_gate_out = &_input_layer_norm_out2;
229 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000230 _activation_input_gate.configure(input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000231 }
232
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100233 // Configure block that calculates the cell state
234 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000235 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
236 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
237 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
238 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
239 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
240 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
241
Michalis Spyroubcedf512018-03-22 14:55:08 +0000242 _memory_group.manage(&_cell_state_out1);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100243 _fully_connected_cell_state.configure(input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000244 _memory_group.manage(&_cell_state_out2);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100245 _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000246 _memory_group.manage(&_cell_state_out3);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100247 _gemm_cell_state1.configure(output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000248 _cell_state_out2.allocator()->allocate();
249 _memory_group.manage(&_cell_state_out4);
giuros01164a2722018-11-20 18:34:46 +0000250 _accum_cell_state1.configure(ArithmeticOperation::ADD, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100251 CLTensor *cell_state_out_ptr = &_cell_state_out4;
252 if(_is_layer_norm_lstm)
253 {
254 _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
255 _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
256 _memory_group.manage(&_cell_layer_norm_out1);
257 _memory_group.manage(&_cell_layer_norm_out2);
258 _mean_std_norm_cell_gate.configure(cell_state_out_ptr);
259 _pixelwise_mul_cell_gate_coeff.configure(cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
260 // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
261 cell_state_out_ptr->allocator()->allocate();
262 _accum_cell_gate_bias.configure(ArithmeticOperation::ADD, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
263 _cell_layer_norm_out1.allocator()->allocate();
264 cell_state_out_ptr = &_cell_layer_norm_out2;
265 }
266 _activation_cell_state.configure(cell_state_out_ptr, nullptr, activation_info);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000267 _memory_group.manage(&_cell_state_out5);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100268 _pixelwise_mul_cell_state1.configure(cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
269 cell_state_out_ptr->allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000270 _pixelwise_mul_cell_state2.configure(forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
giuros01164a2722018-11-20 18:34:46 +0000271 _accum_cell_state2.configure(ArithmeticOperation::ADD, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000272 _cell_state_out3.allocator()->allocate();
273 _cell_state_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000274 // Perform clipping
275 if(cell_threshold != 0.f)
276 {
277 _perform_cell_clipping = true;
278 _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
279 }
280
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100281 // Configure block that calculates the output
282 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000283 // We optimize this as follows:
284 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000285 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000286 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
287 std::vector<const ICLTensor *> in_out_weights;
288 in_out_weights.emplace_back(input_to_output_weights);
289 in_out_weights.emplace_back(recurrent_to_output_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100290 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000291 _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
292
293 _concat_weights_output.configure(input_to_output_weights, recurrent_to_output_weights, &_output2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000294
Michalis Spyroubcedf512018-03-22 14:55:08 +0000295 _memory_group.manage(&_output1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000296 _memory_group.manage(&_output4);
297
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100298 _fully_connected_output.configure(&_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
John Kesapidescafec8f2019-02-19 15:53:59 +0000299
Michalis Spyroubcedf512018-03-22 14:55:08 +0000300 _output2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000301 _forget_gate_out2.allocator()->allocate();
302
303 CLTensor *output_gate_out = &_output4;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000304 if(lstm_params.has_peephole_opt())
305 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000306 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000307
John Kesapidescafec8f2019-02-19 15:53:59 +0000308 _memory_group.manage(&_output3);
309 _pixelwise_mul_output_state1.configure(&_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100310 _accum_output1.configure(&_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
John Kesapidescafec8f2019-02-19 15:53:59 +0000311 _output4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000312 output_gate_out = &_output1;
313
314 // Allocate intermediate buffers
John Kesapidescafec8f2019-02-19 15:53:59 +0000315 _output3.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000316 }
317 else
318 {
319 _output1.allocator()->allocate();
320 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100321 if(_is_layer_norm_lstm)
322 {
323 _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
324 _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
325 _memory_group.manage(&_output_layer_norm_out1);
326 _memory_group.manage(&_output_layer_norm_out2);
327 _mean_std_norm_output_gate.configure(output_gate_out);
328 _pixelwise_mul_output_gate_coeff.configure(output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
329 // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
330 output_gate_out->allocator()->allocate();
331 _accum_output_gate_bias.configure(ArithmeticOperation::ADD, &_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
332 _output_layer_norm_out1.allocator()->allocate();
333 output_gate_out = &_output_layer_norm_out2;
334 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100335 _activation_output.configure(output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000336
Michalis Spyroubcedf512018-03-22 14:55:08 +0000337 // Configure block that calculates the output state
338 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
339 *
340 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
341 * /
342 * output_state = --
343 * \
344 * -- lstm_res , otherwise
345 */
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100346 ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
347 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
348 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
349
Michalis Spyroubcedf512018-03-22 14:55:08 +0000350 _memory_group.manage(&_cell_state_activation);
351 _activation_output_state.configure(&_cell_state_out1, &_cell_state_activation, activation_info);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100352 _pixelwise_mul_output_state2.configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000353 _cell_state_activation.allocator()->allocate();
354
355 if(lstm_params.has_projection())
356 {
357 _has_projection_weights = true;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100358 _fully_connected_output_state.configure(output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
359 _output_state1.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000360 // Perform clipping
361 if(projection_threshold != 0.f)
362 {
363 _perform_projection_clipping = true;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100364 _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000365 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000366 }
367
368 // Copy cell state and output
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100369 _copy_cell_state.configure(&_cell_state_out1, cell_state_out);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100370 _copy_output.configure(output_state_out, output);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000371
372 // Vector for holding the tensors to store in scratch buffer
373 std::vector<ICLTensor *> scratch_inputs;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000374 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000375 {
Georgios Pinitas4f859822019-02-06 18:08:04 +0000376 scratch_inputs.emplace_back(input_gate_out);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000377 }
378 scratch_inputs.emplace_back(&_cell_state_out1);
379 scratch_inputs.emplace_back(forget_gate_out);
380 scratch_inputs.emplace_back(output_gate_out);
Georgios Pinitas09f24972019-05-17 18:14:40 +0100381 _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer, Window::DimX);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000382 input_gate_out->allocator()->allocate();
Michele Di Giorgiodd2619a2018-11-05 16:46:09 +0000383 _cell_state_out1.allocator()->allocate();
384 forget_gate_out->allocator()->allocate();
385 output_gate_out->allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000386}
387
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100388Status CLLSTMLayer::validate(const ITensorInfo *input,
389 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000390 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
391 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100392 const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
393 const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000394 const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
395{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100396 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
397 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
398 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
399 forget_gate_bias, cell_bias, output_gate_bias,
400 output_state_in, cell_state_in,
401 scratch_buffer, output_state_out, cell_state_out, output);
402
403 // Check data types
Michalis Spyroubcedf512018-03-22 14:55:08 +0000404 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100405 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
406 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
407 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
408 forget_gate_bias, cell_bias, output_gate_bias,
409 output_state_in, cell_state_in,
410 scratch_buffer, output_state_out, cell_state_out, output);
411
412 // Check dimensions
Georgios Pinitas42447c12018-07-16 17:01:20 +0100413 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
414 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
415 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
416 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
417 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
418 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
419 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
420 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
421 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
422 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100423 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
424 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100425 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100426 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
427 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100428 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100429 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
430 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000431
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100432 const unsigned int num_batches = input->dimension(1);
433 const unsigned int num_cells = input_to_output_weights->dimension(1);
434
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100435 if(lstm_params.use_layer_norm())
436 {
437 // If CIFG is used, input layer normalization weights tensor is omitted
438 if(lstm_params.has_cifg_opt())
439 {
440 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
441 }
442 else
443 {
444 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
445 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
446 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_batches);
447 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
448 }
449
450 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
451 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
452 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
453 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
454 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
455 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_batches);
456 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_batches);
457 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_batches);
458 }
459
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100460 // Check peephole optimization
Michalis Spyroubcedf512018-03-22 14:55:08 +0000461 if(lstm_params.has_peephole_opt())
462 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +0100463 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100464 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
465 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000466 }
467
468 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000469 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
470 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000471 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
472
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100473 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
474 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
475 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
476 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
477
Michalis Spyroubcedf512018-03-22 14:55:08 +0000478 // Validate forget gate
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100479 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000480
481 std::vector<const ITensorInfo *> inputs_vector;
482 inputs_vector.emplace_back(input);
483 inputs_vector.emplace_back(output_state_in);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100484 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000485 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
486
487 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(input, output_state_in, &forget_gate_concat));
488
Michalis Spyroubcedf512018-03-22 14:55:08 +0000489 if(lstm_params.has_peephole_opt())
490 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100491 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
492 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000493 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100494 if(lstm_params.use_layer_norm())
495 {
496 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
497 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
498 RoundingPolicy::TO_NEAREST_EVEN));
499 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
500 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100501 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000502
503 // Validate input gate
504 if(!lstm_params.has_cifg_opt())
505 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100506 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
507 lstm_params.recurrent_to_input_weights(),
508 lstm_params.input_gate_bias());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100509 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
510 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100511 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100512
John Kesapidescafec8f2019-02-19 15:53:59 +0000513 std::vector<const ITensorInfo *> lstm_weights;
514 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
515 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100516 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000517 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
518 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), &lstm_gate_concat));
519
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100520 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000521
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100522 if(lstm_params.has_peephole_opt())
523 {
524 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
525 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
526 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
527 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
528 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100529
530 if(lstm_params.use_layer_norm())
531 {
532 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
533 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
534 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
535 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100536 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000537 }
538 else
539 {
giuros01164a2722018-11-20 18:34:46 +0000540 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000541 }
542
543 // Validate cell state
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100544 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100545 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
546 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100547 if(lstm_params.use_layer_norm())
548 {
549 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
550 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
551 RoundingPolicy::TO_NEAREST_EVEN));
552 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
553 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100554 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info));
555 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
556 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
557 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000558 if(cell_threshold != 0.f)
559 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100560 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
561 cell_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000562 }
563
John Kesapidescafec8f2019-02-19 15:53:59 +0000564 std::vector<const ITensorInfo *> in_out_weights;
565 in_out_weights.emplace_back(input_to_output_weights);
566 in_out_weights.emplace_back(recurrent_to_output_weights);
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100567 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
John Kesapidescafec8f2019-02-19 15:53:59 +0000568 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
569 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(input_to_output_weights, recurrent_to_output_weights, &in_out_gate_concat));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100570 // Validate output gate tmp
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100571 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
John Kesapidescafec8f2019-02-19 15:53:59 +0000572
Michalis Spyroubcedf512018-03-22 14:55:08 +0000573 if(lstm_params.has_peephole_opt())
574 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100575 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
576 RoundingPolicy::TO_NEAREST_EVEN));
577 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000578 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100579 if(lstm_params.use_layer_norm())
580 {
581 ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
582 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
583 RoundingPolicy::TO_NEAREST_EVEN));
584 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
585 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100586 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000587
588 // Validate output state
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100589 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
590 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000591 if(lstm_params.has_projection())
592 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100593 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000594 if(projection_threshold != 0.f)
595 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100596 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(output_state_out, output_state_out,
597 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000598 }
599 }
600
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100601 // Validate copy kernel
602 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
603 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
604
605 // Validate scratch concatenation
606 std::vector<ITensorInfo *> inputs_vector_info_raw;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000607 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000608 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100609 inputs_vector_info_raw.push_back(&input_gate);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000610 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100611 inputs_vector_info_raw.push_back(&cell_state_tmp);
612 inputs_vector_info_raw.push_back(&forget_gate);
613 inputs_vector_info_raw.push_back(&output_gate_tmp);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000614
Georgios Pinitas09f24972019-05-17 18:14:40 +0100615 ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000616 return Status{};
617}
618
619void CLLSTMLayer::run()
620{
John Kesapidescafec8f2019-02-19 15:53:59 +0000621 prepare();
622
Georgios Pinitasda953f22019-04-02 17:27:03 +0100623 MemoryGroupResourceScope scope_mg(_memory_group);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000624
John Kesapidescafec8f2019-02-19 15:53:59 +0000625 CLScheduler::get().enqueue(_concat_inputs_forget_gate);
626
Michalis Spyroubcedf512018-03-22 14:55:08 +0000627 _fully_connected_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000628
629 if(_run_peephole_opt)
630 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100631 CLScheduler::get().enqueue(_pixelwise_mul_forget_gate);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100632 _accum_forget_gate1.run();
633 }
634 if(_is_layer_norm_lstm)
635 {
636 _mean_std_norm_forget_gate.run();
637 CLScheduler::get().enqueue(_pixelwise_mul_forget_gate_coeff);
638 CLScheduler::get().enqueue(_accum_forget_gate_bias);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000639 }
640 CLScheduler::get().enqueue(_activation_forget_gate);
641
642 if(_run_cifg_opt)
643 {
Georgios Pinitasdbfc2dc2019-04-02 12:51:21 +0100644 CLScheduler::get().enqueue(_ones_memset_kernel);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000645 CLScheduler::get().enqueue(_subtract_input_gate);
646 }
647 else
648 {
649 _fully_connected_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000650
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100651 if(_run_peephole_opt)
652 {
653 CLScheduler::get().enqueue(_pixelwise_mul_input_gate);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100654 _accum_input_gate1.run();
655 }
656
657 if(_is_layer_norm_lstm)
658 {
659 _mean_std_norm_input_gate.run();
660 CLScheduler::get().enqueue(_pixelwise_mul_input_gate_coeff);
661 CLScheduler::get().enqueue(_accum_input_gate_bias);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100662 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000663 CLScheduler::get().enqueue(_activation_input_gate);
664 }
665
666 _fully_connected_cell_state.run();
Georgios Pinitas42a31722018-07-09 14:35:32 +0100667 CLScheduler::get().enqueue(_transpose_cell_state);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000668 _gemm_cell_state1.run();
669 CLScheduler::get().enqueue(_accum_cell_state1);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100670 if(_is_layer_norm_lstm)
671 {
672 _mean_std_norm_cell_gate.run();
673 CLScheduler::get().enqueue(_pixelwise_mul_cell_gate_coeff);
674 CLScheduler::get().enqueue(_accum_cell_gate_bias);
675 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000676 CLScheduler::get().enqueue(_activation_cell_state);
677 CLScheduler::get().enqueue(_pixelwise_mul_cell_state1);
678 CLScheduler::get().enqueue(_pixelwise_mul_cell_state2);
679 CLScheduler::get().enqueue(_accum_cell_state2);
680
681 if(_perform_cell_clipping)
682 {
683 CLScheduler::get().enqueue(_cell_clip);
684 }
685
686 _fully_connected_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000687
688 if(_run_peephole_opt)
689 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100690 CLScheduler::get().enqueue(_pixelwise_mul_output_state1);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100691 _accum_output1.run();
692 }
693 if(_is_layer_norm_lstm)
694 {
695 _mean_std_norm_output_gate.run();
696 CLScheduler::get().enqueue(_pixelwise_mul_output_gate_coeff);
697 CLScheduler::get().enqueue(_accum_output_gate_bias);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000698 }
699 CLScheduler::get().enqueue(_activation_output);
700
701 CLScheduler::get().enqueue(_activation_output_state);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100702 CLScheduler::get().enqueue(_pixelwise_mul_output_state2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000703
704 if(_has_projection_weights)
705 {
706 _fully_connected_output_state.run();
707 if(_perform_projection_clipping)
708 {
709 CLScheduler::get().enqueue(_projection_clip);
710 }
711 }
712
713 CLScheduler::get().enqueue(_copy_cell_state);
714 CLScheduler::get().enqueue(_copy_output);
715
716 _concat_scratch_buffer.run();
giuros01164a2722018-11-20 18:34:46 +0000717}
John Kesapidescafec8f2019-02-19 15:53:59 +0000718
719void CLLSTMLayer::prepare()
720{
721 if(!_is_prepared)
722 {
723 CLScheduler::get().enqueue(_concat_weights_forget_gate);
724 if(!_run_cifg_opt)
725 {
726 CLScheduler::get().enqueue(_concat_weights_input_gate);
727 }
728 CLScheduler::get().enqueue(_concat_weights_output);
729 _is_prepared = true;
730 }
731}