blob: 4ff517509d61c98d871b19225deef85087666ec9 [file] [log] [blame]
Mike Kelly8ae17b32021-02-17 13:45:50 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Mike Kelly8ae17b32021-02-17 13:45:50 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Mike Kelly8ae17b32021-02-17 13:45:50 +000012
13#include <flatbuffers/flatbuffers.h>
Mike Kelly8ae17b32021-02-17 13:45:50 +000014#include <tensorflow/lite/kernels/register.h>
Mike Kelly8ae17b32021-02-17 13:45:50 +000015#include <tensorflow/lite/version.h>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010016
17#include <schema_generated.h>
Mike Kelly8ae17b32021-02-17 13:45:50 +000018
19#include <doctest/doctest.h>
20
21namespace
22{
23
24template <typename T>
25std::vector<char> CreateLstmTfLiteModel(tflite::TensorType tensorType,
26 int32_t batchSize,
27 int32_t inputSize,
28 int32_t outputSize,
29 int32_t numUnits,
30 bool hasInputToInputWeights,
31 const std::vector<T>& inputToInputWeights,
32 const std::vector<T>& inputToForgetWeights,
33 const std::vector<T>& inputToCellWeights,
34 const std::vector<T>& inputToOutputWeights,
35 bool hasRecurrentToInputWeights,
36 const std::vector<T>& recurrentToInputWeights,
37 const std::vector<T>& recurrentToForgetWeights,
38 const std::vector<T>& recurrentToCellWeights,
39 const std::vector<T>& recurrentToOutputWeights,
40 bool hasCellToInputWeights,
41 const std::vector<T>& cellToInputWeights,
42 bool hasCellToForgetWeights,
43 const std::vector<T>& cellToForgetWeights,
44 bool hasCellToOutputWeights,
45 const std::vector<T>& cellToOutputWeights,
46 bool hasInputGateBias,
47 const std::vector<T>& inputGateBias,
48 const std::vector<T>& forgetGateBias,
49 const std::vector<T>& cellBias,
50 const std::vector<T>& outputGateBias,
51 bool hasProjectionWeights,
52 const std::vector<T>& projectionWeights,
53 bool hasProjectionBias,
54 const std::vector<T>& projectionBias,
55 bool hasInputLayerNormWeights,
56 const std::vector<T>& inputLayerNormWeights,
57 bool hasForgetLayerNormWeights,
58 const std::vector<T>& forgetLayerNormWeights,
59 bool hasCellLayerNormWeights,
60 const std::vector<T>& cellLayerNormWeights,
61 bool hasOutputLayerNormWeights,
62 const std::vector<T>& outputLayerNormWeights,
63 tflite::ActivationFunctionType activationFunction,
64 float clippingThresCell,
65 float clippingThresProj,
66 float quantScale = 1.0f,
67 int quantOffset = 0,
68 float outputQuantScale = 2.0f,
69 int outputQuantOffset = 0)
70{
71
72 std::vector <int32_t> tensorInfo0 {};
73 std::vector <int32_t> tensorInfo4 {numUnits};
74 std::vector <int32_t> tensorInfo8 {numUnits, static_cast<int32_t>(2)};
75 std::vector <int32_t> tensorInfo16 {numUnits, static_cast<int32_t>(4)};
76
77 std::vector<int32_t> inputShape {batchSize , inputSize};
78 std::vector<int32_t> outputShape {batchSize , outputSize};
79
80 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
81 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
82
83 std::vector<int> operatorInputs;
84 using namespace tflite;
85 flatbuffers::FlatBufferBuilder flatBufferBuilder;
86 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
87 std::vector<flatbuffers::Offset<Tensor>> tensors;
88
89 auto quantizationParameters =
90 CreateQuantizationParameters(flatBufferBuilder,
91 0,
92 0,
93 flatBufferBuilder.CreateVector<float>({ quantScale }),
94 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
95
96 auto outputQuantizationParameters =
97 CreateQuantizationParameters(flatBufferBuilder,
98 0,
99 0,
100 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
101 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
102
Ryan OShea238ecd92023-03-07 11:44:23 +0000103 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000104 tensors.push_back(CreateTensor(flatBufferBuilder,
105 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
106 inputShape.size()),
107 tensorType,
108 buffers.size() - 1,
109 flatBufferBuilder.CreateString("input_0"),
110 quantizationParameters));
111 operatorInputs.push_back(buffers.size() - 1);
112
113 if (hasInputToInputWeights)
114 {
115 buffers.push_back(
116 CreateBuffer(flatBufferBuilder,
117 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
118 sizeof(T) * inputToInputWeights.size())));
119 tensors.push_back(CreateTensor(flatBufferBuilder,
120 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
121 tensorInfo8.size()),
122 tensorType,
123 buffers.size() - 1,
124 flatBufferBuilder.CreateString("inputToInputWeights"),
125 outputQuantizationParameters));
126 operatorInputs.push_back(buffers.size() - 1);
127 }
128 else
129 {
130 operatorInputs.push_back(kTfLiteOptionalTensor);
131 }
132
133 buffers.push_back(
134 CreateBuffer(flatBufferBuilder,
135 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
136 sizeof(T) * inputToForgetWeights.size())));
137 tensors.push_back(CreateTensor(flatBufferBuilder,
138 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
139 tensorInfo8.size()),
140 tensorType,
141 buffers.size() - 1,
142 flatBufferBuilder.CreateString("inputToForgetWeights"),
143 outputQuantizationParameters));
144 operatorInputs.push_back(buffers.size() - 1);
145
146 buffers.push_back(
147 CreateBuffer(flatBufferBuilder,
148 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
149 sizeof(T) * inputToCellWeights.size())));
150 tensors.push_back(CreateTensor(flatBufferBuilder,
151 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
152 tensorInfo8.size()),
153 tensorType,
154 buffers.size() - 1,
155 flatBufferBuilder.CreateString("inputToCellWeights"),
156 outputQuantizationParameters));
157 operatorInputs.push_back(buffers.size() - 1);
158
159 buffers.push_back(
160 CreateBuffer(flatBufferBuilder,
161 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
162 sizeof(T) * inputToOutputWeights.size())));
163 tensors.push_back(CreateTensor(flatBufferBuilder,
164 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
165 tensorInfo8.size()),
166 tensorType,
167 buffers.size() - 1,
168 flatBufferBuilder.CreateString("inputToOutputWeights"),
169 outputQuantizationParameters));
170 operatorInputs.push_back(buffers.size() - 1);
171
172 if (hasRecurrentToInputWeights)
173 {
174 buffers.push_back(CreateBuffer(
175 flatBufferBuilder,
176 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
177 sizeof(T) * recurrentToInputWeights.size())));
178 tensors.push_back(CreateTensor(flatBufferBuilder,
179 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
180 tensorInfo16.size()),
181 tensorType,
182 buffers.size() - 1,
183 flatBufferBuilder.CreateString("recurrentToInputWeights"),
184 outputQuantizationParameters));
185 operatorInputs.push_back(buffers.size() - 1);
186 }
187 else
188 {
189 operatorInputs.push_back(kTfLiteOptionalTensor);
190 }
191
192 buffers.push_back(
193 CreateBuffer(flatBufferBuilder,
194 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
195 sizeof(T) * recurrentToForgetWeights.size())));
196 tensors.push_back(CreateTensor(flatBufferBuilder,
197 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
198 tensorInfo16.size()),
199 tensorType,
200 buffers.size() - 1,
201 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
202 outputQuantizationParameters));
203 operatorInputs.push_back(buffers.size() - 1);
204
205 buffers.push_back(
206 CreateBuffer(flatBufferBuilder,
207 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
208 sizeof(T) * recurrentToCellWeights.size())));
209 tensors.push_back(CreateTensor(flatBufferBuilder,
210 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
211 tensorInfo16.size()),
212 tensorType,
213 buffers.size() - 1,
214 flatBufferBuilder.CreateString("recurrentToCellWeights"),
215 outputQuantizationParameters));
216 operatorInputs.push_back(buffers.size() - 1);
217
218 buffers.push_back(
219 CreateBuffer(flatBufferBuilder,
220 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
221 sizeof(T) * recurrentToOutputWeights.size())));
222 tensors.push_back(CreateTensor(flatBufferBuilder,
223 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
224 tensorInfo16.size()),
225 tensorType,
226 buffers.size() - 1 ,
227 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
228 outputQuantizationParameters));
229 operatorInputs.push_back(buffers.size() - 1);
230
231 if (hasCellToInputWeights)
232 {
233 buffers.push_back(
234 CreateBuffer(flatBufferBuilder,
235 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
236 sizeof(T) * cellToInputWeights.size())));
237 tensors.push_back(CreateTensor(flatBufferBuilder,
238 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
239 tensorInfo4.size()),
240 tensorType,
241 buffers.size() - 1,
242 flatBufferBuilder.CreateString("cellToInputWeights"),
243 outputQuantizationParameters));
244 operatorInputs.push_back(buffers.size() - 1);
245 }
246 else
247 {
248 operatorInputs.push_back(kTfLiteOptionalTensor);
249 }
250
251 if (hasCellToForgetWeights)
252 {
253 buffers.push_back(
254 CreateBuffer(flatBufferBuilder,
255 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
256 sizeof(T) * cellToForgetWeights.size())));
257 tensors.push_back(CreateTensor(flatBufferBuilder,
258 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
259 tensorInfo4.size()),
260 tensorType,
261 buffers.size() - 1,
262 flatBufferBuilder.CreateString("cellToForgetWeights"),
263 outputQuantizationParameters));
264 operatorInputs.push_back(buffers.size() - 1);
265 }
266 else
267 {
268 operatorInputs.push_back(kTfLiteOptionalTensor);
269 }
270
271 if (hasCellToOutputWeights)
272 {
273 buffers.push_back(
274 CreateBuffer(flatBufferBuilder,
275 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
276 sizeof(T) * cellToOutputWeights.size())));
277 tensors.push_back(CreateTensor(flatBufferBuilder,
278 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
279 tensorInfo4.size()),
280 tensorType,
281 buffers.size() - 1,
282 flatBufferBuilder.CreateString("cellToOutputWeights"),
283 outputQuantizationParameters));
284 operatorInputs.push_back(buffers.size() - 1);
285 }
286 else
287 {
288 operatorInputs.push_back(kTfLiteOptionalTensor);
289 }
290
291 if (hasInputGateBias)
292 {
293 buffers.push_back(
294 CreateBuffer(flatBufferBuilder,
295 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
296 sizeof(T) * inputGateBias.size())));
297 tensors.push_back(CreateTensor(flatBufferBuilder,
298 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
299 tensorInfo4.size()),
300 tensorType,
301 buffers.size() - 1,
302 flatBufferBuilder.CreateString("inputGateBias"),
303 outputQuantizationParameters));
304 operatorInputs.push_back(buffers.size() - 1);
305 }
306 else
307 {
308 operatorInputs.push_back(kTfLiteOptionalTensor);
309 }
310
311 buffers.push_back(
312 CreateBuffer(flatBufferBuilder,
313 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
314 sizeof(T) * forgetGateBias.size())));
315 tensors.push_back(CreateTensor(flatBufferBuilder,
316 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
317 tensorInfo4.size()),
318 tensorType,
319 buffers.size() - 1,
320 flatBufferBuilder.CreateString("forgetGateBias"),
321 outputQuantizationParameters));
322 operatorInputs.push_back(buffers.size() - 1);
323
324 buffers.push_back(
325 CreateBuffer(flatBufferBuilder,
326 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
327 sizeof(T) * cellBias.size())));
328 tensors.push_back(CreateTensor(flatBufferBuilder,
329 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
330 tensorInfo4.size()),
331 tensorType,
332 buffers.size() - 1,
333 flatBufferBuilder.CreateString("cellBias"),
334 outputQuantizationParameters));
335 operatorInputs.push_back(buffers.size() - 1);
336
337 buffers.push_back(
338 CreateBuffer(flatBufferBuilder,
339 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
340 sizeof(T) * outputGateBias.size())));
341 tensors.push_back(CreateTensor(flatBufferBuilder,
342 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
343 tensorInfo4.size()),
344 tensorType,
345 buffers.size() - 1,
346 flatBufferBuilder.CreateString("outputGateBias"),
347 outputQuantizationParameters));
348 operatorInputs.push_back(buffers.size() - 1);
349
350 if (hasProjectionWeights)
351 {
352 buffers.push_back(
353 CreateBuffer(flatBufferBuilder,
354 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
355 sizeof(T) * projectionWeights.size())));
356 tensors.push_back(CreateTensor(flatBufferBuilder,
357 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
358 tensorInfo4.size()),
359 tensorType,
360 buffers.size() - 1,
361 flatBufferBuilder.CreateString("outputGateBias"),
362 outputQuantizationParameters));
363 operatorInputs.push_back(buffers.size() - 1);
364 }
365 else
366 {
367 operatorInputs.push_back(kTfLiteOptionalTensor);
368 }
369
370 if (hasProjectionBias)
371 {
372 buffers.push_back(
373 CreateBuffer(flatBufferBuilder,
374 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
375 sizeof(T) * projectionBias.size())));
376 tensors.push_back(CreateTensor(flatBufferBuilder,
377 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
378 tensorInfo4.size()),
379 tensorType,
380 buffers.size() - 1,
381 flatBufferBuilder.CreateString("projectionBias"),
382 outputQuantizationParameters));
383 operatorInputs.push_back(buffers.size() - 1);
384 }
385 else
386 {
387 operatorInputs.push_back(kTfLiteOptionalTensor);
388 }
389
Ryan OShea238ecd92023-03-07 11:44:23 +0000390 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000391 tensors.push_back(CreateTensor(flatBufferBuilder,
392 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
393 outputStateInDimensions.size()),
394 tensorType,
395 buffers.size() - 1,
396 flatBufferBuilder.CreateString("outputStateInInfo"),
397 outputQuantizationParameters,
398 true));
399 operatorInputs.push_back(buffers.size() - 1);
400
Ryan OShea238ecd92023-03-07 11:44:23 +0000401 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000402 tensors.push_back(CreateTensor(flatBufferBuilder,
403 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
404 cellStateInDimensions.size()),
405 tensorType,
406 buffers.size() - 1,
407 flatBufferBuilder.CreateString("cellStateInInfo"),
408 outputQuantizationParameters,
409 true));
410 operatorInputs.push_back(buffers.size() - 1);
411
412 if (hasInputLayerNormWeights)
413 {
414 buffers.push_back(
415 CreateBuffer(flatBufferBuilder,
416 flatBufferBuilder.CreateVector(
417 reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
418 sizeof(T) * inputLayerNormWeights.size())));
419 tensors.push_back(CreateTensor(flatBufferBuilder,
420 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
421 tensorInfo4.size()),
422 tensorType,
423 buffers.size() - 1,
424 flatBufferBuilder.CreateString("inputLayerNormWeights"),
425 outputQuantizationParameters));
426 operatorInputs.push_back(buffers.size() - 1);
427 }
428 else
429 {
430 operatorInputs.push_back(kTfLiteOptionalTensor);
431 }
432
433 if (hasForgetLayerNormWeights)
434 {
435 buffers.push_back(
436 CreateBuffer(flatBufferBuilder,
437 flatBufferBuilder.CreateVector(
438 reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
439 sizeof(T) * forgetLayerNormWeights.size())));
440 tensors.push_back(CreateTensor(flatBufferBuilder,
441 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
442 tensorInfo4.size()),
443 tensorType,
444 buffers.size() - 1,
445 flatBufferBuilder.CreateString("forgetLayerNormWeights"),
446 outputQuantizationParameters));
447 operatorInputs.push_back(buffers.size() - 1);
448 }
449 else
450 {
451 operatorInputs.push_back(kTfLiteOptionalTensor);
452 }
453
454 if (hasCellLayerNormWeights)
455 {
456 buffers.push_back(
457 CreateBuffer(flatBufferBuilder,
458 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
459 sizeof(T) * cellLayerNormWeights.size())));
460 tensors.push_back(CreateTensor(flatBufferBuilder,
461 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
462 tensorInfo4.size()),
463 tensorType,
464 buffers.size() - 1,
465 flatBufferBuilder.CreateString("cellLayerNormWeights"),
466 outputQuantizationParameters));
467 operatorInputs.push_back(buffers.size() - 1);
468 }
469 else
470 {
471 operatorInputs.push_back(kTfLiteOptionalTensor);
472 }
473
474 if (hasOutputLayerNormWeights)
475 {
476 buffers.push_back(
477 CreateBuffer(flatBufferBuilder,
478 flatBufferBuilder.CreateVector(
479 reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
480 sizeof(T) * outputLayerNormWeights.size())));
481 tensors.push_back(CreateTensor(flatBufferBuilder,
482 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
483 tensorInfo4.size()),
484 tensorType,
485 buffers.size() - 1,
486 flatBufferBuilder.CreateString("outputLayerNormWeights"),
487 outputQuantizationParameters));
488 operatorInputs.push_back(buffers.size() - 1);
489 }
490 else
491 {
492 operatorInputs.push_back(kTfLiteOptionalTensor);
493 }
494 int outputBufferId = buffers.size();
Ryan OShea238ecd92023-03-07 11:44:23 +0000495 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000496 tensors.push_back(CreateTensor(flatBufferBuilder,
497 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
498 outputShape.size()),
499 tensorType,
500 outputBufferId,
501 flatBufferBuilder.CreateString("output"),
502 outputQuantizationParameters));
503 std::vector<int> operatorOutputs;
504 operatorOutputs.push_back(buffers.size() - 1);
505
506 // create operator
507 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_LSTMOptions;
508 flatbuffers::Offset<void> operatorBuiltinOptions =
509 CreateLSTMOptions(flatBufferBuilder,
510 activationFunction,
511 clippingThresCell,
512 clippingThresProj).Union();
513
514 flatbuffers::Offset <Operator> lstmOperator =
515 CreateOperator(flatBufferBuilder,
516 0,
517 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
518 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
519 operatorBuiltinOptionsType, operatorBuiltinOptions);
520
521 flatbuffers::Offset <SubGraph> subgraph =
522 CreateSubGraph(flatBufferBuilder,
523 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
524 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
525 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
526 flatBufferBuilder.CreateVector(&lstmOperator, 1));
527
528 flatbuffers::Offset <flatbuffers::String> modelDescription =
529 flatBufferBuilder.CreateString("ArmnnDelegate: LSTM Operator Model");
530 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
531 tflite::BuiltinOperator_LSTM);
532
533 flatbuffers::Offset <Model> flatbufferModel =
534 CreateModel(flatBufferBuilder,
535 TFLITE_SCHEMA_VERSION,
536 flatBufferBuilder.CreateVector(&operatorCode, 1),
537 flatBufferBuilder.CreateVector(&subgraph, 1),
538 modelDescription,
539 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
540
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100541 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000542
543 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
544 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
545}
546
547template <typename T>
548void LstmTestImpl(std::vector<armnn::BackendId>& backends,
549 tflite::TensorType tensorType,
550 int32_t batchSize,
551 int32_t inputSize,
552 int32_t outputSize,
553 int32_t numUnits,
554 bool hasInputToInputWeights,
555 const std::vector<T>& inputToInputWeights,
556 const std::vector<T>& inputToForgetWeights,
557 const std::vector<T>& inputToCellWeights,
558 const std::vector<T>& inputToOutputWeights,
559 bool hasRecurrentToInputWeights,
560 const std::vector<T>& recurrentToInputWeights,
561 const std::vector<T>& recurrentToForgetWeights,
562 const std::vector<T>& recurrentToCellWeights,
563 const std::vector<T>& recurrentToOutputWeights,
564 bool hasCellToInputWeights,
565 const std::vector<T>& cellToInputWeights,
566 bool hasCellToForgetWeights,
567 const std::vector<T>& cellToForgetWeights,
568 bool hasCellToOutputWeights,
569 const std::vector<T>& cellToOutputWeights,
570 bool hasInputGateBias,
571 const std::vector<T>& inputGateBias,
572 const std::vector<T>& forgetGateBias,
573 const std::vector<T>& cellBias,
574 const std::vector<T>& outputGateBias,
575 bool hasProjectionWeights,
576 const std::vector<T>& projectionWeights,
577 bool hasProjectionBias,
578 const std::vector<T>& projectionBias,
579 bool hasInputLayerNormWeights,
580 const std::vector<T>& inputLayerNormWeights,
581 bool hasForgetLayerNormWeights,
582 const std::vector<T>& forgetLayerNormWeights,
583 bool hasCellLayerNormWeights,
584 const std::vector<T>& cellLayerNormWeights,
585 bool hasOutputLayerNormWeights,
586 const std::vector<T>& outputLayerNormWeights,
587 std::vector<T>& inputValues,
588 std::vector<T>& expectedOutputValues,
589 tflite::ActivationFunctionType activationFunction,
590 float clippingThresCell,
591 float clippingThresProj)
592{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100593 using namespace delegateTestInterpreter;
Mike Kelly8ae17b32021-02-17 13:45:50 +0000594
595 std::vector<char> modelBuffer = CreateLstmTfLiteModel(tensorType,
596 batchSize,
597 inputSize,
598 outputSize,
599 numUnits,
600 hasInputToInputWeights,
601 inputToInputWeights,
602 inputToForgetWeights,
603 inputToCellWeights,
604 inputToOutputWeights,
605 hasRecurrentToInputWeights,
606 recurrentToInputWeights,
607 recurrentToForgetWeights,
608 recurrentToCellWeights,
609 recurrentToOutputWeights,
610 hasCellToInputWeights,
611 cellToInputWeights,
612 hasCellToForgetWeights,
613 cellToForgetWeights,
614 hasCellToOutputWeights,
615 cellToOutputWeights,
616 hasInputGateBias,
617 inputGateBias,
618 forgetGateBias,
619 cellBias,
620 outputGateBias,
621 hasProjectionWeights,
622 projectionWeights,
623 hasProjectionBias,
624 projectionBias,
625 hasInputLayerNormWeights,
626 inputLayerNormWeights,
627 hasForgetLayerNormWeights,
628 forgetLayerNormWeights,
629 hasCellLayerNormWeights,
630 cellLayerNormWeights,
631 hasOutputLayerNormWeights,
632 outputLayerNormWeights,
633 activationFunction,
634 clippingThresCell,
635 clippingThresProj);
636
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100637 std::vector<int32_t> expectedOutputShape {batchSize , outputSize};
Mike Kelly8ae17b32021-02-17 13:45:50 +0000638
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100639 // Setup interpreter with just TFLite Runtime.
640 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
641 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
642 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
643 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
644 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
645 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000646
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100647 // Setup interpreter with Arm NN Delegate applied.
648 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
649 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
650 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
651 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
652 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
653 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000654
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100655 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
656 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000657
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100658 tfLiteInterpreter.Cleanup();
659 armnnInterpreter.Cleanup();
Mike Kelly8ae17b32021-02-17 13:45:50 +0000660}
661
662} // anonymous namespace