blob: fc000bf95be82318979bb230a3a36f5d5fe57baa [file] [log] [blame]
Sadik Armagan8853c1f2018-10-22 09:04:18 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Sadik Armagan8853c1f2018-10-22 09:04:18 +01003// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan8853c1f2018-10-22 09:04:18 +01006#include "ParserFlatbuffersFixture.hpp"
Sadik Armagan8853c1f2018-10-22 09:04:18 +01007
Sadik Armagan8853c1f2018-10-22 09:04:18 +01008
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("TensorflowLiteParser_FullyConnected")
10{
Sadik Armagan8853c1f2018-10-22 09:04:18 +010011struct FullyConnectedFixture : public ParserFlatbuffersFixture
12{
13 explicit FullyConnectedFixture(const std::string& inputShape,
Finn Williamsd4fa5452021-03-01 12:31:41 +000014 const std::string& outputShape,
15 const std::string& filterShape,
16 const std::string& filterData,
17 const std::string biasShape = "",
18 const std::string biasData = "")
Sadik Armagan8853c1f2018-10-22 09:04:18 +010019 {
20 std::string inputTensors = "[ 0, 2 ]";
21 std::string biasTensor = "";
22 std::string biasBuffer = "";
23 if (biasShape.size() > 0 && biasData.size() > 0)
24 {
25 inputTensors = "[ 0, 2, 3 ]";
26 biasTensor = R"(
27 {
28 "shape": )" + biasShape + R"( ,
29 "type": "INT32",
30 "buffer": 3,
31 "name": "biasTensor",
32 "quantization": {
33 "min": [ 0.0 ],
34 "max": [ 255.0 ],
35 "scale": [ 1.0 ],
36 "zero_point": [ 0 ],
37 }
38 } )";
39 biasBuffer = R"(
40 { "data": )" + biasData + R"(, }, )";
41 }
42 m_JsonString = R"(
43 {
44 "version": 3,
45 "operator_codes": [ { "builtin_code": "FULLY_CONNECTED" } ],
46 "subgraphs": [ {
47 "tensors": [
48 {
49 "shape": )" + inputShape + R"(,
50 "type": "UINT8",
51 "buffer": 0,
52 "name": "inputTensor",
53 "quantization": {
54 "min": [ 0.0 ],
55 "max": [ 255.0 ],
56 "scale": [ 1.0 ],
57 "zero_point": [ 0 ],
58 }
59 },
60 {
61 "shape": )" + outputShape + R"(,
62 "type": "UINT8",
63 "buffer": 1,
64 "name": "outputTensor",
65 "quantization": {
66 "min": [ 0.0 ],
67 "max": [ 511.0 ],
68 "scale": [ 2.0 ],
69 "zero_point": [ 0 ],
70 }
71 },
72 {
73 "shape": )" + filterShape + R"(,
74 "type": "UINT8",
75 "buffer": 2,
76 "name": "filterTensor",
77 "quantization": {
78 "min": [ 0.0 ],
79 "max": [ 255.0 ],
80 "scale": [ 1.0 ],
81 "zero_point": [ 0 ],
82 }
83 }, )" + biasTensor + R"(
84 ],
85 "inputs": [ 0 ],
86 "outputs": [ 1 ],
87 "operators": [
88 {
89 "opcode_index": 0,
90 "inputs": )" + inputTensors + R"(,
91 "outputs": [ 1 ],
92 "builtin_options_type": "FullyConnectedOptions",
93 "builtin_options": {
94 "fused_activation_function": "NONE"
95 },
96 "custom_options_format": "FLEXBUFFERS"
97 }
98 ],
99 } ],
100 "buffers" : [
101 { },
102 { },
103 { "data": )" + filterData + R"(, }, )"
104 + biasBuffer + R"(
105 ]
106 }
107 )";
108 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
109 }
110};
111
112struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
113{
114 FullyConnectedWithNoBiasFixture()
115 : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
116 "[ 1, 1 ]", // outputShape
Nattapat Chaimanowongd8eee592018-10-26 10:24:14 +0100117 "[ 1, 4 ]", // filterShape
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100118 "[ 2, 3, 4, 5 ]") // filterData
119 {}
120};
121
Sadik Armagan1625efc2021-06-10 18:24:34 +0100122TEST_CASE_FIXTURE(FullyConnectedWithNoBiasFixture, "FullyConnectedWithNoBias")
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100123{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000124 RunTest<2, armnn::DataType::QAsymmU8>(
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100125 0,
126 { 10, 20, 30, 40 },
127 { 400/2 });
128}
129
130struct FullyConnectedWithBiasFixture : FullyConnectedFixture
131{
132 FullyConnectedWithBiasFixture()
133 : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
134 "[ 1, 1 ]", // outputShape
Nattapat Chaimanowongd8eee592018-10-26 10:24:14 +0100135 "[ 1, 4 ]", // filterShape
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100136 "[ 2, 3, 4, 5 ]", // filterData
137 "[ 1 ]", // biasShape
138 "[ 10, 0, 0, 0 ]" ) // biasData
139 {}
140};
141
Sadik Armagan1625efc2021-06-10 18:24:34 +0100142TEST_CASE_FIXTURE(FullyConnectedWithBiasFixture, "ParseFullyConnectedWithBias")
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100143{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000144 RunTest<2, armnn::DataType::QAsymmU8>(
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100145 0,
146 { 10, 20, 30, 40 },
147 { (400+10)/2 });
148}
149
Narumol Prangnawarat501f4d42019-04-24 15:52:20 +0100150struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
151{
152 FullyConnectedWithBiasMultipleOutputsFixture()
153 : FullyConnectedFixture("[ 1, 4, 2, 1 ]", // inputShape
154 "[ 2, 1 ]", // outputShape
155 "[ 1, 4 ]", // filterShape
156 "[ 2, 3, 4, 5 ]", // filterData
157 "[ 1 ]", // biasShape
158 "[ 10, 0, 0, 0 ]" ) // biasData
159 {}
160};
161
Sadik Armagan1625efc2021-06-10 18:24:34 +0100162TEST_CASE_FIXTURE(FullyConnectedWithBiasMultipleOutputsFixture, "FullyConnectedWithBiasMultipleOutputs")
Narumol Prangnawarat501f4d42019-04-24 15:52:20 +0100163{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000164 RunTest<2, armnn::DataType::QAsymmU8>(
Narumol Prangnawarat501f4d42019-04-24 15:52:20 +0100165 0,
166 { 1, 2, 3, 4, 10, 20, 30, 40 },
167 { (40+10)/2, (400+10)/2 });
168}
169
Sadik Armagand109a4d2020-07-28 10:42:13 +0100170struct DynamicFullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
171{
172 DynamicFullyConnectedWithBiasMultipleOutputsFixture()
173 : FullyConnectedFixture("[ 1, 4, 2, 1 ]", // inputShape
174 "[ ]", // outputShape
175 "[ 1, 4 ]", // filterShape
176 "[ 2, 3, 4, 5 ]", // filterData
177 "[ 1 ]", // biasShape
178 "[ 10, 0, 0, 0 ]" ) // biasData
179 { }
180};
181
Sadik Armagan1625efc2021-06-10 18:24:34 +0100182TEST_CASE_FIXTURE(
183 DynamicFullyConnectedWithBiasMultipleOutputsFixture, "DynamicFullyConnectedWithBiasMultipleOutputs")
Sadik Armagand109a4d2020-07-28 10:42:13 +0100184{
185 RunTest<2,
186 armnn::DataType::QAsymmU8,
187 armnn::DataType::QAsymmU8>(0,
188 { { "inputTensor", { 1, 2, 3, 4, 10, 20, 30, 40} } },
189 { { "outputTensor", { (40+10)/2, (400+10)/2 } } },
190 true);
191}
192
Finn Williamsd4fa5452021-03-01 12:31:41 +0000193
194struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
195{
196 explicit FullyConnectedNonConstWeightsFixture(const std::string& inputShape,
197 const std::string& outputShape,
198 const std::string& filterShape,
199 const std::string biasShape = "")
200 {
201 std::string inputTensors = "[ 0, 1 ]";
202 std::string biasTensor = "";
203 std::string biasBuffer = "";
204 std::string outputs = "2";
205 if (biasShape.size() > 0)
206 {
207 inputTensors = "[ 0, 1, 2 ]";
208 biasTensor = R"(
209 {
210 "shape": )" + biasShape + R"(,
211 "type": "INT32",
212 "buffer": 2,
213 "name": "bias",
214 "quantization": {
215 "scale": [ 1.0 ],
216 "zero_point": [ 0 ],
217 "details_type": 0,
218 "quantized_dimension": 0
219 },
220 "is_variable": true
221 }, )";
222
mathad01bf7edb62021-04-20 16:12:45 +0100223 biasBuffer = R"(,{ "data": [] } )";
Finn Williamsd4fa5452021-03-01 12:31:41 +0000224 outputs = "3";
225 }
226 m_JsonString = R"(
227 {
228 "version": 3,
229 "operator_codes": [
230 {
231 "builtin_code": "FULLY_CONNECTED",
232 "version": 1
233 }
234 ],
235 "subgraphs": [
236 {
237 "tensors": [
238 {
239 "shape": )" + inputShape + R"(,
240 "type": "INT8",
241 "buffer": 0,
242 "name": "input_0",
243 "quantization": {
244 "scale": [ 1.0 ],
245 "zero_point": [ 0 ],
246 "details_type": 0,
247 "quantized_dimension": 0
248 },
Finn Williamsd4fa5452021-03-01 12:31:41 +0000249 },
250 {
251 "shape": )" + filterShape + R"(,
252 "type": "INT8",
253 "buffer": 1,
254 "name": "weights",
255 "quantization": {
256 "scale": [ 1.0 ],
257 "zero_point": [ 0 ],
258 "details_type": 0,
259 "quantized_dimension": 0
260 },
Finn Williamsd4fa5452021-03-01 12:31:41 +0000261 },
262 )" + biasTensor + R"(
263 {
264 "shape": )" + outputShape + R"(,
265 "type": "INT8",
266 "buffer": 0,
267 "name": "output",
268 "quantization": {
269 "scale": [
270 2.0
271 ],
272 "zero_point": [
273 0
274 ],
275 "details_type": 0,
276 "quantized_dimension": 0
277 },
Finn Williamsd4fa5452021-03-01 12:31:41 +0000278 }
279 ],
280 "inputs": )" + inputTensors + R"(,
281 "outputs": [ )" + outputs + R"( ],
282 "operators": [
283 {
284 "opcode_index": 0,
285 "inputs": )" + inputTensors + R"(,
286 "outputs": [ )" + outputs + R"( ],
287 "builtin_options_type": "FullyConnectedOptions",
288 "builtin_options": {
289 "fused_activation_function": "NONE",
290 "weights_format": "DEFAULT",
291 "keep_num_dims": false,
292 "asymmetric_quantize_inputs": false
293 },
294 "custom_options_format": "FLEXBUFFERS"
295 }
296 ]
297 }
298 ],
299 "description": "ArmnnDelegate: FullyConnected Operator Model",
300 "buffers": [
301 {
302 "data": []
303 },
304 {
mathad01bf7edb62021-04-20 16:12:45 +0100305 "data": []
Finn Williamsd4fa5452021-03-01 12:31:41 +0000306 }
307 )" + biasBuffer + R"(
308 ]
309 }
310 )";
311 Setup();
312 }
313};
314
315struct FullyConnectedNonConstWeights : FullyConnectedNonConstWeightsFixture
316{
317 FullyConnectedNonConstWeights()
318 : FullyConnectedNonConstWeightsFixture("[ 1, 4, 1, 1 ]", // inputShape
319 "[ 1, 1 ]", // outputShape
320 "[ 1, 4 ]", // filterShape
321 "[ 1 ]" ) // biasShape
322
323 {}
324};
325
Sadik Armagan1625efc2021-06-10 18:24:34 +0100326TEST_CASE_FIXTURE(FullyConnectedNonConstWeights, "ParseFullyConnectedNonConstWeights")
Finn Williamsd4fa5452021-03-01 12:31:41 +0000327{
328 RunTest<2, armnn::DataType::QAsymmS8,
329 armnn::DataType::Signed32,
330 armnn::DataType::QAsymmS8>(
331 0,
332 {{{"input_0", { 1, 2, 3, 4 }},{"weights", { 2, 3, 4, 5 }}}},
333 {{"bias", { 10 }}},
334 {{"output", { 25 }}});
335}
336
337struct FullyConnectedNonConstWeightsNoBias : FullyConnectedNonConstWeightsFixture
338{
339 FullyConnectedNonConstWeightsNoBias()
340 : FullyConnectedNonConstWeightsFixture("[ 1, 4, 1, 1 ]", // inputShape
341 "[ 1, 1 ]", // outputShape
342 "[ 1, 4 ]") // filterShape
343
344 {}
345};
346
Sadik Armagan1625efc2021-06-10 18:24:34 +0100347TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonConstWeightsNoBias")
Finn Williamsd4fa5452021-03-01 12:31:41 +0000348{
349 RunTest<2, armnn::DataType::QAsymmS8,
350 armnn::DataType::QAsymmS8>(
351 0,
352 {{{"input_0", { 1, 2, 3, 4 }},{"weights", { 2, 3, 4, 5 }}}},
353 {{"output", { 20 }}});
354}
355
Sadik Armagan1625efc2021-06-10 18:24:34 +0100356}