blob: ce1efe0b4776872a7f904b2aad8b8800bb7c715e [file] [log] [blame]
Mike Kelly8ae17b32021-02-17 13:45:50 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Mike Kelly8ae17b32021-02-17 13:45:50 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Mike Kelly8ae17b32021-02-17 13:45:50 +000012
13#include <flatbuffers/flatbuffers.h>
Mike Kelly8ae17b32021-02-17 13:45:50 +000014#include <tensorflow/lite/kernels/register.h>
Mike Kelly8ae17b32021-02-17 13:45:50 +000015#include <tensorflow/lite/version.h>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010016
Mike Kelly8ae17b32021-02-17 13:45:50 +000017#include <doctest/doctest.h>
18
19namespace
20{
21
22template <typename T>
23std::vector<char> CreateLstmTfLiteModel(tflite::TensorType tensorType,
24 int32_t batchSize,
25 int32_t inputSize,
26 int32_t outputSize,
27 int32_t numUnits,
28 bool hasInputToInputWeights,
29 const std::vector<T>& inputToInputWeights,
30 const std::vector<T>& inputToForgetWeights,
31 const std::vector<T>& inputToCellWeights,
32 const std::vector<T>& inputToOutputWeights,
33 bool hasRecurrentToInputWeights,
34 const std::vector<T>& recurrentToInputWeights,
35 const std::vector<T>& recurrentToForgetWeights,
36 const std::vector<T>& recurrentToCellWeights,
37 const std::vector<T>& recurrentToOutputWeights,
38 bool hasCellToInputWeights,
39 const std::vector<T>& cellToInputWeights,
40 bool hasCellToForgetWeights,
41 const std::vector<T>& cellToForgetWeights,
42 bool hasCellToOutputWeights,
43 const std::vector<T>& cellToOutputWeights,
44 bool hasInputGateBias,
45 const std::vector<T>& inputGateBias,
46 const std::vector<T>& forgetGateBias,
47 const std::vector<T>& cellBias,
48 const std::vector<T>& outputGateBias,
49 bool hasProjectionWeights,
50 const std::vector<T>& projectionWeights,
51 bool hasProjectionBias,
52 const std::vector<T>& projectionBias,
53 bool hasInputLayerNormWeights,
54 const std::vector<T>& inputLayerNormWeights,
55 bool hasForgetLayerNormWeights,
56 const std::vector<T>& forgetLayerNormWeights,
57 bool hasCellLayerNormWeights,
58 const std::vector<T>& cellLayerNormWeights,
59 bool hasOutputLayerNormWeights,
60 const std::vector<T>& outputLayerNormWeights,
61 tflite::ActivationFunctionType activationFunction,
62 float clippingThresCell,
63 float clippingThresProj,
64 float quantScale = 1.0f,
65 int quantOffset = 0,
66 float outputQuantScale = 2.0f,
67 int outputQuantOffset = 0)
68{
69
70 std::vector <int32_t> tensorInfo0 {};
71 std::vector <int32_t> tensorInfo4 {numUnits};
72 std::vector <int32_t> tensorInfo8 {numUnits, static_cast<int32_t>(2)};
73 std::vector <int32_t> tensorInfo16 {numUnits, static_cast<int32_t>(4)};
74
75 std::vector<int32_t> inputShape {batchSize , inputSize};
76 std::vector<int32_t> outputShape {batchSize , outputSize};
77
78 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
79 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
80
81 std::vector<int> operatorInputs;
82 using namespace tflite;
83 flatbuffers::FlatBufferBuilder flatBufferBuilder;
84 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
85 std::vector<flatbuffers::Offset<Tensor>> tensors;
86
87 auto quantizationParameters =
88 CreateQuantizationParameters(flatBufferBuilder,
89 0,
90 0,
91 flatBufferBuilder.CreateVector<float>({ quantScale }),
92 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
93
94 auto outputQuantizationParameters =
95 CreateQuantizationParameters(flatBufferBuilder,
96 0,
97 0,
98 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
99 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
100
Ryan OShea238ecd92023-03-07 11:44:23 +0000101 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000102 tensors.push_back(CreateTensor(flatBufferBuilder,
103 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
104 inputShape.size()),
105 tensorType,
106 buffers.size() - 1,
107 flatBufferBuilder.CreateString("input_0"),
108 quantizationParameters));
109 operatorInputs.push_back(buffers.size() - 1);
110
111 if (hasInputToInputWeights)
112 {
113 buffers.push_back(
114 CreateBuffer(flatBufferBuilder,
115 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
116 sizeof(T) * inputToInputWeights.size())));
117 tensors.push_back(CreateTensor(flatBufferBuilder,
118 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
119 tensorInfo8.size()),
120 tensorType,
121 buffers.size() - 1,
122 flatBufferBuilder.CreateString("inputToInputWeights"),
123 outputQuantizationParameters));
124 operatorInputs.push_back(buffers.size() - 1);
125 }
126 else
127 {
128 operatorInputs.push_back(kTfLiteOptionalTensor);
129 }
130
131 buffers.push_back(
132 CreateBuffer(flatBufferBuilder,
133 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
134 sizeof(T) * inputToForgetWeights.size())));
135 tensors.push_back(CreateTensor(flatBufferBuilder,
136 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
137 tensorInfo8.size()),
138 tensorType,
139 buffers.size() - 1,
140 flatBufferBuilder.CreateString("inputToForgetWeights"),
141 outputQuantizationParameters));
142 operatorInputs.push_back(buffers.size() - 1);
143
144 buffers.push_back(
145 CreateBuffer(flatBufferBuilder,
146 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
147 sizeof(T) * inputToCellWeights.size())));
148 tensors.push_back(CreateTensor(flatBufferBuilder,
149 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
150 tensorInfo8.size()),
151 tensorType,
152 buffers.size() - 1,
153 flatBufferBuilder.CreateString("inputToCellWeights"),
154 outputQuantizationParameters));
155 operatorInputs.push_back(buffers.size() - 1);
156
157 buffers.push_back(
158 CreateBuffer(flatBufferBuilder,
159 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
160 sizeof(T) * inputToOutputWeights.size())));
161 tensors.push_back(CreateTensor(flatBufferBuilder,
162 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
163 tensorInfo8.size()),
164 tensorType,
165 buffers.size() - 1,
166 flatBufferBuilder.CreateString("inputToOutputWeights"),
167 outputQuantizationParameters));
168 operatorInputs.push_back(buffers.size() - 1);
169
170 if (hasRecurrentToInputWeights)
171 {
172 buffers.push_back(CreateBuffer(
173 flatBufferBuilder,
174 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
175 sizeof(T) * recurrentToInputWeights.size())));
176 tensors.push_back(CreateTensor(flatBufferBuilder,
177 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
178 tensorInfo16.size()),
179 tensorType,
180 buffers.size() - 1,
181 flatBufferBuilder.CreateString("recurrentToInputWeights"),
182 outputQuantizationParameters));
183 operatorInputs.push_back(buffers.size() - 1);
184 }
185 else
186 {
187 operatorInputs.push_back(kTfLiteOptionalTensor);
188 }
189
190 buffers.push_back(
191 CreateBuffer(flatBufferBuilder,
192 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
193 sizeof(T) * recurrentToForgetWeights.size())));
194 tensors.push_back(CreateTensor(flatBufferBuilder,
195 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
196 tensorInfo16.size()),
197 tensorType,
198 buffers.size() - 1,
199 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
200 outputQuantizationParameters));
201 operatorInputs.push_back(buffers.size() - 1);
202
203 buffers.push_back(
204 CreateBuffer(flatBufferBuilder,
205 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
206 sizeof(T) * recurrentToCellWeights.size())));
207 tensors.push_back(CreateTensor(flatBufferBuilder,
208 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
209 tensorInfo16.size()),
210 tensorType,
211 buffers.size() - 1,
212 flatBufferBuilder.CreateString("recurrentToCellWeights"),
213 outputQuantizationParameters));
214 operatorInputs.push_back(buffers.size() - 1);
215
216 buffers.push_back(
217 CreateBuffer(flatBufferBuilder,
218 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
219 sizeof(T) * recurrentToOutputWeights.size())));
220 tensors.push_back(CreateTensor(flatBufferBuilder,
221 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
222 tensorInfo16.size()),
223 tensorType,
224 buffers.size() - 1 ,
225 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
226 outputQuantizationParameters));
227 operatorInputs.push_back(buffers.size() - 1);
228
229 if (hasCellToInputWeights)
230 {
231 buffers.push_back(
232 CreateBuffer(flatBufferBuilder,
233 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
234 sizeof(T) * cellToInputWeights.size())));
235 tensors.push_back(CreateTensor(flatBufferBuilder,
236 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
237 tensorInfo4.size()),
238 tensorType,
239 buffers.size() - 1,
240 flatBufferBuilder.CreateString("cellToInputWeights"),
241 outputQuantizationParameters));
242 operatorInputs.push_back(buffers.size() - 1);
243 }
244 else
245 {
246 operatorInputs.push_back(kTfLiteOptionalTensor);
247 }
248
249 if (hasCellToForgetWeights)
250 {
251 buffers.push_back(
252 CreateBuffer(flatBufferBuilder,
253 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
254 sizeof(T) * cellToForgetWeights.size())));
255 tensors.push_back(CreateTensor(flatBufferBuilder,
256 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
257 tensorInfo4.size()),
258 tensorType,
259 buffers.size() - 1,
260 flatBufferBuilder.CreateString("cellToForgetWeights"),
261 outputQuantizationParameters));
262 operatorInputs.push_back(buffers.size() - 1);
263 }
264 else
265 {
266 operatorInputs.push_back(kTfLiteOptionalTensor);
267 }
268
269 if (hasCellToOutputWeights)
270 {
271 buffers.push_back(
272 CreateBuffer(flatBufferBuilder,
273 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
274 sizeof(T) * cellToOutputWeights.size())));
275 tensors.push_back(CreateTensor(flatBufferBuilder,
276 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
277 tensorInfo4.size()),
278 tensorType,
279 buffers.size() - 1,
280 flatBufferBuilder.CreateString("cellToOutputWeights"),
281 outputQuantizationParameters));
282 operatorInputs.push_back(buffers.size() - 1);
283 }
284 else
285 {
286 operatorInputs.push_back(kTfLiteOptionalTensor);
287 }
288
289 if (hasInputGateBias)
290 {
291 buffers.push_back(
292 CreateBuffer(flatBufferBuilder,
293 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
294 sizeof(T) * inputGateBias.size())));
295 tensors.push_back(CreateTensor(flatBufferBuilder,
296 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
297 tensorInfo4.size()),
298 tensorType,
299 buffers.size() - 1,
300 flatBufferBuilder.CreateString("inputGateBias"),
301 outputQuantizationParameters));
302 operatorInputs.push_back(buffers.size() - 1);
303 }
304 else
305 {
306 operatorInputs.push_back(kTfLiteOptionalTensor);
307 }
308
309 buffers.push_back(
310 CreateBuffer(flatBufferBuilder,
311 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
312 sizeof(T) * forgetGateBias.size())));
313 tensors.push_back(CreateTensor(flatBufferBuilder,
314 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
315 tensorInfo4.size()),
316 tensorType,
317 buffers.size() - 1,
318 flatBufferBuilder.CreateString("forgetGateBias"),
319 outputQuantizationParameters));
320 operatorInputs.push_back(buffers.size() - 1);
321
322 buffers.push_back(
323 CreateBuffer(flatBufferBuilder,
324 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
325 sizeof(T) * cellBias.size())));
326 tensors.push_back(CreateTensor(flatBufferBuilder,
327 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
328 tensorInfo4.size()),
329 tensorType,
330 buffers.size() - 1,
331 flatBufferBuilder.CreateString("cellBias"),
332 outputQuantizationParameters));
333 operatorInputs.push_back(buffers.size() - 1);
334
335 buffers.push_back(
336 CreateBuffer(flatBufferBuilder,
337 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
338 sizeof(T) * outputGateBias.size())));
339 tensors.push_back(CreateTensor(flatBufferBuilder,
340 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
341 tensorInfo4.size()),
342 tensorType,
343 buffers.size() - 1,
344 flatBufferBuilder.CreateString("outputGateBias"),
345 outputQuantizationParameters));
346 operatorInputs.push_back(buffers.size() - 1);
347
348 if (hasProjectionWeights)
349 {
350 buffers.push_back(
351 CreateBuffer(flatBufferBuilder,
352 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
353 sizeof(T) * projectionWeights.size())));
354 tensors.push_back(CreateTensor(flatBufferBuilder,
355 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
356 tensorInfo4.size()),
357 tensorType,
358 buffers.size() - 1,
359 flatBufferBuilder.CreateString("outputGateBias"),
360 outputQuantizationParameters));
361 operatorInputs.push_back(buffers.size() - 1);
362 }
363 else
364 {
365 operatorInputs.push_back(kTfLiteOptionalTensor);
366 }
367
368 if (hasProjectionBias)
369 {
370 buffers.push_back(
371 CreateBuffer(flatBufferBuilder,
372 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
373 sizeof(T) * projectionBias.size())));
374 tensors.push_back(CreateTensor(flatBufferBuilder,
375 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
376 tensorInfo4.size()),
377 tensorType,
378 buffers.size() - 1,
379 flatBufferBuilder.CreateString("projectionBias"),
380 outputQuantizationParameters));
381 operatorInputs.push_back(buffers.size() - 1);
382 }
383 else
384 {
385 operatorInputs.push_back(kTfLiteOptionalTensor);
386 }
387
Ryan OShea238ecd92023-03-07 11:44:23 +0000388 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000389 tensors.push_back(CreateTensor(flatBufferBuilder,
390 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
391 outputStateInDimensions.size()),
392 tensorType,
393 buffers.size() - 1,
394 flatBufferBuilder.CreateString("outputStateInInfo"),
395 outputQuantizationParameters,
396 true));
397 operatorInputs.push_back(buffers.size() - 1);
398
Ryan OShea238ecd92023-03-07 11:44:23 +0000399 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000400 tensors.push_back(CreateTensor(flatBufferBuilder,
401 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
402 cellStateInDimensions.size()),
403 tensorType,
404 buffers.size() - 1,
405 flatBufferBuilder.CreateString("cellStateInInfo"),
406 outputQuantizationParameters,
407 true));
408 operatorInputs.push_back(buffers.size() - 1);
409
410 if (hasInputLayerNormWeights)
411 {
412 buffers.push_back(
413 CreateBuffer(flatBufferBuilder,
414 flatBufferBuilder.CreateVector(
415 reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
416 sizeof(T) * inputLayerNormWeights.size())));
417 tensors.push_back(CreateTensor(flatBufferBuilder,
418 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
419 tensorInfo4.size()),
420 tensorType,
421 buffers.size() - 1,
422 flatBufferBuilder.CreateString("inputLayerNormWeights"),
423 outputQuantizationParameters));
424 operatorInputs.push_back(buffers.size() - 1);
425 }
426 else
427 {
428 operatorInputs.push_back(kTfLiteOptionalTensor);
429 }
430
431 if (hasForgetLayerNormWeights)
432 {
433 buffers.push_back(
434 CreateBuffer(flatBufferBuilder,
435 flatBufferBuilder.CreateVector(
436 reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
437 sizeof(T) * forgetLayerNormWeights.size())));
438 tensors.push_back(CreateTensor(flatBufferBuilder,
439 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
440 tensorInfo4.size()),
441 tensorType,
442 buffers.size() - 1,
443 flatBufferBuilder.CreateString("forgetLayerNormWeights"),
444 outputQuantizationParameters));
445 operatorInputs.push_back(buffers.size() - 1);
446 }
447 else
448 {
449 operatorInputs.push_back(kTfLiteOptionalTensor);
450 }
451
452 if (hasCellLayerNormWeights)
453 {
454 buffers.push_back(
455 CreateBuffer(flatBufferBuilder,
456 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
457 sizeof(T) * cellLayerNormWeights.size())));
458 tensors.push_back(CreateTensor(flatBufferBuilder,
459 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
460 tensorInfo4.size()),
461 tensorType,
462 buffers.size() - 1,
463 flatBufferBuilder.CreateString("cellLayerNormWeights"),
464 outputQuantizationParameters));
465 operatorInputs.push_back(buffers.size() - 1);
466 }
467 else
468 {
469 operatorInputs.push_back(kTfLiteOptionalTensor);
470 }
471
472 if (hasOutputLayerNormWeights)
473 {
474 buffers.push_back(
475 CreateBuffer(flatBufferBuilder,
476 flatBufferBuilder.CreateVector(
477 reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
478 sizeof(T) * outputLayerNormWeights.size())));
479 tensors.push_back(CreateTensor(flatBufferBuilder,
480 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
481 tensorInfo4.size()),
482 tensorType,
483 buffers.size() - 1,
484 flatBufferBuilder.CreateString("outputLayerNormWeights"),
485 outputQuantizationParameters));
486 operatorInputs.push_back(buffers.size() - 1);
487 }
488 else
489 {
490 operatorInputs.push_back(kTfLiteOptionalTensor);
491 }
492 int outputBufferId = buffers.size();
Ryan OShea238ecd92023-03-07 11:44:23 +0000493 buffers.push_back(CreateBuffer(flatBufferBuilder));
Mike Kelly8ae17b32021-02-17 13:45:50 +0000494 tensors.push_back(CreateTensor(flatBufferBuilder,
495 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
496 outputShape.size()),
497 tensorType,
498 outputBufferId,
499 flatBufferBuilder.CreateString("output"),
500 outputQuantizationParameters));
501 std::vector<int> operatorOutputs;
502 operatorOutputs.push_back(buffers.size() - 1);
503
504 // create operator
505 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_LSTMOptions;
506 flatbuffers::Offset<void> operatorBuiltinOptions =
507 CreateLSTMOptions(flatBufferBuilder,
508 activationFunction,
509 clippingThresCell,
510 clippingThresProj).Union();
511
512 flatbuffers::Offset <Operator> lstmOperator =
513 CreateOperator(flatBufferBuilder,
514 0,
515 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
516 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
517 operatorBuiltinOptionsType, operatorBuiltinOptions);
518
519 flatbuffers::Offset <SubGraph> subgraph =
520 CreateSubGraph(flatBufferBuilder,
521 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
522 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
523 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
524 flatBufferBuilder.CreateVector(&lstmOperator, 1));
525
526 flatbuffers::Offset <flatbuffers::String> modelDescription =
527 flatBufferBuilder.CreateString("ArmnnDelegate: LSTM Operator Model");
528 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
529 tflite::BuiltinOperator_LSTM);
530
531 flatbuffers::Offset <Model> flatbufferModel =
532 CreateModel(flatBufferBuilder,
533 TFLITE_SCHEMA_VERSION,
534 flatBufferBuilder.CreateVector(&operatorCode, 1),
535 flatBufferBuilder.CreateVector(&subgraph, 1),
536 modelDescription,
537 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
538
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100539 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000540
541 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
542 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
543}
544
545template <typename T>
546void LstmTestImpl(std::vector<armnn::BackendId>& backends,
547 tflite::TensorType tensorType,
548 int32_t batchSize,
549 int32_t inputSize,
550 int32_t outputSize,
551 int32_t numUnits,
552 bool hasInputToInputWeights,
553 const std::vector<T>& inputToInputWeights,
554 const std::vector<T>& inputToForgetWeights,
555 const std::vector<T>& inputToCellWeights,
556 const std::vector<T>& inputToOutputWeights,
557 bool hasRecurrentToInputWeights,
558 const std::vector<T>& recurrentToInputWeights,
559 const std::vector<T>& recurrentToForgetWeights,
560 const std::vector<T>& recurrentToCellWeights,
561 const std::vector<T>& recurrentToOutputWeights,
562 bool hasCellToInputWeights,
563 const std::vector<T>& cellToInputWeights,
564 bool hasCellToForgetWeights,
565 const std::vector<T>& cellToForgetWeights,
566 bool hasCellToOutputWeights,
567 const std::vector<T>& cellToOutputWeights,
568 bool hasInputGateBias,
569 const std::vector<T>& inputGateBias,
570 const std::vector<T>& forgetGateBias,
571 const std::vector<T>& cellBias,
572 const std::vector<T>& outputGateBias,
573 bool hasProjectionWeights,
574 const std::vector<T>& projectionWeights,
575 bool hasProjectionBias,
576 const std::vector<T>& projectionBias,
577 bool hasInputLayerNormWeights,
578 const std::vector<T>& inputLayerNormWeights,
579 bool hasForgetLayerNormWeights,
580 const std::vector<T>& forgetLayerNormWeights,
581 bool hasCellLayerNormWeights,
582 const std::vector<T>& cellLayerNormWeights,
583 bool hasOutputLayerNormWeights,
584 const std::vector<T>& outputLayerNormWeights,
585 std::vector<T>& inputValues,
586 std::vector<T>& expectedOutputValues,
587 tflite::ActivationFunctionType activationFunction,
588 float clippingThresCell,
589 float clippingThresProj)
590{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100591 using namespace delegateTestInterpreter;
Mike Kelly8ae17b32021-02-17 13:45:50 +0000592
593 std::vector<char> modelBuffer = CreateLstmTfLiteModel(tensorType,
594 batchSize,
595 inputSize,
596 outputSize,
597 numUnits,
598 hasInputToInputWeights,
599 inputToInputWeights,
600 inputToForgetWeights,
601 inputToCellWeights,
602 inputToOutputWeights,
603 hasRecurrentToInputWeights,
604 recurrentToInputWeights,
605 recurrentToForgetWeights,
606 recurrentToCellWeights,
607 recurrentToOutputWeights,
608 hasCellToInputWeights,
609 cellToInputWeights,
610 hasCellToForgetWeights,
611 cellToForgetWeights,
612 hasCellToOutputWeights,
613 cellToOutputWeights,
614 hasInputGateBias,
615 inputGateBias,
616 forgetGateBias,
617 cellBias,
618 outputGateBias,
619 hasProjectionWeights,
620 projectionWeights,
621 hasProjectionBias,
622 projectionBias,
623 hasInputLayerNormWeights,
624 inputLayerNormWeights,
625 hasForgetLayerNormWeights,
626 forgetLayerNormWeights,
627 hasCellLayerNormWeights,
628 cellLayerNormWeights,
629 hasOutputLayerNormWeights,
630 outputLayerNormWeights,
631 activationFunction,
632 clippingThresCell,
633 clippingThresProj);
634
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100635 std::vector<int32_t> expectedOutputShape {batchSize , outputSize};
Mike Kelly8ae17b32021-02-17 13:45:50 +0000636
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100637 // Setup interpreter with just TFLite Runtime.
638 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
639 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
640 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
641 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
642 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
643 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000644
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100645 // Setup interpreter with Arm NN Delegate applied.
646 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
647 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
648 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
649 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
650 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
651 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000652
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100653 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
654 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, expectedOutputShape);
Mike Kelly8ae17b32021-02-17 13:45:50 +0000655
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100656 tfLiteInterpreter.Cleanup();
657 armnnInterpreter.Cleanup();
Mike Kelly8ae17b32021-02-17 13:45:50 +0000658}
659
660} // anonymous namespace