blob: 2c12c1976a40847be4c6fa8fc35937b1d2f4f739 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include <boost/test/unit_test.hpp>
6#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
9using armnnTfLiteParser::TfLiteParser;
10using ModelPtr = TfLiteParser::ModelPtr;
11
12BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13
14struct GetInputsOutputsMainFixture : public ParserFlatbuffersFixture
15{
16 explicit GetInputsOutputsMainFixture(const std::string& inputs, const std::string& outputs)
17 {
18 m_JsonString = R"(
19 {
20 "version": 3,
21 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
22 "subgraphs": [
23 {
24 "tensors": [
25 {
26 "shape": [ 1, 1, 1, 1 ] ,
27 "type": "UINT8",
28 "buffer": 0,
29 "name": "OutputTensor",
30 "quantization": {
31 "min": [ 0.0 ],
32 "max": [ 255.0 ],
33 "scale": [ 1.0 ],
34 "zero_point": [ 0 ]
35 }
36 },
37 {
38 "shape": [ 1, 2, 2, 1 ] ,
39 "type": "UINT8",
40 "buffer": 1,
41 "name": "InputTensor",
42 "quantization": {
43 "min": [ -1.2 ],
44 "max": [ 25.5 ],
45 "scale": [ 0.25 ],
46 "zero_point": [ 10 ]
47 }
48 }
49 ],
50 "inputs": [ 1 ],
51 "outputs": [ 0 ],
52 "operators": [ {
53 "opcode_index": 0,
54 "inputs": )"
55 + inputs
56 + R"(,
57 "outputs": )"
58 + outputs
59 + R"(,
60 "builtin_options_type": "Pool2DOptions",
61 "builtin_options":
62 {
63 "padding": "VALID",
64 "stride_w": 2,
65 "stride_h": 2,
66 "filter_width": 2,
67 "filter_height": 2,
68 "fused_activation_function": "NONE"
69 },
70 "custom_options_format": "FLEXBUFFERS"
71 } ]
72 },
73 {
74 "tensors": [
75 {
76 "shape": [ 1, 3, 3, 1 ],
77 "type": "UINT8",
78 "buffer": 0,
79 "name": "ConvInputTensor",
80 "quantization": {
81 "scale": [ 1.0 ],
82 "zero_point": [ 0 ],
83 }
84 },
85 {
86 "shape": [ 1, 1, 1, 1 ],
87 "type": "UINT8",
88 "buffer": 1,
89 "name": "ConvOutputTensor",
90 "quantization": {
91 "min": [ 0.0 ],
92 "max": [ 511.0 ],
93 "scale": [ 2.0 ],
94 "zero_point": [ 0 ],
95 }
96 },
97 {
98 "shape": [ 1, 3, 3, 1 ],
99 "type": "UINT8",
100 "buffer": 2,
101 "name": "filterTensor",
102 "quantization": {
103 "min": [ 0.0 ],
104 "max": [ 255.0 ],
105 "scale": [ 1.0 ],
106 "zero_point": [ 0 ],
107 }
108 }
109 ],
110 "inputs": [ 0 ],
111 "outputs": [ 1 ],
112 "operators": [
113 {
114 "opcode_index": 0,
115 "inputs": [ 0, 2 ],
116 "outputs": [ 1 ],
117 "builtin_options_type": "Conv2DOptions",
118 "builtin_options": {
119 "padding": "VALID",
120 "stride_w": 1,
121 "stride_h": 1,
122 "fused_activation_function": "NONE"
123 },
124 "custom_options_format": "FLEXBUFFERS"
125 }
126 ],
127 }
128 ],
129 "description": "Test Subgraph Inputs Outputs",
130 "buffers" : [
131 { },
132 { },
133 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
134 { },
135 ]
136 })";
137
138 ReadStringToBinary();
139 }
140
141};
142
143struct GetEmptyInputsOutputsFixture : GetInputsOutputsMainFixture
144{
145 GetEmptyInputsOutputsFixture() : GetInputsOutputsMainFixture("[ ]", "[ ]") {}
146};
147
148struct GetInputsOutputsFixture : GetInputsOutputsMainFixture
149{
150 GetInputsOutputsFixture() : GetInputsOutputsMainFixture("[ 1 ]", "[ 0 ]") {}
151};
152
153BOOST_FIXTURE_TEST_CASE(GetEmptyInputs, GetEmptyInputsOutputsFixture)
154{
155 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
156 TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetInputs(model, 0, 0);
157 BOOST_CHECK_EQUAL(0, tensors.size());
158}
159
160BOOST_FIXTURE_TEST_CASE(GetEmptyOutputs, GetEmptyInputsOutputsFixture)
161{
162 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
163 TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetOutputs(model, 0, 0);
164 BOOST_CHECK_EQUAL(0, tensors.size());
165}
166
167BOOST_FIXTURE_TEST_CASE(GetInputs, GetInputsOutputsFixture)
168{
169 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
170 TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetInputs(model, 0, 0);
171 BOOST_CHECK_EQUAL(1, tensors.size());
172 CheckTensors(tensors[0], 4, { 1, 2, 2, 1 }, tflite::TensorType::TensorType_UINT8, 1,
173 "InputTensor", { -1.2f }, { 25.5f }, { 0.25f }, { 10 });
174}
175
176BOOST_FIXTURE_TEST_CASE(GetOutputs, GetInputsOutputsFixture)
177{
178 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
179 TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetOutputs(model, 0, 0);
180 BOOST_CHECK_EQUAL(1, tensors.size());
181 CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 0,
182 "OutputTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
183}
184
185BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs, GetInputsOutputsFixture)
186{
187 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
188 TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetInputs(model, 1, 0);
189 BOOST_CHECK_EQUAL(2, tensors.size());
190 CheckTensors(tensors[0], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 0,
191 "ConvInputTensor", { }, { }, { 1.0f }, { 0 });
192 CheckTensors(tensors[1], 4, { 1, 3, 3, 1 }, tflite::TensorType::TensorType_UINT8, 2,
193 "filterTensor", { 0.0f }, { 255.0f }, { 1.0f }, { 0 });
194}
195
196BOOST_FIXTURE_TEST_CASE(GetOutputs2, GetInputsOutputsFixture)
197{
198 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
199 TfLiteParser::TensorRawPtrVector tensors = TfLiteParser::GetOutputs(model, 1, 0);
200 BOOST_CHECK_EQUAL(1, tensors.size());
201 CheckTensors(tensors[0], 4, { 1, 1, 1, 1 }, tflite::TensorType::TensorType_UINT8, 1,
202 "ConvOutputTensor", { 0.0f }, { 511.0f }, { 2.0f }, { 0 });
203}
204
205BOOST_AUTO_TEST_CASE(GetInputsNullModel)
206{
207 BOOST_CHECK_THROW(TfLiteParser::GetInputs(nullptr, 0, 0), armnn::ParseException);
208}
209
210BOOST_AUTO_TEST_CASE(GetOutputsNullModel)
211{
212 BOOST_CHECK_THROW(TfLiteParser::GetOutputs(nullptr, 0, 0), armnn::ParseException);
213}
214
215BOOST_FIXTURE_TEST_CASE(GetInputsInvalidSubgraph, GetInputsOutputsFixture)
216{
217 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
218 BOOST_CHECK_THROW(TfLiteParser::GetInputs(model, 2, 0), armnn::ParseException);
219}
220
221BOOST_FIXTURE_TEST_CASE(GetOutputsInvalidSubgraph, GetInputsOutputsFixture)
222{
223 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
224 BOOST_CHECK_THROW(TfLiteParser::GetOutputs(model, 2, 0), armnn::ParseException);
225}
226
227BOOST_FIXTURE_TEST_CASE(GetInputsInvalidOperator, GetInputsOutputsFixture)
228{
229 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
230 BOOST_CHECK_THROW(TfLiteParser::GetInputs(model, 0, 1), armnn::ParseException);
231}
232
233BOOST_FIXTURE_TEST_CASE(GetOutputsInvalidOperator, GetInputsOutputsFixture)
234{
235 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
236 BOOST_CHECK_THROW(TfLiteParser::GetOutputs(model, 0, 1), armnn::ParseException);
237}
238
239BOOST_AUTO_TEST_SUITE_END()