blob: a22800c6c8982627a8a10eae799298a5f0c3851f [file] [log] [blame]
Sadik Armagan937565b2021-04-21 14:03:28 +01001//
Colm Donelan7bcae3c2024-01-22 10:07:14 +00002// Copyright © 2021, 2023-2024 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan937565b2021-04-21 14:03:28 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Sadik Armagan937565b2021-04-21 14:03:28 +010012
Sadik Armagan937565b2021-04-21 14:03:28 +010013#include <tensorflow/lite/version.h>
14
Sadik Armagan937565b2021-04-21 14:03:28 +010015namespace
16{
17std::vector<char> CreateCastTfLiteModel(tflite::TensorType inputTensorType,
18 tflite::TensorType outputTensorType,
19 const std::vector <int32_t>& tensorShape,
20 float quantScale = 1.0f,
21 int quantOffset = 0)
22{
23 using namespace tflite;
24 flatbuffers::FlatBufferBuilder flatBufferBuilder;
25
26 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +000027 buffers.push_back(CreateBuffer(flatBufferBuilder));
28 buffers.push_back(CreateBuffer(flatBufferBuilder));
29 buffers.push_back(CreateBuffer(flatBufferBuilder));
Sadik Armagan937565b2021-04-21 14:03:28 +010030
31 auto quantizationParameters =
32 CreateQuantizationParameters(flatBufferBuilder,
33 0,
34 0,
35 flatBufferBuilder.CreateVector<float>({quantScale}),
36 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
37
38 std::array<flatbuffers::Offset<Tensor>, 2> tensors;
39 tensors[0] = CreateTensor(flatBufferBuilder,
40 flatBufferBuilder.CreateVector<int32_t>(tensorShape.data(),
41 tensorShape.size()),
42 inputTensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000043 1,
Sadik Armagan937565b2021-04-21 14:03:28 +010044 flatBufferBuilder.CreateString("input"),
45 quantizationParameters);
46 tensors[1] = CreateTensor(flatBufferBuilder,
47 flatBufferBuilder.CreateVector<int32_t>(tensorShape.data(),
48 tensorShape.size()),
49 outputTensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000050 2,
Sadik Armagan937565b2021-04-21 14:03:28 +010051 flatBufferBuilder.CreateString("output"),
52 quantizationParameters);
53
54 const std::vector<int32_t> operatorInputs({0});
55 const std::vector<int32_t> operatorOutputs({1});
56
57 flatbuffers::Offset<Operator> castOperator =
58 CreateOperator(flatBufferBuilder,
59 0,
60 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
61 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
62 BuiltinOptions_CastOptions,
63 CreateCastOptions(flatBufferBuilder).Union());
64
65 flatbuffers::Offset<flatbuffers::String> modelDescription =
66 flatBufferBuilder.CreateString("ArmnnDelegate: CAST Operator Model");
67 flatbuffers::Offset<OperatorCode> operatorCode =
68 CreateOperatorCode(flatBufferBuilder, tflite::BuiltinOperator_CAST);
69
70 const std::vector<int32_t> subgraphInputs({0});
71 const std::vector<int32_t> subgraphOutputs({1});
72 flatbuffers::Offset<SubGraph> subgraph =
73 CreateSubGraph(flatBufferBuilder,
74 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
75 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
76 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
77 flatBufferBuilder.CreateVector(&castOperator, 1));
78
79 flatbuffers::Offset<Model> flatbufferModel =
80 CreateModel(flatBufferBuilder,
81 TFLITE_SCHEMA_VERSION,
82 flatBufferBuilder.CreateVector(&operatorCode, 1),
83 flatBufferBuilder.CreateVector(&subgraph, 1),
84 modelDescription,
85 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
86
Matthew Sloyanebe392d2023-03-30 10:12:08 +010087 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Sadik Armagan937565b2021-04-21 14:03:28 +010088 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
89 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
90}
91
92template<typename T, typename K>
93void CastTest(tflite::TensorType inputTensorType,
94 tflite::TensorType outputTensorType,
Sadik Armagan937565b2021-04-21 14:03:28 +010095 std::vector<int32_t>& shape,
96 std::vector<T>& inputValues,
97 std::vector<K>& expectedOutputValues,
98 float quantScale = 1.0f,
Colm Donelaneff204a2023-11-28 15:46:09 +000099 int quantOffset = 0,
100 const std::vector<armnn::BackendId>& backends = {})
Sadik Armagan937565b2021-04-21 14:03:28 +0100101{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100102 using namespace delegateTestInterpreter;
Sadik Armagan937565b2021-04-21 14:03:28 +0100103 std::vector<char> modelBuffer = CreateCastTfLiteModel(inputTensorType,
104 outputTensorType,
105 shape,
106 quantScale,
107 quantOffset);
108
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100109 // Setup interpreter with just TFLite Runtime.
110 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
111 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
112 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
113 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
114 std::vector<K> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<K>(0);
115 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Sadik Armagan937565b2021-04-21 14:03:28 +0100116
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100117 // Setup interpreter with Arm NN Delegate applied.
Colm Donelaneff204a2023-11-28 15:46:09 +0000118 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, CaptureAvailableBackends(backends));
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100119 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
120 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
121 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
122 std::vector<K> armnnOutputValues = armnnInterpreter.GetOutputResult<K>(0);
123 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
Sadik Armagan937565b2021-04-21 14:03:28 +0100124
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100125 armnnDelegate::CompareOutputData<K>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
126 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, shape);
Sadik Armagan937565b2021-04-21 14:03:28 +0100127
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100128 tfLiteInterpreter.Cleanup();
129 armnnInterpreter.Cleanup();
Sadik Armagan937565b2021-04-21 14:03:28 +0100130}
131
132} // anonymous namespace