blob: 7e6808d11e05c5094002ec4bbe1157a34f65bddb [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include <boost/test/unit_test.hpp>
6#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
9using armnnTfLiteParser::TfLiteParser;
10using ModelPtr = TfLiteParser::ModelPtr;
11using TensorRawPtr = TfLiteParser::TensorRawPtr;
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct GetSubgraphInputsOutputsMainFixture : public ParserFlatbuffersFixture
16{
17 explicit GetSubgraphInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
18 {
19 m_JsonString = R"(
20 {
21 "version": 3,
22 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
23 "subgraphs": [
24 {
25 "tensors": [
26 {
27 "shape": [ 1, 1, 1, 1 ] ,
28 "type": "UINT8",
29 "buffer": 0,
30 "name": "OutputTensor",
31 "quantization": {
32 "min": [ 0.0 ],
33 "max": [ 255.0 ],
34 "scale": [ 1.0 ],
35 "zero_point": [ 0 ]
36 }
37 },
38 {
39 "shape": [ 1, 2, 2, 1 ] ,
40 "type": "UINT8",
41 "buffer": 1,
42 "name": "InputTensor",
43 "quantization": {
44 "min": [ -1.2 ],
45 "max": [ 25.5 ],
46 "scale": [ 0.25 ],
47 "zero_point": [ 10 ]
48 }
49 }
50 ],
51 "inputs": )"
52 + inputs
53 + R"(,
54 "outputs": )"
55 + outputs
56 + R"(,
57 "operators": [ {
58 "opcode_index": 0,
59 "inputs": [ 1 ],
60 "outputs": [ 0 ],
61 "builtin_options_type": "Pool2DOptions",
62 "builtin_options":
63 {
64 "padding": "VALID",
65 "stride_w": 2,
66 "stride_h": 2,
67 "filter_width": 2,
68 "filter_height": 2,
69 "fused_activation_function": "NONE"
70 },
71 "custom_options_format": "FLEXBUFFERS"
72 } ]
73 },
74 {
75 "tensors": [
76 {
77 "shape": [ 1, 3, 3, 1 ],
78 "type": "UINT8",
79 "buffer": 0,
80 "name": "ConvInputTensor",
81 "quantization": {
82 "scale": [ 1.0 ],
83 "zero_point": [ 0 ],
84 }
85 },
86 {
87 "shape": [ 1, 1, 1, 1 ],
88 "type": "UINT8",
89 "buffer": 1,
90 "name": "ConvOutputTensor",
91 "quantization": {
92 "min": [ 0.0 ],
93 "max": [ 511.0 ],
94 "scale": [ 2.0 ],
95 "zero_point": [ 0 ],
96 }
97 },
98 {
99 "shape": [ 1, 3, 3, 1 ],
100 "type": "UINT8",
101 "buffer": 2,
102 "name": "filterTensor",
103 "quantization": {
104 "min": [ 0.0 ],
105 "max": [ 255.0 ],
106 "scale": [ 1.0 ],
107 "zero_point": [ 0 ],
108 }
109 }
110 ],
111 "inputs": [ 0 ],
112 "outputs": [ 1 ],
113 "operators": [
114 {
115 "opcode_index": 0,
116 "inputs": [ 0, 2 ],
117 "outputs": [ 1 ],
118 "builtin_options_type": "Conv2DOptions",
119 "builtin_options": {
120 "padding": "VALID",
121 "stride_w": 1,
122 "stride_h": 1,
123 "fused_activation_function": "NONE"
124 },
125 "custom_options_format": "FLEXBUFFERS"
126 }
127 ],
128 }
129 ],
130 "description": "Test Subgraph Inputs Outputs",
131 "buffers" : [
132 { },
133 { },
134 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
135 { },
136 ]
137 })";
138
139 ReadStringToBinary();
140 }
141
142};
143
144struct GetEmptySubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
145{
146 GetEmptySubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ ]", "[ ]") {}
147};
148
149struct GetSubgraphInputsOutputsFixture : GetSubgraphInputsOutputsMainFixture
150{
151 GetSubgraphInputsOutputsFixture() : GetSubgraphInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
152};
153
154BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphInputs, GetEmptySubgraphInputsOutputsFixture)
155{
156 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
157 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 0);
158 BOOST_CHECK_EQUAL(0, subgraphTensors.size());
159}
160
161BOOST_FIXTURE_TEST_CASE(GetEmptySubgraphOutputs, GetEmptySubgraphInputsOutputsFixture)
162{
163 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
164 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 0);
165 BOOST_CHECK_EQUAL(0, subgraphTensors.size());
166}
167
168BOOST_FIXTURE_TEST_CASE(GetSubgraphInputs, GetSubgraphInputsOutputsFixture)
169{
170 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
171 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 0);
172 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
173 BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
174 CheckTensors(subgraphTensors[0].second, 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
175 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
176}
177
178BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsSimpleQuantized, GetSubgraphInputsOutputsFixture)
179{
180 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
181 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 0);
182 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
183 BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
184 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
185 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
186}
187
188BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsEmptyMinMax, GetSubgraphInputsOutputsFixture)
189{
190 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
191 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphInputs(model, 1);
192 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
193 BOOST_CHECK_EQUAL(0, subgraphTensors[0].first);
194 CheckTensors(subgraphTensors[0].second, 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
195 "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
196}
197
198BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputs, GetSubgraphInputsOutputsFixture)
199{
200 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
201 TfLiteParser::TensorIdRawPtrVector subgraphTensors = TfLiteParser::GetSubgraphOutputs(model, 1);
202 BOOST_CHECK_EQUAL(1, subgraphTensors.size());
203 BOOST_CHECK_EQUAL(1, subgraphTensors[0].first);
204 CheckTensors(subgraphTensors[0].second, 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
205 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
206}
207
208BOOST_AUTO_TEST_CASE(GetSubgraphInputsNullModel)
209{
210 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphInputs(nullptr, 0), armnn::ParseException);
211}
212
213BOOST_AUTO_TEST_CASE(GetSubgraphOutputsNullModel)
214{
215 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphOutputs(nullptr, 0), armnn::ParseException);
216}
217
218BOOST_FIXTURE_TEST_CASE(GetSubgraphInputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
219{
220 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
221 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphInputs(model, 2), armnn::ParseException);
222}
223
224BOOST_FIXTURE_TEST_CASE(GetSubgraphOutputsInvalidSubgraph, GetSubgraphInputsOutputsFixture)
225{
226 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
227 BOOST_CHECK_THROW(TfLiteParser::GetSubgraphOutputs(model, 2), armnn::ParseException);
228}
229
230BOOST_AUTO_TEST_SUITE_END()