blob: 47ee7c2410f17540bd97140fcc2ba3735502a065 [file] [log] [blame]
Sadik Armagan67e95f22020-10-29 16:14:54 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan67e95f22020-10-29 16:14:54 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf7ac72c2021-05-05 15:03:50 +01008#include "TestUtils.hpp"
9
Sadik Armagan67e95f22020-10-29 16:14:54 +000010#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
Teresa Charlinad1b3d72023-03-14 12:10:28 +000016#include <schema_generated.h>
Sadik Armagan67e95f22020-10-29 16:14:54 +000017#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010024template <typename T>
Sadik Armagan67e95f22020-10-29 16:14:54 +000025std::vector<char> CreateElementwiseBinaryTfLiteModel(tflite::BuiltinOperator binaryOperatorCode,
26 tflite::ActivationFunctionType activationType,
27 tflite::TensorType tensorType,
28 const std::vector <int32_t>& input0TensorShape,
29 const std::vector <int32_t>& input1TensorShape,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000030 const std::vector <int32_t>& outputTensorShape,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010031 std::vector<T>& input1Values,
32 bool constantInput = false,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000033 float quantScale = 1.0f,
34 int quantOffset = 0)
Sadik Armagan67e95f22020-10-29 16:14:54 +000035{
36 using namespace tflite;
37 flatbuffers::FlatBufferBuilder flatBufferBuilder;
38
39 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
Ryan OShea238ecd92023-03-07 11:44:23 +000040 buffers.push_back(CreateBuffer(flatBufferBuilder));
41 buffers.push_back(CreateBuffer(flatBufferBuilder));
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010042 if (constantInput)
43 {
44 buffers.push_back(
45 CreateBuffer(flatBufferBuilder,
46 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(input1Values.data()),
47 sizeof(T) * input1Values.size())));
48 }
49 else
50 {
Ryan OShea238ecd92023-03-07 11:44:23 +000051 buffers.push_back(CreateBuffer(flatBufferBuilder));
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010052 }
Ryan OShea238ecd92023-03-07 11:44:23 +000053 buffers.push_back(CreateBuffer(flatBufferBuilder));
Sadik Armagan67e95f22020-10-29 16:14:54 +000054
Sadik Armagan21a94ff2020-11-09 08:38:30 +000055 auto quantizationParameters =
56 CreateQuantizationParameters(flatBufferBuilder,
57 0,
58 0,
59 flatBufferBuilder.CreateVector<float>({ quantScale }),
60 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
61
62
Sadik Armagan67e95f22020-10-29 16:14:54 +000063 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
64 tensors[0] = CreateTensor(flatBufferBuilder,
65 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
66 input0TensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000067 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000068 1,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000069 flatBufferBuilder.CreateString("input_0"),
70 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000071 tensors[1] = CreateTensor(flatBufferBuilder,
72 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
73 input1TensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000074 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000075 2,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000076 flatBufferBuilder.CreateString("input_1"),
77 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000078 tensors[2] = CreateTensor(flatBufferBuilder,
79 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
80 outputTensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000081 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000082 3,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000083 flatBufferBuilder.CreateString("output"),
84 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000085
86 // create operator
87 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
88 flatbuffers::Offset<void> operatorBuiltinOptions = 0;
89 switch (binaryOperatorCode)
90 {
91 case BuiltinOperator_ADD:
92 {
93 operatorBuiltinOptionsType = BuiltinOptions_AddOptions;
94 operatorBuiltinOptions = CreateAddOptions(flatBufferBuilder, activationType).Union();
95 break;
96 }
97 case BuiltinOperator_DIV:
98 {
99 operatorBuiltinOptionsType = BuiltinOptions_DivOptions;
100 operatorBuiltinOptions = CreateDivOptions(flatBufferBuilder, activationType).Union();
101 break;
102 }
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000103 case BuiltinOperator_MAXIMUM:
104 {
105 operatorBuiltinOptionsType = BuiltinOptions_MaximumMinimumOptions;
106 operatorBuiltinOptions = CreateMaximumMinimumOptions(flatBufferBuilder).Union();
107 break;
108 }
109 case BuiltinOperator_MINIMUM:
110 {
111 operatorBuiltinOptionsType = BuiltinOptions_MaximumMinimumOptions;
112 operatorBuiltinOptions = CreateMaximumMinimumOptions(flatBufferBuilder).Union();
113 break;
114 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000115 case BuiltinOperator_MUL:
116 {
117 operatorBuiltinOptionsType = BuiltinOptions_MulOptions;
118 operatorBuiltinOptions = CreateMulOptions(flatBufferBuilder, activationType).Union();
119 break;
120 }
121 case BuiltinOperator_SUB:
122 {
123 operatorBuiltinOptionsType = BuiltinOptions_SubOptions;
124 operatorBuiltinOptions = CreateSubOptions(flatBufferBuilder, activationType).Union();
125 break;
126 }
Jim Flynn4b2f3472021-10-13 21:20:07 +0100127 case BuiltinOperator_FLOOR_DIV:
128 {
129 operatorBuiltinOptionsType = tflite::BuiltinOptions_FloorDivOptions;
130 operatorBuiltinOptions = CreateSubOptions(flatBufferBuilder, activationType).Union();
131 break;
132 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000133 default:
134 break;
135 }
Keith Davis892fafe2020-11-26 17:40:35 +0000136 const std::vector<int32_t> operatorInputs{0, 1};
137 const std::vector<int32_t> operatorOutputs{2};
Sadik Armagan67e95f22020-10-29 16:14:54 +0000138 flatbuffers::Offset <Operator> elementwiseBinaryOperator =
139 CreateOperator(flatBufferBuilder,
140 0,
141 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
142 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
143 operatorBuiltinOptionsType,
144 operatorBuiltinOptions);
145
Keith Davis892fafe2020-11-26 17:40:35 +0000146 const std::vector<int> subgraphInputs{0, 1};
147 const std::vector<int> subgraphOutputs{2};
Sadik Armagan67e95f22020-10-29 16:14:54 +0000148 flatbuffers::Offset <SubGraph> subgraph =
149 CreateSubGraph(flatBufferBuilder,
150 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
151 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
152 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
153 flatBufferBuilder.CreateVector(&elementwiseBinaryOperator, 1));
154
155 flatbuffers::Offset <flatbuffers::String> modelDescription =
156 flatBufferBuilder.CreateString("ArmnnDelegate: Elementwise Binary Operator Model");
157 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, binaryOperatorCode);
158
159 flatbuffers::Offset <Model> flatbufferModel =
160 CreateModel(flatBufferBuilder,
161 TFLITE_SCHEMA_VERSION,
162 flatBufferBuilder.CreateVector(&operatorCode, 1),
163 flatBufferBuilder.CreateVector(&subgraph, 1),
164 modelDescription,
165 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
166
167 flatBufferBuilder.Finish(flatbufferModel);
168
169 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
170 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
171}
172
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000173template <typename T>
174void ElementwiseBinaryTest(tflite::BuiltinOperator binaryOperatorCode,
175 tflite::ActivationFunctionType activationType,
176 tflite::TensorType tensorType,
177 std::vector<armnn::BackendId>& backends,
178 std::vector<int32_t>& input0Shape,
179 std::vector<int32_t>& input1Shape,
180 std::vector<int32_t>& outputShape,
181 std::vector<T>& input0Values,
182 std::vector<T>& input1Values,
183 std::vector<T>& expectedOutputValues,
184 float quantScale = 1.0f,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100185 int quantOffset = 0,
186 bool constantInput = false)
Sadik Armagan67e95f22020-10-29 16:14:54 +0000187{
188 using namespace tflite;
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100189 std::vector<char> modelBuffer = CreateElementwiseBinaryTfLiteModel<T>(binaryOperatorCode,
190 activationType,
191 tensorType,
192 input0Shape,
193 input1Shape,
194 outputShape,
195 input1Values,
196 constantInput,
197 quantScale,
198 quantOffset);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000199
200 const Model* tfLiteModel = GetModel(modelBuffer.data());
201 // Create TfLite Interpreters
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100202 std::unique_ptr <Interpreter> armnnDelegateInterpreter;
Sadik Armagan67e95f22020-10-29 16:14:54 +0000203 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
204 (&armnnDelegateInterpreter) == kTfLiteOk);
205 CHECK(armnnDelegateInterpreter != nullptr);
206 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
207
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100208 std::unique_ptr <Interpreter> tfLiteInterpreter;
Sadik Armagan67e95f22020-10-29 16:14:54 +0000209 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
210 (&tfLiteInterpreter) == kTfLiteOk);
211 CHECK(tfLiteInterpreter != nullptr);
212 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
213
214 // Create the ArmNN Delegate
215 armnnDelegate::DelegateOptions delegateOptions(backends);
216 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100217 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
218 armnnDelegate::TfLiteArmnnDelegateDelete);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000219 CHECK(theArmnnDelegate != nullptr);
220 // Modify armnnDelegateInterpreter to use armnnDelegate
221 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
222
223 // Set input data
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100224 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, input0Values);
225 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, input0Values);
226 if (!constantInput)
Sadik Armagan67e95f22020-10-29 16:14:54 +0000227 {
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100228 armnnDelegate::FillInput<T>(tfLiteInterpreter, 1, input1Values);
229 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 1, input1Values);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000230 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000231 // Run EnqueWorkload
232 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
233 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
234
235 // Compare output data
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100236 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
237 armnnDelegateInterpreter,
238 outputShape,
239 expectedOutputValues);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000240 armnnDelegateInterpreter.reset(nullptr);
241}
242
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000243} // anonymous namespace