blob: 37062c34003ca78f6f1ab37bb40dcd92b55aa23c [file] [log] [blame]
Sadik Armagan6e36a642020-11-10 21:18:41 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00008#include "TestUtils.hpp"
9
Sadik Armagan6e36a642020-11-10 21:18:41 +000010#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
24template <typename T>
25std::vector<char> CreateFullyConnectedTfLiteModel(tflite::TensorType tensorType,
26 tflite::ActivationFunctionType activationType,
27 const std::vector <int32_t>& inputTensorShape,
28 const std::vector <int32_t>& weightsTensorShape,
29 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000030 std::vector <int32_t>& outputTensorShape,
31 std::vector <T>& weightsData,
32 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +000033 float quantScale = 1.0f,
34 int quantOffset = 0,
35 float outputQuantScale = 2.0f,
36 int outputQuantOffset = 0)
37{
38 using namespace tflite;
39 flatbuffers::FlatBufferBuilder flatBufferBuilder;
40 std::array<flatbuffers::Offset<tflite::Buffer>, 3> buffers;
41 buffers[0] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
Sadik Armagan6e36a642020-11-10 21:18:41 +000042
43 auto biasTensorType = ::tflite::TensorType_FLOAT32;
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000044 if (tensorType == ::tflite::TensorType_INT8)
Sadik Armagan6e36a642020-11-10 21:18:41 +000045 {
46 biasTensorType = ::tflite::TensorType_INT32;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000047 }
48 if (constantWeights)
49 {
50 buffers[1] = CreateBuffer(flatBufferBuilder,
51 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(weightsData.data()),
52 sizeof(T) * weightsData.size()));
Sadik Armagan6e36a642020-11-10 21:18:41 +000053
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000054 if (tensorType == ::tflite::TensorType_INT8)
55 {
56 std::vector<int32_t> biasData = { 10 };
57 buffers[2] = CreateBuffer(flatBufferBuilder,
58 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
59 sizeof(int32_t) * biasData.size()));
60
61 }
62 else
63 {
64 std::vector<float> biasData = { 10 };
65 buffers[2] = CreateBuffer(flatBufferBuilder,
66 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
67 sizeof(float) * biasData.size()));
68 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000069 }
70 else
71 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000072 buffers[1] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
73 buffers[2] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
Sadik Armagan6e36a642020-11-10 21:18:41 +000074 }
75
76 auto quantizationParameters =
77 CreateQuantizationParameters(flatBufferBuilder,
78 0,
79 0,
80 flatBufferBuilder.CreateVector<float>({ quantScale }),
81 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
82
83 auto outputQuantizationParameters =
84 CreateQuantizationParameters(flatBufferBuilder,
85 0,
86 0,
87 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
88 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
89
90 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
91 tensors[0] = CreateTensor(flatBufferBuilder,
92 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
93 inputTensorShape.size()),
94 tensorType,
95 0,
96 flatBufferBuilder.CreateString("input_0"),
97 quantizationParameters);
98 tensors[1] = CreateTensor(flatBufferBuilder,
99 flatBufferBuilder.CreateVector<int32_t>(weightsTensorShape.data(),
100 weightsTensorShape.size()),
101 tensorType,
102 1,
103 flatBufferBuilder.CreateString("weights"),
104 quantizationParameters);
105 tensors[2] = CreateTensor(flatBufferBuilder,
106 flatBufferBuilder.CreateVector<int32_t>(biasTensorShape.data(),
107 biasTensorShape.size()),
108 biasTensorType,
109 2,
110 flatBufferBuilder.CreateString("bias"),
111 quantizationParameters);
112
113 tensors[3] = CreateTensor(flatBufferBuilder,
114 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
115 outputTensorShape.size()),
116 tensorType,
117 0,
118 flatBufferBuilder.CreateString("output"),
119 outputQuantizationParameters);
120
121
122 // create operator
123 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_FullyConnectedOptions;
124 flatbuffers::Offset<void> operatorBuiltinOptions =
125 CreateFullyConnectedOptions(flatBufferBuilder,
126 activationType,
127 FullyConnectedOptionsWeightsFormat_DEFAULT, false).Union();
128
Keith Davis892fafe2020-11-26 17:40:35 +0000129 const std::vector<int> operatorInputs{0, 1, 2};
130 const std::vector<int> operatorOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000131 flatbuffers::Offset <Operator> fullyConnectedOperator =
132 CreateOperator(flatBufferBuilder,
133 0,
134 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
135 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
136 operatorBuiltinOptionsType, operatorBuiltinOptions);
137
Keith Davis892fafe2020-11-26 17:40:35 +0000138 const std::vector<int> subgraphInputs{0, 1, 2};
139 const std::vector<int> subgraphOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000140 flatbuffers::Offset <SubGraph> subgraph =
141 CreateSubGraph(flatBufferBuilder,
142 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
143 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
144 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
145 flatBufferBuilder.CreateVector(&fullyConnectedOperator, 1));
146
147 flatbuffers::Offset <flatbuffers::String> modelDescription =
148 flatBufferBuilder.CreateString("ArmnnDelegate: FullyConnected Operator Model");
149 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
150 tflite::BuiltinOperator_FULLY_CONNECTED);
151
152 flatbuffers::Offset <Model> flatbufferModel =
153 CreateModel(flatBufferBuilder,
154 TFLITE_SCHEMA_VERSION,
155 flatBufferBuilder.CreateVector(&operatorCode, 1),
156 flatBufferBuilder.CreateVector(&subgraph, 1),
157 modelDescription,
158 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
159
160 flatBufferBuilder.Finish(flatbufferModel);
161
162 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
163 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
164}
165
166template <typename T>
167void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
168 tflite::TensorType tensorType,
169 tflite::ActivationFunctionType activationType,
170 const std::vector <int32_t>& inputTensorShape,
171 const std::vector <int32_t>& weightsTensorShape,
172 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000173 std::vector <int32_t>& outputTensorShape,
174 std::vector <T>& inputValues,
175 std::vector <T>& expectedOutputValues,
176 std::vector <T>& weightsData,
177 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000178 float quantScale = 1.0f,
179 int quantOffset = 0)
180{
181 using namespace tflite;
182
183 std::vector<char> modelBuffer = CreateFullyConnectedTfLiteModel(tensorType,
184 activationType,
185 inputTensorShape,
186 weightsTensorShape,
187 biasTensorShape,
188 outputTensorShape,
189 weightsData,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000190 constantWeights,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000191 quantScale,
192 quantOffset);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000193 const Model* tfLiteModel = GetModel(modelBuffer.data());
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000194
Sadik Armagan6e36a642020-11-10 21:18:41 +0000195 // Create TfLite Interpreters
196 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
197 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
198 (&armnnDelegateInterpreter) == kTfLiteOk);
199 CHECK(armnnDelegateInterpreter != nullptr);
200 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
201
202 std::unique_ptr<Interpreter> tfLiteInterpreter;
203 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
204 (&tfLiteInterpreter) == kTfLiteOk);
205 CHECK(tfLiteInterpreter != nullptr);
206 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
207
208 // Create the ArmNN Delegate
209 armnnDelegate::DelegateOptions delegateOptions(backends);
210 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000211 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
212 armnnDelegate::TfLiteArmnnDelegateDelete);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000213 CHECK(theArmnnDelegate != nullptr);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000214
Sadik Armagan6e36a642020-11-10 21:18:41 +0000215 // Modify armnnDelegateInterpreter to use armnnDelegate
216 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
217
218 // Set input data
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000219 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, inputValues);
220 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, inputValues);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000221
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000222 if (!constantWeights)
Sadik Armagan6e36a642020-11-10 21:18:41 +0000223 {
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000224 armnnDelegate::FillInput<T>(tfLiteInterpreter, 1, weightsData);
225 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 1, weightsData);
226
227 if (tensorType == ::tflite::TensorType_INT8)
228 {
229 std::vector <int32_t> biasData = {10};
230 armnnDelegate::FillInput<int32_t>(tfLiteInterpreter, 2, biasData);
231 armnnDelegate::FillInput<int32_t>(armnnDelegateInterpreter, 2, biasData);
232 }
233 else
234 {
235 std::vector<float> biasData = {10};
236 armnnDelegate::FillInput<float>(tfLiteInterpreter, 2, biasData);
237 armnnDelegate::FillInput<float>(armnnDelegateInterpreter, 2, biasData);
238 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000239 }
240
241 // Run EnqueWorkload
242 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
243 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
244
245 // Compare output data
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000246 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
247 armnnDelegateInterpreter,
248 outputTensorShape,
249 expectedOutputValues);
250 armnnDelegateInterpreter.reset(nullptr);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000251}
252
253} // anonymous namespace