blob: ef9f87a5d52ab724cce08e2507ff7384fa904d73 [file] [log] [blame]
Sadik Armagan8b9858d2020-11-09 08:26:22 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan8b9858d2020-11-09 08:26:22 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Jan Eilersfe73b042020-11-18 10:36:46 +00008#include "TestUtils.hpp"
9
Sadik Armagan8b9858d2020-11-09 08:26:22 +000010#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
Teresa Charlinad1b3d72023-03-14 12:10:28 +000016#include <schema_generated.h>
Sadik Armagan8b9858d2020-11-09 08:26:22 +000017#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
24std::vector<char> CreateComparisonTfLiteModel(tflite::BuiltinOperator comparisonOperatorCode,
25 tflite::TensorType tensorType,
26 const std::vector <int32_t>& input0TensorShape,
27 const std::vector <int32_t>& input1TensorShape,
28 const std::vector <int32_t>& outputTensorShape,
29 float quantScale = 1.0f,
30 int quantOffset = 0)
31{
32 using namespace tflite;
33 flatbuffers::FlatBufferBuilder flatBufferBuilder;
34
35 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +000036 buffers.push_back(CreateBuffer(flatBufferBuilder));
37 buffers.push_back(CreateBuffer(flatBufferBuilder));
38 buffers.push_back(CreateBuffer(flatBufferBuilder));
39 buffers.push_back(CreateBuffer(flatBufferBuilder));
Sadik Armagan8b9858d2020-11-09 08:26:22 +000040
41 auto quantizationParameters =
42 CreateQuantizationParameters(flatBufferBuilder,
43 0,
44 0,
45 flatBufferBuilder.CreateVector<float>({ quantScale }),
46 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
47
48 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
49 tensors[0] = CreateTensor(flatBufferBuilder,
50 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
51 input0TensorShape.size()),
52 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000053 1,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000054 flatBufferBuilder.CreateString("input_0"),
55 quantizationParameters);
56 tensors[1] = CreateTensor(flatBufferBuilder,
57 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
58 input1TensorShape.size()),
59 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000060 2,
Sadik Armagan8b9858d2020-11-09 08:26:22 +000061 flatBufferBuilder.CreateString("input_1"),
62 quantizationParameters);
63 tensors[2] = CreateTensor(flatBufferBuilder,
64 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
65 outputTensorShape.size()),
66 ::tflite::TensorType_BOOL,
Ryan OShea238ecd92023-03-07 11:44:23 +000067 3);
Sadik Armagan8b9858d2020-11-09 08:26:22 +000068
69 // create operator
70 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_EqualOptions;;
71 flatbuffers::Offset<void> operatorBuiltinOptions = CreateEqualOptions(flatBufferBuilder).Union();
72 switch (comparisonOperatorCode)
73 {
74 case BuiltinOperator_EQUAL:
75 {
76 operatorBuiltinOptionsType = BuiltinOptions_EqualOptions;
77 operatorBuiltinOptions = CreateEqualOptions(flatBufferBuilder).Union();
78 break;
79 }
80 case BuiltinOperator_NOT_EQUAL:
81 {
82 operatorBuiltinOptionsType = BuiltinOptions_NotEqualOptions;
83 operatorBuiltinOptions = CreateNotEqualOptions(flatBufferBuilder).Union();
84 break;
85 }
86 case BuiltinOperator_GREATER:
87 {
88 operatorBuiltinOptionsType = BuiltinOptions_GreaterOptions;
89 operatorBuiltinOptions = CreateGreaterOptions(flatBufferBuilder).Union();
90 break;
91 }
92 case BuiltinOperator_GREATER_EQUAL:
93 {
94 operatorBuiltinOptionsType = BuiltinOptions_GreaterEqualOptions;
95 operatorBuiltinOptions = CreateGreaterEqualOptions(flatBufferBuilder).Union();
96 break;
97 }
98 case BuiltinOperator_LESS:
99 {
100 operatorBuiltinOptionsType = BuiltinOptions_LessOptions;
101 operatorBuiltinOptions = CreateLessOptions(flatBufferBuilder).Union();
102 break;
103 }
104 case BuiltinOperator_LESS_EQUAL:
105 {
106 operatorBuiltinOptionsType = BuiltinOptions_LessEqualOptions;
107 operatorBuiltinOptions = CreateLessEqualOptions(flatBufferBuilder).Union();
108 break;
109 }
110 default:
111 break;
112 }
Keith Davis892fafe2020-11-26 17:40:35 +0000113 const std::vector<int32_t> operatorInputs{0, 1};
114 const std::vector<int32_t> operatorOutputs{2};
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000115 flatbuffers::Offset <Operator> comparisonOperator =
116 CreateOperator(flatBufferBuilder,
117 0,
118 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
119 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
120 operatorBuiltinOptionsType,
121 operatorBuiltinOptions);
122
Keith Davis892fafe2020-11-26 17:40:35 +0000123 const std::vector<int> subgraphInputs{0, 1};
124 const std::vector<int> subgraphOutputs{2};
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000125 flatbuffers::Offset <SubGraph> subgraph =
126 CreateSubGraph(flatBufferBuilder,
127 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
128 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
129 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
130 flatBufferBuilder.CreateVector(&comparisonOperator, 1));
131
132 flatbuffers::Offset <flatbuffers::String> modelDescription =
133 flatBufferBuilder.CreateString("ArmnnDelegate: Comparison Operator Model");
134 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, comparisonOperatorCode);
135
136 flatbuffers::Offset <Model> flatbufferModel =
137 CreateModel(flatBufferBuilder,
138 TFLITE_SCHEMA_VERSION,
139 flatBufferBuilder.CreateVector(&operatorCode, 1),
140 flatBufferBuilder.CreateVector(&subgraph, 1),
141 modelDescription,
142 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
143
144 flatBufferBuilder.Finish(flatbufferModel);
145
146 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
147 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
148}
149
150template <typename T>
151void ComparisonTest(tflite::BuiltinOperator comparisonOperatorCode,
152 tflite::TensorType tensorType,
153 std::vector<armnn::BackendId>& backends,
154 std::vector<int32_t>& input0Shape,
155 std::vector<int32_t>& input1Shape,
156 std::vector<int32_t>& outputShape,
157 std::vector<T>& input0Values,
158 std::vector<T>& input1Values,
159 std::vector<bool>& expectedOutputValues,
160 float quantScale = 1.0f,
161 int quantOffset = 0)
162{
163 using namespace tflite;
164 std::vector<char> modelBuffer = CreateComparisonTfLiteModel(comparisonOperatorCode,
165 tensorType,
166 input0Shape,
167 input1Shape,
168 outputShape,
169 quantScale,
170 quantOffset);
171
172 const Model* tfLiteModel = GetModel(modelBuffer.data());
173 // Create TfLite Interpreters
174 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
175 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
176 (&armnnDelegateInterpreter) == kTfLiteOk);
177 CHECK(armnnDelegateInterpreter != nullptr);
178 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
179
180 std::unique_ptr<Interpreter> tfLiteInterpreter;
181 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
182 (&tfLiteInterpreter) == kTfLiteOk);
183 CHECK(tfLiteInterpreter != nullptr);
184 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
185
186 // Create the ArmNN Delegate
187 armnnDelegate::DelegateOptions delegateOptions(backends);
188 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
189 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
190 armnnDelegate::TfLiteArmnnDelegateDelete);
191 CHECK(theArmnnDelegate != nullptr);
192 // Modify armnnDelegateInterpreter to use armnnDelegate
193 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
194
195 // Set input data
196 auto tfLiteDelegateInput0Id = tfLiteInterpreter->inputs()[0];
197 auto tfLiteDelageInput0Data = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateInput0Id);
198 for (unsigned int i = 0; i < input0Values.size(); ++i)
199 {
200 tfLiteDelageInput0Data[i] = input0Values[i];
201 }
202
203 auto tfLiteDelegateInput1Id = tfLiteInterpreter->inputs()[1];
204 auto tfLiteDelageInput1Data = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateInput1Id);
205 for (unsigned int i = 0; i < input1Values.size(); ++i)
206 {
207 tfLiteDelageInput1Data[i] = input1Values[i];
208 }
209
210 auto armnnDelegateInput0Id = armnnDelegateInterpreter->inputs()[0];
211 auto armnnDelegateInput0Data = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateInput0Id);
212 for (unsigned int i = 0; i < input0Values.size(); ++i)
213 {
214 armnnDelegateInput0Data[i] = input0Values[i];
215 }
216
217 auto armnnDelegateInput1Id = armnnDelegateInterpreter->inputs()[1];
218 auto armnnDelegateInput1Data = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateInput1Id);
219 for (unsigned int i = 0; i < input1Values.size(); ++i)
220 {
221 armnnDelegateInput1Data[i] = input1Values[i];
222 }
223
224 // Run EnqueWorkload
225 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
226 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
227 // Compare output data
228 auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
229 auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<bool>(tfLiteDelegateOutputId);
230 auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
231 auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<bool>(armnnDelegateOutputId);
232
Jan Eilersfe73b042020-11-18 10:36:46 +0000233 armnnDelegate::CompareData(expectedOutputValues , armnnDelegateOutputData, expectedOutputValues.size());
234 armnnDelegate::CompareData(expectedOutputValues , tfLiteDelageOutputData , expectedOutputValues.size());
235 armnnDelegate::CompareData(tfLiteDelageOutputData, armnnDelegateOutputData, expectedOutputValues.size());
Sadik Armagan8b9858d2020-11-09 08:26:22 +0000236}
237
238} // anonymous namespace