blob: e9e5c092d6642f3a916ac6a12ef717c381b7d967 [file] [log] [blame]
Sadik Armagan6e36a642020-11-10 21:18:41 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan6e36a642020-11-10 21:18:41 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf0a6dec2021-03-25 07:46:55 +00008#include "TestUtils.hpp"
9
Sadik Armagan6e36a642020-11-10 21:18:41 +000010#include <armnn_delegate.hpp>
Matthew Sloyanebe392d2023-03-30 10:12:08 +010011#include <DelegateTestInterpreter.hpp>
Sadik Armagan6e36a642020-11-10 21:18:41 +000012
13#include <flatbuffers/flatbuffers.h>
Sadik Armagan6e36a642020-11-10 21:18:41 +000014#include <tensorflow/lite/kernels/register.h>
Sadik Armagan6e36a642020-11-10 21:18:41 +000015#include <tensorflow/lite/version.h>
16
Matthew Sloyanebe392d2023-03-30 10:12:08 +010017#include <schema_generated.h>
18
Sadik Armagan6e36a642020-11-10 21:18:41 +000019#include <doctest/doctest.h>
20
21namespace
22{
23
24template <typename T>
25std::vector<char> CreateFullyConnectedTfLiteModel(tflite::TensorType tensorType,
26 tflite::ActivationFunctionType activationType,
27 const std::vector <int32_t>& inputTensorShape,
28 const std::vector <int32_t>& weightsTensorShape,
29 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000030 std::vector <int32_t>& outputTensorShape,
31 std::vector <T>& weightsData,
32 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +000033 float quantScale = 1.0f,
34 int quantOffset = 0,
35 float outputQuantScale = 2.0f,
36 int outputQuantOffset = 0)
37{
38 using namespace tflite;
39 flatbuffers::FlatBufferBuilder flatBufferBuilder;
Ryan OShea238ecd92023-03-07 11:44:23 +000040 std::array<flatbuffers::Offset<tflite::Buffer>, 5> buffers;
41 buffers[0] = CreateBuffer(flatBufferBuilder);
42 buffers[1] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000043
44 auto biasTensorType = ::tflite::TensorType_FLOAT32;
Narumol Prangnawarat55518ca2020-11-20 14:50:54 +000045 if (tensorType == ::tflite::TensorType_INT8)
Sadik Armagan6e36a642020-11-10 21:18:41 +000046 {
47 biasTensorType = ::tflite::TensorType_INT32;
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000048 }
49 if (constantWeights)
50 {
Ryan OShea238ecd92023-03-07 11:44:23 +000051 buffers[2] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000052 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(weightsData.data()),
53 sizeof(T) * weightsData.size()));
Sadik Armagan6e36a642020-11-10 21:18:41 +000054
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000055 if (tensorType == ::tflite::TensorType_INT8)
56 {
57 std::vector<int32_t> biasData = { 10 };
Ryan OShea238ecd92023-03-07 11:44:23 +000058 buffers[3] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000059 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
60 sizeof(int32_t) * biasData.size()));
61
62 }
63 else
64 {
65 std::vector<float> biasData = { 10 };
Ryan OShea238ecd92023-03-07 11:44:23 +000066 buffers[3] = CreateBuffer(flatBufferBuilder,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +000067 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(biasData.data()),
68 sizeof(float) * biasData.size()));
69 }
Sadik Armagan6e36a642020-11-10 21:18:41 +000070 }
71 else
72 {
Ryan OShea238ecd92023-03-07 11:44:23 +000073 buffers[2] = CreateBuffer(flatBufferBuilder);
74 buffers[3] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000075 }
Ryan OShea238ecd92023-03-07 11:44:23 +000076 buffers[4] = CreateBuffer(flatBufferBuilder);
Sadik Armagan6e36a642020-11-10 21:18:41 +000077
78 auto quantizationParameters =
79 CreateQuantizationParameters(flatBufferBuilder,
80 0,
81 0,
82 flatBufferBuilder.CreateVector<float>({ quantScale }),
83 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
84
85 auto outputQuantizationParameters =
86 CreateQuantizationParameters(flatBufferBuilder,
87 0,
88 0,
89 flatBufferBuilder.CreateVector<float>({ outputQuantScale }),
90 flatBufferBuilder.CreateVector<int64_t>({ outputQuantOffset }));
91
92 std::array<flatbuffers::Offset<Tensor>, 4> tensors;
93 tensors[0] = CreateTensor(flatBufferBuilder,
94 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
95 inputTensorShape.size()),
96 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000097 1,
Sadik Armagan6e36a642020-11-10 21:18:41 +000098 flatBufferBuilder.CreateString("input_0"),
99 quantizationParameters);
100 tensors[1] = CreateTensor(flatBufferBuilder,
101 flatBufferBuilder.CreateVector<int32_t>(weightsTensorShape.data(),
102 weightsTensorShape.size()),
103 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000104 2,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000105 flatBufferBuilder.CreateString("weights"),
106 quantizationParameters);
107 tensors[2] = CreateTensor(flatBufferBuilder,
108 flatBufferBuilder.CreateVector<int32_t>(biasTensorShape.data(),
109 biasTensorShape.size()),
110 biasTensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000111 3,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000112 flatBufferBuilder.CreateString("bias"),
113 quantizationParameters);
114
115 tensors[3] = CreateTensor(flatBufferBuilder,
116 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
117 outputTensorShape.size()),
118 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000119 4,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000120 flatBufferBuilder.CreateString("output"),
121 outputQuantizationParameters);
122
123
124 // create operator
125 tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_FullyConnectedOptions;
126 flatbuffers::Offset<void> operatorBuiltinOptions =
127 CreateFullyConnectedOptions(flatBufferBuilder,
128 activationType,
129 FullyConnectedOptionsWeightsFormat_DEFAULT, false).Union();
130
Keith Davis892fafe2020-11-26 17:40:35 +0000131 const std::vector<int> operatorInputs{0, 1, 2};
132 const std::vector<int> operatorOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000133 flatbuffers::Offset <Operator> fullyConnectedOperator =
134 CreateOperator(flatBufferBuilder,
135 0,
136 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
137 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
138 operatorBuiltinOptionsType, operatorBuiltinOptions);
139
Keith Davis892fafe2020-11-26 17:40:35 +0000140 const std::vector<int> subgraphInputs{0, 1, 2};
141 const std::vector<int> subgraphOutputs{3};
Sadik Armagan6e36a642020-11-10 21:18:41 +0000142 flatbuffers::Offset <SubGraph> subgraph =
143 CreateSubGraph(flatBufferBuilder,
144 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
145 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
146 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
147 flatBufferBuilder.CreateVector(&fullyConnectedOperator, 1));
148
149 flatbuffers::Offset <flatbuffers::String> modelDescription =
150 flatBufferBuilder.CreateString("ArmnnDelegate: FullyConnected Operator Model");
151 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
152 tflite::BuiltinOperator_FULLY_CONNECTED);
153
154 flatbuffers::Offset <Model> flatbufferModel =
155 CreateModel(flatBufferBuilder,
156 TFLITE_SCHEMA_VERSION,
157 flatBufferBuilder.CreateVector(&operatorCode, 1),
158 flatBufferBuilder.CreateVector(&subgraph, 1),
159 modelDescription,
160 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
161
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100162 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000163
164 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
165 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
166}
167
168template <typename T>
169void FullyConnectedTest(std::vector<armnn::BackendId>& backends,
170 tflite::TensorType tensorType,
171 tflite::ActivationFunctionType activationType,
172 const std::vector <int32_t>& inputTensorShape,
173 const std::vector <int32_t>& weightsTensorShape,
174 const std::vector <int32_t>& biasTensorShape,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000175 std::vector <int32_t>& outputTensorShape,
176 std::vector <T>& inputValues,
177 std::vector <T>& expectedOutputValues,
178 std::vector <T>& weightsData,
179 bool constantWeights = true,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000180 float quantScale = 1.0f,
181 int quantOffset = 0)
182{
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100183 using namespace delegateTestInterpreter;
Sadik Armagan6e36a642020-11-10 21:18:41 +0000184
185 std::vector<char> modelBuffer = CreateFullyConnectedTfLiteModel(tensorType,
186 activationType,
187 inputTensorShape,
188 weightsTensorShape,
189 biasTensorShape,
190 outputTensorShape,
191 weightsData,
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000192 constantWeights,
Sadik Armagan6e36a642020-11-10 21:18:41 +0000193 quantScale,
194 quantOffset);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000195
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100196 // Setup interpreter with just TFLite Runtime.
197 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
198 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000199
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100200 // Setup interpreter with Arm NN Delegate applied.
201 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
202 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000203
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100204 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
205 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000206
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000207 if (!constantWeights)
Sadik Armagan6e36a642020-11-10 21:18:41 +0000208 {
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100209 CHECK(tfLiteInterpreter.FillInputTensor<T>(weightsData, 1) == kTfLiteOk);
210 CHECK(armnnInterpreter.FillInputTensor<T>(weightsData, 1) == kTfLiteOk);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000211
212 if (tensorType == ::tflite::TensorType_INT8)
213 {
214 std::vector <int32_t> biasData = {10};
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100215 CHECK(tfLiteInterpreter.FillInputTensor<int32_t>(biasData, 2) == kTfLiteOk);
216 CHECK(armnnInterpreter.FillInputTensor<int32_t>(biasData, 2) == kTfLiteOk);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000217 }
218 else
219 {
220 std::vector<float> biasData = {10};
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100221 CHECK(tfLiteInterpreter.FillInputTensor<float>(biasData, 2) == kTfLiteOk);
222 CHECK(armnnInterpreter.FillInputTensor<float>(biasData, 2) == kTfLiteOk);
Sadik Armaganf0a6dec2021-03-25 07:46:55 +0000223 }
Sadik Armagan6e36a642020-11-10 21:18:41 +0000224 }
225
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100226 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
227 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
228 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
Sadik Armagan6e36a642020-11-10 21:18:41 +0000229
Matthew Sloyanebe392d2023-03-30 10:12:08 +0100230 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
231 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
232 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
233
234 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
235 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShape);
236
237 tfLiteInterpreter.Cleanup();
238 armnnInterpreter.Cleanup();
Sadik Armagan6e36a642020-11-10 21:18:41 +0000239}
240
241} // anonymous namespace