blob: a3f009a86301371ae2eb2848186125a536c37950 [file] [log] [blame]
Sadik Armagan6e36a642020-11-10 21:18:41 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan6e36a642020-11-10 21:18:41 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00008#include "TestUtils.hpp"
9
Sadik Armagan6e36a642020-11-10 21:18:41 +000010#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
24template <typename T>
25std::vector<char> CreateFullyConnectedTfLiteModel(tflite::TensorType tensorType,
26 tflite::ActivationFunctionType activationType,
27 const std::vector <int32_t>& inputTensorShape,
28 const std::vector <int32_t>& weightsTensorShape,
29 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000030 std::vector <int32_t>& outputTensorShape,
31 std::vector <T>& weightsData,
32 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +000033 float quantScale = 1.0f,
34 int quantOffset = 0,
35 float outputQuantScale = 2.0f,
36 int outputQuantOffset = 0)
37{
38 using namespace tflite;
39 flatbuffers::FlatBufferBuilder flatBufferBuilder;
Ryan OShea238ecd92023-03-07 11:44:23 +000040 std::array<flatbuffers::Offset<tflite::Buffer>, 5> buffers;
41 buffers[0] = CreateBuffer(flatBufferBuilder);
42 buffers[1] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000043
44 auto biasTensorType = ::tflite::TensorType_FLOAT32;
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000045 if (tensorType == ::tflite::TensorType_INT8)
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 {
47 biasTensorType = ::tflite::TensorType_INT32;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000048 }
49 if (constantWeights)
50 {
Ryan OShea238ecd92023-03-07 11:44:23 +000051 buffers[2] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000052 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(weightsData.data()),
53 sizeof(T) * weightsData.size()));
Sadik Armagan6e36a642020-11-10 21:18:41 +000054
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000055 if (tensorType == ::tflite::TensorType_INT8)
56 {
57 std::vector<int32_t> biasData = { 10 };
Ryan OShea238ecd92023-03-07 11:44:23 +000058 buffers[3] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000059 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
60 sizeof(int32_t) * biasData.size()));
61
62 }
63 else
64 {
65 std::vector<float> biasData = { 10 };
Ryan OShea238ecd92023-03-07 11:44:23 +000066 buffers[3] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000067 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
68 sizeof(float) * biasData.size()));
69 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000070 }
71 else
72 {
Ryan OShea238ecd92023-03-07 11:44:23 +000073 buffers[2] = CreateBuffer(flatBufferBuilder);
74 buffers[3] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000075 }
Ryan OShea238ecd92023-03-07 11:44:23 +000076 buffers[4] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000077
78 auto quantizationParameters =
79 CreateQuantizationParameters(flatBufferBuilder,
80 0,
81 0,
82 flatBufferBuilder.CreateVector<float>({ quantScale }),
83 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
84
85 auto outputQuantizationParameters =
86 CreateQuantizationParameters(flatBufferBuilder,
87 0,
88 0,
89 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
90 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
91
92 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
93 tensors[0] = CreateTensor(flatBufferBuilder,
94 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
95 inputTensorShape.size()),
96 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000097 1,
Sadik Armagan6e36a642020-11-10 21:18:41 +000098 flatBufferBuilder.CreateString("input_0"),
99 quantizationParameters);
100 tensors[1] = CreateTensor(flatBufferBuilder,
101 flatBufferBuilder.CreateVector<int32_t>(weightsTensorShape.data(),
102 weightsTensorShape.size()),
103 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000104 2,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000105 flatBufferBuilder.CreateString("weights"),
106 quantizationParameters);
107 tensors[2] = CreateTensor(flatBufferBuilder,
108 flatBufferBuilder.CreateVector<int32_t>(biasTensorShape.data(),
109 biasTensorShape.size()),
110 biasTensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000111 3,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000112 flatBufferBuilder.CreateString("bias"),
113 quantizationParameters);
114
115 tensors[3] = CreateTensor(flatBufferBuilder,
116 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
117 outputTensorShape.size()),
118 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000119 4,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000120 flatBufferBuilder.CreateString("output"),
121 outputQuantizationParameters);
122
123
124 // create operator
125 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_FullyConnectedOptions;
126 flatbuffers::Offset<void> operatorBuiltinOptions =
127 CreateFullyConnectedOptions(flatBufferBuilder,
128 activationType,
129 FullyConnectedOptionsWeightsFormat_DEFAULT, false).Union();
130
Keith Davis892fafe2020-11-26 17:40:35 +0000131 const std::vector<int> operatorInputs{0, 1, 2};
132 const std::vector<int> operatorOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000133 flatbuffers::Offset <Operator> fullyConnectedOperator =
134 CreateOperator(flatBufferBuilder,
135 0,
136 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
137 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
138 operatorBuiltinOptionsType, operatorBuiltinOptions);
139
Keith Davis892fafe2020-11-26 17:40:35 +0000140 const std::vector<int> subgraphInputs{0, 1, 2};
141 const std::vector<int> subgraphOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000142 flatbuffers::Offset <SubGraph> subgraph =
143 CreateSubGraph(flatBufferBuilder,
144 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
145 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
146 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
147 flatBufferBuilder.CreateVector(&fullyConnectedOperator, 1));
148
149 flatbuffers::Offset <flatbuffers::String> modelDescription =
150 flatBufferBuilder.CreateString("ArmnnDelegate: FullyConnected Operator Model");
151 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
152 tflite::BuiltinOperator_FULLY_CONNECTED);
153
154 flatbuffers::Offset <Model> flatbufferModel =
155 CreateModel(flatBufferBuilder,
156 TFLITE_SCHEMA_VERSION,
157 flatBufferBuilder.CreateVector(&operatorCode, 1),
158 flatBufferBuilder.CreateVector(&subgraph, 1),
159 modelDescription,
160 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
161
162 flatBufferBuilder.Finish(flatbufferModel);
163
164 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
165 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
166}
167
168template <typename T>
169void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
170 tflite::TensorType tensorType,
171 tflite::ActivationFunctionType activationType,
172 const std::vector <int32_t>& inputTensorShape,
173 const std::vector <int32_t>& weightsTensorShape,
174 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000175 std::vector <int32_t>& outputTensorShape,
176 std::vector <T>& inputValues,
177 std::vector <T>& expectedOutputValues,
178 std::vector <T>& weightsData,
179 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000180 float quantScale = 1.0f,
181 int quantOffset = 0)
182{
183 using namespace tflite;
184
185 std::vector<char> modelBuffer = CreateFullyConnectedTfLiteModel(tensorType,
186 activationType,
187 inputTensorShape,
188 weightsTensorShape,
189 biasTensorShape,
190 outputTensorShape,
191 weightsData,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000192 constantWeights,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000193 quantScale,
194 quantOffset);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000195 const Model* tfLiteModel = GetModel(modelBuffer.data());
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000196
Sadik Armagan6e36a642020-11-10 21:18:41 +0000197 // Create TfLite Interpreters
198 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
199 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
200 (&armnnDelegateInterpreter) == kTfLiteOk);
201 CHECK(armnnDelegateInterpreter != nullptr);
202 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
203
204 std::unique_ptr<Interpreter> tfLiteInterpreter;
205 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
206 (&tfLiteInterpreter) == kTfLiteOk);
207 CHECK(tfLiteInterpreter != nullptr);
208 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
209
210 // Create the ArmNN Delegate
211 armnnDelegate::DelegateOptions delegateOptions(backends);
212 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000213 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
214 armnnDelegate::TfLiteArmnnDelegateDelete);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000215 CHECK(theArmnnDelegate != nullptr);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000216
Sadik Armagan6e36a642020-11-10 21:18:41 +0000217 // Modify armnnDelegateInterpreter to use armnnDelegate
218 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
219
220 // Set input data
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000221 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, inputValues);
222 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, inputValues);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000223
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000224 if (!constantWeights)
Sadik Armagan6e36a642020-11-10 21:18:41 +0000225 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000226 armnnDelegate::FillInput<T>(tfLiteInterpreter, 1, weightsData);
227 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 1, weightsData);
228
229 if (tensorType == ::tflite::TensorType_INT8)
230 {
231 std::vector <int32_t> biasData = {10};
232 armnnDelegate::FillInput<int32_t>(tfLiteInterpreter, 2, biasData);
233 armnnDelegate::FillInput<int32_t>(armnnDelegateInterpreter, 2, biasData);
234 }
235 else
236 {
237 std::vector<float> biasData = {10};
238 armnnDelegate::FillInput<float>(tfLiteInterpreter, 2, biasData);
239 armnnDelegate::FillInput<float>(armnnDelegateInterpreter, 2, biasData);
240 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000241 }
242
243 // Run EnqueWorkload
244 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
245 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
246
247 // Compare output data
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000248 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
249 armnnDelegateInterpreter,
250 outputTensorShape,
251 expectedOutputValues);
252 armnnDelegateInterpreter.reset(nullptr);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000253}
254
255} // anonymous namespace