blob: 100e8e96d526c2d95920af56a5384236ca78d4b7 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include <boost/test/unit_test.hpp>
6#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
Kevin May7d96b162021-02-03 17:38:41 +00009using armnnTfLiteParser::TfLiteParserImpl;
10using ModelPtr = TfLiteParserImpl::ModelPtr;
11using TensorRawPtr = TfLiteParserImpl::TensorRawPtr;
telsoa01c577f2c2018-08-31 09:22:23 +010012
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
16{
17 explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
18 {
19 m_JsonString = R"(
20 {
21 "version": 3,
22 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
23 "subgraphs": [
24 {
25 "tensors": [
26 {
27 "shape": [ 1, 1, 1, 1 ] ,
28 "type": "UINT8",
29 "buffer": 0,
30 "name": "OutputTensor",
31 "quantization": {
32 "min": [ 0.0 ],
33 "max": [ 255.0 ],
34 "scale": [ 1.0 ],
35 "zero_point": [ 0 ]
36 }
37 },
38 {
39 "shape": [ 1, 2, 2, 1 ] ,
40 "type": "UINT8",
41 "buffer": 1,
42 "name": "InputTensor",
43 "quantization": {
44 "min": [ -1.2 ],
45 "max": [ 25.5 ],
46 "scale": [ 0.25 ],
47 "zero_point": [ 10 ]
48 }
49 }
50 ],
51 "inputs": )"
52 + inputs
53 + R"(,
54 "outputs": )"
55 + outputs
56 + R"(,
57 "operators": [ {
58 "opcode_index": 0,
59 "inputs": [ 1 ],
60 "outputs": [ 0 ],
61 "builtin_options_type": "Pool2DOptions",
62 "builtin_options":
63 {
64 "padding": "VALID",
65 "stride_w": 2,
66 "stride_h": 2,
67 "filter_width": 2,
68 "filter_height": 2,
69 "fused_activation_function": "NONE"
70 },
71 "custom_options_format": "FLEXBUFFERS"
72 } ]
73 },
74 {
75 "tensors": [
76 {
77 "shape": [ 1, 3, 3, 1 ],
78 "type": "UINT8",
79 "buffer": 0,
80 "name": "ConvInputTensor",
81 "quantization": {
82 "scale": [ 1.0 ],
83 "zero_point": [ 0 ],
84 }
85 },
86 {
87 "shape": [ 1, 1, 1, 1 ],
88 "type": "UINT8",
89 "buffer": 1,
90 "name": "ConvOutputTensor",
91 "quantization": {
92 "min": [ 0.0 ],
93 "max": [ 511.0 ],
94 "scale": [ 2.0 ],
95 "zero_point": [ 0 ],
96 }
97 },
98 {
99 "shape": [ 1, 3, 3, 1 ],
100 "type": "UINT8",
101 "buffer": 2,
102 "name": "filterTensor",
103 "quantization": {
104 "min": [ 0.0 ],
105 "max": [ 255.0 ],
106 "scale": [ 1.0 ],
107 "zero_point": [ 0 ],
108 }
109 }
110 ],
111 "inputs": [ 0 ],
112 "outputs": [ 1 ],
113 "operators": [
114 {
115 "opcode_index": 0,
116 "inputs": [ 0, 2 ],
117 "outputs": [ 1 ],
118 "builtin_options_type": "Conv2DOptions",
119 "builtin_options": {
120 "padding": "VALID",
121 "stride_w": 1,
122 "stride_h": 1,
123 "fused_activation_function": "NONE"
124 },
125 "custom_options_format": "FLEXBUFFERS"
126 }
127 ],
128 }
129 ],
130 "description": "Test Subgraph Inputs Outputs",
131 "buffers" : [
132 { },
133 { },
134 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
135 { },
136 ]
137 })";
138
139 ReadStringToBinary();
140 }
141
142};
143
144struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
145{
146 GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
147};
148
149struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
150{
151 GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
152};
153
154BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphInputs, GetEmptySubgraphInputsOutputsFixture)
155{
Kevin May7d96b162021-02-03 17:38:41 +0000156 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
157 m_GraphBinary.size());
158 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
telsoa01c577f2c2018-08-31 09:22:23 +0100159 BOOST_CHECK_EQUAL(0, subgraphTensors.size());
160}
161
162BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphOutputs, GetEmptySubgraphInputsOutputsFixture)
163{
Kevin May7d96b162021-02-03 17:38:41 +0000164 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
165 m_GraphBinary.size());
166 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
telsoa01c577f2c2018-08-31 09:22:23 +0100167 BOOST_CHECK_EQUAL(0, subgraphTensors.size());
168}
169
170BOOST_FIXTURE_TEST_CASE(GetSubgraphInputs, GetSubgraphInputsOutputsFixture)
171{
Kevin May7d96b162021-02-03 17:38:41 +0000172 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
173 m_GraphBinary.size());
174 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 0);
telsoa01c577f2c2018-08-31 09:22:23 +0100175 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
176 BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
177 CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
178 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
179}
180
181BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsSimpleQuantized, GetSubgraphInputsOutputsFixture)
182{
Kevin May7d96b162021-02-03 17:38:41 +0000183 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
184 m_GraphBinary.size());
185 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 0);
telsoa01c577f2c2018-08-31 09:22:23 +0100186 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
187 BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
188 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
189 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
190}
191
192BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsEmptyMinMax, GetSubgraphInputsOutputsFixture)
193{
Kevin May7d96b162021-02-03 17:38:41 +0000194 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
195 m_GraphBinary.size());
196 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphInputs(model, 1);
telsoa01c577f2c2018-08-31 09:22:23 +0100197 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
198 BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
199 CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
200 "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
201}
202
203BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputs, GetSubgraphInputsOutputsFixture)
204{
Kevin May7d96b162021-02-03 17:38:41 +0000205 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
206 m_GraphBinary.size());
207 TfLiteParserImpl::TensorIdRawPtrVector subgraphTensors = TfLiteParserImpl::GetSubgraphOutputs(model, 1);
telsoa01c577f2c2018-08-31 09:22:23 +0100208 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
209 BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
210 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
211 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
212}
213
214BOOST_AUTO_TEST_CASE(GetSubgraphInputsNullModel)
215{
Kevin May7d96b162021-02-03 17:38:41 +0000216 BOOST_CHECK_THROW(TfLiteParserImpl::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100217}
218
219BOOST_AUTO_TEST_CASE(GetSubgraphOutputsNullModel)
220{
Kevin May7d96b162021-02-03 17:38:41 +0000221 BOOST_CHECK_THROW(TfLiteParserImpl::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100222}
223
224BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
225{
Kevin May7d96b162021-02-03 17:38:41 +0000226 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
227 m_GraphBinary.size());
228 BOOST_CHECK_THROW(TfLiteParserImpl::GetSubgraphInputs(model, 2), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100229}
230
231BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
232{
Kevin May7d96b162021-02-03 17:38:41 +0000233 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
234 m_GraphBinary.size());
235 BOOST_CHECK_THROW(TfLiteParserImpl::GetSubgraphOutputs(model, 2), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100236}
237
238BOOST_AUTO_TEST_SUITE_END()