blob: a004762a4e3bfe6494093d21b8e1fa8190f2c08e [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
Georgios Pinitas4f859822019-02-06 18:08:04 +00002 * Copyright (c) 2018-2019 ARM Limited.
Michalis Spyroubcedf512018-03-22 14:55:08 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25
26#include "arm_compute/core/PixelValue.h"
27#include "arm_compute/core/Utils.h"
28#include "arm_compute/core/Validate.h"
29#include "arm_compute/core/utils/misc/ShapeCalculator.h"
30#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31#include "arm_compute/runtime/CL/CLScheduler.h"
32
33#include <cmath>
34#include <memory>
35#include <tuple>
36
37using namespace arm_compute;
38using namespace arm_compute::misc::shape_calculator;
39
40CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
Georgios Pinitas42a31722018-07-09 14:35:32 +010041 : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _gemm_input_gate(), _transpose_input_gate(), _accum_input_gate1(), _accum_input_gate2(), _subtract_input_gate(),
42 _pixelwise_mul_input_gate(), _activation_input_gate(), _fully_connected_forget_gate(), _gemm_forget_gate(), _transpose_forget_gate(), _accum_forget_gate1(), _accum_forget_gate2(),
43 _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _gemm_cell_state2(), _transpose_cell_state(), _accum_cell_state1(), _accum_cell_state2(),
44 _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(), _gemm_output(), _pixelwise_mul_output_state1(), _transpose_output(),
45 _accum_output1(), _accum_output2(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _gemm_output_state(), _accum_output_state(),
John Kesapidescafec8f2019-02-19 15:53:59 +000046 _projection_clip(), _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(),
47 _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(), _forget_gate_out5(),
48 _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(), _cell_state_activation(),
49 _output_state1(), _ones(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false), _perform_projection_clipping(false), _is_prepared(false)
Michalis Spyroubcedf512018-03-22 14:55:08 +000050{
51}
52
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010053void CLLSTMLayer::configure(const ICLTensor *input,
54 const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +000055 const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
56 const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010057 const ICLTensor *output_state_in, const ICLTensor *cell_state_in,
58 ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
59 const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +000060{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010061 ARM_COMPUTE_ERROR_ON_NULLPTR(input,
62 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
63 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
64 forget_gate_bias, cell_bias, output_gate_bias,
65 output_state_in, cell_state_in,
66 scratch_buffer, output_state_out, cell_state_out, output);
67
68 // Set lstm parameters
Michalis Spyroubcedf512018-03-22 14:55:08 +000069 LSTMParams<ITensorInfo> lstm_params_info;
70 if(lstm_params.has_peephole_opt())
71 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +010072 lstm_params_info.set_peephole_params(lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info());
Michalis Spyroubcedf512018-03-22 14:55:08 +000073 }
74 if(lstm_params.has_projection())
75 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010076 lstm_params_info.set_projection_params(lstm_params.projection_weights()->info(),
77 lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr);
Michalis Spyroubcedf512018-03-22 14:55:08 +000078 }
79 if(!lstm_params.has_cifg_opt())
80 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010081 const ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr;
Michalis Spyroubcedf512018-03-22 14:55:08 +000082 lstm_params_info.set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(),
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010083 cell_to_input_weights_info, lstm_params.input_gate_bias()->info());
Michalis Spyroubcedf512018-03-22 14:55:08 +000084 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010085
86 // Validate
Michalis Spyroubcedf512018-03-22 14:55:08 +000087 ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
88 input_to_cell_weights->info(), input_to_output_weights->info(),
89 recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
90 forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010091 output_state_in->info(), cell_state_in->info(),
92 scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
93 lstm_params_info, activation_info, cell_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +000094
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010095 const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010096 // Configure block that calculates the forget gate
97 // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +000098 // We optimize this as follows:
99 // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
Michalis Spyroubcedf512018-03-22 14:55:08 +0000100 _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000101 _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100102 _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000103
John Kesapidescafec8f2019-02-19 15:53:59 +0000104 std::vector<const ICLTensor *> inputs_vector;
105 inputs_vector.emplace_back(input);
106 inputs_vector.emplace_back(output_state_in);
107 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
108 _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
109
Michalis Spyroubcedf512018-03-22 14:55:08 +0000110 _memory_group.manage(&_forget_gate_out2);
John Kesapidescafec8f2019-02-19 15:53:59 +0000111 _concat_inputs_forget_gate.configure(input, output_state_in, &_forget_gate_out2);
112
113 std::vector<const ICLTensor *> weights_vector;
114
115 weights_vector.emplace_back(input_to_forget_weights);
116 weights_vector.emplace_back(recurrent_to_forget_weights);
117 const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(weights_vector);
118 _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
119
120 _concat_weights_forget_gate.configure(input_to_forget_weights, recurrent_to_forget_weights, &_forget_gate_out6);
121
Georgios Pinitas42a31722018-07-09 14:35:32 +0100122 _memory_group.manage(&_forget_gate_out5);
John Kesapidescafec8f2019-02-19 15:53:59 +0000123 _fully_connected_forget_gate.configure(&_forget_gate_out2, &_forget_gate_out6, forget_gate_bias, &_forget_gate_out5);
124 _memory_group.manage(&_forget_gate_out1);
125 _memory_group.manage(&_forget_gate_out3);
126 _forget_gate_out6.allocator()->allocate();
127
Georgios Pinitas42a31722018-07-09 14:35:32 +0100128 CLTensor *forget_gate_out = &_forget_gate_out5;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000129 if(lstm_params.has_peephole_opt())
130 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100131 _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000132
133 _run_peephole_opt = true;
134 _memory_group.manage(&_forget_gate_out4);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100135 _pixelwise_mul_forget_gate.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100136 _accum_forget_gate2.configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000137 _forget_gate_out4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000138 _forget_gate_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000139 forget_gate_out = &_forget_gate_out3;
140 }
141 else
142 {
143 _forget_gate_out3.allocator()->allocate();
144 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000145 _activation_forget_gate.configure(forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000146
Michalis Spyroubcedf512018-03-22 14:55:08 +0000147 // Configure block that calculates the input gate
Georgios Pinitas42a31722018-07-09 14:35:32 +0100148 // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Michalis Spyroubcedf512018-03-22 14:55:08 +0000149 // input_gate = 1 - forget_gate, with CIFG
John Kesapidescafec8f2019-02-19 15:53:59 +0000150 // We optimize this as follows:
151 // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100152 _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000153 CLTensor *input_gate_out = &_input_gate_out1;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000154 if(lstm_params.has_cifg_opt())
155 {
156 _memory_group.manage(&_input_gate_out1);
157 _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas4f859822019-02-06 18:08:04 +0000158 _subtract_input_gate.configure(ArithmeticOperation::SUB, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000159 _ones.allocator()->allocate();
160 _run_cifg_opt = true;
161 }
162 else
163 {
Michalis Spyroubcedf512018-03-22 14:55:08 +0000164 _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
Georgios Pinitas42a31722018-07-09 14:35:32 +0100165 _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000166
167 std::vector<const ICLTensor *> lstm_weights;
168 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
169 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
170 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(lstm_weights);
171 _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
172
173 _concat_weights_input_gate.configure(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), &_input_gate_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000174
175 _memory_group.manage(&_input_gate_out1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000176
Michalis Spyroubcedf512018-03-22 14:55:08 +0000177 _memory_group.manage(&_input_gate_out3);
John Kesapidescafec8f2019-02-19 15:53:59 +0000178 _fully_connected_input_gate.configure(&_forget_gate_out2, &_input_gate_out2, lstm_params.input_gate_bias(), &_input_gate_out3);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000179 _input_gate_out2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000180
181 input_gate_out = &_input_gate_out3;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100182 if(_run_peephole_opt)
183 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000184 _memory_group.manage(&_input_gate_out4);
185 _pixelwise_mul_input_gate.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
186 _accum_input_gate2.configure(&_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
187 _input_gate_out3.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000188 _input_gate_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000189 input_gate_out = &_input_gate_out1;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100190 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000191 else
192 {
193 _input_gate_out1.allocator()->allocate();
194 }
195 _activation_input_gate.configure(input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000196 }
197
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100198 // Configure block that calculates the cell state
199 // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000200 TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
201 _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
202 _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
203 _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
204 _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
205 _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
206
Michalis Spyroubcedf512018-03-22 14:55:08 +0000207 _memory_group.manage(&_cell_state_out1);
Georgios Pinitas7d66a8e2018-07-17 12:28:42 +0100208 _fully_connected_cell_state.configure(input, input_to_cell_weights, cell_bias, &_cell_state_out1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000209 _memory_group.manage(&_cell_state_out2);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100210 _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000211 _memory_group.manage(&_cell_state_out3);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100212 _gemm_cell_state1.configure(output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000213 _cell_state_out2.allocator()->allocate();
214 _memory_group.manage(&_cell_state_out4);
giuros01164a2722018-11-20 18:34:46 +0000215 _accum_cell_state1.configure(ArithmeticOperation::ADD, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000216 _activation_cell_state.configure(&_cell_state_out4, nullptr, activation_info);
217 _memory_group.manage(&_cell_state_out5);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000218 _pixelwise_mul_cell_state1.configure(&_cell_state_out4, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000219 _cell_state_out4.allocator()->allocate();
Georgios Pinitas4f859822019-02-06 18:08:04 +0000220 _pixelwise_mul_cell_state2.configure(forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
giuros01164a2722018-11-20 18:34:46 +0000221 _accum_cell_state2.configure(ArithmeticOperation::ADD, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000222 _cell_state_out3.allocator()->allocate();
223 _cell_state_out5.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000224 // Perform clipping
225 if(cell_threshold != 0.f)
226 {
227 _perform_cell_clipping = true;
228 _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
229 }
230
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100231 // Configure block that calculates the output
232 // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
John Kesapidescafec8f2019-02-19 15:53:59 +0000233 // We optimize this as follows:
234 // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000235 _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
John Kesapidescafec8f2019-02-19 15:53:59 +0000236 _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
237 std::vector<const ICLTensor *> in_out_weights;
238 in_out_weights.emplace_back(input_to_output_weights);
239 in_out_weights.emplace_back(recurrent_to_output_weights);
240 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(in_out_weights);
241 _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
242
243 _concat_weights_output.configure(input_to_output_weights, recurrent_to_output_weights, &_output2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000244
Michalis Spyroubcedf512018-03-22 14:55:08 +0000245 _memory_group.manage(&_output1);
John Kesapidescafec8f2019-02-19 15:53:59 +0000246 _memory_group.manage(&_output4);
247
248 _fully_connected_output.configure(&_forget_gate_out2, &_output2, output_gate_bias, &_output4);
249
Michalis Spyroubcedf512018-03-22 14:55:08 +0000250 _output2.allocator()->allocate();
John Kesapidescafec8f2019-02-19 15:53:59 +0000251 _forget_gate_out2.allocator()->allocate();
252
253 CLTensor *output_gate_out = &_output4;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000254 if(lstm_params.has_peephole_opt())
255 {
John Kesapidescafec8f2019-02-19 15:53:59 +0000256 _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000257
John Kesapidescafec8f2019-02-19 15:53:59 +0000258 _memory_group.manage(&_output3);
259 _pixelwise_mul_output_state1.configure(&_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
260 _accum_output2.configure(&_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
261 _output4.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000262 output_gate_out = &_output1;
263
264 // Allocate intermediate buffers
John Kesapidescafec8f2019-02-19 15:53:59 +0000265 _output3.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000266 }
267 else
268 {
269 _output1.allocator()->allocate();
270 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100271 _activation_output.configure(output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000272
Michalis Spyroubcedf512018-03-22 14:55:08 +0000273 // Configure block that calculates the output state
274 /** lstm_res = PixelwiseMul(output, Activation(cell_state))
275 *
276 * -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
277 * /
278 * output_state = --
279 * \
280 * -- lstm_res , otherwise
281 */
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100282 ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
283 _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
284 _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
285
Michalis Spyroubcedf512018-03-22 14:55:08 +0000286 _memory_group.manage(&_cell_state_activation);
287 _activation_output_state.configure(&_cell_state_out1, &_cell_state_activation, activation_info);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100288 _pixelwise_mul_output_state2.configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000289 _cell_state_activation.allocator()->allocate();
290
291 if(lstm_params.has_projection())
292 {
293 _has_projection_weights = true;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100294 _fully_connected_output_state.configure(output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
295 _output_state1.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000296 // Perform clipping
297 if(projection_threshold != 0.f)
298 {
299 _perform_projection_clipping = true;
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100300 _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000301 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000302 }
303
304 // Copy cell state and output
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100305 _copy_cell_state.configure(&_cell_state_out1, cell_state_out);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100306 _copy_output.configure(output_state_out, output);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000307
308 // Vector for holding the tensors to store in scratch buffer
309 std::vector<ICLTensor *> scratch_inputs;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000310 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000311 {
Georgios Pinitas4f859822019-02-06 18:08:04 +0000312 scratch_inputs.emplace_back(input_gate_out);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000313 }
314 scratch_inputs.emplace_back(&_cell_state_out1);
315 scratch_inputs.emplace_back(forget_gate_out);
316 scratch_inputs.emplace_back(output_gate_out);
317 _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000318 input_gate_out->allocator()->allocate();
Michele Di Giorgiodd2619a2018-11-05 16:46:09 +0000319 _cell_state_out1.allocator()->allocate();
320 forget_gate_out->allocator()->allocate();
321 output_gate_out->allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000322}
323
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100324Status CLLSTMLayer::validate(const ITensorInfo *input,
325 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000326 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
327 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100328 const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
329 const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
Michalis Spyroubcedf512018-03-22 14:55:08 +0000330 const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
331{
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100332 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
333 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
334 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
335 forget_gate_bias, cell_bias, output_gate_bias,
336 output_state_in, cell_state_in,
337 scratch_buffer, output_state_out, cell_state_out, output);
338
339 // Check data types
Michalis Spyroubcedf512018-03-22 14:55:08 +0000340 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100341 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
342 input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
343 recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
344 forget_gate_bias, cell_bias, output_gate_bias,
345 output_state_in, cell_state_in,
346 scratch_buffer, output_state_out, cell_state_out, output);
347
348 // Check dimensions
Georgios Pinitas42447c12018-07-16 17:01:20 +0100349 ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
350 ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
351 ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
352 ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
353 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
354 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
355 ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
356 ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
357 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
358 ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100359 ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
360 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100361 ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100362 ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
363 ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100364 ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100365 ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
366 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000367
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100368 const unsigned int num_batches = input->dimension(1);
369 const unsigned int num_cells = input_to_output_weights->dimension(1);
370
371 // Check peephole optimization
Michalis Spyroubcedf512018-03-22 14:55:08 +0000372 if(lstm_params.has_peephole_opt())
373 {
Michalis Spyrou09daf4d2018-06-28 17:07:22 +0100374 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100375 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
376 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000377 }
378
379 TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000380 TensorShape num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
381 const TensorInfo units_out_transposed_info = TensorInfo(units_out_transposed_shape, 1, input->data_type());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000382 const TensorInfo num_units_transposed_info = TensorInfo(num_units_transposed_shape, 1, input->data_type());
383
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100384 TensorInfo input_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
385 TensorInfo forget_gate = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
386 TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
387 TensorInfo cell_state_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
388
Michalis Spyroubcedf512018-03-22 14:55:08 +0000389 // Validate forget gate
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100390 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, forget_gate_bias, &forget_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000391
392 std::vector<const ITensorInfo *> inputs_vector;
393 inputs_vector.emplace_back(input);
394 inputs_vector.emplace_back(output_state_in);
395 const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(inputs_vector);
396 TensorInfo forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
397
398 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(input, output_state_in, &forget_gate_concat));
399
Michalis Spyroubcedf512018-03-22 14:55:08 +0000400 if(lstm_params.has_peephole_opt())
401 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100402 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
403 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000404 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100405 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000406
407 // Validate input gate
408 if(!lstm_params.has_cifg_opt())
409 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100410 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
411 lstm_params.recurrent_to_input_weights(),
412 lstm_params.input_gate_bias());
Georgios Pinitas42447c12018-07-16 17:01:20 +0100413 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
414 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
Georgios Pinitas42447c12018-07-16 17:01:20 +0100415 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100416
John Kesapidescafec8f2019-02-19 15:53:59 +0000417 std::vector<const ITensorInfo *> lstm_weights;
418 lstm_weights.emplace_back(lstm_params.input_to_input_weights());
419 lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
420 TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(lstm_weights);
421 TensorInfo lstm_gate_concat = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
422 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), &lstm_gate_concat));
423
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100424 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), &input_gate));
John Kesapidescafec8f2019-02-19 15:53:59 +0000425
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100426 if(lstm_params.has_peephole_opt())
427 {
428 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
429 ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
430 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
431 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
432 }
433 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000434 }
435 else
436 {
giuros01164a2722018-11-20 18:34:46 +0000437 ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::SUB, &forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000438 }
439
440 // Validate cell state
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100441 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, cell_bias, &cell_state_tmp));
442 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
443 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
444 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info));
445 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
446 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
447 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000448 if(cell_threshold != 0.f)
449 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100450 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
451 cell_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000452 }
453
John Kesapidescafec8f2019-02-19 15:53:59 +0000454 std::vector<const ITensorInfo *> in_out_weights;
455 in_out_weights.emplace_back(input_to_output_weights);
456 in_out_weights.emplace_back(recurrent_to_output_weights);
457 TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(in_out_weights);
458 TensorInfo in_out_gate_concat = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
459 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenate2TensorsKernel::validate(input_to_output_weights, recurrent_to_output_weights, &in_out_gate_concat));
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100460 // Validate output gate tmp
461 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, output_gate_bias, &output_gate_tmp));
John Kesapidescafec8f2019-02-19 15:53:59 +0000462
Michalis Spyroubcedf512018-03-22 14:55:08 +0000463 if(lstm_params.has_peephole_opt())
464 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100465 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
466 RoundingPolicy::TO_NEAREST_EVEN));
467 ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000468 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100469 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000470
471 // Validate output state
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100472 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
473 ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000474 if(lstm_params.has_projection())
475 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100476 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000477 if(projection_threshold != 0.f)
478 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100479 ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(output_state_out, output_state_out,
480 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000481 }
482 }
483
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100484 // Validate copy kernel
485 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(&cell_state_tmp, cell_state_out));
486 ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
487
488 // Validate scratch concatenation
489 std::vector<ITensorInfo *> inputs_vector_info_raw;
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000490 if(!lstm_params.has_cifg_opt())
Michalis Spyroubcedf512018-03-22 14:55:08 +0000491 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100492 inputs_vector_info_raw.push_back(&input_gate);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000493 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100494 inputs_vector_info_raw.push_back(&cell_state_tmp);
495 inputs_vector_info_raw.push_back(&forget_gate);
496 inputs_vector_info_raw.push_back(&output_gate_tmp);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000497
498 ARM_COMPUTE_RETURN_ON_ERROR(CLWidthConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer));
499 return Status{};
500}
501
502void CLLSTMLayer::run()
503{
John Kesapidescafec8f2019-02-19 15:53:59 +0000504 prepare();
505
Michalis Spyroubcedf512018-03-22 14:55:08 +0000506 _memory_group.acquire();
507
John Kesapidescafec8f2019-02-19 15:53:59 +0000508 CLScheduler::get().enqueue(_concat_inputs_forget_gate);
509
Michalis Spyroubcedf512018-03-22 14:55:08 +0000510 _fully_connected_forget_gate.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000511
512 if(_run_peephole_opt)
513 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100514 CLScheduler::get().enqueue(_pixelwise_mul_forget_gate);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000515 _accum_forget_gate2.run();
516 }
517 CLScheduler::get().enqueue(_activation_forget_gate);
518
519 if(_run_cifg_opt)
520 {
521 _ones.map(true);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100522 if(_ones.info()->data_type() == DataType::F16)
523 {
524 std::fill_n(reinterpret_cast<half *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
525 }
526 else
527 {
528 std::fill_n(reinterpret_cast<float *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
529 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000530 _ones.unmap();
531 CLScheduler::get().enqueue(_subtract_input_gate);
532 }
533 else
534 {
535 _fully_connected_input_gate.run();
John Kesapidescafec8f2019-02-19 15:53:59 +0000536
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100537 if(_run_peephole_opt)
538 {
539 CLScheduler::get().enqueue(_pixelwise_mul_input_gate);
540 _accum_input_gate2.run();
541 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000542 CLScheduler::get().enqueue(_activation_input_gate);
543 }
544
545 _fully_connected_cell_state.run();
Georgios Pinitas42a31722018-07-09 14:35:32 +0100546 CLScheduler::get().enqueue(_transpose_cell_state);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000547 _gemm_cell_state1.run();
548 CLScheduler::get().enqueue(_accum_cell_state1);
549 CLScheduler::get().enqueue(_activation_cell_state);
550 CLScheduler::get().enqueue(_pixelwise_mul_cell_state1);
551 CLScheduler::get().enqueue(_pixelwise_mul_cell_state2);
552 CLScheduler::get().enqueue(_accum_cell_state2);
553
554 if(_perform_cell_clipping)
555 {
556 CLScheduler::get().enqueue(_cell_clip);
557 }
558
559 _fully_connected_output.run();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000560
561 if(_run_peephole_opt)
562 {
Georgios Pinitas42a31722018-07-09 14:35:32 +0100563 CLScheduler::get().enqueue(_pixelwise_mul_output_state1);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000564 _accum_output2.run();
565 }
566 CLScheduler::get().enqueue(_activation_output);
567
568 CLScheduler::get().enqueue(_activation_output_state);
Georgios Pinitas42a31722018-07-09 14:35:32 +0100569 CLScheduler::get().enqueue(_pixelwise_mul_output_state2);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000570
571 if(_has_projection_weights)
572 {
573 _fully_connected_output_state.run();
574 if(_perform_projection_clipping)
575 {
576 CLScheduler::get().enqueue(_projection_clip);
577 }
578 }
579
580 CLScheduler::get().enqueue(_copy_cell_state);
581 CLScheduler::get().enqueue(_copy_output);
582
583 _concat_scratch_buffer.run();
584
585 _memory_group.release();
giuros01164a2722018-11-20 18:34:46 +0000586}
John Kesapidescafec8f2019-02-19 15:53:59 +0000587
588void CLLSTMLayer::prepare()
589{
590 if(!_is_prepared)
591 {
592 CLScheduler::get().enqueue(_concat_weights_forget_gate);
593 if(!_run_cifg_opt)
594 {
595 CLScheduler::get().enqueue(_concat_weights_input_gate);
596 }
597 CLScheduler::get().enqueue(_concat_weights_output);
598 _is_prepared = true;
599 }
600}