blob: 9d6ef87e3f6f8999a25d10b7c4b05c66ce258dea [file] [log] [blame]
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18#include <tensorflow/lite/c/common.h>
19
20#include <doctest/doctest.h>
21
22
23#include <armnn/utility/IgnoreUnused.hpp>
24#include <armnn/utility/NumericCast.hpp>
25#include <armnn/TypesUtils.hpp>
26
27#include <armnn/Types.hpp>
28
29#include <initializer_list>
30#include <iterator>
31#include <vector>
32
33namespace
34{
35
36template <typename T>
37std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,
38 int32_t batchSize,
39 int32_t timeSize,
40 int32_t inputSize,
41 int32_t outputSize,
42 int32_t numUnits,
43 bool hasInputToInputWeights,
44 const std::vector<T>& inputToInputWeights,
45 const std::vector<T>& inputToForgetWeights,
46 const std::vector<T>& inputToCellWeights,
47 const std::vector<T>& inputToOutputWeights,
48 bool hasRecurrentToInputWeights,
49 const std::vector<T>& recurrentToInputWeights,
50 const std::vector<T>& recurrentToForgetWeights,
51 const std::vector<T>& recurrentToCellWeights,
52 const std::vector<T>& recurrentToOutputWeights,
53 bool hasCellToInputWeights,
54 const std::vector<T>& cellToInputWeights,
55 bool hasCellToForgetWeights,
56 const std::vector<T>& cellToForgetWeights,
57 bool hasCellToOutputWeights,
58 const std::vector<T>& cellToOutputWeights,
59 bool hasInputGateBias,
60 const std::vector<float>& inputGateBias,
61 const std::vector<float>& forgetGateBias,
62 const std::vector<float>& cellBias,
63 const std::vector<float>& outputGateBias,
64 bool hasProjectionWeights,
65 const std::vector<T>& projectionWeights,
66 bool hasProjectionBias,
67 const std::vector<float>& projectionBias,
68 bool hasInputLayerNormWeights,
69 const std::vector<float>& inputLayerNormWeights,
70 bool hasForgetLayerNormWeights,
71 const std::vector<float>& forgetLayerNormWeights,
72 bool hasCellLayerNormWeights,
73 const std::vector<float>& cellLayerNormWeights,
74 bool hasOutputLayerNormWeights,
75 const std::vector<float>& outputLayerNormWeights,
76 tflite::ActivationFunctionType activationFunction,
77 float clippingThresCell,
78 float clippingThresProj,
79 bool isTimeMajor,
80 float quantScale,
81 int quantOffset = 0)
82{
83
84 std::vector<int32_t> tensorInfo0{};
85 std::vector<int32_t> tensorInfoNumUnits{numUnits};
86 std::vector<int32_t> tensorInfoInputSize{numUnits, inputSize};
87 std::vector<int32_t> tensorInfoOutputSize{numUnits, outputSize};
88
89 std::vector<int32_t> inputShape;
90 std::vector<int32_t> outputShape;
91 if (isTimeMajor)
92 {
93 inputShape = {timeSize, batchSize, inputSize};
94 outputShape = {timeSize, batchSize, outputSize};
95 }
96 else
97 {
98 inputShape = {batchSize, timeSize, inputSize};
99 outputShape = {batchSize, timeSize, outputSize};
100 }
101 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
102 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
103 std::vector<int32_t> projectionWeightDimensions{outputSize, numUnits};
104 std::vector<int32_t> projectionBiasDimensions{outputSize};
105
106 std::vector<int> operatorInputs;
107 using namespace tflite;
108 flatbuffers::FlatBufferBuilder flatBufferBuilder;
109 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
110 std::vector<flatbuffers::Offset<Tensor>> tensors;
111
112 auto quantizationParameters =
113 CreateQuantizationParameters(flatBufferBuilder,
114 0,
115 0,
116 flatBufferBuilder.CreateVector<float>({ 1.0f }),
117 flatBufferBuilder.CreateVector<int64_t>({ 0 }));
118
119 auto weightQuantizationParameters =
120 CreateQuantizationParameters(flatBufferBuilder,
121 0,
122 0,
123 flatBufferBuilder.CreateVector<float>({ quantScale }),
124 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
125
126 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
127 tensors.push_back(CreateTensor(flatBufferBuilder,
128 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
129 inputShape.size()),
130 ::tflite::TensorType_FLOAT32,
131 buffers.size() - 1,
132 flatBufferBuilder.CreateString("input_0")));
133 operatorInputs.push_back(buffers.size() - 1);
134
135 if (hasInputToInputWeights)
136 {
137 buffers.push_back(
138 CreateBuffer(flatBufferBuilder,
139 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToInputWeights.data()),
140 sizeof(T) * inputToInputWeights.size())));
141 tensors.push_back(CreateTensor(flatBufferBuilder,
142 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
143 tensorInfoInputSize.size()),
144 tensorType,
145 buffers.size() - 1,
146 flatBufferBuilder.CreateString("inputToInputWeights"),
147 weightQuantizationParameters));
148 operatorInputs.push_back(buffers.size() - 1);
149 }
150 else
151 {
152 operatorInputs.push_back(kTfLiteOptionalTensor);
153 }
154
155 buffers.push_back(
156 CreateBuffer(flatBufferBuilder,
157 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToForgetWeights.data()),
158 sizeof(T) * inputToForgetWeights.size())));
159 tensors.push_back(CreateTensor(flatBufferBuilder,
160 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
161 tensorInfoInputSize.size()),
162 tensorType,
163 buffers.size() - 1,
164 flatBufferBuilder.CreateString("inputToForgetWeights"),
165 weightQuantizationParameters));
166 operatorInputs.push_back(buffers.size() - 1);
167
168 buffers.push_back(
169 CreateBuffer(flatBufferBuilder,
170 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToCellWeights.data()),
171 sizeof(T) * inputToCellWeights.size())));
172 tensors.push_back(CreateTensor(flatBufferBuilder,
173 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
174 tensorInfoInputSize.size()),
175 tensorType,
176 buffers.size() - 1,
177 flatBufferBuilder.CreateString("inputToCellWeights"),
178 weightQuantizationParameters));
179 operatorInputs.push_back(buffers.size() - 1);
180
181 buffers.push_back(
182 CreateBuffer(flatBufferBuilder,
183 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(inputToOutputWeights.data()),
184 sizeof(T) * inputToOutputWeights.size())));
185 tensors.push_back(CreateTensor(flatBufferBuilder,
186 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
187 tensorInfoInputSize.size()),
188 tensorType,
189 buffers.size() - 1,
190 flatBufferBuilder.CreateString("inputToOutputWeights"),
191 weightQuantizationParameters));
192 operatorInputs.push_back(buffers.size() - 1);
193
194 if (hasRecurrentToInputWeights)
195 {
196 buffers.push_back(CreateBuffer(
197 flatBufferBuilder,
198 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
199 sizeof(T) * recurrentToInputWeights.size())));
200 tensors.push_back(CreateTensor(flatBufferBuilder,
201 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
202 tensorInfoOutputSize.size()),
203 tensorType,
204 buffers.size() - 1,
205 flatBufferBuilder.CreateString("recurrentToInputWeights"),
206 weightQuantizationParameters));
207 operatorInputs.push_back(buffers.size() - 1);
208 }
209 else
210 {
211 operatorInputs.push_back(kTfLiteOptionalTensor);
212 }
213
214 buffers.push_back(
215 CreateBuffer(flatBufferBuilder,
216 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToForgetWeights.data()),
217 sizeof(T) * recurrentToForgetWeights.size())));
218 tensors.push_back(CreateTensor(flatBufferBuilder,
219 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
220 tensorInfoOutputSize.size()),
221 tensorType,
222 buffers.size() - 1,
223 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
224 weightQuantizationParameters));
225 operatorInputs.push_back(buffers.size() - 1);
226
227 buffers.push_back(
228 CreateBuffer(flatBufferBuilder,
229 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToCellWeights.data()),
230 sizeof(T) * recurrentToCellWeights.size())));
231 tensors.push_back(CreateTensor(flatBufferBuilder,
232 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
233 tensorInfoOutputSize.size()),
234 tensorType,
235 buffers.size() - 1,
236 flatBufferBuilder.CreateString("recurrentToCellWeights"),
237 weightQuantizationParameters));
238 operatorInputs.push_back(buffers.size() - 1);
239
240 buffers.push_back(
241 CreateBuffer(flatBufferBuilder,
242 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(recurrentToOutputWeights.data()),
243 sizeof(T) * recurrentToOutputWeights.size())));
244 tensors.push_back(CreateTensor(flatBufferBuilder,
245 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
246 tensorInfoOutputSize.size()),
247 tensorType,
248 buffers.size() - 1 ,
249 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
250 weightQuantizationParameters));
251 operatorInputs.push_back(buffers.size() - 1);
252
253 if (hasCellToInputWeights)
254 {
255 buffers.push_back(
256 CreateBuffer(flatBufferBuilder,
257 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToInputWeights.data()),
258 sizeof(T) * cellToInputWeights.size())));
259 tensors.push_back(CreateTensor(flatBufferBuilder,
260 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
261 tensorInfoNumUnits.size()),
262 tensorType,
263 buffers.size() - 1,
264 flatBufferBuilder.CreateString("cellToInputWeights"),
265 weightQuantizationParameters));
266 operatorInputs.push_back(buffers.size() - 1);
267 }
268 else
269 {
270 operatorInputs.push_back(kTfLiteOptionalTensor);
271 }
272
273 if (hasCellToForgetWeights)
274 {
275 buffers.push_back(
276 CreateBuffer(flatBufferBuilder,
277 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToForgetWeights.data()),
278 sizeof(T) * cellToForgetWeights.size())));
279 tensors.push_back(CreateTensor(flatBufferBuilder,
280 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
281 tensorInfoNumUnits.size()),
282 tensorType,
283 buffers.size() - 1,
284 flatBufferBuilder.CreateString("cellToForgetWeights"),
285 weightQuantizationParameters));
286 operatorInputs.push_back(buffers.size() - 1);
287 }
288 else
289 {
290 operatorInputs.push_back(kTfLiteOptionalTensor);
291 }
292
293 if (hasCellToOutputWeights)
294 {
295 buffers.push_back(
296 CreateBuffer(flatBufferBuilder,
297 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellToOutputWeights.data()),
298 sizeof(T) * cellToOutputWeights.size())));
299 tensors.push_back(CreateTensor(flatBufferBuilder,
300 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
301 tensorInfoNumUnits.size()),
302 tensorType,
303 buffers.size() - 1,
304 flatBufferBuilder.CreateString("cellToOutputWeights"),
305 weightQuantizationParameters));
306 operatorInputs.push_back(buffers.size() - 1);
307 }
308 else
309 {
310 operatorInputs.push_back(kTfLiteOptionalTensor);
311 }
312
313 if (hasInputGateBias)
314 {
315 buffers.push_back(
316 CreateBuffer(flatBufferBuilder,
317 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
318 sizeof(float) * inputGateBias.size())));
319 tensors.push_back(CreateTensor(flatBufferBuilder,
320 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
321 tensorInfoNumUnits.size()),
322 ::tflite::TensorType_FLOAT32,
323 buffers.size() - 1,
324 flatBufferBuilder.CreateString("inputGateBias")));
325 operatorInputs.push_back(buffers.size() - 1);
326 }
327 else
328 {
329 operatorInputs.push_back(kTfLiteOptionalTensor);
330 }
331
332 buffers.push_back(
333 CreateBuffer(flatBufferBuilder,
334 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(forgetGateBias.data()),
335 sizeof(float) * forgetGateBias.size())));
336 tensors.push_back(CreateTensor(flatBufferBuilder,
337 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
338 tensorInfoNumUnits.size()),
339 ::tflite::TensorType_FLOAT32,
340 buffers.size() - 1,
341 flatBufferBuilder.CreateString("forgetGateBias")));
342 operatorInputs.push_back(buffers.size() - 1);
343
344 buffers.push_back(
345 CreateBuffer(flatBufferBuilder,
346 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellBias.data()),
347 sizeof(float) * cellBias.size())));
348 tensors.push_back(CreateTensor(flatBufferBuilder,
349 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
350 tensorInfoNumUnits.size()),
351 ::tflite::TensorType_FLOAT32,
352 buffers.size() - 1,
353 flatBufferBuilder.CreateString("cellBias")));
354 operatorInputs.push_back(buffers.size() - 1);
355
356 buffers.push_back(
357 CreateBuffer(flatBufferBuilder,
358 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(outputGateBias.data()),
359 sizeof(float) * outputGateBias.size())));
360 tensors.push_back(CreateTensor(flatBufferBuilder,
361 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
362 tensorInfoNumUnits.size()),
363 ::tflite::TensorType_FLOAT32,
364 buffers.size() - 1,
365 flatBufferBuilder.CreateString("outputGateBias")));
366 operatorInputs.push_back(buffers.size() - 1);
367
368 if (hasProjectionWeights)
369 {
370 buffers.push_back(
371 CreateBuffer(flatBufferBuilder,
372 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionWeights.data()),
373 sizeof(T) * projectionWeights.size())));
374 tensors.push_back(CreateTensor(flatBufferBuilder,
375 flatBufferBuilder.CreateVector<int32_t>(projectionWeightDimensions.data(),
376 projectionWeightDimensions.size()),
377 tensorType,
378 buffers.size() - 1,
379 flatBufferBuilder.CreateString("projectionWeights"),
380 weightQuantizationParameters));
381 operatorInputs.push_back(buffers.size() - 1);
382 }
383 else
384 {
385 operatorInputs.push_back(kTfLiteOptionalTensor);
386 }
387
388 if (hasProjectionBias)
389 {
390 buffers.push_back(
391 CreateBuffer(flatBufferBuilder,
392 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(projectionBias.data()),
393 sizeof(float) * projectionBias.size())));
394 tensors.push_back(CreateTensor(flatBufferBuilder,
395 flatBufferBuilder.CreateVector<int32_t>(projectionBiasDimensions.data(),
396 projectionBiasDimensions.size()),
397 ::tflite::TensorType_FLOAT32,
398 buffers.size() - 1,
399 flatBufferBuilder.CreateString("projectionBias")));
400 operatorInputs.push_back(buffers.size() - 1);
401 }
402 else
403 {
404 operatorInputs.push_back(kTfLiteOptionalTensor);
405 }
406
407 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
408 tensors.push_back(CreateTensor(flatBufferBuilder,
409 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
410 outputStateInDimensions.size()),
411 ::tflite::TensorType_FLOAT32,
412 buffers.size() - 1,
413 flatBufferBuilder.CreateString("outputStateInInfo"),
414 quantizationParameters,
415 true));
416 operatorInputs.push_back(buffers.size() - 1);
417
418 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
419 tensors.push_back(CreateTensor(flatBufferBuilder,
420 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
421 cellStateInDimensions.size()),
422 ::tflite::TensorType_FLOAT32,
423 buffers.size() - 1,
424 flatBufferBuilder.CreateString("cellStateInInfo"),
425 quantizationParameters,
426 true));
427 operatorInputs.push_back(buffers.size() - 1);
428
429 if (hasInputLayerNormWeights)
430 {
431 buffers.push_back(
432 CreateBuffer(flatBufferBuilder,
433 flatBufferBuilder.CreateVector(
434 reinterpret_cast<const uint8_t *>(inputLayerNormWeights.data()),
435 sizeof(float) * inputLayerNormWeights.size())));
436 tensors.push_back(CreateTensor(flatBufferBuilder,
437 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
438 tensorInfoNumUnits.size()),
439 ::tflite::TensorType_FLOAT32,
440 buffers.size() - 1,
441 flatBufferBuilder.CreateString("inputLayerNormWeights")));
442 operatorInputs.push_back(buffers.size() - 1);
443 }
444 else
445 {
446 operatorInputs.push_back(kTfLiteOptionalTensor);
447 }
448
449 if (hasForgetLayerNormWeights)
450 {
451 buffers.push_back(
452 CreateBuffer(flatBufferBuilder,
453 flatBufferBuilder.CreateVector(
454 reinterpret_cast<const uint8_t *>(forgetLayerNormWeights.data()),
455 sizeof(float) * forgetLayerNormWeights.size())));
456 tensors.push_back(CreateTensor(flatBufferBuilder,
457 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
458 tensorInfoNumUnits.size()),
459 ::tflite::TensorType_FLOAT32,
460 buffers.size() - 1,
461 flatBufferBuilder.CreateString("forgetLayerNormWeights")));
462 operatorInputs.push_back(buffers.size() - 1);
463 }
464 else
465 {
466 operatorInputs.push_back(kTfLiteOptionalTensor);
467 }
468
469 if (hasCellLayerNormWeights)
470 {
471 buffers.push_back(
472 CreateBuffer(flatBufferBuilder,
473 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t *>(cellLayerNormWeights.data()),
474 sizeof(float) * cellLayerNormWeights.size())));
475 tensors.push_back(CreateTensor(flatBufferBuilder,
476 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
477 tensorInfoNumUnits.size()),
478 ::tflite::TensorType_FLOAT32,
479 buffers.size() - 1,
480 flatBufferBuilder.CreateString("cellLayerNormWeights")));
481 operatorInputs.push_back(buffers.size() - 1);
482 }
483 else
484 {
485 operatorInputs.push_back(kTfLiteOptionalTensor);
486 }
487
488 if (hasOutputLayerNormWeights)
489 {
490 buffers.push_back(
491 CreateBuffer(flatBufferBuilder,
492 flatBufferBuilder.CreateVector(
493 reinterpret_cast<const uint8_t *>(outputLayerNormWeights.data()),
494 sizeof(float) * outputLayerNormWeights.size())));
495 tensors.push_back(CreateTensor(flatBufferBuilder,
496 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
497 tensorInfoNumUnits.size()),
498 ::tflite::TensorType_FLOAT32,
499 buffers.size() - 1,
500 flatBufferBuilder.CreateString("outputLayerNormWeights")));
501 operatorInputs.push_back(buffers.size() - 1);
502 }
503 else
504 {
505 operatorInputs.push_back(kTfLiteOptionalTensor);
506 }
507 int outputBufferId = buffers.size();
508 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
509 tensors.push_back(CreateTensor(flatBufferBuilder,
510 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
511 outputShape.size()),
512 ::tflite::TensorType_FLOAT32,
513 outputBufferId,
514 flatBufferBuilder.CreateString("output")));
515 std::vector<int> operatorOutputs;
516 operatorOutputs.push_back(buffers.size() - 1);
517
518 // create operator
519 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
520 flatbuffers::Offset<void> operatorBuiltinOptions =
521 CreateUnidirectionalSequenceLSTMOptions(flatBufferBuilder,
522 activationFunction,
523 clippingThresCell,
524 clippingThresProj,
525 isTimeMajor).Union();
526
527 flatbuffers::Offset<Operator> lstmOperator =
528 CreateOperator(flatBufferBuilder,
529 0,
530 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
531 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
532 operatorBuiltinOptionsType, operatorBuiltinOptions);
533
534 flatbuffers::Offset <SubGraph> subgraph =
535 CreateSubGraph(flatBufferBuilder,
536 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
537 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
538 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
539 flatBufferBuilder.CreateVector(&lstmOperator, 1));
540
541 flatbuffers::Offset <flatbuffers::String> modelDescription =
542 flatBufferBuilder.CreateString("ArmnnDelegate: UnidirectionalSequenceLSTM Operator Model");
543 flatbuffers::Offset <OperatorCode> operatorCode =
544 CreateOperatorCode(flatBufferBuilder, tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM);
545
546 flatbuffers::Offset <Model> flatbufferModel =
547 CreateModel(flatBufferBuilder,
548 TFLITE_SCHEMA_VERSION,
549 flatBufferBuilder.CreateVector(&operatorCode, 1),
550 flatBufferBuilder.CreateVector(&subgraph, 1),
551 modelDescription,
552 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
553
554 flatBufferBuilder.Finish(flatbufferModel);
555
556 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
557 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
558}
559
560template <typename T>
561void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
562 tflite::TensorType tensorType,
563 int32_t batchSize,
564 int32_t timeSize,
565 int32_t inputSize,
566 int32_t outputSize,
567 int32_t numUnits,
568 bool hasInputToInputWeights,
569 const std::vector<T>& inputToInputWeights,
570 const std::vector<T>& inputToForgetWeights,
571 const std::vector<T>& inputToCellWeights,
572 const std::vector<T>& inputToOutputWeights,
573 bool hasRecurrentToInputWeights,
574 const std::vector<T>& recurrentToInputWeights,
575 const std::vector<T>& recurrentToForgetWeights,
576 const std::vector<T>& recurrentToCellWeights,
577 const std::vector<T>& recurrentToOutputWeights,
578 bool hasCellToInputWeights,
579 const std::vector<T>& cellToInputWeights,
580 bool hasCellToForgetWeights,
581 const std::vector<T>& cellToForgetWeights,
582 bool hasCellToOutputWeights,
583 const std::vector<T>& cellToOutputWeights,
584 bool hasInputGateBias,
585 const std::vector<float>& inputGateBias,
586 const std::vector<float>& forgetGateBias,
587 const std::vector<float>& cellBias,
588 const std::vector<float>& outputGateBias,
589 bool hasProjectionWeights,
590 const std::vector<T>& projectionWeights,
591 bool hasProjectionBias,
592 const std::vector<float>& projectionBias,
593 bool hasInputLayerNormWeights,
594 const std::vector<float>& inputLayerNormWeights,
595 bool hasForgetLayerNormWeights,
596 const std::vector<float>& forgetLayerNormWeights,
597 bool hasCellLayerNormWeights,
598 const std::vector<float>& cellLayerNormWeights,
599 bool hasOutputLayerNormWeights,
600 const std::vector<float>& outputLayerNormWeights,
601 std::vector<float>& inputValues,
602 std::vector<float>& expectedOutputValues,
603 tflite::ActivationFunctionType activationFunction,
604 float clippingThresCell,
605 float clippingThresProj,
606 bool isTimeMajor,
607 float quantScale = 0.1f)
608{
609 using namespace tflite;
610
611 std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
612 batchSize,
613 timeSize,
614 inputSize,
615 outputSize,
616 numUnits,
617 hasInputToInputWeights,
618 inputToInputWeights,
619 inputToForgetWeights,
620 inputToCellWeights,
621 inputToOutputWeights,
622 hasRecurrentToInputWeights,
623 recurrentToInputWeights,
624 recurrentToForgetWeights,
625 recurrentToCellWeights,
626 recurrentToOutputWeights,
627 hasCellToInputWeights,
628 cellToInputWeights,
629 hasCellToForgetWeights,
630 cellToForgetWeights,
631 hasCellToOutputWeights,
632 cellToOutputWeights,
633 hasInputGateBias,
634 inputGateBias,
635 forgetGateBias,
636 cellBias,
637 outputGateBias,
638 hasProjectionWeights,
639 projectionWeights,
640 hasProjectionBias,
641 projectionBias,
642 hasInputLayerNormWeights,
643 inputLayerNormWeights,
644 hasForgetLayerNormWeights,
645 forgetLayerNormWeights,
646 hasCellLayerNormWeights,
647 cellLayerNormWeights,
648 hasOutputLayerNormWeights,
649 outputLayerNormWeights,
650 activationFunction,
651 clippingThresCell,
652 clippingThresProj,
653 isTimeMajor,
654 quantScale);
655
656 const Model* tfLiteModel = GetModel(modelBuffer.data());
657 // Create TfLite Interpreters
658 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
659 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
660 (&armnnDelegateInterpreter) == kTfLiteOk);
661 CHECK(armnnDelegateInterpreter != nullptr);
662 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
663
664 std::unique_ptr<Interpreter> tfLiteInterpreter;
665 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
666 (&tfLiteInterpreter) == kTfLiteOk);
667 CHECK(tfLiteInterpreter != nullptr);
668 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
669
670 // Create the ArmNN Delegate
671 armnnDelegate::DelegateOptions delegateOptions(backends);
672 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
673 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
674 armnnDelegate::TfLiteArmnnDelegateDelete);
675 CHECK(theArmnnDelegate != nullptr);
676 // Modify armnnDelegateInterpreter to use armnnDelegate
677 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
678
679 // Set input data
680 auto tfLiteDelegateInputId = tfLiteInterpreter->inputs()[0];
681 auto tfLiteDelageInputData = tfLiteInterpreter->typed_tensor<float>(tfLiteDelegateInputId);
682 for (unsigned int i = 0; i < inputValues.size(); ++i)
683 {
684 tfLiteDelageInputData[i] = inputValues[i];
685 }
686
687 auto armnnDelegateInputId = armnnDelegateInterpreter->inputs()[0];
688 auto armnnDelegateInputData = armnnDelegateInterpreter->typed_tensor<float>(armnnDelegateInputId);
689 for (unsigned int i = 0; i < inputValues.size(); ++i)
690 {
691 armnnDelegateInputData[i] = inputValues[i];
692 }
693
694 // Run EnqueueWorkload
695 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
696 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
697
698 // Compare output data
699 auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
700 auto tfLiteDelagateOutputData = tfLiteInterpreter->typed_tensor<float>(tfLiteDelegateOutputId);
701 auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
702 auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<float>(armnnDelegateOutputId);
703
704 if (tensorType == ::tflite::TensorType_INT8)
705 {
706 // Allow 2% tolerance for Quantized weights
707 armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData,
708 expectedOutputValues.size(), 2);
709 armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData,
710 expectedOutputValues.size(), 2);
711 armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData,
712 expectedOutputValues.size(), 2);
713 }
714 else
715 {
716 armnnDelegate::CompareData(expectedOutputValues.data(), armnnDelegateOutputData, expectedOutputValues.size());
717 armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteDelagateOutputData, expectedOutputValues.size());
718 armnnDelegate::CompareData(tfLiteDelagateOutputData, armnnDelegateOutputData, expectedOutputValues.size());
719 }
720}
721
722} // anonymous namespace