blob: 70df2c3a5a3e82c916fbd229ec2756355f81d631 [file] [log] [blame]
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Benthamf48afc62020-01-15 17:55:08 +00005#include <armnn/Logging.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00006
Sadik Armagan232cfc22019-03-13 18:33:10 +00007#if defined(ARMNN_CAFFE_PARSER)
8#include <armnnCaffeParser/ICaffeParser.hpp>
9#endif
10#if defined(ARMNN_ONNX_PARSER)
11#include <armnnOnnxParser/IOnnxParser.hpp>
12#endif
13#if defined(ARMNN_SERIALIZER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000014#include <armnnSerializer/ISerializer.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000015#endif
16#if defined(ARMNN_TF_PARSER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000017#include <armnnTfParser/ITfParser.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000018#endif
19#if defined(ARMNN_TF_LITE_PARSER)
20#include <armnnTfLiteParser/ITfLiteParser.hpp>
21#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000022
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000023#include <HeapProfiling.hpp>
24
25#include <boost/format.hpp>
26#include <boost/algorithm/string/split.hpp>
27#include <boost/algorithm/string/classification.hpp>
28#include <boost/program_options.hpp>
29
Les Bell10e6be42019-03-28 12:26:46 +000030#include <cstdlib>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000031#include <fstream>
Les Bell10e6be42019-03-28 12:26:46 +000032#include <iostream>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000033
34namespace
35{
36
37namespace po = boost::program_options;
38
39armnn::TensorShape ParseTensorShape(std::istream& stream)
40{
41 std::vector<unsigned int> result;
42 std::string line;
43
44 while (std::getline(stream, line))
45 {
46 std::vector<std::string> tokens;
47 try
48 {
49 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
50 boost::split(tokens, line, boost::algorithm::is_any_of(","), boost::token_compress_on);
51 }
52 catch (const std::exception& e)
53 {
Derek Lamberti08446972019-11-26 16:38:31 +000054 ARMNN_LOG(error) << "An error occurred when splitting tokens: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000055 continue;
56 }
57 for (const std::string& token : tokens)
58 {
59 if (!token.empty())
60 {
61 try
62 {
63 result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
64 }
65 catch (const std::exception&)
66 {
Derek Lamberti08446972019-11-26 16:38:31 +000067 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000068 }
69 }
70 }
71 }
72
73 return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
74}
75
76bool CheckOption(const po::variables_map& vm,
77 const char* option)
78{
79 if (option == nullptr)
80 {
81 return false;
82 }
83
84 // Check whether 'option' is provided.
85 return vm.find(option) != vm.end();
86}
87
88void CheckOptionDependency(const po::variables_map& vm,
89 const char* option,
90 const char* required)
91{
92 if (option == nullptr || required == nullptr)
93 {
94 throw po::error("Invalid option to check dependency for");
95 }
96
97 // Check that if 'option' is provided, 'required' is also provided.
98 if (CheckOption(vm, option) && !vm[option].defaulted())
99 {
100 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
101 {
102 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
103 }
104 }
105}
106
107void CheckOptionDependencies(const po::variables_map& vm)
108{
109 CheckOptionDependency(vm, "model-path", "model-format");
110 CheckOptionDependency(vm, "model-path", "input-name");
111 CheckOptionDependency(vm, "model-path", "output-name");
112 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
113}
114
115int ParseCommandLineArgs(int argc, const char* argv[],
116 std::string& modelFormat,
117 std::string& modelPath,
118 std::vector<std::string>& inputNames,
119 std::vector<std::string>& inputTensorShapeStrs,
120 std::vector<std::string>& outputNames,
121 std::string& outputPath, bool& isModelBinary)
122{
123 po::options_description desc("Options");
124
125 desc.add_options()
126 ("help", "Display usage information")
Sadik Armagan232cfc22019-03-13 18:33:10 +0000127 ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
128#if defined(ARMNN_CAFFE_PARSER)
129 ", caffe-binary, caffe-text"
130#endif
131#if defined(ARMNN_ONNX_PARSER)
132 ", onnx-binary, onnx-text"
133#endif
Les Bell10e6be42019-03-28 12:26:46 +0000134#if defined(ARMNN_TF_PARSER)
Sadik Armagan232cfc22019-03-13 18:33:10 +0000135 ", tensorflow-binary, tensorflow-text"
136#endif
137#if defined(ARMNN_TF_LITE_PARSER)
138 ", tflite-binary"
139#endif
140 ".")
Les Bell10e6be42019-03-28 12:26:46 +0000141 ("model-path,m", po::value(&modelPath)->required(), "Path to model file.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000142 ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000143 "Identifier of the input tensors in the network, separated by whitespace.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000144 ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000145 "The shape of the input tensor in the network as a flat array of integers, separated by comma."
146 " Multiple shapes are separated by whitespace."
147 " This parameter is optional, depending on the network.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000148 ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
149 "Identifier of the output tensor in the network.")
150 ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
151
152 po::variables_map vm;
153 try
154 {
155 po::store(po::parse_command_line(argc, argv, desc), vm);
156
157 if (CheckOption(vm, "help") || argc <= 1)
158 {
Les Bell10e6be42019-03-28 12:26:46 +0000159 std::cout << "Convert a neural network model from provided file to ArmNN format." << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000160 std::cout << std::endl;
161 std::cout << desc << std::endl;
Les Bell10e6be42019-03-28 12:26:46 +0000162 exit(EXIT_SUCCESS);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000163 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000164 po::notify(vm);
165 }
166 catch (const po::error& e)
167 {
168 std::cerr << e.what() << std::endl << std::endl;
169 std::cerr << desc << std::endl;
170 return EXIT_FAILURE;
171 }
172
173 try
174 {
175 CheckOptionDependencies(vm);
176 }
177 catch (const po::error& e)
178 {
179 std::cerr << e.what() << std::endl << std::endl;
180 std::cerr << desc << std::endl;
181 return EXIT_FAILURE;
182 }
183
184 if (modelFormat.find("bin") != std::string::npos)
185 {
186 isModelBinary = true;
187 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000188 else if (modelFormat.find("text") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000189 {
190 isModelBinary = false;
191 }
192 else
193 {
Derek Lamberti08446972019-11-26 16:38:31 +0000194 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000195 return EXIT_FAILURE;
196 }
197
Matthew Benthamc01b3912019-04-26 16:57:29 +0100198 if (!vm["input-tensor-shape"].empty())
199 {
200 inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
201 }
202
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000203 inputNames = vm["input-name"].as<std::vector<std::string>>();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000204 outputNames = vm["output-name"].as<std::vector<std::string>>();
205
206 return EXIT_SUCCESS;
207}
208
Sadik Armagan232cfc22019-03-13 18:33:10 +0000209template<typename T>
210struct ParserType
211{
212 typedef T parserType;
213};
214
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000215class ArmnnConverter
216{
217public:
218 ArmnnConverter(const std::string& modelPath,
219 const std::vector<std::string>& inputNames,
220 const std::vector<armnn::TensorShape>& inputShapes,
221 const std::vector<std::string>& outputNames,
222 const std::string& outputPath,
223 bool isModelBinary)
224 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
225 m_ModelPath(modelPath),
226 m_InputNames(inputNames),
227 m_InputShapes(inputShapes),
228 m_OutputNames(outputNames),
229 m_OutputPath(outputPath),
230 m_IsModelBinary(isModelBinary) {}
231
232 bool Serialize()
233 {
234 if (m_NetworkPtr.get() == nullptr)
235 {
236 return false;
237 }
238
239 auto serializer(armnnSerializer::ISerializer::Create());
240
241 serializer->Serialize(*m_NetworkPtr);
242
243 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
244
245 bool retVal = serializer->SaveSerializedToStream(file);
246
247 return retVal;
248 }
249
250 template <typename IParser>
251 bool CreateNetwork ()
252 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000253 return CreateNetwork (ParserType<IParser>());
254 }
255
256private:
257 armnn::INetworkPtr m_NetworkPtr;
258 std::string m_ModelPath;
259 std::vector<std::string> m_InputNames;
260 std::vector<armnn::TensorShape> m_InputShapes;
261 std::vector<std::string> m_OutputNames;
262 std::string m_OutputPath;
263 bool m_IsModelBinary;
264
265 template <typename IParser>
266 bool CreateNetwork (ParserType<IParser>)
267 {
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000268 // Create a network from a file on disk
269 auto parser(IParser::Create());
270
271 std::map<std::string, armnn::TensorShape> inputShapes;
272 if (!m_InputShapes.empty())
273 {
274 const size_t numInputShapes = m_InputShapes.size();
275 const size_t numInputBindings = m_InputNames.size();
276 if (numInputShapes < numInputBindings)
277 {
278 throw armnn::Exception(boost::str(boost::format(
279 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
280 % numInputBindings % numInputShapes));
281 }
282
283 for (size_t i = 0; i < numInputShapes; i++)
284 {
285 inputShapes[m_InputNames[i]] = m_InputShapes[i];
286 }
287 }
288
289 {
290 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
291 m_NetworkPtr = (m_IsModelBinary ?
292 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
293 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
294 }
295
296 return m_NetworkPtr.get() != nullptr;
297 }
298
Sadik Armagan232cfc22019-03-13 18:33:10 +0000299#if defined(ARMNN_TF_LITE_PARSER)
300 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
301 {
302 // Create a network from a file on disk
303 auto parser(armnnTfLiteParser::ITfLiteParser::Create());
304
305 if (!m_InputShapes.empty())
306 {
307 const size_t numInputShapes = m_InputShapes.size();
308 const size_t numInputBindings = m_InputNames.size();
309 if (numInputShapes < numInputBindings)
310 {
311 throw armnn::Exception(boost::str(boost::format(
312 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
313 % numInputBindings % numInputShapes));
314 }
315 }
316
317 {
318 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
319 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
320 }
321
322 return m_NetworkPtr.get() != nullptr;
323 }
324#endif
325
326#if defined(ARMNN_ONNX_PARSER)
327 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
328 {
329 // Create a network from a file on disk
330 auto parser(armnnOnnxParser::IOnnxParser::Create());
331
332 if (!m_InputShapes.empty())
333 {
334 const size_t numInputShapes = m_InputShapes.size();
335 const size_t numInputBindings = m_InputNames.size();
336 if (numInputShapes < numInputBindings)
337 {
338 throw armnn::Exception(boost::str(boost::format(
339 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
340 % numInputBindings % numInputShapes));
341 }
342 }
343
344 {
345 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
346 m_NetworkPtr = (m_IsModelBinary ?
347 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
348 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
349 }
350
351 return m_NetworkPtr.get() != nullptr;
352 }
353#endif
354
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000355};
356
357} // anonymous namespace
358
359int main(int argc, const char* argv[])
360{
361
Sadik Armagan232cfc22019-03-13 18:33:10 +0000362#if (!defined(ARMNN_CAFFE_PARSER) \
363 && !defined(ARMNN_ONNX_PARSER) \
364 && !defined(ARMNN_TF_PARSER) \
365 && !defined(ARMNN_TF_LITE_PARSER))
Derek Lamberti08446972019-11-26 16:38:31 +0000366 ARMNN_LOG(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000367 return EXIT_FAILURE;
368#endif
369
370#if !defined(ARMNN_SERIALIZER)
Derek Lamberti08446972019-11-26 16:38:31 +0000371 ARMNN_LOG(fatal) << "Not built with Serializer support.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000372 return EXIT_FAILURE;
373#endif
374
375#ifdef NDEBUG
376 armnn::LogSeverity level = armnn::LogSeverity::Info;
377#else
378 armnn::LogSeverity level = armnn::LogSeverity::Debug;
379#endif
380
381 armnn::ConfigureLogging(true, true, level);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000382
383 std::string modelFormat;
384 std::string modelPath;
385
386 std::vector<std::string> inputNames;
387 std::vector<std::string> inputTensorShapeStrs;
388 std::vector<armnn::TensorShape> inputTensorShapes;
389
390 std::vector<std::string> outputNames;
391 std::string outputPath;
392
393 bool isModelBinary = true;
394
395 if (ParseCommandLineArgs(
396 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
397 != EXIT_SUCCESS)
398 {
399 return EXIT_FAILURE;
400 }
401
402 for (const std::string& shapeStr : inputTensorShapeStrs)
403 {
404 if (!shapeStr.empty())
405 {
406 std::stringstream ss(shapeStr);
407
408 try
409 {
410 armnn::TensorShape shape = ParseTensorShape(ss);
411 inputTensorShapes.push_back(shape);
412 }
413 catch (const armnn::InvalidArgumentException& e)
414 {
Derek Lamberti08446972019-11-26 16:38:31 +0000415 ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000416 return EXIT_FAILURE;
417 }
418 }
419 }
420
421 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
422
Sadik Armagan232cfc22019-03-13 18:33:10 +0000423 if (modelFormat.find("caffe") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000424 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000425#if defined(ARMNN_CAFFE_PARSER)
426 if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
427 {
Derek Lamberti08446972019-11-26 16:38:31 +0000428 ARMNN_LOG(fatal) << "Failed to load model from file";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000429 return EXIT_FAILURE;
430 }
431#else
Derek Lamberti08446972019-11-26 16:38:31 +0000432 ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000433 return EXIT_FAILURE;
434#endif
435 }
436 else if (modelFormat.find("onnx") != std::string::npos)
437 {
438#if defined(ARMNN_ONNX_PARSER)
439 if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
440 {
Derek Lamberti08446972019-11-26 16:38:31 +0000441 ARMNN_LOG(fatal) << "Failed to load model from file";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000442 return EXIT_FAILURE;
443 }
444#else
Derek Lamberti08446972019-11-26 16:38:31 +0000445 ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000446 return EXIT_FAILURE;
447#endif
448 }
449 else if (modelFormat.find("tensorflow") != std::string::npos)
450 {
451#if defined(ARMNN_TF_PARSER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000452 if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
453 {
Derek Lamberti08446972019-11-26 16:38:31 +0000454 ARMNN_LOG(fatal) << "Failed to load model from file";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000455 return EXIT_FAILURE;
456 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000457#else
Derek Lamberti08446972019-11-26 16:38:31 +0000458 ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000459 return EXIT_FAILURE;
460#endif
461 }
462 else if (modelFormat.find("tflite") != std::string::npos)
463 {
464#if defined(ARMNN_TF_LITE_PARSER)
465 if (!isModelBinary)
466 {
Derek Lamberti08446972019-11-26 16:38:31 +0000467 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
Sadik Armagan232cfc22019-03-13 18:33:10 +0000468 for tflite files";
469 return EXIT_FAILURE;
470 }
471
472 if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
473 {
Derek Lamberti08446972019-11-26 16:38:31 +0000474 ARMNN_LOG(fatal) << "Failed to load model from file";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000475 return EXIT_FAILURE;
476 }
477#else
Derek Lamberti08446972019-11-26 16:38:31 +0000478 ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000479 return EXIT_FAILURE;
480#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000481 }
482 else
483 {
Derek Lamberti08446972019-11-26 16:38:31 +0000484 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000485 return EXIT_FAILURE;
486 }
487
488 if (!converter.Serialize())
489 {
Derek Lamberti08446972019-11-26 16:38:31 +0000490 ARMNN_LOG(fatal) << "Failed to serialize model";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000491 return EXIT_FAILURE;
492 }
493
494 return EXIT_SUCCESS;
495}