blob: c11cd2ba68a3359e40d5272513cadc037a700811 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
Sadik Armagan1625efc2021-06-10 18:24:34 +01005
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "ParserFlatbuffersFixture.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01007
Rob Hughes9542f902021-07-14 09:48:54 +01008#include <armnnUtils/Filesystem.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +01009
Kevin May7d96b162021-02-03 17:38:41 +000010using armnnTfLiteParser::TfLiteParserImpl;
11using ModelPtr = TfLiteParserImpl::ModelPtr;
12using SubgraphPtr = TfLiteParserImpl::SubgraphPtr;
13using OperatorPtr = TfLiteParserImpl::OperatorPtr;
telsoa01c577f2c2018-08-31 09:22:23 +010014
Sadik Armagan1625efc2021-06-10 18:24:34 +010015TEST_SUITE("TensorflowLiteParser_LoadModel")
16{
telsoa01c577f2c2018-08-31 09:22:23 +010017struct LoadModelFixture : public ParserFlatbuffersFixture
18{
19 explicit LoadModelFixture()
20 {
21 m_JsonString = R"(
22 {
23 "version": 3,
24 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
25 "subgraphs": [
26 {
27 "tensors": [
28 {
29 "shape": [ 1, 1, 1, 1 ] ,
30 "type": "UINT8",
31 "buffer": 0,
32 "name": "OutputTensor",
33 "quantization": {
34 "min": [ 0.0 ],
35 "max": [ 255.0 ],
36 "scale": [ 1.0 ],
37 "zero_point": [ 0 ]
38 }
39 },
40 {
41 "shape": [ 1, 2, 2, 1 ] ,
42 "type": "UINT8",
43 "buffer": 1,
44 "name": "InputTensor",
45 "quantization": {
46 "min": [ 0.0 ],
47 "max": [ 255.0 ],
48 "scale": [ 1.0 ],
49 "zero_point": [ 0 ]
50 }
51 }
52 ],
53 "inputs": [ 1 ],
54 "outputs": [ 0 ],
55 "operators": [ {
56 "opcode_index": 0,
57 "inputs": [ 1 ],
58 "outputs": [ 0 ],
59 "builtin_options_type": "Pool2DOptions",
60 "builtin_options":
61 {
62 "padding": "VALID",
63 "stride_w": 2,
64 "stride_h": 2,
65 "filter_width": 2,
66 "filter_height": 2,
67 "fused_activation_function": "NONE"
68 },
69 "custom_options_format": "FLEXBUFFERS"
70 } ]
71 },
72 {
73 "tensors": [
74 {
75 "shape": [ 1, 3, 3, 1 ],
76 "type": "UINT8",
77 "buffer": 0,
78 "name": "ConvInputTensor",
79 "quantization": {
80 "scale": [ 1.0 ],
81 "zero_point": [ 0 ],
82 }
83 },
84 {
85 "shape": [ 1, 1, 1, 1 ],
86 "type": "UINT8",
87 "buffer": 1,
88 "name": "ConvOutputTensor",
89 "quantization": {
90 "min": [ 0.0 ],
91 "max": [ 511.0 ],
92 "scale": [ 2.0 ],
93 "zero_point": [ 0 ],
94 }
95 },
96 {
97 "shape": [ 1, 3, 3, 1 ],
98 "type": "UINT8",
99 "buffer": 2,
100 "name": "filterTensor",
101 "quantization": {
102 "min": [ 0.0 ],
103 "max": [ 255.0 ],
104 "scale": [ 1.0 ],
105 "zero_point": [ 0 ],
106 }
107 }
108 ],
109 "inputs": [ 0 ],
110 "outputs": [ 1 ],
111 "operators": [
112 {
113 "opcode_index": 1,
114 "inputs": [ 0, 2 ],
115 "outputs": [ 1 ],
116 "builtin_options_type": "Conv2DOptions",
117 "builtin_options": {
118 "padding": "VALID",
119 "stride_w": 1,
120 "stride_h": 1,
121 "fused_activation_function": "NONE"
122 },
123 "custom_options_format": "FLEXBUFFERS"
124 }
125 ],
126 }
127 ],
128 "description": "Test loading a model",
129 "buffers" : [ {}, {} ]
130 })";
131
132 ReadStringToBinary();
133 }
134
135 void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
136 const std::vector<tflite::BuiltinOperator>& opcodes,
137 size_t subgraphs, const std::string desc, size_t buffers)
138 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100139 CHECK(model);
140 CHECK_EQ(version, model->version);
141 CHECK_EQ(opcodeSize, model->operator_codes.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100142 CheckBuiltinOperators(opcodes, model->operator_codes);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100143 CHECK_EQ(subgraphs, model->subgraphs.size());
144 CHECK_EQ(desc, model->description);
145 CHECK_EQ(buffers, model->buffers.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100146 }
147
148 void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
149 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
150 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100151 CHECK_EQ(expectedOperators.size(), result.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100152 for (size_t i = 0; i < expectedOperators.size(); i++)
153 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100154 CHECK_EQ(expectedOperators[i], result[i]->builtin_code);
telsoa01c577f2c2018-08-31 09:22:23 +0100155 }
156 }
157
Derek Lambertiff05cc52019-04-26 13:05:17 +0100158 void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
telsoa01c577f2c2018-08-31 09:22:23 +0100159 const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
160 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100161 CHECK(subgraph);
162 CHECK_EQ(tensors, subgraph->tensors.size());
163 CHECK(std::equal(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end()));
164 CHECK(std::equal(outputs.begin(), outputs.end(),
165 subgraph->outputs.begin(), subgraph->outputs.end()));
166 CHECK_EQ(operators, subgraph->operators.size());
167 CHECK_EQ(name, subgraph->name);
telsoa01c577f2c2018-08-31 09:22:23 +0100168 }
169
170 void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs,
171 const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
172 tflite::CustomOptionsFormat custom_options_format)
173 {
Sadik Armagan1625efc2021-06-10 18:24:34 +0100174 CHECK(operatorPtr);
175 CHECK_EQ(opcode, operatorPtr->opcode_index);
176 CHECK(std::equal(inputs.begin(), inputs.end(),
177 operatorPtr->inputs.begin(), operatorPtr->inputs.end()));
178 CHECK(std::equal(outputs.begin(), outputs.end(),
179 operatorPtr->outputs.begin(), operatorPtr->outputs.end()));
180 CHECK_EQ(optionType, operatorPtr->builtin_options.type);
181 CHECK_EQ(custom_options_format, operatorPtr->custom_options_format);
telsoa01c577f2c2018-08-31 09:22:23 +0100182 }
183};
184
Sadik Armagan1625efc2021-06-10 18:24:34 +0100185TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromBinary")
telsoa01c577f2c2018-08-31 09:22:23 +0100186{
Kevin May7d96b162021-02-03 17:38:41 +0000187 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
188 m_GraphBinary.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100189 CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
190 2, "Test loading a model", 2);
191 CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
192 CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
193 CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
194 tflite::CustomOptionsFormat_FLEXBUFFERS);
195 CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
196 tflite::CustomOptionsFormat_FLEXBUFFERS);
197}
198
Sadik Armagan1625efc2021-06-10 18:24:34 +0100199TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromFile")
telsoa01c577f2c2018-08-31 09:22:23 +0100200{
Francis Murtagh532a29d2020-06-29 11:50:01 +0100201 using namespace fs;
202 fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv");
telsoa01c577f2c2018-08-31 09:22:23 +0100203 bool saved = flatbuffers::SaveFile(fname.c_str(),
204 reinterpret_cast<char *>(m_GraphBinary.data()),
205 m_GraphBinary.size(), true);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100206 CHECK_MESSAGE(saved, "Cannot save test file");
telsoa01c577f2c2018-08-31 09:22:23 +0100207
Kevin May7d96b162021-02-03 17:38:41 +0000208 TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str());
telsoa01c577f2c2018-08-31 09:22:23 +0100209 CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
210 2, "Test loading a model", 2);
211 CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
212 CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
213 CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
214 tflite::CustomOptionsFormat_FLEXBUFFERS);
215 CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
216 tflite::CustomOptionsFormat_FLEXBUFFERS);
Matthew Bentham9fc8c0f2019-03-20 12:46:58 +0000217 remove(fname);
telsoa01c577f2c2018-08-31 09:22:23 +0100218}
219
Sadik Armagan1625efc2021-06-10 18:24:34 +0100220TEST_CASE("LoadNullBinary")
telsoa01c577f2c2018-08-31 09:22:23 +0100221{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100222 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
telsoa01c577f2c2018-08-31 09:22:23 +0100223}
224
Sadik Armagan1625efc2021-06-10 18:24:34 +0100225TEST_CASE("LoadInvalidBinary")
telsoa01c577f2c2018-08-31 09:22:23 +0100226{
227 std::string testData = "invalid data";
Sadik Armagan1625efc2021-06-10 18:24:34 +0100228 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
telsoa01c577f2c2018-08-31 09:22:23 +0100229 testData.length()), armnn::ParseException);
230}
231
Sadik Armagan1625efc2021-06-10 18:24:34 +0100232TEST_CASE("LoadFileNotFound")
telsoa01c577f2c2018-08-31 09:22:23 +0100233{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100234 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
telsoa01c577f2c2018-08-31 09:22:23 +0100235}
236
Sadik Armagan1625efc2021-06-10 18:24:34 +0100237TEST_CASE("LoadNullPtrFile")
telsoa01c577f2c2018-08-31 09:22:23 +0100238{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100239 CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
telsoa01c577f2c2018-08-31 09:22:23 +0100240}
241
Sadik Armagan1625efc2021-06-10 18:24:34 +0100242}