blob: a32e9adfe5a1d92f53cadb3bbec7a7d2b5c787e2 [file] [log] [blame]
Michalis Spyroubcedf512018-03-22 14:55:08 +00001/*
Matthew Bentham945b8da2023-07-12 11:54:59 +00002 * Copyright (c) 2018-2023 Arm Limited.
Michalis Spyroubcedf512018-03-22 14:55:08 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE
25#define ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE
26
27#include "tests/Globals.h"
28#include "tests/framework/Asserts.h"
29#include "tests/framework/Fixture.h"
30#include "tests/validation/reference/ActivationLayer.h"
Georgios Pinitascbf39c62018-09-10 15:07:45 +010031#include "tests/validation/reference/ArithmeticOperations.h"
Pablo Tello3dd5b682019-03-04 14:14:02 +000032#include "tests/validation/reference/ConcatenateLayer.h"
Michalis Spyroubcedf512018-03-22 14:55:08 +000033#include "tests/validation/reference/FullyConnectedLayer.h"
34#include "tests/validation/reference/GEMM.h"
Michele Di Giorgio39438b42019-06-04 12:41:45 +010035#include "tests/validation/reference/MeanStdDevNormalizationLayer.h"
Michalis Spyroubcedf512018-03-22 14:55:08 +000036#include "tests/validation/reference/PixelWiseMultiplication.h"
37#include "tests/validation/reference/Transpose.h"
38
39namespace arm_compute
40{
41namespace test
42{
43namespace validation
44{
45template <typename TensorType, typename AccessorType, typename FunctionType, typename FunctionParams, typename T>
46class LSTMLayerValidationFixture : public framework::Fixture
47{
48public:
Michalis Spyroubcedf512018-03-22 14:55:08 +000049 void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape,
Michele Di Giorgio39438b42019-06-04 12:41:45 +010050 TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt,
51 bool use_layer_norm)
Michalis Spyroubcedf512018-03-22 14:55:08 +000052 {
53 _target = compute_target(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
Michele Di Giorgio39438b42019-06-04 12:41:45 +010054 data_type, projection_opt, peephole_opt, use_layer_norm);
Michalis Spyroubcedf512018-03-22 14:55:08 +000055 _reference = compute_reference(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
Michele Di Giorgio39438b42019-06-04 12:41:45 +010056 data_type, projection_opt, peephole_opt, use_layer_norm);
Michalis Spyroubcedf512018-03-22 14:55:08 +000057 }
58
59protected:
60 template <typename U>
61 void fill(U &&tensor, int i)
62 {
Giorgio Arena4bdd1772020-12-17 16:47:07 +000063 static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
Giorgio Arena33b103b2021-01-08 10:37:15 +000064 using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
Giorgio Arena4bdd1772020-12-17 16:47:07 +000065
66 DistributionType distribution{ T(-1.0f), T(1.0f) };
Michalis Spyroubcedf512018-03-22 14:55:08 +000067 library->fill(tensor, distribution, i);
68 }
69 template <typename U>
70 void fill_custom_val(U &&tensor, float num, int i)
71 {
Giorgio Arena4bdd1772020-12-17 16:47:07 +000072 static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
Giorgio Arena33b103b2021-01-08 10:37:15 +000073 using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
Giorgio Arena4bdd1772020-12-17 16:47:07 +000074
75 DistributionType distribution{ T(num), T(num) };
Michalis Spyroubcedf512018-03-22 14:55:08 +000076 library->fill(tensor, distribution, i);
77 }
78 TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
79 const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
Michele Di Giorgio39438b42019-06-04 12:41:45 +010080 float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
Michalis Spyroubcedf512018-03-22 14:55:08 +000081 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010082 const unsigned int num_cells = input_weights_shape.y();
83 const unsigned int num_outputs = recurrent_weights_shape.x();
Michalis Spyroubcedf512018-03-22 14:55:08 +000084
85 // Create tensors
86 TensorType input = create_tensor<TensorType>(input_shape, data_type);
87 TensorType input_to_forget_w = create_tensor<TensorType>(input_weights_shape, data_type);
88 TensorType input_to_cell_w = create_tensor<TensorType>(input_weights_shape, data_type);
89 TensorType input_to_output_w = create_tensor<TensorType>(input_weights_shape, data_type);
90 TensorType recurrent_to_forget_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
91 TensorType recurrent_to_cell_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
92 TensorType recurrent_to_output_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
93 TensorType forget_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
94 TensorType cell_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
95 TensorType output_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010096 TensorType output_state_in = create_tensor<TensorType>(output_shape, data_type);
97 TensorType cell_state_in = create_tensor<TensorType>(output_cell_shape, data_type);
Michalis Spyroubcedf512018-03-22 14:55:08 +000098 TensorType scratch = create_tensor<TensorType>(scratch_shape, data_type);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +010099 TensorType output_state_out = create_tensor<TensorType>(output_shape, data_type);
100 TensorType cell_state_out = create_tensor<TensorType>(output_cell_shape, data_type);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000101 TensorType output = create_tensor<TensorType>(output_shape, data_type);
102 TensorType input_to_input_w;
103 TensorType recurrent_to_input_w;
104 TensorType cell_to_input_w;
105 TensorType cell_to_forget_w;
106 TensorType input_gate_bias;
107 TensorType cell_to_output_w;
108 TensorType projection_w;
109 TensorType projection_bias;
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100110 TensorType input_layer_norm_w;
111 TensorType forget_layer_norm_w;
112 TensorType cell_layer_norm_w;
113 TensorType output_layer_norm_w;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000114
Georgios Pinitas0cc37c32018-11-14 15:54:26 +0000115 bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000116
117 FunctionParams lstm_params;
118
119 if(!cifg_opt)
120 {
121 input_to_input_w = create_tensor<TensorType>(input_weights_shape, data_type);
122 recurrent_to_input_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100123 if(peephole_opt)
124 {
125 cell_to_input_w = create_tensor<TensorType>(cell_bias_shape, data_type);
126 }
127 input_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000128 lstm_params.set_cifg_params(&input_to_input_w, &recurrent_to_input_w, &cell_to_input_w, &input_gate_bias);
129 }
130
131 if(peephole_opt)
132 {
Michalis Spyroubcedf512018-03-22 14:55:08 +0000133 cell_to_forget_w = create_tensor<TensorType>(cell_bias_shape, data_type);
134 cell_to_output_w = create_tensor<TensorType>(cell_bias_shape, data_type);
Michalis Spyrou09daf4d2018-06-28 17:07:22 +0100135 lstm_params.set_peephole_params(&cell_to_forget_w, &cell_to_output_w);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000136 }
137
138 if(projection_opt)
139 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100140 projection_w = create_tensor<TensorType>(TensorShape(num_cells, num_outputs), data_type);
141 projection_bias = create_tensor<TensorType>(TensorShape(num_outputs), data_type);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000142 lstm_params.set_projection_params(&projection_w, &projection_bias);
143 }
144
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100145 if(use_layer_norm)
146 {
147 forget_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
148 cell_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
149 output_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
150 if(!cifg_opt)
151 {
152 input_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
153 lstm_params.set_layer_normalization_params(&input_layer_norm_w, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
154 }
155 else
156 {
157 lstm_params.set_layer_normalization_params(nullptr, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
158 }
159 }
160
Michalis Spyroubcedf512018-03-22 14:55:08 +0000161 // Create and configure function
162 FunctionType lstm;
163 lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w,
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100164 &recurrent_to_cell_w, &recurrent_to_output_w, &forget_gate_bias, &cell_bias, &output_gate_bias,
165 &output_state_in, &cell_state_in,
166 &scratch, &output_state_out, &cell_state_out, &output,
167 lstm_params, info, cell_threshold, projection_threshold);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000168
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100169 ARM_COMPUTE_ASSERT(input.info()->is_resizable());
170 ARM_COMPUTE_ASSERT(input_to_forget_w.info()->is_resizable());
171 ARM_COMPUTE_ASSERT(input_to_cell_w.info()->is_resizable());
172 ARM_COMPUTE_ASSERT(input_to_output_w.info()->is_resizable());
173 ARM_COMPUTE_ASSERT(recurrent_to_forget_w.info()->is_resizable());
174 ARM_COMPUTE_ASSERT(recurrent_to_cell_w.info()->is_resizable());
175 ARM_COMPUTE_ASSERT(recurrent_to_output_w.info()->is_resizable());
176 ARM_COMPUTE_ASSERT(forget_gate_bias.info()->is_resizable());
177 ARM_COMPUTE_ASSERT(cell_bias.info()->is_resizable());
178 ARM_COMPUTE_ASSERT(output_gate_bias.info()->is_resizable());
179 ARM_COMPUTE_ASSERT(output_state_in.info()->is_resizable());
180 ARM_COMPUTE_ASSERT(cell_state_in.info()->is_resizable());
181 ARM_COMPUTE_ASSERT(scratch.info()->is_resizable());
182 ARM_COMPUTE_ASSERT(output_state_out.info()->is_resizable());
183 ARM_COMPUTE_ASSERT(cell_state_out.info()->is_resizable());
184 ARM_COMPUTE_ASSERT(output.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000185
186 // Allocate tensors
187 input.allocator()->allocate();
188 input_to_forget_w.allocator()->allocate();
189 input_to_cell_w.allocator()->allocate();
190 input_to_output_w.allocator()->allocate();
191 recurrent_to_forget_w.allocator()->allocate();
192 recurrent_to_cell_w.allocator()->allocate();
193 recurrent_to_output_w.allocator()->allocate();
194 forget_gate_bias.allocator()->allocate();
195 cell_bias.allocator()->allocate();
196 output_gate_bias.allocator()->allocate();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100197 output_state_in.allocator()->allocate();
198 cell_state_in.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000199 scratch.allocator()->allocate();
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100200 output_state_out.allocator()->allocate();
201 cell_state_out.allocator()->allocate();
Michalis Spyroubcedf512018-03-22 14:55:08 +0000202 output.allocator()->allocate();
203
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100204 ARM_COMPUTE_ASSERT(!input.info()->is_resizable());
205 ARM_COMPUTE_ASSERT(!input_to_forget_w.info()->is_resizable());
206 ARM_COMPUTE_ASSERT(!input_to_cell_w.info()->is_resizable());
207 ARM_COMPUTE_ASSERT(!input_to_output_w.info()->is_resizable());
208 ARM_COMPUTE_ASSERT(!recurrent_to_forget_w.info()->is_resizable());
209 ARM_COMPUTE_ASSERT(!recurrent_to_cell_w.info()->is_resizable());
210 ARM_COMPUTE_ASSERT(!recurrent_to_output_w.info()->is_resizable());
211 ARM_COMPUTE_ASSERT(!forget_gate_bias.info()->is_resizable());
212 ARM_COMPUTE_ASSERT(!cell_bias.info()->is_resizable());
213 ARM_COMPUTE_ASSERT(!output_gate_bias.info()->is_resizable());
214 ARM_COMPUTE_ASSERT(!output_state_in.info()->is_resizable());
215 ARM_COMPUTE_ASSERT(!cell_state_in.info()->is_resizable());
216 ARM_COMPUTE_ASSERT(!scratch.info()->is_resizable());
217 ARM_COMPUTE_ASSERT(!output_state_out.info()->is_resizable());
218 ARM_COMPUTE_ASSERT(!cell_state_out.info()->is_resizable());
219 ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000220
221 // Fill tensors
222 fill(AccessorType(input), 0);
223 fill(AccessorType(input_to_forget_w), 1);
224 fill(AccessorType(input_to_cell_w), 2);
225 fill(AccessorType(input_to_output_w), 3);
226 fill(AccessorType(recurrent_to_forget_w), 4);
227 fill(AccessorType(recurrent_to_cell_w), 5);
228 fill(AccessorType(recurrent_to_output_w), 6);
229 fill(AccessorType(forget_gate_bias), 7);
230 fill(AccessorType(cell_bias), 8);
231 fill(AccessorType(output_gate_bias), 9);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100232 fill(AccessorType(output_state_in), 10);
233 fill(AccessorType(cell_state_in), 11);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000234 fill(AccessorType(scratch), 12);
235
236 if(!cifg_opt)
237 {
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100238 ARM_COMPUTE_ASSERT(input_to_input_w.info()->is_resizable());
239 ARM_COMPUTE_ASSERT(recurrent_to_input_w.info()->is_resizable());
240 ARM_COMPUTE_ASSERT(cell_to_input_w.info()->is_resizable());
241 ARM_COMPUTE_ASSERT(input_gate_bias.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000242 input_to_input_w.allocator()->allocate();
243 recurrent_to_input_w.allocator()->allocate();
244 cell_to_input_w.allocator()->allocate();
245 input_gate_bias.allocator()->allocate();
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100246 ARM_COMPUTE_ASSERT(!input_to_input_w.info()->is_resizable());
247 ARM_COMPUTE_ASSERT(!recurrent_to_input_w.info()->is_resizable());
248 ARM_COMPUTE_ASSERT(!cell_to_input_w.info()->is_resizable());
249 ARM_COMPUTE_ASSERT(!input_gate_bias.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000250 fill(AccessorType(input_to_input_w), 13);
251 fill(AccessorType(recurrent_to_input_w), 14);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100252 if(peephole_opt)
253 {
254 fill(AccessorType(cell_to_input_w), 15);
255 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000256 fill(AccessorType(recurrent_to_input_w), 16);
257 fill(AccessorType(input_gate_bias), 17);
258 }
259
260 if(peephole_opt)
261 {
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100262 ARM_COMPUTE_ASSERT(cell_to_forget_w.info()->is_resizable());
263 ARM_COMPUTE_ASSERT(cell_to_output_w.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000264 cell_to_forget_w.allocator()->allocate();
265 cell_to_output_w.allocator()->allocate();
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100266 ARM_COMPUTE_ASSERT(!cell_to_forget_w.info()->is_resizable());
267 ARM_COMPUTE_ASSERT(!cell_to_output_w.info()->is_resizable());
Georgios Pinitas4f859822019-02-06 18:08:04 +0000268 fill(AccessorType(cell_to_forget_w), 18);
269 fill(AccessorType(cell_to_output_w), 19);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000270 }
271
272 if(projection_opt)
273 {
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100274 ARM_COMPUTE_ASSERT(projection_w.info()->is_resizable());
275 ARM_COMPUTE_ASSERT(projection_bias.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000276
277 projection_w.allocator()->allocate();
278 projection_bias.allocator()->allocate();
279
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100280 ARM_COMPUTE_ASSERT(!projection_w.info()->is_resizable());
281 ARM_COMPUTE_ASSERT(!projection_bias.info()->is_resizable());
Michalis Spyroubcedf512018-03-22 14:55:08 +0000282
Georgios Pinitas4f859822019-02-06 18:08:04 +0000283 fill(AccessorType(projection_w), 20);
284 fill(AccessorType(projection_bias), 21);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000285 }
286
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100287 if(use_layer_norm)
288 {
289 if(!cifg_opt)
290 {
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100291 ARM_COMPUTE_ASSERT(input_layer_norm_w.info()->is_resizable());
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100292
293 input_layer_norm_w.allocator()->allocate();
294
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100295 ARM_COMPUTE_ASSERT(!input_layer_norm_w.info()->is_resizable());
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100296
297 fill(AccessorType(input_layer_norm_w), 22);
298 }
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100299 ARM_COMPUTE_ASSERT(forget_layer_norm_w.info()->is_resizable());
300 ARM_COMPUTE_ASSERT(cell_layer_norm_w.info()->is_resizable());
301 ARM_COMPUTE_ASSERT(output_layer_norm_w.info()->is_resizable());
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100302
303 forget_layer_norm_w.allocator()->allocate();
304 cell_layer_norm_w.allocator()->allocate();
305 output_layer_norm_w.allocator()->allocate();
306
Michele Di Giorgio4fc10b32021-04-30 18:30:41 +0100307 ARM_COMPUTE_ASSERT(!forget_layer_norm_w.info()->is_resizable());
308 ARM_COMPUTE_ASSERT(!cell_layer_norm_w.info()->is_resizable());
309 ARM_COMPUTE_ASSERT(!output_layer_norm_w.info()->is_resizable());
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100310
311 fill(AccessorType(forget_layer_norm_w), 23);
312 fill(AccessorType(cell_layer_norm_w), 24);
313 fill(AccessorType(output_layer_norm_w), 25);
314 }
315
Michalis Spyroubcedf512018-03-22 14:55:08 +0000316 // Compute function
317 lstm.run();
318
Georgios Pinitas4f859822019-02-06 18:08:04 +0000319 _target_scratch = std::move(scratch);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000320 return output;
321 }
322
323 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
324 const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100325 float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
Michalis Spyroubcedf512018-03-22 14:55:08 +0000326 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100327 const unsigned int num_cells = input_weights_shape.y();
328 const unsigned int num_outputs = recurrent_weights_shape.x();
329
330 // Create projection weights shape
331 TensorShape projection_weights_shape(num_cells, num_outputs);
332
Michalis Spyroubcedf512018-03-22 14:55:08 +0000333 // Create projection bias shape
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100334 TensorShape projection_bias_shape(num_outputs);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000335
336 TensorShape gemm_shape{ 1, output_shape.y() };
337 SimpleTensor<T> gemm_out{ gemm_shape, data_type };
338
339 // Create reference
340 SimpleTensor<T> input{ input_shape, data_type };
341 SimpleTensor<T> input_to_input_w{ input_weights_shape, data_type };
342 SimpleTensor<T> input_to_forget_w{ input_weights_shape, data_type };
343 SimpleTensor<T> input_to_cell_w{ input_weights_shape, data_type };
344 SimpleTensor<T> input_to_output_w{ input_weights_shape, data_type };
345 SimpleTensor<T> recurrent_to_input_w{ recurrent_weights_shape, data_type };
346 SimpleTensor<T> recurrent_to_forget_w{ recurrent_weights_shape, data_type };
347 SimpleTensor<T> recurrent_to_cell_w{ recurrent_weights_shape, data_type };
348 SimpleTensor<T> recurrent_to_output_w{ recurrent_weights_shape, data_type };
349 SimpleTensor<T> cell_to_input_w{ cell_bias_shape, data_type };
350 SimpleTensor<T> cell_to_forget_w{ cell_bias_shape, data_type };
351 SimpleTensor<T> cell_to_output_w{ cell_bias_shape, data_type };
352 SimpleTensor<T> input_gate_bias{ cell_bias_shape, data_type };
353 SimpleTensor<T> forget_gate_bias{ cell_bias_shape, data_type };
354 SimpleTensor<T> cell_bias{ cell_bias_shape, data_type };
355 SimpleTensor<T> output_gate_bias{ cell_bias_shape, data_type };
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100356 SimpleTensor<T> projection_w{ projection_weights_shape, data_type };
Michalis Spyroubcedf512018-03-22 14:55:08 +0000357 SimpleTensor<T> projection_bias{ projection_bias_shape, data_type };
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100358 SimpleTensor<T> output_state_in{ output_shape, data_type };
359 SimpleTensor<T> cell_state_in{ output_cell_shape, data_type };
Michalis Spyroubcedf512018-03-22 14:55:08 +0000360 SimpleTensor<T> scratch{ scratch_shape, data_type };
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100361 SimpleTensor<T> output_state_out{ output_shape, data_type };
362 SimpleTensor<T> cell_state_out{ output_cell_shape, data_type };
Michalis Spyroubcedf512018-03-22 14:55:08 +0000363 SimpleTensor<T> output{ output_shape, data_type };
364
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100365 bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
366
Michalis Spyroubcedf512018-03-22 14:55:08 +0000367 // Fill reference
368 fill(input, 0);
369 fill(input_to_forget_w, 1);
370 fill(input_to_cell_w, 2);
371 fill(input_to_output_w, 3);
372 fill(recurrent_to_forget_w, 4);
373 fill(recurrent_to_cell_w, 5);
374 fill(recurrent_to_output_w, 6);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100375 if(use_layer_norm)
376 {
377 fill_custom_val(forget_gate_bias, 0.f, 7);
378 fill_custom_val(cell_bias, 0.f, 8);
379 fill_custom_val(output_gate_bias, 0.f, 9);
380 }
381 else
382 {
383 fill(forget_gate_bias, 7);
384 fill(cell_bias, 8);
385 fill(output_gate_bias, 9);
386 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100387 fill(output_state_in, 10);
388 fill(cell_state_in, 11);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000389 fill(scratch, 12);
390 fill(input_to_input_w, 13);
391 fill(recurrent_to_input_w, 14);
392 fill(cell_to_input_w, 15);
393 fill(recurrent_to_input_w, 16);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100394 if(!cifg_opt && use_layer_norm)
395 {
396 fill_custom_val(input_gate_bias, 0.f, 17);
397 }
398 else
399 {
400 fill(input_gate_bias, 17);
401 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000402 fill(cell_to_forget_w, 18);
403 fill(cell_to_output_w, 19);
404 fill(projection_w, 20);
405 fill(projection_bias, 21);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000406
Michalis Spyroubcedf512018-03-22 14:55:08 +0000407 // Compute forget_gate
408 SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape);
409 SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100410 SimpleTensor<T> gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100411 SimpleTensor<T> forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000412
413 if(peephole_opt)
414 {
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100415 SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100416 forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000417 }
418
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100419 if(use_layer_norm)
420 {
421 SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type };
422 fill(forget_layer_norm_w, 23);
423 forget_gate = reference::mean_std_normalization_layer(forget_gate);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100424 forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100425 fill(forget_gate_bias, 7);
426 forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE);
427 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000428 forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
429
430 // Compute input_gate
431 SimpleTensor<T> input_gate;
432 if(cifg_opt)
433 {
434 SimpleTensor<T> ones{ cell_bias_shape, data_type };
435 fill_custom_val(ones, 1.f, 0);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100436 input_gate = reference::arithmetic_operation<T>(reference::ArithmeticOperation::SUB, ones, forget_gate, data_type, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000437 }
438 else
439 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100440 SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
441 transposed_weights = reference::transpose(recurrent_to_input_w);
442 gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100443 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100444 if(peephole_opt)
445 {
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100446 SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100447 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100448 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100449 if(use_layer_norm)
450 {
451 SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type };
452 fill(input_layer_norm_w, 22);
453 input_gate = reference::mean_std_normalization_layer(input_gate);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100454 input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100455 fill(input_gate_bias, 17);
456 input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE);
457 }
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100458 input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000459 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000460 // Compute cell_state
461 SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
462 transposed_weights = reference::transpose(recurrent_to_cell_w);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100463 gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100464 SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100465 cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100466 if(use_layer_norm)
467 {
468 SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type };
469 fill(cell_layer_norm_w, 24);
470 cell_state_out = reference::mean_std_normalization_layer(cell_state_out);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100471 cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100472 fill(cell_bias, 8);
473 cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
474 }
Pablo Marquez Tello9454cf72022-02-16 11:15:58 +0000475 cell_state_out = reference::activation_layer(cell_state_out, info);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100476 cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100477 cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
Pablo Marquez Tello9454cf72022-02-16 11:15:58 +0000478
Michalis Spyroubcedf512018-03-22 14:55:08 +0000479 if(cell_threshold != 0.f)
480 {
Pablo Marquez Tello9454cf72022-02-16 11:15:58 +0000481 cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000482 }
483
484 // Compute output
485 SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape);
486 transposed_weights = reference::transpose(recurrent_to_output_w);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100487 gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100488 output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000489 if(peephole_opt)
490 {
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100491 pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Georgios Pinitascbf39c62018-09-10 15:07:45 +0100492 output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000493 }
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100494 if(use_layer_norm)
495 {
496 SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type };
497 fill(output_layer_norm_w, 25);
498 output = reference::mean_std_normalization_layer(output);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100499 output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Michele Di Giorgio39438b42019-06-04 12:41:45 +0100500 fill(output_gate_bias, 9);
501 output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE);
502 }
Michalis Spyroubcedf512018-03-22 14:55:08 +0000503 output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
504
505 // Compute output state
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100506 SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info);
Michele Di Giorgio9428a182020-03-30 14:10:20 +0100507 output_state_out = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000508
509 if(projection_opt)
510 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100511 SimpleTensor<T> fully_connected_projection = reference::fully_connected_layer(output_state_out, projection_w, projection_bias, output_cell_shape);
Michalis Spyroubcedf512018-03-22 14:55:08 +0000512 if(projection_threshold != 0.f)
513 {
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100514 output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
Michalis Spyroubcedf512018-03-22 14:55:08 +0000515 }
516 }
Georgios Pinitas4f859822019-02-06 18:08:04 +0000517 std::vector<SimpleTensor<T>> scratch_inputs;
518 if(!cifg_opt)
519 {
520 scratch_inputs.emplace_back(std::move(input_gate));
521 }
522 scratch_inputs.emplace_back(std::move(cell_state_out));
523 scratch_inputs.emplace_back(std::move(forget_gate));
524 scratch_inputs.emplace_back(std::move(output));
Pablo Tello3dd5b682019-03-04 14:14:02 +0000525 scratch = reference::concatenate_layer(scratch_inputs, scratch, Window::DimX);
Georgios Pinitas4f859822019-02-06 18:08:04 +0000526 _reference_scratch = std::move(scratch);
Georgios Pinitas8bc745d2018-07-18 19:51:24 +0100527 return output_state_out;
Michalis Spyroubcedf512018-03-22 14:55:08 +0000528 }
529
530 TensorType _target{};
Georgios Pinitas4f859822019-02-06 18:08:04 +0000531 TensorType _target_scratch{};
Michalis Spyroubcedf512018-03-22 14:55:08 +0000532 SimpleTensor<T> _reference{};
Georgios Pinitas4f859822019-02-06 18:08:04 +0000533 SimpleTensor<T> _reference_scratch{};
Michalis Spyroubcedf512018-03-22 14:55:08 +0000534};
535} // namespace validation
536} // namespace test
537} // namespace arm_compute
538#endif /* ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE */