blob: c27f8d854b2b1c0b7f27b885993bd13993da798f [file] [log] [blame]
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010012
13#include <flatbuffers/flatbuffers.h>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010014#include <tensorflow/lite/kernels/register.h>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010015#include <tensorflow/lite/version.h>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010016
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010017#include <doctest/doctest.h>
18
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010019#include <armnn/utility/IgnoreUnused.hpp>
20#include <armnn/utility/NumericCast.hpp>
21#include <armnn/TypesUtils.hpp>
22
23#include <armnn/Types.hpp>
24
25#include <initializer_list>
26#include <iterator>
27#include <vector>
28
29namespace
30{
31
Ryan OShea238ecd92023-03-07 11:44:23 +000032template<typename T>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010033std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,
34 int32_t batchSize,
35 int32_t timeSize,
36 int32_t inputSize,
37 int32_t outputSize,
38 int32_t numUnits,
39 bool hasInputToInputWeights,
40 const std::vector<T>& inputToInputWeights,
41 const std::vector<T>& inputToForgetWeights,
42 const std::vector<T>& inputToCellWeights,
43 const std::vector<T>& inputToOutputWeights,
44 bool hasRecurrentToInputWeights,
45 const std::vector<T>& recurrentToInputWeights,
46 const std::vector<T>& recurrentToForgetWeights,
47 const std::vector<T>& recurrentToCellWeights,
48 const std::vector<T>& recurrentToOutputWeights,
49 bool hasCellToInputWeights,
50 const std::vector<T>& cellToInputWeights,
51 bool hasCellToForgetWeights,
52 const std::vector<T>& cellToForgetWeights,
53 bool hasCellToOutputWeights,
54 const std::vector<T>& cellToOutputWeights,
55 bool hasInputGateBias,
56 const std::vector<float>& inputGateBias,
57 const std::vector<float>& forgetGateBias,
58 const std::vector<float>& cellBias,
59 const std::vector<float>& outputGateBias,
60 bool hasProjectionWeights,
61 const std::vector<T>& projectionWeights,
62 bool hasProjectionBias,
63 const std::vector<float>& projectionBias,
64 bool hasInputLayerNormWeights,
65 const std::vector<float>& inputLayerNormWeights,
66 bool hasForgetLayerNormWeights,
67 const std::vector<float>& forgetLayerNormWeights,
68 bool hasCellLayerNormWeights,
69 const std::vector<float>& cellLayerNormWeights,
70 bool hasOutputLayerNormWeights,
71 const std::vector<float>& outputLayerNormWeights,
72 tflite::ActivationFunctionType activationFunction,
73 float clippingThresCell,
74 float clippingThresProj,
75 bool isTimeMajor,
76 float quantScale,
Ryan OShea238ecd92023-03-07 11:44:23 +000077 int quantOffset = 0)
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010078{
79
80 std::vector<int32_t> tensorInfo0{};
81 std::vector<int32_t> tensorInfoNumUnits{numUnits};
82 std::vector<int32_t> tensorInfoInputSize{numUnits, inputSize};
83 std::vector<int32_t> tensorInfoOutputSize{numUnits, outputSize};
84
85 std::vector<int32_t> inputShape;
86 std::vector<int32_t> outputShape;
87 if (isTimeMajor)
88 {
89 inputShape = {timeSize, batchSize, inputSize};
90 outputShape = {timeSize, batchSize, outputSize};
91 }
92 else
93 {
94 inputShape = {batchSize, timeSize, inputSize};
95 outputShape = {batchSize, timeSize, outputSize};
96 }
97 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
98 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
99 std::vector<int32_t> projectionWeightDimensions{outputSize, numUnits};
100 std::vector<int32_t> projectionBiasDimensions{outputSize};
101
102 std::vector<int> operatorInputs;
103 using namespace tflite;
Ryan OShea238ecd92023-03-07 11:44:23 +0000104 flatbuffers::FlatBufferBuilder flatBufferBuilder;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100105 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +0000106 std::vector<flatbuffers::Offset<Tensor>> tensors;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100107
108 auto quantizationParameters =
Ryan OShea238ecd92023-03-07 11:44:23 +0000109 CreateQuantizationParameters(flatBufferBuilder,
110 0,
111 0,
112 flatBufferBuilder.CreateVector<float>({1.0f}),
113 flatBufferBuilder.CreateVector<int64_t>({0}));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100114
115 auto weightQuantizationParameters =
Ryan OShea238ecd92023-03-07 11:44:23 +0000116 CreateQuantizationParameters(flatBufferBuilder,
117 0,
118 0,
119 flatBufferBuilder.CreateVector<float>({quantScale}),
120 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100121
Ryan OShea238ecd92023-03-07 11:44:23 +0000122 buffers.push_back(CreateBuffer(flatBufferBuilder));
123 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100124 tensors.push_back(CreateTensor(flatBufferBuilder,
125 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
126 inputShape.size()),
127 ::tflite::TensorType_FLOAT32,
128 buffers.size() - 1,
129 flatBufferBuilder.CreateString("input_0")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000130 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100131
132 if (hasInputToInputWeights)
133 {
134 buffers.push_back(
135 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000136 flatBufferBuilder.CreateVector(
137 reinterpret_cast<const uint8_t*>(inputToInputWeights.data()),
138 sizeof(T) * inputToInputWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100139 tensors.push_back(CreateTensor(flatBufferBuilder,
140 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
141 tensorInfoInputSize.size()),
142 tensorType,
143 buffers.size() - 1,
144 flatBufferBuilder.CreateString("inputToInputWeights"),
145 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000146 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100147 }
148 else
149 {
150 operatorInputs.push_back(kTfLiteOptionalTensor);
151 }
152
153 buffers.push_back(
154 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000155 flatBufferBuilder.CreateVector(
156 reinterpret_cast<const uint8_t*>(inputToForgetWeights.data()),
157 sizeof(T) * inputToForgetWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100158 tensors.push_back(CreateTensor(flatBufferBuilder,
159 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
160 tensorInfoInputSize.size()),
161 tensorType,
162 buffers.size() - 1,
163 flatBufferBuilder.CreateString("inputToForgetWeights"),
164 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000165 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100166
167 buffers.push_back(
168 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000169 flatBufferBuilder.CreateVector(
170 reinterpret_cast<const uint8_t*>(inputToCellWeights.data()),
171 sizeof(T) * inputToCellWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100172 tensors.push_back(CreateTensor(flatBufferBuilder,
173 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
174 tensorInfoInputSize.size()),
175 tensorType,
176 buffers.size() - 1,
177 flatBufferBuilder.CreateString("inputToCellWeights"),
178 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000179 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100180
181 buffers.push_back(
182 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000183 flatBufferBuilder.CreateVector(
184 reinterpret_cast<const uint8_t*>(inputToOutputWeights.data()),
185 sizeof(T) * inputToOutputWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100186 tensors.push_back(CreateTensor(flatBufferBuilder,
187 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
188 tensorInfoInputSize.size()),
189 tensorType,
190 buffers.size() - 1,
191 flatBufferBuilder.CreateString("inputToOutputWeights"),
192 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000193 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100194
195 if (hasRecurrentToInputWeights)
196 {
197 buffers.push_back(CreateBuffer(
198 flatBufferBuilder,
199 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
200 sizeof(T) * recurrentToInputWeights.size())));
201 tensors.push_back(CreateTensor(flatBufferBuilder,
202 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
203 tensorInfoOutputSize.size()),
204 tensorType,
205 buffers.size() - 1,
206 flatBufferBuilder.CreateString("recurrentToInputWeights"),
207 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000208 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100209 }
210 else
211 {
212 operatorInputs.push_back(kTfLiteOptionalTensor);
213 }
214
215 buffers.push_back(
216 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000217 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
218 recurrentToForgetWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100219 sizeof(T) * recurrentToForgetWeights.size())));
220 tensors.push_back(CreateTensor(flatBufferBuilder,
221 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
222 tensorInfoOutputSize.size()),
223 tensorType,
224 buffers.size() - 1,
225 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
226 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000227 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100228
229 buffers.push_back(
230 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000231 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
232 recurrentToCellWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100233 sizeof(T) * recurrentToCellWeights.size())));
234 tensors.push_back(CreateTensor(flatBufferBuilder,
235 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
236 tensorInfoOutputSize.size()),
237 tensorType,
238 buffers.size() - 1,
239 flatBufferBuilder.CreateString("recurrentToCellWeights"),
240 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000241 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100242
243 buffers.push_back(
244 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000245 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
246 recurrentToOutputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100247 sizeof(T) * recurrentToOutputWeights.size())));
248 tensors.push_back(CreateTensor(flatBufferBuilder,
249 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
250 tensorInfoOutputSize.size()),
251 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000252 buffers.size() - 1,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100253 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
254 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000255 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100256
257 if (hasCellToInputWeights)
258 {
259 buffers.push_back(
260 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000261 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
262 cellToInputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100263 sizeof(T) * cellToInputWeights.size())));
264 tensors.push_back(CreateTensor(flatBufferBuilder,
265 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
266 tensorInfoNumUnits.size()),
267 tensorType,
268 buffers.size() - 1,
269 flatBufferBuilder.CreateString("cellToInputWeights"),
270 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000271 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100272 }
273 else
274 {
275 operatorInputs.push_back(kTfLiteOptionalTensor);
276 }
277
278 if (hasCellToForgetWeights)
279 {
280 buffers.push_back(
281 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000282 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
283 cellToForgetWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100284 sizeof(T) * cellToForgetWeights.size())));
285 tensors.push_back(CreateTensor(flatBufferBuilder,
286 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
287 tensorInfoNumUnits.size()),
288 tensorType,
289 buffers.size() - 1,
290 flatBufferBuilder.CreateString("cellToForgetWeights"),
291 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000292 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100293 }
294 else
295 {
296 operatorInputs.push_back(kTfLiteOptionalTensor);
297 }
298
299 if (hasCellToOutputWeights)
300 {
301 buffers.push_back(
302 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000303 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
304 cellToOutputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100305 sizeof(T) * cellToOutputWeights.size())));
306 tensors.push_back(CreateTensor(flatBufferBuilder,
307 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
308 tensorInfoNumUnits.size()),
309 tensorType,
310 buffers.size() - 1,
311 flatBufferBuilder.CreateString("cellToOutputWeights"),
312 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000313 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100314 }
315 else
316 {
317 operatorInputs.push_back(kTfLiteOptionalTensor);
318 }
319
320 if (hasInputGateBias)
321 {
322 buffers.push_back(
323 CreateBuffer(flatBufferBuilder,
324 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
325 sizeof(float) * inputGateBias.size())));
326 tensors.push_back(CreateTensor(flatBufferBuilder,
327 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
328 tensorInfoNumUnits.size()),
329 ::tflite::TensorType_FLOAT32,
330 buffers.size() - 1,
331 flatBufferBuilder.CreateString("inputGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000332 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100333 }
334 else
335 {
336 operatorInputs.push_back(kTfLiteOptionalTensor);
337 }
338
339 buffers.push_back(
340 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000341 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(forgetGateBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100342 sizeof(float) * forgetGateBias.size())));
343 tensors.push_back(CreateTensor(flatBufferBuilder,
344 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
345 tensorInfoNumUnits.size()),
346 ::tflite::TensorType_FLOAT32,
347 buffers.size() - 1,
348 flatBufferBuilder.CreateString("forgetGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000349 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100350
351 buffers.push_back(
352 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000353 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100354 sizeof(float) * cellBias.size())));
355 tensors.push_back(CreateTensor(flatBufferBuilder,
356 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
357 tensorInfoNumUnits.size()),
358 ::tflite::TensorType_FLOAT32,
359 buffers.size() - 1,
360 flatBufferBuilder.CreateString("cellBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000361 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100362
363 buffers.push_back(
364 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000365 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(outputGateBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100366 sizeof(float) * outputGateBias.size())));
367 tensors.push_back(CreateTensor(flatBufferBuilder,
368 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
369 tensorInfoNumUnits.size()),
370 ::tflite::TensorType_FLOAT32,
371 buffers.size() - 1,
372 flatBufferBuilder.CreateString("outputGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000373 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100374
375 if (hasProjectionWeights)
376 {
377 buffers.push_back(
378 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000379 flatBufferBuilder.CreateVector(
380 reinterpret_cast<const uint8_t*>(projectionWeights.data()),
381 sizeof(T) * projectionWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100382 tensors.push_back(CreateTensor(flatBufferBuilder,
383 flatBufferBuilder.CreateVector<int32_t>(projectionWeightDimensions.data(),
384 projectionWeightDimensions.size()),
385 tensorType,
386 buffers.size() - 1,
387 flatBufferBuilder.CreateString("projectionWeights"),
388 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000389 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100390 }
391 else
392 {
393 operatorInputs.push_back(kTfLiteOptionalTensor);
394 }
395
396 if (hasProjectionBias)
397 {
398 buffers.push_back(
399 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000400 flatBufferBuilder.CreateVector(
401 reinterpret_cast<const uint8_t*>(projectionBias.data()),
402 sizeof(float) * projectionBias.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100403 tensors.push_back(CreateTensor(flatBufferBuilder,
404 flatBufferBuilder.CreateVector<int32_t>(projectionBiasDimensions.data(),
405 projectionBiasDimensions.size()),
406 ::tflite::TensorType_FLOAT32,
407 buffers.size() - 1,
408 flatBufferBuilder.CreateString("projectionBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000409 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100410 }
411 else
412 {
413 operatorInputs.push_back(kTfLiteOptionalTensor);
414 }
415
Ryan OShea238ecd92023-03-07 11:44:23 +0000416 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100417 tensors.push_back(CreateTensor(flatBufferBuilder,
418 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
419 outputStateInDimensions.size()),
420 ::tflite::TensorType_FLOAT32,
421 buffers.size() - 1,
422 flatBufferBuilder.CreateString("outputStateInInfo"),
423 quantizationParameters,
424 true));
Ryan OShea238ecd92023-03-07 11:44:23 +0000425 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100426
Ryan OShea238ecd92023-03-07 11:44:23 +0000427 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100428 tensors.push_back(CreateTensor(flatBufferBuilder,
429 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
430 cellStateInDimensions.size()),
431 ::tflite::TensorType_FLOAT32,
432 buffers.size() - 1,
433 flatBufferBuilder.CreateString("cellStateInInfo"),
434 quantizationParameters,
435 true));
Ryan OShea238ecd92023-03-07 11:44:23 +0000436 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100437
438 if (hasInputLayerNormWeights)
439 {
440 buffers.push_back(
441 CreateBuffer(flatBufferBuilder,
442 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000443 reinterpret_cast<const uint8_t*>(inputLayerNormWeights.data()),
444 sizeof(float) * inputLayerNormWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100445 tensors.push_back(CreateTensor(flatBufferBuilder,
446 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
447 tensorInfoNumUnits.size()),
448 ::tflite::TensorType_FLOAT32,
449 buffers.size() - 1,
450 flatBufferBuilder.CreateString("inputLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000451 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100452 }
453 else
454 {
455 operatorInputs.push_back(kTfLiteOptionalTensor);
456 }
457
458 if (hasForgetLayerNormWeights)
459 {
460 buffers.push_back(
461 CreateBuffer(flatBufferBuilder,
462 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000463 reinterpret_cast<const uint8_t*>(forgetLayerNormWeights.data()),
464 sizeof(float) * forgetLayerNormWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100465 tensors.push_back(CreateTensor(flatBufferBuilder,
466 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
467 tensorInfoNumUnits.size()),
468 ::tflite::TensorType_FLOAT32,
469 buffers.size() - 1,
470 flatBufferBuilder.CreateString("forgetLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000471 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100472 }
473 else
474 {
475 operatorInputs.push_back(kTfLiteOptionalTensor);
476 }
477
478 if (hasCellLayerNormWeights)
479 {
480 buffers.push_back(
481 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000482 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
483 cellLayerNormWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100484 sizeof(float) * cellLayerNormWeights.size())));
485 tensors.push_back(CreateTensor(flatBufferBuilder,
486 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
487 tensorInfoNumUnits.size()),
488 ::tflite::TensorType_FLOAT32,
489 buffers.size() - 1,
490 flatBufferBuilder.CreateString("cellLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000491 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100492 }
493 else
494 {
495 operatorInputs.push_back(kTfLiteOptionalTensor);
496 }
497
498 if (hasOutputLayerNormWeights)
499 {
500 buffers.push_back(
501 CreateBuffer(flatBufferBuilder,
502 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000503 reinterpret_cast<const uint8_t*>(outputLayerNormWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100504 sizeof(float) * outputLayerNormWeights.size())));
505 tensors.push_back(CreateTensor(flatBufferBuilder,
506 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
507 tensorInfoNumUnits.size()),
508 ::tflite::TensorType_FLOAT32,
509 buffers.size() - 1,
510 flatBufferBuilder.CreateString("outputLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000511 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100512 }
513 else
514 {
515 operatorInputs.push_back(kTfLiteOptionalTensor);
516 }
Ryan OShea238ecd92023-03-07 11:44:23 +0000517 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100518 tensors.push_back(CreateTensor(flatBufferBuilder,
519 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
520 outputShape.size()),
521 ::tflite::TensorType_FLOAT32,
Ryan OShea238ecd92023-03-07 11:44:23 +0000522 buffers.size() - 1,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100523 flatBufferBuilder.CreateString("output")));
524 std::vector<int> operatorOutputs;
Ryan OShea238ecd92023-03-07 11:44:23 +0000525 operatorOutputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100526
527 // create operator
Ryan OShea238ecd92023-03-07 11:44:23 +0000528 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
529 flatbuffers::Offset<void> operatorBuiltinOptions =
530 CreateUnidirectionalSequenceLSTMOptions(flatBufferBuilder,
531 activationFunction,
532 clippingThresCell,
533 clippingThresProj,
534 isTimeMajor).Union();
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100535
536 flatbuffers::Offset<Operator> lstmOperator =
Ryan OShea238ecd92023-03-07 11:44:23 +0000537 CreateOperator(flatBufferBuilder,
538 0,
539 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
540 operatorInputs.size()),
541 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
542 operatorOutputs.size()),
543 operatorBuiltinOptionsType, operatorBuiltinOptions);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100544
Ryan OShea238ecd92023-03-07 11:44:23 +0000545 flatbuffers::Offset<SubGraph> subgraph =
546 CreateSubGraph(flatBufferBuilder,
547 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
548 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
549 operatorInputs.size()),
550 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
551 operatorOutputs.size()),
552 flatBufferBuilder.CreateVector(&lstmOperator, 1));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100553
Ryan OShea238ecd92023-03-07 11:44:23 +0000554 flatbuffers::Offset<flatbuffers::String> modelDescription =
555 flatBufferBuilder.CreateString(
556 "ArmnnDelegate: UnidirectionalSequenceLSTM Operator Model");
557 flatbuffers::Offset<OperatorCode> operatorCode =
558 CreateOperatorCode(flatBufferBuilder,
559 tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100560
Ryan OShea238ecd92023-03-07 11:44:23 +0000561 flatbuffers::Offset<Model> flatbufferModel =
562 CreateModel(flatBufferBuilder,
563 TFLITE_SCHEMA_VERSION,
564 flatBufferBuilder.CreateVector(&operatorCode, 1),
565 flatBufferBuilder.CreateVector(&subgraph, 1),
566 modelDescription,
567 flatBufferBuilder.CreateVector(buffers));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100568
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100569 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100570
571 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
572 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
573}
574
Ryan OShea238ecd92023-03-07 11:44:23 +0000575template<typename T>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100576void UnidirectionalSequenceLstmTestImpl(std::vector<armnn::BackendId>& backends,
577 tflite::TensorType tensorType,
578 int32_t batchSize,
579 int32_t timeSize,
580 int32_t inputSize,
581 int32_t outputSize,
582 int32_t numUnits,
583 bool hasInputToInputWeights,
584 const std::vector<T>& inputToInputWeights,
585 const std::vector<T>& inputToForgetWeights,
586 const std::vector<T>& inputToCellWeights,
587 const std::vector<T>& inputToOutputWeights,
588 bool hasRecurrentToInputWeights,
589 const std::vector<T>& recurrentToInputWeights,
590 const std::vector<T>& recurrentToForgetWeights,
591 const std::vector<T>& recurrentToCellWeights,
592 const std::vector<T>& recurrentToOutputWeights,
593 bool hasCellToInputWeights,
594 const std::vector<T>& cellToInputWeights,
595 bool hasCellToForgetWeights,
596 const std::vector<T>& cellToForgetWeights,
597 bool hasCellToOutputWeights,
598 const std::vector<T>& cellToOutputWeights,
599 bool hasInputGateBias,
600 const std::vector<float>& inputGateBias,
601 const std::vector<float>& forgetGateBias,
602 const std::vector<float>& cellBias,
603 const std::vector<float>& outputGateBias,
604 bool hasProjectionWeights,
605 const std::vector<T>& projectionWeights,
606 bool hasProjectionBias,
607 const std::vector<float>& projectionBias,
608 bool hasInputLayerNormWeights,
609 const std::vector<float>& inputLayerNormWeights,
610 bool hasForgetLayerNormWeights,
611 const std::vector<float>& forgetLayerNormWeights,
612 bool hasCellLayerNormWeights,
613 const std::vector<float>& cellLayerNormWeights,
614 bool hasOutputLayerNormWeights,
615 const std::vector<float>& outputLayerNormWeights,
616 std::vector<float>& inputValues,
617 std::vector<float>& expectedOutputValues,
618 tflite::ActivationFunctionType activationFunction,
619 float clippingThresCell,
620 float clippingThresProj,
621 bool isTimeMajor,
622 float quantScale = 0.1f)
623{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100624 using namespace delegateTestInterpreter;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100625
626 std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000627 batchSize,
628 timeSize,
629 inputSize,
630 outputSize,
631 numUnits,
632 hasInputToInputWeights,
633 inputToInputWeights,
634 inputToForgetWeights,
635 inputToCellWeights,
636 inputToOutputWeights,
637 hasRecurrentToInputWeights,
638 recurrentToInputWeights,
639 recurrentToForgetWeights,
640 recurrentToCellWeights,
641 recurrentToOutputWeights,
642 hasCellToInputWeights,
643 cellToInputWeights,
644 hasCellToForgetWeights,
645 cellToForgetWeights,
646 hasCellToOutputWeights,
647 cellToOutputWeights,
648 hasInputGateBias,
649 inputGateBias,
650 forgetGateBias,
651 cellBias,
652 outputGateBias,
653 hasProjectionWeights,
654 projectionWeights,
655 hasProjectionBias,
656 projectionBias,
657 hasInputLayerNormWeights,
658 inputLayerNormWeights,
659 hasForgetLayerNormWeights,
660 forgetLayerNormWeights,
661 hasCellLayerNormWeights,
662 cellLayerNormWeights,
663 hasOutputLayerNormWeights,
664 outputLayerNormWeights,
665 activationFunction,
666 clippingThresCell,
667 clippingThresProj,
668 isTimeMajor,
669 quantScale);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100670
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100671 std::vector<int32_t> outputShape;
672 if (isTimeMajor)
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100673 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100674 outputShape = {timeSize, batchSize, outputSize};
675 }
676 else
677 {
678 outputShape = {batchSize, timeSize, outputSize};
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100679 }
680
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100681 // Setup interpreter with just TFLite Runtime.
682 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
683 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
684 CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
685 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
686 std::vector<float> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0);
687 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100688
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100689 // Setup interpreter with Arm NN Delegate applied.
690 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
691 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
692 CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
693 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
694 std::vector<float> armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0);
695 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100696
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100697 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100698
699 if (tensorType == ::tflite::TensorType_INT8)
700 {
701 // Allow 2% tolerance for Quantized weights
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100702 armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100703 expectedOutputValues.size(), 2);
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100704 armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100705 expectedOutputValues.size(), 2);
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100706 armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100707 expectedOutputValues.size(), 2);
708 }
709 else
710 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100711 armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100712 }
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100713
714 tfLiteInterpreter.Cleanup();
715 armnnInterpreter.Cleanup();
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100716}
717
718} // anonymous namespace