blob: 20c1102bd93d55844d8e0bbf0c89d3d2eca326f5 [file] [log] [blame]
Sadik Armagan6e36a642020-11-10 21:18:41 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan6e36a642020-11-10 21:18:41 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00008#include "TestUtils.hpp"
9
Sadik Armagan6e36a642020-11-10 21:18:41 +000010#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000012
13#include <flatbuffers/flatbuffers.h>
Sadik Armagan6e36a642020-11-10 21:18:41 +000014#include <tensorflow/lite/kernels/register.h>
Sadik Armagan6e36a642020-11-10 21:18:41 +000015#include <tensorflow/lite/version.h>
16
17#include <doctest/doctest.h>
18
19namespace
20{
21
22template <typename T>
23std::vector<char> CreateFullyConnectedTfLiteModel(tflite::TensorType tensorType,
24 tflite::ActivationFunctionType activationType,
25 const std::vector <int32_t>& inputTensorShape,
26 const std::vector <int32_t>& weightsTensorShape,
27 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000028 std::vector <int32_t>& outputTensorShape,
29 std::vector <T>& weightsData,
30 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +000031 float quantScale = 1.0f,
32 int quantOffset = 0,
33 float outputQuantScale = 2.0f,
34 int outputQuantOffset = 0)
35{
36 using namespace tflite;
37 flatbuffers::FlatBufferBuilder flatBufferBuilder;
Ryan OShea238ecd92023-03-07 11:44:23 +000038 std::array<flatbuffers::Offset<tflite::Buffer>, 5> buffers;
39 buffers[0] = CreateBuffer(flatBufferBuilder);
40 buffers[1] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000041
42 auto biasTensorType = ::tflite::TensorType_FLOAT32;
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000043 if (tensorType == ::tflite::TensorType_INT8)
Sadik Armagan6e36a642020-11-10 21:18:41 +000044 {
45 biasTensorType = ::tflite::TensorType_INT32;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000046 }
47 if (constantWeights)
48 {
Ryan OShea238ecd92023-03-07 11:44:23 +000049 buffers[2] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000050 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(weightsData.data()),
51 sizeof(T) * weightsData.size()));
Sadik Armagan6e36a642020-11-10 21:18:41 +000052
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000053 if (tensorType == ::tflite::TensorType_INT8)
54 {
55 std::vector<int32_t> biasData = { 10 };
Ryan OShea238ecd92023-03-07 11:44:23 +000056 buffers[3] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000057 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
58 sizeof(int32_t) * biasData.size()));
59
60 }
61 else
62 {
63 std::vector<float> biasData = { 10 };
Ryan OShea238ecd92023-03-07 11:44:23 +000064 buffers[3] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000065 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
66 sizeof(float) * biasData.size()));
67 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000068 }
69 else
70 {
Ryan OShea238ecd92023-03-07 11:44:23 +000071 buffers[2] = CreateBuffer(flatBufferBuilder);
72 buffers[3] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000073 }
Ryan OShea238ecd92023-03-07 11:44:23 +000074 buffers[4] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000075
76 auto quantizationParameters =
77 CreateQuantizationParameters(flatBufferBuilder,
78 0,
79 0,
80 flatBufferBuilder.CreateVector<float>({ quantScale }),
81 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
82
83 auto outputQuantizationParameters =
84 CreateQuantizationParameters(flatBufferBuilder,
85 0,
86 0,
87 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
88 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
89
90 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
91 tensors[0] = CreateTensor(flatBufferBuilder,
92 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
93 inputTensorShape.size()),
94 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000095 1,
Sadik Armagan6e36a642020-11-10 21:18:41 +000096 flatBufferBuilder.CreateString("input_0"),
97 quantizationParameters);
98 tensors[1] = CreateTensor(flatBufferBuilder,
99 flatBufferBuilder.CreateVector<int32_t>(weightsTensorShape.data(),
100 weightsTensorShape.size()),
101 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000102 2,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000103 flatBufferBuilder.CreateString("weights"),
104 quantizationParameters);
105 tensors[2] = CreateTensor(flatBufferBuilder,
106 flatBufferBuilder.CreateVector<int32_t>(biasTensorShape.data(),
107 biasTensorShape.size()),
108 biasTensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000109 3,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000110 flatBufferBuilder.CreateString("bias"),
111 quantizationParameters);
112
113 tensors[3] = CreateTensor(flatBufferBuilder,
114 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
115 outputTensorShape.size()),
116 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000117 4,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000118 flatBufferBuilder.CreateString("output"),
119 outputQuantizationParameters);
120
121
122 // create operator
123 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_FullyConnectedOptions;
124 flatbuffers::Offset<void> operatorBuiltinOptions =
125 CreateFullyConnectedOptions(flatBufferBuilder,
126 activationType,
127 FullyConnectedOptionsWeightsFormat_DEFAULT, false).Union();
128
Keith Davis892fafe2020-11-26 17:40:35 +0000129 const std::vector<int> operatorInputs{0, 1, 2};
130 const std::vector<int> operatorOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000131 flatbuffers::Offset <Operator> fullyConnectedOperator =
132 CreateOperator(flatBufferBuilder,
133 0,
134 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
135 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
136 operatorBuiltinOptionsType, operatorBuiltinOptions);
137
Keith Davis892fafe2020-11-26 17:40:35 +0000138 const std::vector<int> subgraphInputs{0, 1, 2};
139 const std::vector<int> subgraphOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000140 flatbuffers::Offset <SubGraph> subgraph =
141 CreateSubGraph(flatBufferBuilder,
142 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
143 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
144 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
145 flatBufferBuilder.CreateVector(&fullyConnectedOperator, 1));
146
147 flatbuffers::Offset <flatbuffers::String> modelDescription =
148 flatBufferBuilder.CreateString("ArmnnDelegate: FullyConnected Operator Model");
149 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
150 tflite::BuiltinOperator_FULLY_CONNECTED);
151
152 flatbuffers::Offset <Model> flatbufferModel =
153 CreateModel(flatBufferBuilder,
154 TFLITE_SCHEMA_VERSION,
155 flatBufferBuilder.CreateVector(&operatorCode, 1),
156 flatBufferBuilder.CreateVector(&subgraph, 1),
157 modelDescription,
158 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
159
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100160 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000161
162 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
163 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
164}
165
166template <typename T>
167void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
168 tflite::TensorType tensorType,
169 tflite::ActivationFunctionType activationType,
170 const std::vector <int32_t>& inputTensorShape,
171 const std::vector <int32_t>& weightsTensorShape,
172 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000173 std::vector <int32_t>& outputTensorShape,
174 std::vector <T>& inputValues,
175 std::vector <T>& expectedOutputValues,
176 std::vector <T>& weightsData,
177 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000178 float quantScale = 1.0f,
179 int quantOffset = 0)
180{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100181 using namespace delegateTestInterpreter;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000182
183 std::vector<char> modelBuffer = CreateFullyConnectedTfLiteModel(tensorType,
184 activationType,
185 inputTensorShape,
186 weightsTensorShape,
187 biasTensorShape,
188 outputTensorShape,
189 weightsData,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000190 constantWeights,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000191 quantScale,
192 quantOffset);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000193
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100194 // Setup interpreter with just TFLite Runtime.
195 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
196 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000197
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100198 // Setup interpreter with Arm NN Delegate applied.
199 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
200 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000201
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100202 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
203 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000204
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000205 if (!constantWeights)
Sadik Armagan6e36a642020-11-10 21:18:41 +0000206 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100207 CHECK(tfLiteInterpreter.FillInputTensor<T>(weightsData, 1) == kTfLiteOk);
208 CHECK(armnnInterpreter.FillInputTensor<T>(weightsData, 1) == kTfLiteOk);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000209
210 if (tensorType == ::tflite::TensorType_INT8)
211 {
212 std::vector <int32_t> biasData = {10};
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100213 CHECK(tfLiteInterpreter.FillInputTensor<int32_t>(biasData, 2) == kTfLiteOk);
214 CHECK(armnnInterpreter.FillInputTensor<int32_t>(biasData, 2) == kTfLiteOk);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000215 }
216 else
217 {
218 std::vector<float> biasData = {10};
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100219 CHECK(tfLiteInterpreter.FillInputTensor<float>(biasData, 2) == kTfLiteOk);
220 CHECK(armnnInterpreter.FillInputTensor<float>(biasData, 2) == kTfLiteOk);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000221 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000222 }
223
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100224 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
225 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
226 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000227
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100228 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
229 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
230 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
231
232 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
233 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShape);
234
235 tfLiteInterpreter.Cleanup();
236 armnnInterpreter.Cleanup();
Sadik Armagan6e36a642020-11-10 21:18:41 +0000237}
238
239} // anonymous namespace