blob: f143e6989bf92a964b0b8ea9f95718d132e6f22f [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Benthamf48afc62020-01-15 17:55:08 +00005#include <armnn/Logging.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00006
Sadik Armagan232cfc22019-03-13 18:33:10 +00007#if defined(ARMNN_CAFFE_PARSER)
8#include <armnnCaffeParser/ICaffeParser.hpp>
9#endif
10#if defined(ARMNN_ONNX_PARSER)
11#include <armnnOnnxParser/IOnnxParser.hpp>
12#endif
13#if defined(ARMNN_SERIALIZER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000014#include <armnnSerializer/ISerializer.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000015#endif
16#if defined(ARMNN_TF_PARSER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000017#include <armnnTfParser/ITfParser.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000018#endif
19#if defined(ARMNN_TF_LITE_PARSER)
20#include <armnnTfLiteParser/ITfLiteParser.hpp>
21#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000022
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000023#include <HeapProfiling.hpp>
David Monahana8837bf2020-04-16 10:01:56 +010024#include "armnn/utility/StringUtils.hpp"
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000025
26#include <boost/format.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000027#include <boost/program_options.hpp>
28
Les Bell10e6be42019-03-28 12:26:46 +000029#include <cstdlib>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000030#include <fstream>
Les Bell10e6be42019-03-28 12:26:46 +000031#include <iostream>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000032
33namespace
34{
35
36namespace po = boost::program_options;
37
38armnn::TensorShape ParseTensorShape(std::istream& stream)
39{
40 std::vector<unsigned int> result;
41 std::string line;
42
43 while (std::getline(stream, line))
44 {
David Monahana8837bf2020-04-16 10:01:56 +010045 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000046 for (const std::string& token : tokens)
47 {
48 if (!token.empty())
49 {
50 try
51 {
52 result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
53 }
54 catch (const std::exception&)
55 {
Derek Lamberti08446972019-11-26 16:38:31 +000056 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000057 }
58 }
59 }
60 }
61
62 return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
63}
64
65bool CheckOption(const po::variables_map& vm,
66 const char* option)
67{
68 if (option == nullptr)
69 {
70 return false;
71 }
72
73 // Check whether 'option' is provided.
74 return vm.find(option) != vm.end();
75}
76
77void CheckOptionDependency(const po::variables_map& vm,
78 const char* option,
79 const char* required)
80{
81 if (option == nullptr || required == nullptr)
82 {
83 throw po::error("Invalid option to check dependency for");
84 }
85
86 // Check that if 'option' is provided, 'required' is also provided.
87 if (CheckOption(vm, option) && !vm[option].defaulted())
88 {
89 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
90 {
91 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
92 }
93 }
94}
95
96void CheckOptionDependencies(const po::variables_map& vm)
97{
98 CheckOptionDependency(vm, "model-path", "model-format");
99 CheckOptionDependency(vm, "model-path", "input-name");
100 CheckOptionDependency(vm, "model-path", "output-name");
101 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
102}
103
104int ParseCommandLineArgs(int argc, const char* argv[],
105 std::string& modelFormat,
106 std::string& modelPath,
107 std::vector<std::string>& inputNames,
108 std::vector<std::string>& inputTensorShapeStrs,
109 std::vector<std::string>& outputNames,
110 std::string& outputPath, bool& isModelBinary)
111{
112 po::options_description desc("Options");
113
114 desc.add_options()
115 ("help", "Display usage information")
Sadik Armagan232cfc22019-03-13 18:33:10 +0000116 ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
117#if defined(ARMNN_CAFFE_PARSER)
118 ", caffe-binary, caffe-text"
119#endif
120#if defined(ARMNN_ONNX_PARSER)
121 ", onnx-binary, onnx-text"
122#endif
Les Bell10e6be42019-03-28 12:26:46 +0000123#if defined(ARMNN_TF_PARSER)
Sadik Armagan232cfc22019-03-13 18:33:10 +0000124 ", tensorflow-binary, tensorflow-text"
125#endif
126#if defined(ARMNN_TF_LITE_PARSER)
127 ", tflite-binary"
128#endif
129 ".")
Les Bell10e6be42019-03-28 12:26:46 +0000130 ("model-path,m", po::value(&modelPath)->required(), "Path to model file.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000131 ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000132 "Identifier of the input tensors in the network, separated by whitespace.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000133 ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000134 "The shape of the input tensor in the network as a flat array of integers, separated by comma."
135 " Multiple shapes are separated by whitespace."
136 " This parameter is optional, depending on the network.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000137 ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
138 "Identifier of the output tensor in the network.")
139 ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
140
141 po::variables_map vm;
142 try
143 {
144 po::store(po::parse_command_line(argc, argv, desc), vm);
145
146 if (CheckOption(vm, "help") || argc <= 1)
147 {
Les Bell10e6be42019-03-28 12:26:46 +0000148 std::cout << "Convert a neural network model from provided file to ArmNN format." << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000149 std::cout << std::endl;
150 std::cout << desc << std::endl;
Les Bell10e6be42019-03-28 12:26:46 +0000151 exit(EXIT_SUCCESS);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000152 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000153 po::notify(vm);
154 }
155 catch (const po::error& e)
156 {
157 std::cerr << e.what() << std::endl << std::endl;
158 std::cerr << desc << std::endl;
159 return EXIT_FAILURE;
160 }
161
162 try
163 {
164 CheckOptionDependencies(vm);
165 }
166 catch (const po::error& e)
167 {
168 std::cerr << e.what() << std::endl << std::endl;
169 std::cerr << desc << std::endl;
170 return EXIT_FAILURE;
171 }
172
173 if (modelFormat.find("bin") != std::string::npos)
174 {
175 isModelBinary = true;
176 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000177 else if (modelFormat.find("text") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000178 {
179 isModelBinary = false;
180 }
181 else
182 {
Derek Lamberti08446972019-11-26 16:38:31 +0000183 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000184 return EXIT_FAILURE;
185 }
186
Matthew Benthamc01b3912019-04-26 16:57:29 +0100187 if (!vm["input-tensor-shape"].empty())
188 {
189 inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
190 }
191
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000192 inputNames = vm["input-name"].as<std::vector<std::string>>();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000193 outputNames = vm["output-name"].as<std::vector<std::string>>();
194
195 return EXIT_SUCCESS;
196}
197
Sadik Armagan232cfc22019-03-13 18:33:10 +0000198template<typename T>
199struct ParserType
200{
201 typedef T parserType;
202};
203
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000204class ArmnnConverter
205{
206public:
207 ArmnnConverter(const std::string& modelPath,
208 const std::vector<std::string>& inputNames,
209 const std::vector<armnn::TensorShape>& inputShapes,
210 const std::vector<std::string>& outputNames,
211 const std::string& outputPath,
212 bool isModelBinary)
213 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
214 m_ModelPath(modelPath),
215 m_InputNames(inputNames),
216 m_InputShapes(inputShapes),
217 m_OutputNames(outputNames),
218 m_OutputPath(outputPath),
219 m_IsModelBinary(isModelBinary) {}
220
221 bool Serialize()
222 {
223 if (m_NetworkPtr.get() == nullptr)
224 {
225 return false;
226 }
227
228 auto serializer(armnnSerializer::ISerializer::Create());
229
230 serializer->Serialize(*m_NetworkPtr);
231
232 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
233
234 bool retVal = serializer->SaveSerializedToStream(file);
235
236 return retVal;
237 }
238
239 template <typename IParser>
240 bool CreateNetwork ()
241 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000242 return CreateNetwork (ParserType<IParser>());
243 }
244
245private:
246 armnn::INetworkPtr m_NetworkPtr;
247 std::string m_ModelPath;
248 std::vector<std::string> m_InputNames;
249 std::vector<armnn::TensorShape> m_InputShapes;
250 std::vector<std::string> m_OutputNames;
251 std::string m_OutputPath;
252 bool m_IsModelBinary;
253
254 template <typename IParser>
255 bool CreateNetwork (ParserType<IParser>)
256 {
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000257 // Create a network from a file on disk
258 auto parser(IParser::Create());
259
260 std::map<std::string, armnn::TensorShape> inputShapes;
261 if (!m_InputShapes.empty())
262 {
263 const size_t numInputShapes = m_InputShapes.size();
264 const size_t numInputBindings = m_InputNames.size();
265 if (numInputShapes < numInputBindings)
266 {
267 throw armnn::Exception(boost::str(boost::format(
268 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
269 % numInputBindings % numInputShapes));
270 }
271
272 for (size_t i = 0; i < numInputShapes; i++)
273 {
274 inputShapes[m_InputNames[i]] = m_InputShapes[i];
275 }
276 }
277
278 {
279 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
280 m_NetworkPtr = (m_IsModelBinary ?
281 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
282 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
283 }
284
285 return m_NetworkPtr.get() != nullptr;
286 }
287
Sadik Armagan232cfc22019-03-13 18:33:10 +0000288#if defined(ARMNN_TF_LITE_PARSER)
289 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
290 {
291 // Create a network from a file on disk
292 auto parser(armnnTfLiteParser::ITfLiteParser::Create());
293
294 if (!m_InputShapes.empty())
295 {
296 const size_t numInputShapes = m_InputShapes.size();
297 const size_t numInputBindings = m_InputNames.size();
298 if (numInputShapes < numInputBindings)
299 {
300 throw armnn::Exception(boost::str(boost::format(
301 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
302 % numInputBindings % numInputShapes));
303 }
304 }
305
306 {
307 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
308 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
309 }
310
311 return m_NetworkPtr.get() != nullptr;
312 }
313#endif
314
315#if defined(ARMNN_ONNX_PARSER)
316 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
317 {
318 // Create a network from a file on disk
319 auto parser(armnnOnnxParser::IOnnxParser::Create());
320
321 if (!m_InputShapes.empty())
322 {
323 const size_t numInputShapes = m_InputShapes.size();
324 const size_t numInputBindings = m_InputNames.size();
325 if (numInputShapes < numInputBindings)
326 {
327 throw armnn::Exception(boost::str(boost::format(
328 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
329 % numInputBindings % numInputShapes));
330 }
331 }
332
333 {
334 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
335 m_NetworkPtr = (m_IsModelBinary ?
336 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
337 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
338 }
339
340 return m_NetworkPtr.get() != nullptr;
341 }
342#endif
343
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000344};
345
346} // anonymous namespace
347
348int main(int argc, const char* argv[])
349{
350
Sadik Armagan232cfc22019-03-13 18:33:10 +0000351#if (!defined(ARMNN_CAFFE_PARSER) \
352 && !defined(ARMNN_ONNX_PARSER) \
353 && !defined(ARMNN_TF_PARSER) \
354 && !defined(ARMNN_TF_LITE_PARSER))
Derek Lamberti08446972019-11-26 16:38:31 +0000355 ARMNN_LOG(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000356 return EXIT_FAILURE;
357#endif
358
359#if !defined(ARMNN_SERIALIZER)
Derek Lamberti08446972019-11-26 16:38:31 +0000360 ARMNN_LOG(fatal) << "Not built with Serializer support.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000361 return EXIT_FAILURE;
362#endif
363
364#ifdef NDEBUG
365 armnn::LogSeverity level = armnn::LogSeverity::Info;
366#else
367 armnn::LogSeverity level = armnn::LogSeverity::Debug;
368#endif
369
370 armnn::ConfigureLogging(true, true, level);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000371
372 std::string modelFormat;
373 std::string modelPath;
374
375 std::vector<std::string> inputNames;
376 std::vector<std::string> inputTensorShapeStrs;
377 std::vector<armnn::TensorShape> inputTensorShapes;
378
379 std::vector<std::string> outputNames;
380 std::string outputPath;
381
382 bool isModelBinary = true;
383
384 if (ParseCommandLineArgs(
385 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
386 != EXIT_SUCCESS)
387 {
388 return EXIT_FAILURE;
389 }
390
391 for (const std::string& shapeStr : inputTensorShapeStrs)
392 {
393 if (!shapeStr.empty())
394 {
395 std::stringstream ss(shapeStr);
396
397 try
398 {
399 armnn::TensorShape shape = ParseTensorShape(ss);
400 inputTensorShapes.push_back(shape);
401 }
402 catch (const armnn::InvalidArgumentException& e)
403 {
Derek Lamberti08446972019-11-26 16:38:31 +0000404 ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000405 return EXIT_FAILURE;
406 }
407 }
408 }
409
410 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
411
Derek Lambertic9e52792020-03-11 11:42:26 +0000412 try
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000413 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000414 if (modelFormat.find("caffe") != std::string::npos)
415 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000416#if defined(ARMNN_CAFFE_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000417 if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
418 {
419 ARMNN_LOG(fatal) << "Failed to load model from file";
420 return EXIT_FAILURE;
421 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000422#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000423 ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
424 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000425#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000426 }
427 else if (modelFormat.find("onnx") != std::string::npos)
428 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000429#if defined(ARMNN_ONNX_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000430 if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
431 {
432 ARMNN_LOG(fatal) << "Failed to load model from file";
433 return EXIT_FAILURE;
434 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000435#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000436 ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
437 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000438#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000439 }
440 else if (modelFormat.find("tensorflow") != std::string::npos)
441 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000442#if defined(ARMNN_TF_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000443 if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
444 {
445 ARMNN_LOG(fatal) << "Failed to load model from file";
446 return EXIT_FAILURE;
447 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000448#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000449 ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
450 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000451#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000452 }
453 else if (modelFormat.find("tflite") != std::string::npos)
454 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000455#if defined(ARMNN_TF_LITE_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000456 if (!isModelBinary)
457 {
458 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
459 for tflite files";
460 return EXIT_FAILURE;
461 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000462
Derek Lambertic9e52792020-03-11 11:42:26 +0000463 if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
464 {
465 ARMNN_LOG(fatal) << "Failed to load model from file";
466 return EXIT_FAILURE;
467 }
468#else
469 ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
470 return EXIT_FAILURE;
471#endif
472 }
473 else
Sadik Armagan232cfc22019-03-13 18:33:10 +0000474 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000475 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000476 return EXIT_FAILURE;
477 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000478 }
Derek Lambertic9e52792020-03-11 11:42:26 +0000479 catch(armnn::Exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000480 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000481 ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000482 return EXIT_FAILURE;
483 }
484
485 if (!converter.Serialize())
486 {
Derek Lamberti08446972019-11-26 16:38:31 +0000487 ARMNN_LOG(fatal) << "Failed to serialize model";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000488 return EXIT_FAILURE;
489 }
490
491 return EXIT_SUCCESS;
492}