blob: 9ae3412a11bd1b8fee9b089311091b2bb2534c32 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include <boost/test/unit_test.hpp>
6#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
telsoa01c577f2c2018-08-31 09:22:23 +01009using armnnTfLiteParser::TfLiteParser;
10using ModelPtr = TfLiteParser::ModelPtr;
Derek Lambertiff05cc52019-04-26 13:05:17 +010011using SubgraphPtr = TfLiteParser::SubgraphPtr;
telsoa01c577f2c2018-08-31 09:22:23 +010012using OperatorPtr = TfLiteParser::OperatorPtr;
13
14BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
15
16struct LoadModelFixture : public ParserFlatbuffersFixture
17{
18 explicit LoadModelFixture()
19 {
20 m_JsonString = R"(
21 {
22 "version": 3,
23 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
24 "subgraphs": [
25 {
26 "tensors": [
27 {
28 "shape": [ 1, 1, 1, 1 ] ,
29 "type": "UINT8",
30 "buffer": 0,
31 "name": "OutputTensor",
32 "quantization": {
33 "min": [ 0.0 ],
34 "max": [ 255.0 ],
35 "scale": [ 1.0 ],
36 "zero_point": [ 0 ]
37 }
38 },
39 {
40 "shape": [ 1, 2, 2, 1 ] ,
41 "type": "UINT8",
42 "buffer": 1,
43 "name": "InputTensor",
44 "quantization": {
45 "min": [ 0.0 ],
46 "max": [ 255.0 ],
47 "scale": [ 1.0 ],
48 "zero_point": [ 0 ]
49 }
50 }
51 ],
52 "inputs": [ 1 ],
53 "outputs": [ 0 ],
54 "operators": [ {
55 "opcode_index": 0,
56 "inputs": [ 1 ],
57 "outputs": [ 0 ],
58 "builtin_options_type": "Pool2DOptions",
59 "builtin_options":
60 {
61 "padding": "VALID",
62 "stride_w": 2,
63 "stride_h": 2,
64 "filter_width": 2,
65 "filter_height": 2,
66 "fused_activation_function": "NONE"
67 },
68 "custom_options_format": "FLEXBUFFERS"
69 } ]
70 },
71 {
72 "tensors": [
73 {
74 "shape": [ 1, 3, 3, 1 ],
75 "type": "UINT8",
76 "buffer": 0,
77 "name": "ConvInputTensor",
78 "quantization": {
79 "scale": [ 1.0 ],
80 "zero_point": [ 0 ],
81 }
82 },
83 {
84 "shape": [ 1, 1, 1, 1 ],
85 "type": "UINT8",
86 "buffer": 1,
87 "name": "ConvOutputTensor",
88 "quantization": {
89 "min": [ 0.0 ],
90 "max": [ 511.0 ],
91 "scale": [ 2.0 ],
92 "zero_point": [ 0 ],
93 }
94 },
95 {
96 "shape": [ 1, 3, 3, 1 ],
97 "type": "UINT8",
98 "buffer": 2,
99 "name": "filterTensor",
100 "quantization": {
101 "min": [ 0.0 ],
102 "max": [ 255.0 ],
103 "scale": [ 1.0 ],
104 "zero_point": [ 0 ],
105 }
106 }
107 ],
108 "inputs": [ 0 ],
109 "outputs": [ 1 ],
110 "operators": [
111 {
112 "opcode_index": 1,
113 "inputs": [ 0, 2 ],
114 "outputs": [ 1 ],
115 "builtin_options_type": "Conv2DOptions",
116 "builtin_options": {
117 "padding": "VALID",
118 "stride_w": 1,
119 "stride_h": 1,
120 "fused_activation_function": "NONE"
121 },
122 "custom_options_format": "FLEXBUFFERS"
123 }
124 ],
125 }
126 ],
127 "description": "Test loading a model",
128 "buffers" : [ {}, {} ]
129 })";
130
131 ReadStringToBinary();
132 }
133
134 void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
135 const std::vector<tflite::BuiltinOperator>& opcodes,
136 size_t subgraphs, const std::string desc, size_t buffers)
137 {
138 BOOST_CHECK(model);
139 BOOST_CHECK_EQUAL(version, model->version);
140 BOOST_CHECK_EQUAL(opcodeSize, model->operator_codes.size());
141 CheckBuiltinOperators(opcodes, model->operator_codes);
142 BOOST_CHECK_EQUAL(subgraphs, model->subgraphs.size());
143 BOOST_CHECK_EQUAL(desc, model->description);
144 BOOST_CHECK_EQUAL(buffers, model->buffers.size());
145 }
146
147 void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
148 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
149 {
150 BOOST_CHECK_EQUAL(expectedOperators.size(), result.size());
151 for (size_t i = 0; i < expectedOperators.size(); i++)
152 {
153 BOOST_CHECK_EQUAL(expectedOperators[i], result[i]->builtin_code);
154 }
155 }
156
Derek Lambertiff05cc52019-04-26 13:05:17 +0100157 void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
telsoa01c577f2c2018-08-31 09:22:23 +0100158 const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
159 {
160 BOOST_CHECK(subgraph);
161 BOOST_CHECK_EQUAL(tensors, subgraph->tensors.size());
162 BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end());
163 BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
164 subgraph->outputs.begin(), subgraph->outputs.end());
165 BOOST_CHECK_EQUAL(operators, subgraph->operators.size());
166 BOOST_CHECK_EQUAL(name, subgraph->name);
167 }
168
169 void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs,
170 const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
171 tflite::CustomOptionsFormat custom_options_format)
172 {
173 BOOST_CHECK(operatorPtr);
174 BOOST_CHECK_EQUAL(opcode, operatorPtr->opcode_index);
175 BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(),
176 operatorPtr->inputs.begin(), operatorPtr->inputs.end());
177 BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
178 operatorPtr->outputs.begin(), operatorPtr->outputs.end());
179 BOOST_CHECK_EQUAL(optionType, operatorPtr->builtin_options.type);
180 BOOST_CHECK_EQUAL(custom_options_format, operatorPtr->custom_options_format);
181 }
182};
183
184BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
185{
186 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
187 CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
188 2, "Test loading a model", 2);
189 CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
190 CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
191 CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
192 tflite::CustomOptionsFormat_FLEXBUFFERS);
193 CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
194 tflite::CustomOptionsFormat_FLEXBUFFERS);
195}
196
197BOOST_FIXTURE_TEST_CASE(LoadModelFromFile, LoadModelFixture)
198{
Matthew Bentham9fc8c0f2019-03-20 12:46:58 +0000199 using namespace boost::filesystem;
200 std::string fname = unique_path(temp_directory_path() / "%%%%-%%%%-%%%%.tflite").string();
telsoa01c577f2c2018-08-31 09:22:23 +0100201 bool saved = flatbuffers::SaveFile(fname.c_str(),
202 reinterpret_cast<char *>(m_GraphBinary.data()),
203 m_GraphBinary.size(), true);
204 BOOST_CHECK_MESSAGE(saved, "Cannot save test file");
205
206 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromFile(fname.c_str());
207 CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
208 2, "Test loading a model", 2);
209 CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
210 CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
211 CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
212 tflite::CustomOptionsFormat_FLEXBUFFERS);
213 CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
214 tflite::CustomOptionsFormat_FLEXBUFFERS);
Matthew Bentham9fc8c0f2019-03-20 12:46:58 +0000215 remove(fname);
telsoa01c577f2c2018-08-31 09:22:23 +0100216}
217
218BOOST_AUTO_TEST_CASE(LoadNullBinary)
219{
220 BOOST_CHECK_THROW(TfLiteParser::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
221}
222
223BOOST_AUTO_TEST_CASE(LoadInvalidBinary)
224{
225 std::string testData = "invalid data";
226 BOOST_CHECK_THROW(TfLiteParser::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
227 testData.length()), armnn::ParseException);
228}
229
230BOOST_AUTO_TEST_CASE(LoadFileNotFound)
231{
232 BOOST_CHECK_THROW(TfLiteParser::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
233}
234
235BOOST_AUTO_TEST_CASE(LoadNullPtrFile)
236{
237 BOOST_CHECK_THROW(TfLiteParser::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
238}
239
240BOOST_AUTO_TEST_SUITE_END()