blob: 04cbb5dba2138213ba7776d49c6befe8652de5b2 [file] [log] [blame]
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <armnn/ArmNN.hpp>
6
Sadik Armagan232cfc22019-03-13 18:33:10 +00007#if defined(ARMNN_CAFFE_PARSER)
8#include <armnnCaffeParser/ICaffeParser.hpp>
9#endif
10#if defined(ARMNN_ONNX_PARSER)
11#include <armnnOnnxParser/IOnnxParser.hpp>
12#endif
13#if defined(ARMNN_SERIALIZER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000014#include <armnnSerializer/ISerializer.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000015#endif
16#if defined(ARMNN_TF_PARSER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000017#include <armnnTfParser/ITfParser.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000018#endif
19#if defined(ARMNN_TF_LITE_PARSER)
20#include <armnnTfLiteParser/ITfLiteParser.hpp>
21#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000022
23#include <Logging.hpp>
24#include <HeapProfiling.hpp>
25
26#include <boost/format.hpp>
27#include <boost/algorithm/string/split.hpp>
28#include <boost/algorithm/string/classification.hpp>
29#include <boost/program_options.hpp>
30
Les Bell10e6be42019-03-28 12:26:46 +000031#include <cstdlib>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000032#include <fstream>
Les Bell10e6be42019-03-28 12:26:46 +000033#include <iostream>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000034
35namespace
36{
37
38namespace po = boost::program_options;
39
40armnn::TensorShape ParseTensorShape(std::istream& stream)
41{
42 std::vector<unsigned int> result;
43 std::string line;
44
45 while (std::getline(stream, line))
46 {
47 std::vector<std::string> tokens;
48 try
49 {
50 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
51 boost::split(tokens, line, boost::algorithm::is_any_of(","), boost::token_compress_on);
52 }
53 catch (const std::exception& e)
54 {
55 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
56 continue;
57 }
58 for (const std::string& token : tokens)
59 {
60 if (!token.empty())
61 {
62 try
63 {
64 result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
65 }
66 catch (const std::exception&)
67 {
68 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
69 }
70 }
71 }
72 }
73
74 return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
75}
76
77bool CheckOption(const po::variables_map& vm,
78 const char* option)
79{
80 if (option == nullptr)
81 {
82 return false;
83 }
84
85 // Check whether 'option' is provided.
86 return vm.find(option) != vm.end();
87}
88
89void CheckOptionDependency(const po::variables_map& vm,
90 const char* option,
91 const char* required)
92{
93 if (option == nullptr || required == nullptr)
94 {
95 throw po::error("Invalid option to check dependency for");
96 }
97
98 // Check that if 'option' is provided, 'required' is also provided.
99 if (CheckOption(vm, option) && !vm[option].defaulted())
100 {
101 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
102 {
103 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
104 }
105 }
106}
107
108void CheckOptionDependencies(const po::variables_map& vm)
109{
110 CheckOptionDependency(vm, "model-path", "model-format");
111 CheckOptionDependency(vm, "model-path", "input-name");
112 CheckOptionDependency(vm, "model-path", "output-name");
113 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
114}
115
116int ParseCommandLineArgs(int argc, const char* argv[],
117 std::string& modelFormat,
118 std::string& modelPath,
119 std::vector<std::string>& inputNames,
120 std::vector<std::string>& inputTensorShapeStrs,
121 std::vector<std::string>& outputNames,
122 std::string& outputPath, bool& isModelBinary)
123{
124 po::options_description desc("Options");
125
126 desc.add_options()
127 ("help", "Display usage information")
Sadik Armagan232cfc22019-03-13 18:33:10 +0000128 ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
129#if defined(ARMNN_CAFFE_PARSER)
130 ", caffe-binary, caffe-text"
131#endif
132#if defined(ARMNN_ONNX_PARSER)
133 ", onnx-binary, onnx-text"
134#endif
Les Bell10e6be42019-03-28 12:26:46 +0000135#if defined(ARMNN_TF_PARSER)
Sadik Armagan232cfc22019-03-13 18:33:10 +0000136 ", tensorflow-binary, tensorflow-text"
137#endif
138#if defined(ARMNN_TF_LITE_PARSER)
139 ", tflite-binary"
140#endif
141 ".")
Les Bell10e6be42019-03-28 12:26:46 +0000142 ("model-path,m", po::value(&modelPath)->required(), "Path to model file.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000143 ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000144 "Identifier of the input tensors in the network, separated by whitespace.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000145 ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000146 "The shape of the input tensor in the network as a flat array of integers, separated by comma."
147 " Multiple shapes are separated by whitespace."
148 " This parameter is optional, depending on the network.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000149 ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
150 "Identifier of the output tensor in the network.")
151 ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
152
153 po::variables_map vm;
154 try
155 {
156 po::store(po::parse_command_line(argc, argv, desc), vm);
157
158 if (CheckOption(vm, "help") || argc <= 1)
159 {
Les Bell10e6be42019-03-28 12:26:46 +0000160 std::cout << "Convert a neural network model from provided file to ArmNN format." << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000161 std::cout << std::endl;
162 std::cout << desc << std::endl;
Les Bell10e6be42019-03-28 12:26:46 +0000163 exit(EXIT_SUCCESS);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000164 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000165 po::notify(vm);
166 }
167 catch (const po::error& e)
168 {
169 std::cerr << e.what() << std::endl << std::endl;
170 std::cerr << desc << std::endl;
171 return EXIT_FAILURE;
172 }
173
174 try
175 {
176 CheckOptionDependencies(vm);
177 }
178 catch (const po::error& e)
179 {
180 std::cerr << e.what() << std::endl << std::endl;
181 std::cerr << desc << std::endl;
182 return EXIT_FAILURE;
183 }
184
185 if (modelFormat.find("bin") != std::string::npos)
186 {
187 isModelBinary = true;
188 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000189 else if (modelFormat.find("text") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000190 {
191 isModelBinary = false;
192 }
193 else
194 {
195 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
196 return EXIT_FAILURE;
197 }
198
Matthew Benthamc01b3912019-04-26 16:57:29 +0100199 if (!vm["input-tensor-shape"].empty())
200 {
201 inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
202 }
203
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000204 inputNames = vm["input-name"].as<std::vector<std::string>>();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000205 outputNames = vm["output-name"].as<std::vector<std::string>>();
206
207 return EXIT_SUCCESS;
208}
209
Sadik Armagan232cfc22019-03-13 18:33:10 +0000210template<typename T>
211struct ParserType
212{
213 typedef T parserType;
214};
215
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000216class ArmnnConverter
217{
218public:
219 ArmnnConverter(const std::string& modelPath,
220 const std::vector<std::string>& inputNames,
221 const std::vector<armnn::TensorShape>& inputShapes,
222 const std::vector<std::string>& outputNames,
223 const std::string& outputPath,
224 bool isModelBinary)
225 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
226 m_ModelPath(modelPath),
227 m_InputNames(inputNames),
228 m_InputShapes(inputShapes),
229 m_OutputNames(outputNames),
230 m_OutputPath(outputPath),
231 m_IsModelBinary(isModelBinary) {}
232
233 bool Serialize()
234 {
235 if (m_NetworkPtr.get() == nullptr)
236 {
237 return false;
238 }
239
240 auto serializer(armnnSerializer::ISerializer::Create());
241
242 serializer->Serialize(*m_NetworkPtr);
243
244 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
245
246 bool retVal = serializer->SaveSerializedToStream(file);
247
248 return retVal;
249 }
250
251 template <typename IParser>
252 bool CreateNetwork ()
253 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000254 return CreateNetwork (ParserType<IParser>());
255 }
256
257private:
258 armnn::INetworkPtr m_NetworkPtr;
259 std::string m_ModelPath;
260 std::vector<std::string> m_InputNames;
261 std::vector<armnn::TensorShape> m_InputShapes;
262 std::vector<std::string> m_OutputNames;
263 std::string m_OutputPath;
264 bool m_IsModelBinary;
265
266 template <typename IParser>
267 bool CreateNetwork (ParserType<IParser>)
268 {
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000269 // Create a network from a file on disk
270 auto parser(IParser::Create());
271
272 std::map<std::string, armnn::TensorShape> inputShapes;
273 if (!m_InputShapes.empty())
274 {
275 const size_t numInputShapes = m_InputShapes.size();
276 const size_t numInputBindings = m_InputNames.size();
277 if (numInputShapes < numInputBindings)
278 {
279 throw armnn::Exception(boost::str(boost::format(
280 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
281 % numInputBindings % numInputShapes));
282 }
283
284 for (size_t i = 0; i < numInputShapes; i++)
285 {
286 inputShapes[m_InputNames[i]] = m_InputShapes[i];
287 }
288 }
289
290 {
291 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
292 m_NetworkPtr = (m_IsModelBinary ?
293 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
294 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
295 }
296
297 return m_NetworkPtr.get() != nullptr;
298 }
299
Sadik Armagan232cfc22019-03-13 18:33:10 +0000300#if defined(ARMNN_TF_LITE_PARSER)
301 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
302 {
303 // Create a network from a file on disk
304 auto parser(armnnTfLiteParser::ITfLiteParser::Create());
305
306 if (!m_InputShapes.empty())
307 {
308 const size_t numInputShapes = m_InputShapes.size();
309 const size_t numInputBindings = m_InputNames.size();
310 if (numInputShapes < numInputBindings)
311 {
312 throw armnn::Exception(boost::str(boost::format(
313 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
314 % numInputBindings % numInputShapes));
315 }
316 }
317
318 {
319 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
320 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
321 }
322
323 return m_NetworkPtr.get() != nullptr;
324 }
325#endif
326
327#if defined(ARMNN_ONNX_PARSER)
328 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
329 {
330 // Create a network from a file on disk
331 auto parser(armnnOnnxParser::IOnnxParser::Create());
332
333 if (!m_InputShapes.empty())
334 {
335 const size_t numInputShapes = m_InputShapes.size();
336 const size_t numInputBindings = m_InputNames.size();
337 if (numInputShapes < numInputBindings)
338 {
339 throw armnn::Exception(boost::str(boost::format(
340 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
341 % numInputBindings % numInputShapes));
342 }
343 }
344
345 {
346 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
347 m_NetworkPtr = (m_IsModelBinary ?
348 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
349 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
350 }
351
352 return m_NetworkPtr.get() != nullptr;
353 }
354#endif
355
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000356};
357
358} // anonymous namespace
359
360int main(int argc, const char* argv[])
361{
362
Sadik Armagan232cfc22019-03-13 18:33:10 +0000363#if (!defined(ARMNN_CAFFE_PARSER) \
364 && !defined(ARMNN_ONNX_PARSER) \
365 && !defined(ARMNN_TF_PARSER) \
366 && !defined(ARMNN_TF_LITE_PARSER))
367 BOOST_LOG_TRIVIAL(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000368 return EXIT_FAILURE;
369#endif
370
371#if !defined(ARMNN_SERIALIZER)
372 BOOST_LOG_TRIVIAL(fatal) << "Not built with Serializer support.";
373 return EXIT_FAILURE;
374#endif
375
376#ifdef NDEBUG
377 armnn::LogSeverity level = armnn::LogSeverity::Info;
378#else
379 armnn::LogSeverity level = armnn::LogSeverity::Debug;
380#endif
381
382 armnn::ConfigureLogging(true, true, level);
383 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
384
385 std::string modelFormat;
386 std::string modelPath;
387
388 std::vector<std::string> inputNames;
389 std::vector<std::string> inputTensorShapeStrs;
390 std::vector<armnn::TensorShape> inputTensorShapes;
391
392 std::vector<std::string> outputNames;
393 std::string outputPath;
394
395 bool isModelBinary = true;
396
397 if (ParseCommandLineArgs(
398 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
399 != EXIT_SUCCESS)
400 {
401 return EXIT_FAILURE;
402 }
403
404 for (const std::string& shapeStr : inputTensorShapeStrs)
405 {
406 if (!shapeStr.empty())
407 {
408 std::stringstream ss(shapeStr);
409
410 try
411 {
412 armnn::TensorShape shape = ParseTensorShape(ss);
413 inputTensorShapes.push_back(shape);
414 }
415 catch (const armnn::InvalidArgumentException& e)
416 {
417 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
418 return EXIT_FAILURE;
419 }
420 }
421 }
422
423 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
424
Sadik Armagan232cfc22019-03-13 18:33:10 +0000425 if (modelFormat.find("caffe") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000426 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000427#if defined(ARMNN_CAFFE_PARSER)
428 if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
429 {
430 BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
431 return EXIT_FAILURE;
432 }
433#else
434 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
435 return EXIT_FAILURE;
436#endif
437 }
438 else if (modelFormat.find("onnx") != std::string::npos)
439 {
440#if defined(ARMNN_ONNX_PARSER)
441 if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
442 {
443 BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
444 return EXIT_FAILURE;
445 }
446#else
447 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
448 return EXIT_FAILURE;
449#endif
450 }
451 else if (modelFormat.find("tensorflow") != std::string::npos)
452 {
453#if defined(ARMNN_TF_PARSER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000454 if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
455 {
456 BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
457 return EXIT_FAILURE;
458 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000459#else
460 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
461 return EXIT_FAILURE;
462#endif
463 }
464 else if (modelFormat.find("tflite") != std::string::npos)
465 {
466#if defined(ARMNN_TF_LITE_PARSER)
467 if (!isModelBinary)
468 {
469 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
470 for tflite files";
471 return EXIT_FAILURE;
472 }
473
474 if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
475 {
476 BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
477 return EXIT_FAILURE;
478 }
479#else
480 BOOST_LOG_TRIVIAL(fatal) << "Not built with TfLite parser support.";
481 return EXIT_FAILURE;
482#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000483 }
484 else
485 {
Les Bell10e6be42019-03-28 12:26:46 +0000486 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000487 return EXIT_FAILURE;
488 }
489
490 if (!converter.Serialize())
491 {
492 BOOST_LOG_TRIVIAL(fatal) << "Failed to serialize model";
493 return EXIT_FAILURE;
494 }
495
496 return EXIT_SUCCESS;
497}