blob: a47d80ea5d2f3410864109c8e1eec24b66cf7aec [file] [log] [blame]
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2021, 2023-2024 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarat7684b182021-08-12 14:48:15 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010012
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010013#include <tensorflow/lite/version.h>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010014
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010015#include <armnn/utility/NumericCast.hpp>
16#include <armnn/TypesUtils.hpp>
17
18#include <armnn/Types.hpp>
19
20#include <initializer_list>
21#include <iterator>
22#include <vector>
23
24namespace
25{
26
Ryan OShea238ecd92023-03-07 11:44:23 +000027template<typename T>
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010028std::vector<char> CreateUnidirectionalSequenceLstmTfLiteModel(tflite::TensorType tensorType,
29 int32_t batchSize,
30 int32_t timeSize,
31 int32_t inputSize,
32 int32_t outputSize,
33 int32_t numUnits,
34 bool hasInputToInputWeights,
35 const std::vector<T>& inputToInputWeights,
36 const std::vector<T>& inputToForgetWeights,
37 const std::vector<T>& inputToCellWeights,
38 const std::vector<T>& inputToOutputWeights,
39 bool hasRecurrentToInputWeights,
40 const std::vector<T>& recurrentToInputWeights,
41 const std::vector<T>& recurrentToForgetWeights,
42 const std::vector<T>& recurrentToCellWeights,
43 const std::vector<T>& recurrentToOutputWeights,
44 bool hasCellToInputWeights,
45 const std::vector<T>& cellToInputWeights,
46 bool hasCellToForgetWeights,
47 const std::vector<T>& cellToForgetWeights,
48 bool hasCellToOutputWeights,
49 const std::vector<T>& cellToOutputWeights,
50 bool hasInputGateBias,
51 const std::vector<float>& inputGateBias,
52 const std::vector<float>& forgetGateBias,
53 const std::vector<float>& cellBias,
54 const std::vector<float>& outputGateBias,
55 bool hasProjectionWeights,
56 const std::vector<T>& projectionWeights,
57 bool hasProjectionBias,
58 const std::vector<float>& projectionBias,
59 bool hasInputLayerNormWeights,
60 const std::vector<float>& inputLayerNormWeights,
61 bool hasForgetLayerNormWeights,
62 const std::vector<float>& forgetLayerNormWeights,
63 bool hasCellLayerNormWeights,
64 const std::vector<float>& cellLayerNormWeights,
65 bool hasOutputLayerNormWeights,
66 const std::vector<float>& outputLayerNormWeights,
67 tflite::ActivationFunctionType activationFunction,
68 float clippingThresCell,
69 float clippingThresProj,
70 bool isTimeMajor,
71 float quantScale,
Ryan OShea238ecd92023-03-07 11:44:23 +000072 int quantOffset = 0)
Narumol Prangnawarat7684b182021-08-12 14:48:15 +010073{
74
75 std::vector<int32_t> tensorInfo0{};
76 std::vector<int32_t> tensorInfoNumUnits{numUnits};
77 std::vector<int32_t> tensorInfoInputSize{numUnits, inputSize};
78 std::vector<int32_t> tensorInfoOutputSize{numUnits, outputSize};
79
80 std::vector<int32_t> inputShape;
81 std::vector<int32_t> outputShape;
82 if (isTimeMajor)
83 {
84 inputShape = {timeSize, batchSize, inputSize};
85 outputShape = {timeSize, batchSize, outputSize};
86 }
87 else
88 {
89 inputShape = {batchSize, timeSize, inputSize};
90 outputShape = {batchSize, timeSize, outputSize};
91 }
92 std::vector<int32_t> outputStateInDimensions{batchSize, outputSize};
93 std::vector<int32_t> cellStateInDimensions{batchSize, numUnits};
94 std::vector<int32_t> projectionWeightDimensions{outputSize, numUnits};
95 std::vector<int32_t> projectionBiasDimensions{outputSize};
96
97 std::vector<int> operatorInputs;
98 using namespace tflite;
Ryan OShea238ecd92023-03-07 11:44:23 +000099 flatbuffers::FlatBufferBuilder flatBufferBuilder;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100100 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +0000101 std::vector<flatbuffers::Offset<Tensor>> tensors;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100102
103 auto quantizationParameters =
Ryan OShea238ecd92023-03-07 11:44:23 +0000104 CreateQuantizationParameters(flatBufferBuilder,
105 0,
106 0,
107 flatBufferBuilder.CreateVector<float>({1.0f}),
108 flatBufferBuilder.CreateVector<int64_t>({0}));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100109
110 auto weightQuantizationParameters =
Ryan OShea238ecd92023-03-07 11:44:23 +0000111 CreateQuantizationParameters(flatBufferBuilder,
112 0,
113 0,
114 flatBufferBuilder.CreateVector<float>({quantScale}),
115 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100116
Ryan OShea238ecd92023-03-07 11:44:23 +0000117 buffers.push_back(CreateBuffer(flatBufferBuilder));
118 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100119 tensors.push_back(CreateTensor(flatBufferBuilder,
120 flatBufferBuilder.CreateVector<int32_t>(inputShape.data(),
121 inputShape.size()),
122 ::tflite::TensorType_FLOAT32,
123 buffers.size() - 1,
124 flatBufferBuilder.CreateString("input_0")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000125 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100126
127 if (hasInputToInputWeights)
128 {
129 buffers.push_back(
130 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000131 flatBufferBuilder.CreateVector(
132 reinterpret_cast<const uint8_t*>(inputToInputWeights.data()),
133 sizeof(T) * inputToInputWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100134 tensors.push_back(CreateTensor(flatBufferBuilder,
135 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
136 tensorInfoInputSize.size()),
137 tensorType,
138 buffers.size() - 1,
139 flatBufferBuilder.CreateString("inputToInputWeights"),
140 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000141 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100142 }
143 else
144 {
145 operatorInputs.push_back(kTfLiteOptionalTensor);
146 }
147
148 buffers.push_back(
149 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000150 flatBufferBuilder.CreateVector(
151 reinterpret_cast<const uint8_t*>(inputToForgetWeights.data()),
152 sizeof(T) * inputToForgetWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100153 tensors.push_back(CreateTensor(flatBufferBuilder,
154 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
155 tensorInfoInputSize.size()),
156 tensorType,
157 buffers.size() - 1,
158 flatBufferBuilder.CreateString("inputToForgetWeights"),
159 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000160 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100161
162 buffers.push_back(
163 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000164 flatBufferBuilder.CreateVector(
165 reinterpret_cast<const uint8_t*>(inputToCellWeights.data()),
166 sizeof(T) * inputToCellWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100167 tensors.push_back(CreateTensor(flatBufferBuilder,
168 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
169 tensorInfoInputSize.size()),
170 tensorType,
171 buffers.size() - 1,
172 flatBufferBuilder.CreateString("inputToCellWeights"),
173 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000174 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100175
176 buffers.push_back(
177 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000178 flatBufferBuilder.CreateVector(
179 reinterpret_cast<const uint8_t*>(inputToOutputWeights.data()),
180 sizeof(T) * inputToOutputWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100181 tensors.push_back(CreateTensor(flatBufferBuilder,
182 flatBufferBuilder.CreateVector<int32_t>(tensorInfoInputSize.data(),
183 tensorInfoInputSize.size()),
184 tensorType,
185 buffers.size() - 1,
186 flatBufferBuilder.CreateString("inputToOutputWeights"),
187 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000188 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100189
190 if (hasRecurrentToInputWeights)
191 {
192 buffers.push_back(CreateBuffer(
193 flatBufferBuilder,
194 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(recurrentToInputWeights.data()),
195 sizeof(T) * recurrentToInputWeights.size())));
196 tensors.push_back(CreateTensor(flatBufferBuilder,
197 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
198 tensorInfoOutputSize.size()),
199 tensorType,
200 buffers.size() - 1,
201 flatBufferBuilder.CreateString("recurrentToInputWeights"),
202 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000203 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100204 }
205 else
206 {
207 operatorInputs.push_back(kTfLiteOptionalTensor);
208 }
209
210 buffers.push_back(
211 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000212 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
213 recurrentToForgetWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100214 sizeof(T) * recurrentToForgetWeights.size())));
215 tensors.push_back(CreateTensor(flatBufferBuilder,
216 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
217 tensorInfoOutputSize.size()),
218 tensorType,
219 buffers.size() - 1,
220 flatBufferBuilder.CreateString("recurrentToForgetWeights"),
221 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000222 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100223
224 buffers.push_back(
225 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000226 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
227 recurrentToCellWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100228 sizeof(T) * recurrentToCellWeights.size())));
229 tensors.push_back(CreateTensor(flatBufferBuilder,
230 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
231 tensorInfoOutputSize.size()),
232 tensorType,
233 buffers.size() - 1,
234 flatBufferBuilder.CreateString("recurrentToCellWeights"),
235 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000236 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100237
238 buffers.push_back(
239 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000240 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
241 recurrentToOutputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100242 sizeof(T) * recurrentToOutputWeights.size())));
243 tensors.push_back(CreateTensor(flatBufferBuilder,
244 flatBufferBuilder.CreateVector<int32_t>(tensorInfoOutputSize.data(),
245 tensorInfoOutputSize.size()),
246 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000247 buffers.size() - 1,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100248 flatBufferBuilder.CreateString("recurrentToOutputWeights"),
249 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000250 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100251
252 if (hasCellToInputWeights)
253 {
254 buffers.push_back(
255 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000256 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
257 cellToInputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100258 sizeof(T) * cellToInputWeights.size())));
259 tensors.push_back(CreateTensor(flatBufferBuilder,
260 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
261 tensorInfoNumUnits.size()),
262 tensorType,
263 buffers.size() - 1,
264 flatBufferBuilder.CreateString("cellToInputWeights"),
265 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000266 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100267 }
268 else
269 {
270 operatorInputs.push_back(kTfLiteOptionalTensor);
271 }
272
273 if (hasCellToForgetWeights)
274 {
275 buffers.push_back(
276 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000277 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
278 cellToForgetWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100279 sizeof(T) * cellToForgetWeights.size())));
280 tensors.push_back(CreateTensor(flatBufferBuilder,
281 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
282 tensorInfoNumUnits.size()),
283 tensorType,
284 buffers.size() - 1,
285 flatBufferBuilder.CreateString("cellToForgetWeights"),
286 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000287 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100288 }
289 else
290 {
291 operatorInputs.push_back(kTfLiteOptionalTensor);
292 }
293
294 if (hasCellToOutputWeights)
295 {
296 buffers.push_back(
297 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000298 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
299 cellToOutputWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100300 sizeof(T) * cellToOutputWeights.size())));
301 tensors.push_back(CreateTensor(flatBufferBuilder,
302 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
303 tensorInfoNumUnits.size()),
304 tensorType,
305 buffers.size() - 1,
306 flatBufferBuilder.CreateString("cellToOutputWeights"),
307 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000308 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100309 }
310 else
311 {
312 operatorInputs.push_back(kTfLiteOptionalTensor);
313 }
314
315 if (hasInputGateBias)
316 {
317 buffers.push_back(
318 CreateBuffer(flatBufferBuilder,
319 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(inputGateBias.data()),
320 sizeof(float) * inputGateBias.size())));
321 tensors.push_back(CreateTensor(flatBufferBuilder,
322 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
323 tensorInfoNumUnits.size()),
324 ::tflite::TensorType_FLOAT32,
325 buffers.size() - 1,
326 flatBufferBuilder.CreateString("inputGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000327 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100328 }
329 else
330 {
331 operatorInputs.push_back(kTfLiteOptionalTensor);
332 }
333
334 buffers.push_back(
335 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000336 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(forgetGateBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100337 sizeof(float) * forgetGateBias.size())));
338 tensors.push_back(CreateTensor(flatBufferBuilder,
339 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
340 tensorInfoNumUnits.size()),
341 ::tflite::TensorType_FLOAT32,
342 buffers.size() - 1,
343 flatBufferBuilder.CreateString("forgetGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000344 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100345
346 buffers.push_back(
347 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000348 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(cellBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100349 sizeof(float) * cellBias.size())));
350 tensors.push_back(CreateTensor(flatBufferBuilder,
351 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
352 tensorInfoNumUnits.size()),
353 ::tflite::TensorType_FLOAT32,
354 buffers.size() - 1,
355 flatBufferBuilder.CreateString("cellBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000356 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100357
358 buffers.push_back(
359 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000360 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(outputGateBias.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100361 sizeof(float) * outputGateBias.size())));
362 tensors.push_back(CreateTensor(flatBufferBuilder,
363 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
364 tensorInfoNumUnits.size()),
365 ::tflite::TensorType_FLOAT32,
366 buffers.size() - 1,
367 flatBufferBuilder.CreateString("outputGateBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000368 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100369
370 if (hasProjectionWeights)
371 {
372 buffers.push_back(
373 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000374 flatBufferBuilder.CreateVector(
375 reinterpret_cast<const uint8_t*>(projectionWeights.data()),
376 sizeof(T) * projectionWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100377 tensors.push_back(CreateTensor(flatBufferBuilder,
378 flatBufferBuilder.CreateVector<int32_t>(projectionWeightDimensions.data(),
379 projectionWeightDimensions.size()),
380 tensorType,
381 buffers.size() - 1,
382 flatBufferBuilder.CreateString("projectionWeights"),
383 weightQuantizationParameters));
Ryan OShea238ecd92023-03-07 11:44:23 +0000384 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100385 }
386 else
387 {
388 operatorInputs.push_back(kTfLiteOptionalTensor);
389 }
390
391 if (hasProjectionBias)
392 {
393 buffers.push_back(
394 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000395 flatBufferBuilder.CreateVector(
396 reinterpret_cast<const uint8_t*>(projectionBias.data()),
397 sizeof(float) * projectionBias.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100398 tensors.push_back(CreateTensor(flatBufferBuilder,
399 flatBufferBuilder.CreateVector<int32_t>(projectionBiasDimensions.data(),
400 projectionBiasDimensions.size()),
401 ::tflite::TensorType_FLOAT32,
402 buffers.size() - 1,
403 flatBufferBuilder.CreateString("projectionBias")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000404 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100405 }
406 else
407 {
408 operatorInputs.push_back(kTfLiteOptionalTensor);
409 }
410
Ryan OShea238ecd92023-03-07 11:44:23 +0000411 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100412 tensors.push_back(CreateTensor(flatBufferBuilder,
413 flatBufferBuilder.CreateVector<int32_t>(outputStateInDimensions.data(),
414 outputStateInDimensions.size()),
415 ::tflite::TensorType_FLOAT32,
416 buffers.size() - 1,
417 flatBufferBuilder.CreateString("outputStateInInfo"),
418 quantizationParameters,
419 true));
Ryan OShea238ecd92023-03-07 11:44:23 +0000420 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100421
Ryan OShea238ecd92023-03-07 11:44:23 +0000422 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100423 tensors.push_back(CreateTensor(flatBufferBuilder,
424 flatBufferBuilder.CreateVector<int32_t>(cellStateInDimensions.data(),
425 cellStateInDimensions.size()),
426 ::tflite::TensorType_FLOAT32,
427 buffers.size() - 1,
428 flatBufferBuilder.CreateString("cellStateInInfo"),
429 quantizationParameters,
430 true));
Ryan OShea238ecd92023-03-07 11:44:23 +0000431 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100432
433 if (hasInputLayerNormWeights)
434 {
435 buffers.push_back(
436 CreateBuffer(flatBufferBuilder,
437 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000438 reinterpret_cast<const uint8_t*>(inputLayerNormWeights.data()),
439 sizeof(float) * inputLayerNormWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100440 tensors.push_back(CreateTensor(flatBufferBuilder,
441 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
442 tensorInfoNumUnits.size()),
443 ::tflite::TensorType_FLOAT32,
444 buffers.size() - 1,
445 flatBufferBuilder.CreateString("inputLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000446 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100447 }
448 else
449 {
450 operatorInputs.push_back(kTfLiteOptionalTensor);
451 }
452
453 if (hasForgetLayerNormWeights)
454 {
455 buffers.push_back(
456 CreateBuffer(flatBufferBuilder,
457 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000458 reinterpret_cast<const uint8_t*>(forgetLayerNormWeights.data()),
459 sizeof(float) * forgetLayerNormWeights.size())));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100460 tensors.push_back(CreateTensor(flatBufferBuilder,
461 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
462 tensorInfoNumUnits.size()),
463 ::tflite::TensorType_FLOAT32,
464 buffers.size() - 1,
465 flatBufferBuilder.CreateString("forgetLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000466 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100467 }
468 else
469 {
470 operatorInputs.push_back(kTfLiteOptionalTensor);
471 }
472
473 if (hasCellLayerNormWeights)
474 {
475 buffers.push_back(
476 CreateBuffer(flatBufferBuilder,
Ryan OShea238ecd92023-03-07 11:44:23 +0000477 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(
478 cellLayerNormWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100479 sizeof(float) * cellLayerNormWeights.size())));
480 tensors.push_back(CreateTensor(flatBufferBuilder,
481 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
482 tensorInfoNumUnits.size()),
483 ::tflite::TensorType_FLOAT32,
484 buffers.size() - 1,
485 flatBufferBuilder.CreateString("cellLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000486 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100487 }
488 else
489 {
490 operatorInputs.push_back(kTfLiteOptionalTensor);
491 }
492
493 if (hasOutputLayerNormWeights)
494 {
495 buffers.push_back(
496 CreateBuffer(flatBufferBuilder,
497 flatBufferBuilder.CreateVector(
Ryan OShea238ecd92023-03-07 11:44:23 +0000498 reinterpret_cast<const uint8_t*>(outputLayerNormWeights.data()),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100499 sizeof(float) * outputLayerNormWeights.size())));
500 tensors.push_back(CreateTensor(flatBufferBuilder,
501 flatBufferBuilder.CreateVector<int32_t>(tensorInfoNumUnits.data(),
502 tensorInfoNumUnits.size()),
503 ::tflite::TensorType_FLOAT32,
504 buffers.size() - 1,
505 flatBufferBuilder.CreateString("outputLayerNormWeights")));
Ryan OShea238ecd92023-03-07 11:44:23 +0000506 operatorInputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100507 }
508 else
509 {
510 operatorInputs.push_back(kTfLiteOptionalTensor);
511 }
Ryan OShea238ecd92023-03-07 11:44:23 +0000512 buffers.push_back(CreateBuffer(flatBufferBuilder));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100513 tensors.push_back(CreateTensor(flatBufferBuilder,
514 flatBufferBuilder.CreateVector<int32_t>(outputShape.data(),
515 outputShape.size()),
516 ::tflite::TensorType_FLOAT32,
Ryan OShea238ecd92023-03-07 11:44:23 +0000517 buffers.size() - 1,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100518 flatBufferBuilder.CreateString("output")));
519 std::vector<int> operatorOutputs;
Ryan OShea238ecd92023-03-07 11:44:23 +0000520 operatorOutputs.push_back(tensors.size() - 1);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100521
522 // create operator
Ryan OShea238ecd92023-03-07 11:44:23 +0000523 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
524 flatbuffers::Offset<void> operatorBuiltinOptions =
525 CreateUnidirectionalSequenceLSTMOptions(flatBufferBuilder,
526 activationFunction,
527 clippingThresCell,
528 clippingThresProj,
529 isTimeMajor).Union();
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100530
531 flatbuffers::Offset<Operator> lstmOperator =
Ryan OShea238ecd92023-03-07 11:44:23 +0000532 CreateOperator(flatBufferBuilder,
533 0,
534 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
535 operatorInputs.size()),
536 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
537 operatorOutputs.size()),
538 operatorBuiltinOptionsType, operatorBuiltinOptions);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100539
Ryan OShea238ecd92023-03-07 11:44:23 +0000540 flatbuffers::Offset<SubGraph> subgraph =
541 CreateSubGraph(flatBufferBuilder,
542 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
543 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(),
544 operatorInputs.size()),
545 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(),
546 operatorOutputs.size()),
547 flatBufferBuilder.CreateVector(&lstmOperator, 1));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100548
Ryan OShea238ecd92023-03-07 11:44:23 +0000549 flatbuffers::Offset<flatbuffers::String> modelDescription =
550 flatBufferBuilder.CreateString(
551 "ArmnnDelegate: UnidirectionalSequenceLSTM Operator Model");
552 flatbuffers::Offset<OperatorCode> operatorCode =
553 CreateOperatorCode(flatBufferBuilder,
554 tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100555
Ryan OShea238ecd92023-03-07 11:44:23 +0000556 flatbuffers::Offset<Model> flatbufferModel =
557 CreateModel(flatBufferBuilder,
558 TFLITE_SCHEMA_VERSION,
559 flatBufferBuilder.CreateVector(&operatorCode, 1),
560 flatBufferBuilder.CreateVector(&subgraph, 1),
561 modelDescription,
562 flatBufferBuilder.CreateVector(buffers));
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100563
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100564 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100565
566 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
567 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
568}
569
Ryan OShea238ecd92023-03-07 11:44:23 +0000570template<typename T>
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000571void UnidirectionalSequenceLstmTestImpl(tflite::TensorType tensorType,
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100572 int32_t batchSize,
573 int32_t timeSize,
574 int32_t inputSize,
575 int32_t outputSize,
576 int32_t numUnits,
577 bool hasInputToInputWeights,
578 const std::vector<T>& inputToInputWeights,
579 const std::vector<T>& inputToForgetWeights,
580 const std::vector<T>& inputToCellWeights,
581 const std::vector<T>& inputToOutputWeights,
582 bool hasRecurrentToInputWeights,
583 const std::vector<T>& recurrentToInputWeights,
584 const std::vector<T>& recurrentToForgetWeights,
585 const std::vector<T>& recurrentToCellWeights,
586 const std::vector<T>& recurrentToOutputWeights,
587 bool hasCellToInputWeights,
588 const std::vector<T>& cellToInputWeights,
589 bool hasCellToForgetWeights,
590 const std::vector<T>& cellToForgetWeights,
591 bool hasCellToOutputWeights,
592 const std::vector<T>& cellToOutputWeights,
593 bool hasInputGateBias,
594 const std::vector<float>& inputGateBias,
595 const std::vector<float>& forgetGateBias,
596 const std::vector<float>& cellBias,
597 const std::vector<float>& outputGateBias,
598 bool hasProjectionWeights,
599 const std::vector<T>& projectionWeights,
600 bool hasProjectionBias,
601 const std::vector<float>& projectionBias,
602 bool hasInputLayerNormWeights,
603 const std::vector<float>& inputLayerNormWeights,
604 bool hasForgetLayerNormWeights,
605 const std::vector<float>& forgetLayerNormWeights,
606 bool hasCellLayerNormWeights,
607 const std::vector<float>& cellLayerNormWeights,
608 bool hasOutputLayerNormWeights,
609 const std::vector<float>& outputLayerNormWeights,
610 std::vector<float>& inputValues,
611 std::vector<float>& expectedOutputValues,
612 tflite::ActivationFunctionType activationFunction,
613 float clippingThresCell,
614 float clippingThresProj,
615 bool isTimeMajor,
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000616 const std::vector<armnn::BackendId>& backends = {},
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100617 float quantScale = 0.1f)
618{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100619 using namespace delegateTestInterpreter;
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100620
621 std::vector<char> modelBuffer = CreateUnidirectionalSequenceLstmTfLiteModel(tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000622 batchSize,
623 timeSize,
624 inputSize,
625 outputSize,
626 numUnits,
627 hasInputToInputWeights,
628 inputToInputWeights,
629 inputToForgetWeights,
630 inputToCellWeights,
631 inputToOutputWeights,
632 hasRecurrentToInputWeights,
633 recurrentToInputWeights,
634 recurrentToForgetWeights,
635 recurrentToCellWeights,
636 recurrentToOutputWeights,
637 hasCellToInputWeights,
638 cellToInputWeights,
639 hasCellToForgetWeights,
640 cellToForgetWeights,
641 hasCellToOutputWeights,
642 cellToOutputWeights,
643 hasInputGateBias,
644 inputGateBias,
645 forgetGateBias,
646 cellBias,
647 outputGateBias,
648 hasProjectionWeights,
649 projectionWeights,
650 hasProjectionBias,
651 projectionBias,
652 hasInputLayerNormWeights,
653 inputLayerNormWeights,
654 hasForgetLayerNormWeights,
655 forgetLayerNormWeights,
656 hasCellLayerNormWeights,
657 cellLayerNormWeights,
658 hasOutputLayerNormWeights,
659 outputLayerNormWeights,
660 activationFunction,
661 clippingThresCell,
662 clippingThresProj,
663 isTimeMajor,
664 quantScale);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100665
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100666 std::vector<int32_t> outputShape;
667 if (isTimeMajor)
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100668 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100669 outputShape = {timeSize, batchSize, outputSize};
670 }
671 else
672 {
673 outputShape = {batchSize, timeSize, outputSize};
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100674 }
675
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100676 // Setup interpreter with just TFLite Runtime.
677 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
678 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
679 CHECK(tfLiteInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
680 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
681 std::vector<float> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<float>(0);
682 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100683
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100684 // Setup interpreter with Arm NN Delegate applied.
Colm Donelan7bcae3c2024-01-22 10:07:14 +0000685 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100686 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
687 CHECK(armnnInterpreter.FillInputTensor<float>(inputValues, 0) == kTfLiteOk);
688 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
689 std::vector<float> armnnOutputValues = armnnInterpreter.GetOutputResult<float>(0);
690 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100691
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100692 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputShape);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100693
694 if (tensorType == ::tflite::TensorType_INT8)
695 {
696 // Allow 2% tolerance for Quantized weights
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100697 armnnDelegate::CompareData(expectedOutputValues.data(), armnnOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100698 expectedOutputValues.size(), 2);
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100699 armnnDelegate::CompareData(expectedOutputValues.data(), tfLiteOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100700 expectedOutputValues.size(), 2);
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100701 armnnDelegate::CompareData(tfLiteOutputValues.data(), armnnOutputValues.data(),
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100702 expectedOutputValues.size(), 2);
703 }
704 else
705 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100706 armnnDelegate::CompareOutputData<float>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100707 }
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100708
709 tfLiteInterpreter.Cleanup();
710 armnnInterpreter.Cleanup();
Narumol Prangnawarat7684b182021-08-12 14:48:15 +0100711}
712
713} // anonymous namespace