blob: 36a606119a17f491391c0b3d87641d74a707fea3 [file] [log] [blame]
Mike Kelly8ae17b32021-02-17 13:45:50 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18#include <tensorflow/lite/c/common.h>
19
20#include <doctest/doctest.h>
21
22namespace
23{
24
25template <typename T>
26std::vector<char> CreateLstmTfLiteModel(tflite::TensorType tensorType,
27 int32_t batchSize,
28 int32_t inputSize,
29 int32_t outputSize,
30 int32_t numUnits,
31 bool hasInputToInputWeights,
32 const std::vector<T>& inputToInputWeights,
33 const std::vector<T>& inputToForgetWeights,
34 const std::vector<T>& inputToCellWeights,
35 const std::vector<T>& inputToOutputWeights,
36 bool hasRecurrentToInputWeights,
37 const std::vector<T>& recurrentToInputWeights,
38 const std::vector<T>& recurrentToForgetWeights,
39 const std::vector<T>& recurrentToCellWeights,
40 const std::vector<T>& recurrentToOutputWeights,
41 bool hasCellToInputWeights,
42 const std::vector<T>& cellToInputWeights,
43 bool hasCellToForgetWeights,
44 const std::vector<T>& cellToForgetWeights,
45 bool hasCellToOutputWeights,
46 const std::vector<T>& cellToOutputWeights,
47 bool hasInputGateBias,
48 const std::vector<T>& inputGateBias,
49 const std::vector<T>& forgetGateBias,
50 const std::vector<T>& cellBias,
51 const std::vector<T>& outputGateBias,
52 bool hasProjectionWeights,
53 const std::vector<T>& projectionWeights,
54 bool hasProjectionBias,
55 const std::vector<T>& projectionBias,
56 bool hasInputLayerNormWeights,
57 const std::vector<T>& inputLayerNormWeights,
58 bool hasForgetLayerNormWeights,
59 const std::vector<T>& forgetLayerNormWeights,
60 bool hasCellLayerNormWeights,
61 const std::vector<T>& cellLayerNormWeights,
62 bool hasOutputLayerNormWeights,
63 const std::vector<T>& outputLayerNormWeights,
64 tflite::ActivationFunctionType activationFunction,
65 float clippingThresCell,
66 float clippingThresProj,
67 float quantScale = 1.0f,
68 int quantOffset = 0,
69 float outputQuantScale = 2.0f,
70 int outputQuantOffset = 0)
71{
72
73 std::vector <int32_t> tensorInfo0 {};
74 std::vector <int32_t> tensorInfo4 {numUnits};
75 std::vector <int32_t> tensorInfo8 {numUnits, static_cast<int32_t>(2)};
76 std::vector <int32_t> tensorInfo16 {numUnits, static_cast<int32_t>(4)};
77
78 std::vector<int32_t> inputShape {batchSize , inputSize};
79 std::vector<int32_t> outputShape {batchSize , outputSize};
80
81 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
82 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
83
84 std::vector<int> operatorInputs;
85 using namespace tflite;
86 flatbuffers::FlatBufferBuilder flatBufferBuilder;
87 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
88 std::vector<flatbuffers::Offset<Tensor>> tensors;
89
90 auto quantizationParameters =
91 CreateQuantizationParameters(flatBufferBuilder,
92 0,
93 0,
94 flatBufferBuilder.CreateVector<float>({ quantScale }),
95 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
96
97 auto outputQuantizationParameters =
98 CreateQuantizationParameters(flatBufferBuilder,
99 0,
100 0,
101 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
102 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
103
104 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
105 tensors.push_back(CreateTensor(flatBufferBuilder,
106 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
107 inputShape.size()),
108 tensorType,
109 buffers.size() - 1,
110 flatBufferBuilder.CreateString("input_0"),
111 quantizationParameters));
112 operatorInputs.push_back(buffers.size() - 1);
113
114 if (hasInputToInputWeights)
115 {
116 buffers.push_back(
117 CreateBuffer(flatBufferBuilder,
118 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
119 sizeof(T) * inputToInputWeights.size())));
120 tensors.push_back(CreateTensor(flatBufferBuilder,
121 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
122 tensorInfo8.size()),
123 tensorType,
124 buffers.size() - 1,
125 flatBufferBuilder.CreateString("inputToInputWeights"),
126 outputQuantizationParameters));
127 operatorInputs.push_back(buffers.size() - 1);
128 }
129 else
130 {
131 operatorInputs.push_back(kTfLiteOptionalTensor);
132 }
133
134 buffers.push_back(
135 CreateBuffer(flatBufferBuilder,
136 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
137 sizeof(T) * inputToForgetWeights.size())));
138 tensors.push_back(CreateTensor(flatBufferBuilder,
139 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
140 tensorInfo8.size()),
141 tensorType,
142 buffers.size() - 1,
143 flatBufferBuilder.CreateString("inputToForgetWeights"),
144 outputQuantizationParameters));
145 operatorInputs.push_back(buffers.size() - 1);
146
147 buffers.push_back(
148 CreateBuffer(flatBufferBuilder,
149 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
150 sizeof(T) * inputToCellWeights.size())));
151 tensors.push_back(CreateTensor(flatBufferBuilder,
152 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
153 tensorInfo8.size()),
154 tensorType,
155 buffers.size() - 1,
156 flatBufferBuilder.CreateString("inputToCellWeights"),
157 outputQuantizationParameters));
158 operatorInputs.push_back(buffers.size() - 1);
159
160 buffers.push_back(
161 CreateBuffer(flatBufferBuilder,
162 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
163 sizeof(T) * inputToOutputWeights.size())));
164 tensors.push_back(CreateTensor(flatBufferBuilder,
165 flatBufferBuilder.CreateVector<int32_t>(tensorInfo8.data(),
166 tensorInfo8.size()),
167 tensorType,
168 buffers.size() - 1,
169 flatBufferBuilder.CreateString("inputToOutputWeights"),
170 outputQuantizationParameters));
171 operatorInputs.push_back(buffers.size() - 1);
172
173 if (hasRecurrentToInputWeights)
174 {
175 buffers.push_back(CreateBuffer(
176 flatBufferBuilder,
177 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
178 sizeof(T) * recurrentToInputWeights.size())));
179 tensors.push_back(CreateTensor(flatBufferBuilder,
180 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
181 tensorInfo16.size()),
182 tensorType,
183 buffers.size() - 1,
184 flatBufferBuilder.CreateString("recurrentToInputWeights"),
185 outputQuantizationParameters));
186 operatorInputs.push_back(buffers.size() - 1);
187 }
188 else
189 {
190 operatorInputs.push_back(kTfLiteOptionalTensor);
191 }
192
193 buffers.push_back(
194 CreateBuffer(flatBufferBuilder,
195 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
196 sizeof(T) * recurrentToForgetWeights.size())));
197 tensors.push_back(CreateTensor(flatBufferBuilder,
198 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
199 tensorInfo16.size()),
200 tensorType,
201 buffers.size() - 1,
202 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
203 outputQuantizationParameters));
204 operatorInputs.push_back(buffers.size() - 1);
205
206 buffers.push_back(
207 CreateBuffer(flatBufferBuilder,
208 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
209 sizeof(T) * recurrentToCellWeights.size())));
210 tensors.push_back(CreateTensor(flatBufferBuilder,
211 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
212 tensorInfo16.size()),
213 tensorType,
214 buffers.size() - 1,
215 flatBufferBuilder.CreateString("recurrentToCellWeights"),
216 outputQuantizationParameters));
217 operatorInputs.push_back(buffers.size() - 1);
218
219 buffers.push_back(
220 CreateBuffer(flatBufferBuilder,
221 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
222 sizeof(T) * recurrentToOutputWeights.size())));
223 tensors.push_back(CreateTensor(flatBufferBuilder,
224 flatBufferBuilder.CreateVector<int32_t>(tensorInfo16.data(),
225 tensorInfo16.size()),
226 tensorType,
227 buffers.size() - 1 ,
228 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
229 outputQuantizationParameters));
230 operatorInputs.push_back(buffers.size() - 1);
231
232 if (hasCellToInputWeights)
233 {
234 buffers.push_back(
235 CreateBuffer(flatBufferBuilder,
236 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
237 sizeof(T) * cellToInputWeights.size())));
238 tensors.push_back(CreateTensor(flatBufferBuilder,
239 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
240 tensorInfo4.size()),
241 tensorType,
242 buffers.size() - 1,
243 flatBufferBuilder.CreateString("cellToInputWeights"),
244 outputQuantizationParameters));
245 operatorInputs.push_back(buffers.size() - 1);
246 }
247 else
248 {
249 operatorInputs.push_back(kTfLiteOptionalTensor);
250 }
251
252 if (hasCellToForgetWeights)
253 {
254 buffers.push_back(
255 CreateBuffer(flatBufferBuilder,
256 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
257 sizeof(T) * cellToForgetWeights.size())));
258 tensors.push_back(CreateTensor(flatBufferBuilder,
259 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
260 tensorInfo4.size()),
261 tensorType,
262 buffers.size() - 1,
263 flatBufferBuilder.CreateString("cellToForgetWeights"),
264 outputQuantizationParameters));
265 operatorInputs.push_back(buffers.size() - 1);
266 }
267 else
268 {
269 operatorInputs.push_back(kTfLiteOptionalTensor);
270 }
271
272 if (hasCellToOutputWeights)
273 {
274 buffers.push_back(
275 CreateBuffer(flatBufferBuilder,
276 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
277 sizeof(T) * cellToOutputWeights.size())));
278 tensors.push_back(CreateTensor(flatBufferBuilder,
279 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
280 tensorInfo4.size()),
281 tensorType,
282 buffers.size() - 1,
283 flatBufferBuilder.CreateString("cellToOutputWeights"),
284 outputQuantizationParameters));
285 operatorInputs.push_back(buffers.size() - 1);
286 }
287 else
288 {
289 operatorInputs.push_back(kTfLiteOptionalTensor);
290 }
291
292 if (hasInputGateBias)
293 {
294 buffers.push_back(
295 CreateBuffer(flatBufferBuilder,
296 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
297 sizeof(T) * inputGateBias.size())));
298 tensors.push_back(CreateTensor(flatBufferBuilder,
299 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
300 tensorInfo4.size()),
301 tensorType,
302 buffers.size() - 1,
303 flatBufferBuilder.CreateString("inputGateBias"),
304 outputQuantizationParameters));
305 operatorInputs.push_back(buffers.size() - 1);
306 }
307 else
308 {
309 operatorInputs.push_back(kTfLiteOptionalTensor);
310 }
311
312 buffers.push_back(
313 CreateBuffer(flatBufferBuilder,
314 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
315 sizeof(T) * forgetGateBias.size())));
316 tensors.push_back(CreateTensor(flatBufferBuilder,
317 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
318 tensorInfo4.size()),
319 tensorType,
320 buffers.size() - 1,
321 flatBufferBuilder.CreateString("forgetGateBias"),
322 outputQuantizationParameters));
323 operatorInputs.push_back(buffers.size() - 1);
324
325 buffers.push_back(
326 CreateBuffer(flatBufferBuilder,
327 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
328 sizeof(T) * cellBias.size())));
329 tensors.push_back(CreateTensor(flatBufferBuilder,
330 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
331 tensorInfo4.size()),
332 tensorType,
333 buffers.size() - 1,
334 flatBufferBuilder.CreateString("cellBias"),
335 outputQuantizationParameters));
336 operatorInputs.push_back(buffers.size() - 1);
337
338 buffers.push_back(
339 CreateBuffer(flatBufferBuilder,
340 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
341 sizeof(T) * outputGateBias.size())));
342 tensors.push_back(CreateTensor(flatBufferBuilder,
343 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
344 tensorInfo4.size()),
345 tensorType,
346 buffers.size() - 1,
347 flatBufferBuilder.CreateString("outputGateBias"),
348 outputQuantizationParameters));
349 operatorInputs.push_back(buffers.size() - 1);
350
351 if (hasProjectionWeights)
352 {
353 buffers.push_back(
354 CreateBuffer(flatBufferBuilder,
355 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
356 sizeof(T) * projectionWeights.size())));
357 tensors.push_back(CreateTensor(flatBufferBuilder,
358 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
359 tensorInfo4.size()),
360 tensorType,
361 buffers.size() - 1,
362 flatBufferBuilder.CreateString("outputGateBias"),
363 outputQuantizationParameters));
364 operatorInputs.push_back(buffers.size() - 1);
365 }
366 else
367 {
368 operatorInputs.push_back(kTfLiteOptionalTensor);
369 }
370
371 if (hasProjectionBias)
372 {
373 buffers.push_back(
374 CreateBuffer(flatBufferBuilder,
375 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
376 sizeof(T) * projectionBias.size())));
377 tensors.push_back(CreateTensor(flatBufferBuilder,
378 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
379 tensorInfo4.size()),
380 tensorType,
381 buffers.size() - 1,
382 flatBufferBuilder.CreateString("projectionBias"),
383 outputQuantizationParameters));
384 operatorInputs.push_back(buffers.size() - 1);
385 }
386 else
387 {
388 operatorInputs.push_back(kTfLiteOptionalTensor);
389 }
390
391 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
392 tensors.push_back(CreateTensor(flatBufferBuilder,
393 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
394 outputStateInDimensions.size()),
395 tensorType,
396 buffers.size() - 1,
397 flatBufferBuilder.CreateString("outputStateInInfo"),
398 outputQuantizationParameters,
399 true));
400 operatorInputs.push_back(buffers.size() - 1);
401
402 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
403 tensors.push_back(CreateTensor(flatBufferBuilder,
404 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
405 cellStateInDimensions.size()),
406 tensorType,
407 buffers.size() - 1,
408 flatBufferBuilder.CreateString("cellStateInInfo"),
409 outputQuantizationParameters,
410 true));
411 operatorInputs.push_back(buffers.size() - 1);
412
413 if (hasInputLayerNormWeights)
414 {
415 buffers.push_back(
416 CreateBuffer(flatBufferBuilder,
417 flatBufferBuilder.CreateVector(
418 reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
419 sizeof(T) * inputLayerNormWeights.size())));
420 tensors.push_back(CreateTensor(flatBufferBuilder,
421 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
422 tensorInfo4.size()),
423 tensorType,
424 buffers.size() - 1,
425 flatBufferBuilder.CreateString("inputLayerNormWeights"),
426 outputQuantizationParameters));
427 operatorInputs.push_back(buffers.size() - 1);
428 }
429 else
430 {
431 operatorInputs.push_back(kTfLiteOptionalTensor);
432 }
433
434 if (hasForgetLayerNormWeights)
435 {
436 buffers.push_back(
437 CreateBuffer(flatBufferBuilder,
438 flatBufferBuilder.CreateVector(
439 reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
440 sizeof(T) * forgetLayerNormWeights.size())));
441 tensors.push_back(CreateTensor(flatBufferBuilder,
442 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
443 tensorInfo4.size()),
444 tensorType,
445 buffers.size() - 1,
446 flatBufferBuilder.CreateString("forgetLayerNormWeights"),
447 outputQuantizationParameters));
448 operatorInputs.push_back(buffers.size() - 1);
449 }
450 else
451 {
452 operatorInputs.push_back(kTfLiteOptionalTensor);
453 }
454
455 if (hasCellLayerNormWeights)
456 {
457 buffers.push_back(
458 CreateBuffer(flatBufferBuilder,
459 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
460 sizeof(T) * cellLayerNormWeights.size())));
461 tensors.push_back(CreateTensor(flatBufferBuilder,
462 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
463 tensorInfo4.size()),
464 tensorType,
465 buffers.size() - 1,
466 flatBufferBuilder.CreateString("cellLayerNormWeights"),
467 outputQuantizationParameters));
468 operatorInputs.push_back(buffers.size() - 1);
469 }
470 else
471 {
472 operatorInputs.push_back(kTfLiteOptionalTensor);
473 }
474
475 if (hasOutputLayerNormWeights)
476 {
477 buffers.push_back(
478 CreateBuffer(flatBufferBuilder,
479 flatBufferBuilder.CreateVector(
480 reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
481 sizeof(T) * outputLayerNormWeights.size())));
482 tensors.push_back(CreateTensor(flatBufferBuilder,
483 flatBufferBuilder.CreateVector<int32_t>(tensorInfo4.data(),
484 tensorInfo4.size()),
485 tensorType,
486 buffers.size() - 1,
487 flatBufferBuilder.CreateString("outputLayerNormWeights"),
488 outputQuantizationParameters));
489 operatorInputs.push_back(buffers.size() - 1);
490 }
491 else
492 {
493 operatorInputs.push_back(kTfLiteOptionalTensor);
494 }
495 int outputBufferId = buffers.size();
496 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
497 tensors.push_back(CreateTensor(flatBufferBuilder,
498 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
499 outputShape.size()),
500 tensorType,
501 outputBufferId,
502 flatBufferBuilder.CreateString("output"),
503 outputQuantizationParameters));
504 std::vector<int> operatorOutputs;
505 operatorOutputs.push_back(buffers.size() - 1);
506
507 // create operator
508 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_LSTMOptions;
509 flatbuffers::Offset<void> operatorBuiltinOptions =
510 CreateLSTMOptions(flatBufferBuilder,
511 activationFunction,
512 clippingThresCell,
513 clippingThresProj).Union();
514
515 flatbuffers::Offset <Operator> lstmOperator =
516 CreateOperator(flatBufferBuilder,
517 0,
518 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
519 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
520 operatorBuiltinOptionsType, operatorBuiltinOptions);
521
522 flatbuffers::Offset <SubGraph> subgraph =
523 CreateSubGraph(flatBufferBuilder,
524 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
525 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
526 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
527 flatBufferBuilder.CreateVector(&lstmOperator, 1));
528
529 flatbuffers::Offset <flatbuffers::String> modelDescription =
530 flatBufferBuilder.CreateString("ArmnnDelegate: LSTM Operator Model");
531 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
532 tflite::BuiltinOperator_LSTM);
533
534 flatbuffers::Offset <Model> flatbufferModel =
535 CreateModel(flatBufferBuilder,
536 TFLITE_SCHEMA_VERSION,
537 flatBufferBuilder.CreateVector(&operatorCode, 1),
538 flatBufferBuilder.CreateVector(&subgraph, 1),
539 modelDescription,
540 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
541
542 flatBufferBuilder.Finish(flatbufferModel);
543
544 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
545 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
546}
547
548template <typename T>
549void LstmTestImpl(std::vector<armnn::BackendId>& backends,
550 tflite::TensorType tensorType,
551 int32_t batchSize,
552 int32_t inputSize,
553 int32_t outputSize,
554 int32_t numUnits,
555 bool hasInputToInputWeights,
556 const std::vector<T>& inputToInputWeights,
557 const std::vector<T>& inputToForgetWeights,
558 const std::vector<T>& inputToCellWeights,
559 const std::vector<T>& inputToOutputWeights,
560 bool hasRecurrentToInputWeights,
561 const std::vector<T>& recurrentToInputWeights,
562 const std::vector<T>& recurrentToForgetWeights,
563 const std::vector<T>& recurrentToCellWeights,
564 const std::vector<T>& recurrentToOutputWeights,
565 bool hasCellToInputWeights,
566 const std::vector<T>& cellToInputWeights,
567 bool hasCellToForgetWeights,
568 const std::vector<T>& cellToForgetWeights,
569 bool hasCellToOutputWeights,
570 const std::vector<T>& cellToOutputWeights,
571 bool hasInputGateBias,
572 const std::vector<T>& inputGateBias,
573 const std::vector<T>& forgetGateBias,
574 const std::vector<T>& cellBias,
575 const std::vector<T>& outputGateBias,
576 bool hasProjectionWeights,
577 const std::vector<T>& projectionWeights,
578 bool hasProjectionBias,
579 const std::vector<T>& projectionBias,
580 bool hasInputLayerNormWeights,
581 const std::vector<T>& inputLayerNormWeights,
582 bool hasForgetLayerNormWeights,
583 const std::vector<T>& forgetLayerNormWeights,
584 bool hasCellLayerNormWeights,
585 const std::vector<T>& cellLayerNormWeights,
586 bool hasOutputLayerNormWeights,
587 const std::vector<T>& outputLayerNormWeights,
588 std::vector<T>& inputValues,
589 std::vector<T>& expectedOutputValues,
590 tflite::ActivationFunctionType activationFunction,
591 float clippingThresCell,
592 float clippingThresProj)
593{
594 using namespace tflite;
595
596 std::vector<char> modelBuffer = CreateLstmTfLiteModel(tensorType,
597 batchSize,
598 inputSize,
599 outputSize,
600 numUnits,
601 hasInputToInputWeights,
602 inputToInputWeights,
603 inputToForgetWeights,
604 inputToCellWeights,
605 inputToOutputWeights,
606 hasRecurrentToInputWeights,
607 recurrentToInputWeights,
608 recurrentToForgetWeights,
609 recurrentToCellWeights,
610 recurrentToOutputWeights,
611 hasCellToInputWeights,
612 cellToInputWeights,
613 hasCellToForgetWeights,
614 cellToForgetWeights,
615 hasCellToOutputWeights,
616 cellToOutputWeights,
617 hasInputGateBias,
618 inputGateBias,
619 forgetGateBias,
620 cellBias,
621 outputGateBias,
622 hasProjectionWeights,
623 projectionWeights,
624 hasProjectionBias,
625 projectionBias,
626 hasInputLayerNormWeights,
627 inputLayerNormWeights,
628 hasForgetLayerNormWeights,
629 forgetLayerNormWeights,
630 hasCellLayerNormWeights,
631 cellLayerNormWeights,
632 hasOutputLayerNormWeights,
633 outputLayerNormWeights,
634 activationFunction,
635 clippingThresCell,
636 clippingThresProj);
637
638 const Model* tfLiteModel = GetModel(modelBuffer.data());
639 // Create TfLite Interpreters
640 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
641 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
642 (&armnnDelegateInterpreter) == kTfLiteOk);
643 CHECK(armnnDelegateInterpreter != nullptr);
644 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
645
646 std::unique_ptr<Interpreter> tfLiteInterpreter;
647 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
648 (&tfLiteInterpreter) == kTfLiteOk);
649 CHECK(tfLiteInterpreter != nullptr);
650 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
651
652 // Create the ArmNN Delegate
653 armnnDelegate::DelegateOptions delegateOptions(backends);
654 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
655 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
656 armnnDelegate::TfLiteArmnnDelegateDelete);
657 CHECK(theArmnnDelegate != nullptr);
658 // Modify armnnDelegateInterpreter to use armnnDelegate
659 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
660
661 // Set input data
662 auto tfLiteDelegateInputId = tfLiteInterpreter->inputs()[0];
663 auto tfLiteDelageInputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateInputId);
664 for (unsigned int i = 0; i < inputValues.size(); ++i)
665 {
666 tfLiteDelageInputData[i] = inputValues[i];
667 }
668
669 auto armnnDelegateInputId = armnnDelegateInterpreter->inputs()[0];
670 auto armnnDelegateInputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateInputId);
671 for (unsigned int i = 0; i < inputValues.size(); ++i)
672 {
673 armnnDelegateInputData[i] = inputValues[i];
674 }
675
676 // Run EnqueWorkload
677 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
678 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
679
680 // Compare output data
681 auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
682 auto tfLiteDelagateOutputData = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateOutputId);
683 auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
684 auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateOutputId);
685
686 armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData, expectedOutputValues.size());
687 armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData, expectedOutputValues.size());
688 armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData, expectedOutputValues.size());
689}
690
691} // anonymous namespace