blob: c058d83bc6630983e2d3d3ff1863bb2b617c47a6 [file] [log] [blame]
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010012
13#include <flatbuffers/flatbuffers.h>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010014#include <tensorflow/lite/kernels/register.h>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010015#include <tensorflow/lite/version.h>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010016
17#include <schema_generated.h>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010018
19#include <doctest/doctest.h>
20
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010021#include <armnn/utility/IgnoreUnused.hpp>
22#include <armnn/utility/NumericCast.hpp>
23#include <armnn/TypesUtils.hpp>
24
25#include <armnn/Types.hpp>
26
27#include <initializer_list>
28#include <iterator>
29#include <vector>
30
31namespace
32{
33
Ryan OShea238ecd92023-03-07 11:44:23 +000034template<typename T>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010035std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,
36 int32_t batchSize,
37 int32_t timeSize,
38 int32_t inputSize,
39 int32_t outputSize,
40 int32_t numUnits,
41 bool hasInputToInputWeights,
42 const std::vector<T>& inputToInputWeights,
43 const std::vector<T>& inputToForgetWeights,
44 const std::vector<T>& inputToCellWeights,
45 const std::vector<T>& inputToOutputWeights,
46 bool hasRecurrentToInputWeights,
47 const std::vector<T>& recurrentToInputWeights,
48 const std::vector<T>& recurrentToForgetWeights,
49 const std::vector<T>& recurrentToCellWeights,
50 const std::vector<T>& recurrentToOutputWeights,
51 bool hasCellToInputWeights,
52 const std::vector<T>& cellToInputWeights,
53 bool hasCellToForgetWeights,
54 const std::vector<T>& cellToForgetWeights,
55 bool hasCellToOutputWeights,
56 const std::vector<T>& cellToOutputWeights,
57 bool hasInputGateBias,
58 const std::vector<float>& inputGateBias,
59 const std::vector<float>& forgetGateBias,
60 const std::vector<float>& cellBias,
61 const std::vector<float>& outputGateBias,
62 bool hasProjectionWeights,
63 const std::vector<T>& projectionWeights,
64 bool hasProjectionBias,
65 const std::vector<float>& projectionBias,
66 bool hasInputLayerNormWeights,
67 const std::vector<float>& inputLayerNormWeights,
68 bool hasForgetLayerNormWeights,
69 const std::vector<float>& forgetLayerNormWeights,
70 bool hasCellLayerNormWeights,
71 const std::vector<float>& cellLayerNormWeights,
72 bool hasOutputLayerNormWeights,
73 const std::vector<float>& outputLayerNormWeights,
74 tflite::ActivationFunctionType activationFunction,
75 float clippingThresCell,
76 float clippingThresProj,
77 bool isTimeMajor,
78 float quantScale,
Ryan OShea238ecd92023-03-07 11:44:23 +000079 int quantOffset = 0)
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010080{
81
82 std::vector<int32_t> tensorInfo0{};
83 std::vector<int32_t> tensorInfoNumUnits{numUnits};
84 std::vector<int32_t> tensorInfoInputSize{numUnits, inputSize};
85 std::vector<int32_t> tensorInfoOutputSize{numUnits, outputSize};
86
87 std::vector<int32_t> inputShape;
88 std::vector<int32_t> outputShape;
89 if (isTimeMajor)
90 {
91 inputShape = {timeSize, batchSize, inputSize};
92 outputShape = {timeSize, batchSize, outputSize};
93 }
94 else
95 {
96 inputShape = {batchSize, timeSize, inputSize};
97 outputShape = {batchSize, timeSize, outputSize};
98 }
99 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
100 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
101 std::vector<int32_t> projectionWeightDimensions{outputSize, numUnits};
102 std::vector<int32_t> projectionBiasDimensions{outputSize};
103
104 std::vector<int> operatorInputs;
105 using namespace tflite;
Ryan OShea238ecd92023-03-07 11:44:23 +0000106 flatbuffers::FlatBufferBuilder flatBufferBuilder;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100107 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +0000108 std::vector<flatbuffers::Offset<Tensor>> tensors;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100109
110 auto quantizationParameters =
Ryan OShea238ecd92023-03-07 11:44:23 +0000111 CreateQuantizationParameters(flatBufferBuilder,
112 0,
113 0,
114 flatBufferBuilder.CreateVector<float>({1.0f}),
115 flatBufferBuilder.CreateVector<int64_t>({0}));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100116
117 auto weightQuantizationParameters =
Ryan OShea238ecd92023-03-07 11:44:23 +0000118 CreateQuantizationParameters(flatBufferBuilder,
119 0,
120 0,
121 flatBufferBuilder.CreateVector<float>({quantScale}),
122 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100123
Ryan OShea238ecd92023-03-07 11:44:23 +0000124 buffers.push_back(CreateBuffer(flatBufferBuilder));
125 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100126 tensors.push_back(CreateTensor(flatBufferBuilder,
127 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
128 inputShape.size()),
129 ::tflite::TensorType_FLOAT32,
130 buffers.size() - 1,
131 flatBufferBuilder.CreateString("input_0")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000132 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100133
134 if (hasInputToInputWeights)
135 {
136 buffers.push_back(
137 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000138 flatBufferBuilder.CreateVector(
139 reinterpret_cast<const uint8_t*>(inputToInputWeights.data()),
140 sizeof(T) * inputToInputWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100141 tensors.push_back(CreateTensor(flatBufferBuilder,
142 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
143 tensorInfoInputSize.size()),
144 tensorType,
145 buffers.size() - 1,
146 flatBufferBuilder.CreateString("inputToInputWeights"),
147 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000148 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100149 }
150 else
151 {
152 operatorInputs.push_back(kTfLiteOptionalTensor);
153 }
154
155 buffers.push_back(
156 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000157 flatBufferBuilder.CreateVector(
158 reinterpret_cast<const uint8_t*>(inputToForgetWeights.data()),
159 sizeof(T) * inputToForgetWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100160 tensors.push_back(CreateTensor(flatBufferBuilder,
161 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
162 tensorInfoInputSize.size()),
163 tensorType,
164 buffers.size() - 1,
165 flatBufferBuilder.CreateString("inputToForgetWeights"),
166 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000167 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100168
169 buffers.push_back(
170 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000171 flatBufferBuilder.CreateVector(
172 reinterpret_cast<const uint8_t*>(inputToCellWeights.data()),
173 sizeof(T) * inputToCellWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100174 tensors.push_back(CreateTensor(flatBufferBuilder,
175 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
176 tensorInfoInputSize.size()),
177 tensorType,
178 buffers.size() - 1,
179 flatBufferBuilder.CreateString("inputToCellWeights"),
180 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000181 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100182
183 buffers.push_back(
184 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000185 flatBufferBuilder.CreateVector(
186 reinterpret_cast<const uint8_t*>(inputToOutputWeights.data()),
187 sizeof(T) * inputToOutputWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100188 tensors.push_back(CreateTensor(flatBufferBuilder,
189 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
190 tensorInfoInputSize.size()),
191 tensorType,
192 buffers.size() - 1,
193 flatBufferBuilder.CreateString("inputToOutputWeights"),
194 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000195 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100196
197 if (hasRecurrentToInputWeights)
198 {
199 buffers.push_back(CreateBuffer(
200 flatBufferBuilder,
201 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
202 sizeof(T) * recurrentToInputWeights.size())));
203 tensors.push_back(CreateTensor(flatBufferBuilder,
204 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
205 tensorInfoOutputSize.size()),
206 tensorType,
207 buffers.size() - 1,
208 flatBufferBuilder.CreateString("recurrentToInputWeights"),
209 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000210 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100211 }
212 else
213 {
214 operatorInputs.push_back(kTfLiteOptionalTensor);
215 }
216
217 buffers.push_back(
218 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000219 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
220 recurrentToForgetWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100221 sizeof(T) * recurrentToForgetWeights.size())));
222 tensors.push_back(CreateTensor(flatBufferBuilder,
223 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
224 tensorInfoOutputSize.size()),
225 tensorType,
226 buffers.size() - 1,
227 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
228 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000229 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100230
231 buffers.push_back(
232 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000233 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
234 recurrentToCellWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100235 sizeof(T) * recurrentToCellWeights.size())));
236 tensors.push_back(CreateTensor(flatBufferBuilder,
237 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
238 tensorInfoOutputSize.size()),
239 tensorType,
240 buffers.size() - 1,
241 flatBufferBuilder.CreateString("recurrentToCellWeights"),
242 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000243 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100244
245 buffers.push_back(
246 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000247 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
248 recurrentToOutputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100249 sizeof(T) * recurrentToOutputWeights.size())));
250 tensors.push_back(CreateTensor(flatBufferBuilder,
251 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
252 tensorInfoOutputSize.size()),
253 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000254 buffers.size() - 1,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100255 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
256 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000257 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100258
259 if (hasCellToInputWeights)
260 {
261 buffers.push_back(
262 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000263 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
264 cellToInputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100265 sizeof(T) * cellToInputWeights.size())));
266 tensors.push_back(CreateTensor(flatBufferBuilder,
267 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
268 tensorInfoNumUnits.size()),
269 tensorType,
270 buffers.size() - 1,
271 flatBufferBuilder.CreateString("cellToInputWeights"),
272 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000273 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100274 }
275 else
276 {
277 operatorInputs.push_back(kTfLiteOptionalTensor);
278 }
279
280 if (hasCellToForgetWeights)
281 {
282 buffers.push_back(
283 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000284 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
285 cellToForgetWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100286 sizeof(T) * cellToForgetWeights.size())));
287 tensors.push_back(CreateTensor(flatBufferBuilder,
288 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
289 tensorInfoNumUnits.size()),
290 tensorType,
291 buffers.size() - 1,
292 flatBufferBuilder.CreateString("cellToForgetWeights"),
293 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000294 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100295 }
296 else
297 {
298 operatorInputs.push_back(kTfLiteOptionalTensor);
299 }
300
301 if (hasCellToOutputWeights)
302 {
303 buffers.push_back(
304 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000305 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
306 cellToOutputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100307 sizeof(T) * cellToOutputWeights.size())));
308 tensors.push_back(CreateTensor(flatBufferBuilder,
309 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
310 tensorInfoNumUnits.size()),
311 tensorType,
312 buffers.size() - 1,
313 flatBufferBuilder.CreateString("cellToOutputWeights"),
314 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000315 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100316 }
317 else
318 {
319 operatorInputs.push_back(kTfLiteOptionalTensor);
320 }
321
322 if (hasInputGateBias)
323 {
324 buffers.push_back(
325 CreateBuffer(flatBufferBuilder,
326 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
327 sizeof(float) * inputGateBias.size())));
328 tensors.push_back(CreateTensor(flatBufferBuilder,
329 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
330 tensorInfoNumUnits.size()),
331 ::tflite::TensorType_FLOAT32,
332 buffers.size() - 1,
333 flatBufferBuilder.CreateString("inputGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000334 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100335 }
336 else
337 {
338 operatorInputs.push_back(kTfLiteOptionalTensor);
339 }
340
341 buffers.push_back(
342 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000343 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(forgetGateBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100344 sizeof(float) * forgetGateBias.size())));
345 tensors.push_back(CreateTensor(flatBufferBuilder,
346 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
347 tensorInfoNumUnits.size()),
348 ::tflite::TensorType_FLOAT32,
349 buffers.size() - 1,
350 flatBufferBuilder.CreateString("forgetGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000351 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100352
353 buffers.push_back(
354 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000355 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100356 sizeof(float) * cellBias.size())));
357 tensors.push_back(CreateTensor(flatBufferBuilder,
358 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
359 tensorInfoNumUnits.size()),
360 ::tflite::TensorType_FLOAT32,
361 buffers.size() - 1,
362 flatBufferBuilder.CreateString("cellBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000363 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100364
365 buffers.push_back(
366 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000367 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(outputGateBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100368 sizeof(float) * outputGateBias.size())));
369 tensors.push_back(CreateTensor(flatBufferBuilder,
370 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
371 tensorInfoNumUnits.size()),
372 ::tflite::TensorType_FLOAT32,
373 buffers.size() - 1,
374 flatBufferBuilder.CreateString("outputGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000375 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100376
377 if (hasProjectionWeights)
378 {
379 buffers.push_back(
380 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000381 flatBufferBuilder.CreateVector(
382 reinterpret_cast<const uint8_t*>(projectionWeights.data()),
383 sizeof(T) * projectionWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100384 tensors.push_back(CreateTensor(flatBufferBuilder,
385 flatBufferBuilder.CreateVector<int32_t>(projectionWeightDimensions.data(),
386 projectionWeightDimensions.size()),
387 tensorType,
388 buffers.size() - 1,
389 flatBufferBuilder.CreateString("projectionWeights"),
390 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000391 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100392 }
393 else
394 {
395 operatorInputs.push_back(kTfLiteOptionalTensor);
396 }
397
398 if (hasProjectionBias)
399 {
400 buffers.push_back(
401 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000402 flatBufferBuilder.CreateVector(
403 reinterpret_cast<const uint8_t*>(projectionBias.data()),
404 sizeof(float) * projectionBias.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100405 tensors.push_back(CreateTensor(flatBufferBuilder,
406 flatBufferBuilder.CreateVector<int32_t>(projectionBiasDimensions.data(),
407 projectionBiasDimensions.size()),
408 ::tflite::TensorType_FLOAT32,
409 buffers.size() - 1,
410 flatBufferBuilder.CreateString("projectionBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000411 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100412 }
413 else
414 {
415 operatorInputs.push_back(kTfLiteOptionalTensor);
416 }
417
Ryan OShea238ecd92023-03-07 11:44:23 +0000418 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100419 tensors.push_back(CreateTensor(flatBufferBuilder,
420 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
421 outputStateInDimensions.size()),
422 ::tflite::TensorType_FLOAT32,
423 buffers.size() - 1,
424 flatBufferBuilder.CreateString("outputStateInInfo"),
425 quantizationParameters,
426 true));
Ryan OShea238ecd92023-03-07 11:44:23 +0000427 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100428
Ryan OShea238ecd92023-03-07 11:44:23 +0000429 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100430 tensors.push_back(CreateTensor(flatBufferBuilder,
431 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
432 cellStateInDimensions.size()),
433 ::tflite::TensorType_FLOAT32,
434 buffers.size() - 1,
435 flatBufferBuilder.CreateString("cellStateInInfo"),
436 quantizationParameters,
437 true));
Ryan OShea238ecd92023-03-07 11:44:23 +0000438 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100439
440 if (hasInputLayerNormWeights)
441 {
442 buffers.push_back(
443 CreateBuffer(flatBufferBuilder,
444 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000445 reinterpret_cast<const uint8_t*>(inputLayerNormWeights.data()),
446 sizeof(float) * inputLayerNormWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100447 tensors.push_back(CreateTensor(flatBufferBuilder,
448 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
449 tensorInfoNumUnits.size()),
450 ::tflite::TensorType_FLOAT32,
451 buffers.size() - 1,
452 flatBufferBuilder.CreateString("inputLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000453 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100454 }
455 else
456 {
457 operatorInputs.push_back(kTfLiteOptionalTensor);
458 }
459
460 if (hasForgetLayerNormWeights)
461 {
462 buffers.push_back(
463 CreateBuffer(flatBufferBuilder,
464 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000465 reinterpret_cast<const uint8_t*>(forgetLayerNormWeights.data()),
466 sizeof(float) * forgetLayerNormWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100467 tensors.push_back(CreateTensor(flatBufferBuilder,
468 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
469 tensorInfoNumUnits.size()),
470 ::tflite::TensorType_FLOAT32,
471 buffers.size() - 1,
472 flatBufferBuilder.CreateString("forgetLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000473 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100474 }
475 else
476 {
477 operatorInputs.push_back(kTfLiteOptionalTensor);
478 }
479
480 if (hasCellLayerNormWeights)
481 {
482 buffers.push_back(
483 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000484 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
485 cellLayerNormWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100486 sizeof(float) * cellLayerNormWeights.size())));
487 tensors.push_back(CreateTensor(flatBufferBuilder,
488 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
489 tensorInfoNumUnits.size()),
490 ::tflite::TensorType_FLOAT32,
491 buffers.size() - 1,
492 flatBufferBuilder.CreateString("cellLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000493 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100494 }
495 else
496 {
497 operatorInputs.push_back(kTfLiteOptionalTensor);
498 }
499
500 if (hasOutputLayerNormWeights)
501 {
502 buffers.push_back(
503 CreateBuffer(flatBufferBuilder,
504 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000505 reinterpret_cast<const uint8_t*>(outputLayerNormWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100506 sizeof(float) * outputLayerNormWeights.size())));
507 tensors.push_back(CreateTensor(flatBufferBuilder,
508 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
509 tensorInfoNumUnits.size()),
510 ::tflite::TensorType_FLOAT32,
511 buffers.size() - 1,
512 flatBufferBuilder.CreateString("outputLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000513 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100514 }
515 else
516 {
517 operatorInputs.push_back(kTfLiteOptionalTensor);
518 }
Ryan OShea238ecd92023-03-07 11:44:23 +0000519 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100520 tensors.push_back(CreateTensor(flatBufferBuilder,
521 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
522 outputShape.size()),
523 ::tflite::TensorType_FLOAT32,
Ryan OShea238ecd92023-03-07 11:44:23 +0000524 buffers.size() - 1,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100525 flatBufferBuilder.CreateString("output")));
526 std::vector<int> operatorOutputs;
Ryan OShea238ecd92023-03-07 11:44:23 +0000527 operatorOutputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100528
529 // create operator
Ryan OShea238ecd92023-03-07 11:44:23 +0000530 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
531 flatbuffers::Offset<void> operatorBuiltinOptions =
532 CreateUnidirectionalSequenceLSTMOptions(flatBufferBuilder,
533 activationFunction,
534 clippingThresCell,
535 clippingThresProj,
536 isTimeMajor).Union();
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100537
538 flatbuffers::Offset<Operator> lstmOperator =
Ryan OShea238ecd92023-03-07 11:44:23 +0000539 CreateOperator(flatBufferBuilder,
540 0,
541 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
542 operatorInputs.size()),
543 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
544 operatorOutputs.size()),
545 operatorBuiltinOptionsType, operatorBuiltinOptions);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100546
Ryan OShea238ecd92023-03-07 11:44:23 +0000547 flatbuffers::Offset<SubGraph> subgraph =
548 CreateSubGraph(flatBufferBuilder,
549 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
550 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
551 operatorInputs.size()),
552 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
553 operatorOutputs.size()),
554 flatBufferBuilder.CreateVector(&lstmOperator, 1));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100555
Ryan OShea238ecd92023-03-07 11:44:23 +0000556 flatbuffers::Offset<flatbuffers::String> modelDescription =
557 flatBufferBuilder.CreateString(
558 "ArmnnDelegate: UnidirectionalSequenceLSTM Operator Model");
559 flatbuffers::Offset<OperatorCode> operatorCode =
560 CreateOperatorCode(flatBufferBuilder,
561 tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100562
Ryan OShea238ecd92023-03-07 11:44:23 +0000563 flatbuffers::Offset<Model> flatbufferModel =
564 CreateModel(flatBufferBuilder,
565 TFLITE_SCHEMA_VERSION,
566 flatBufferBuilder.CreateVector(&operatorCode, 1),
567 flatBufferBuilder.CreateVector(&subgraph, 1),
568 modelDescription,
569 flatBufferBuilder.CreateVector(buffers));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100570
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100571 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100572
573 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
574 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
575}
576
Ryan OShea238ecd92023-03-07 11:44:23 +0000577template<typename T>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100578void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
579 tflite::TensorType tensorType,
580 int32_t batchSize,
581 int32_t timeSize,
582 int32_t inputSize,
583 int32_t outputSize,
584 int32_t numUnits,
585 bool hasInputToInputWeights,
586 const std::vector<T>& inputToInputWeights,
587 const std::vector<T>& inputToForgetWeights,
588 const std::vector<T>& inputToCellWeights,
589 const std::vector<T>& inputToOutputWeights,
590 bool hasRecurrentToInputWeights,
591 const std::vector<T>& recurrentToInputWeights,
592 const std::vector<T>& recurrentToForgetWeights,
593 const std::vector<T>& recurrentToCellWeights,
594 const std::vector<T>& recurrentToOutputWeights,
595 bool hasCellToInputWeights,
596 const std::vector<T>& cellToInputWeights,
597 bool hasCellToForgetWeights,
598 const std::vector<T>& cellToForgetWeights,
599 bool hasCellToOutputWeights,
600 const std::vector<T>& cellToOutputWeights,
601 bool hasInputGateBias,
602 const std::vector<float>& inputGateBias,
603 const std::vector<float>& forgetGateBias,
604 const std::vector<float>& cellBias,
605 const std::vector<float>& outputGateBias,
606 bool hasProjectionWeights,
607 const std::vector<T>& projectionWeights,
608 bool hasProjectionBias,
609 const std::vector<float>& projectionBias,
610 bool hasInputLayerNormWeights,
611 const std::vector<float>& inputLayerNormWeights,
612 bool hasForgetLayerNormWeights,
613 const std::vector<float>& forgetLayerNormWeights,
614 bool hasCellLayerNormWeights,
615 const std::vector<float>& cellLayerNormWeights,
616 bool hasOutputLayerNormWeights,
617 const std::vector<float>& outputLayerNormWeights,
618 std::vector<float>& inputValues,
619 std::vector<float>& expectedOutputValues,
620 tflite::ActivationFunctionType activationFunction,
621 float clippingThresCell,
622 float clippingThresProj,
623 bool isTimeMajor,
624 float quantScale = 0.1f)
625{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100626 using namespace delegateTestInterpreter;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100627
628 std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000629 batchSize,
630 timeSize,
631 inputSize,
632 outputSize,
633 numUnits,
634 hasInputToInputWeights,
635 inputToInputWeights,
636 inputToForgetWeights,
637 inputToCellWeights,
638 inputToOutputWeights,
639 hasRecurrentToInputWeights,
640 recurrentToInputWeights,
641 recurrentToForgetWeights,
642 recurrentToCellWeights,
643 recurrentToOutputWeights,
644 hasCellToInputWeights,
645 cellToInputWeights,
646 hasCellToForgetWeights,
647 cellToForgetWeights,
648 hasCellToOutputWeights,
649 cellToOutputWeights,
650 hasInputGateBias,
651 inputGateBias,
652 forgetGateBias,
653 cellBias,
654 outputGateBias,
655 hasProjectionWeights,
656 projectionWeights,
657 hasProjectionBias,
658 projectionBias,
659 hasInputLayerNormWeights,
660 inputLayerNormWeights,
661 hasForgetLayerNormWeights,
662 forgetLayerNormWeights,
663 hasCellLayerNormWeights,
664 cellLayerNormWeights,
665 hasOutputLayerNormWeights,
666 outputLayerNormWeights,
667 activationFunction,
668 clippingThresCell,
669 clippingThresProj,
670 isTimeMajor,
671 quantScale);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100672
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100673 std::vector<int32_t> outputShape;
674 if (isTimeMajor)
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100675 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100676 outputShape = {timeSize, batchSize, outputSize};
677 }
678 else
679 {
680 outputShape = {batchSize, timeSize, outputSize};
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100681 }
682
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100683 // Setup interpreter with just TFLite Runtime.
684 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
685 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
686 CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
687 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
688 std::vector<float> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0);
689 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100690
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100691 // Setup interpreter with Arm NN Delegate applied.
692 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
693 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
694 CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
695 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
696 std::vector<float> armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0);
697 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100698
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100699 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100700
701 if (tensorType == ::tflite::TensorType_INT8)
702 {
703 // Allow 2% tolerance for Quantized weights
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100704 armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100705 expectedOutputValues.size(), 2);
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100706 armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100707 expectedOutputValues.size(), 2);
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100708 armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100709 expectedOutputValues.size(), 2);
710 }
711 else
712 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100713 armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100714 }
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100715
716 tfLiteInterpreter.Cleanup();
717 armnnInterpreter.Cleanup();
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100718}
719
720} // anonymous namespace