blob: 21b89ea6f8534b357c600a3c5bded95ffd516a91 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
Matthew Benthamf48afc62020-01-15 17:55:08 +00005#include <armnn/Logging.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00006
Sadik Armagan232cfc22019-03-13 18:33:10 +00007#if defined(ARMNN_CAFFE_PARSER)
8#include <armnnCaffeParser/ICaffeParser.hpp>
9#endif
10#if defined(ARMNN_ONNX_PARSER)
11#include <armnnOnnxParser/IOnnxParser.hpp>
12#endif
13#if defined(ARMNN_SERIALIZER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000014#include <armnnSerializer/ISerializer.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000015#endif
16#if defined(ARMNN_TF_PARSER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000017#include <armnnTfParser/ITfParser.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000018#endif
19#if defined(ARMNN_TF_LITE_PARSER)
20#include <armnnTfLiteParser/ITfLiteParser.hpp>
21#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000022
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000023#include <HeapProfiling.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +010024#include <armnn/utility/NumericCast.hpp>
David Monahana8837bf2020-04-16 10:01:56 +010025#include "armnn/utility/StringUtils.hpp"
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000026
27#include <boost/format.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000028#include <boost/program_options.hpp>
29
Les Bell10e6be42019-03-28 12:26:46 +000030#include <cstdlib>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000031#include <fstream>
Les Bell10e6be42019-03-28 12:26:46 +000032#include <iostream>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000033
34namespace
35{
36
37namespace po = boost::program_options;
38
39armnn::TensorShape ParseTensorShape(std::istream& stream)
40{
41 std::vector<unsigned int> result;
42 std::string line;
43
44 while (std::getline(stream, line))
45 {
David Monahana8837bf2020-04-16 10:01:56 +010046 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000047 for (const std::string& token : tokens)
48 {
49 if (!token.empty())
50 {
51 try
52 {
Matthew Sloyan0663d662020-09-14 11:47:26 +010053 result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000054 }
55 catch (const std::exception&)
56 {
Derek Lamberti08446972019-11-26 16:38:31 +000057 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000058 }
59 }
60 }
61 }
62
Matthew Sloyan0663d662020-09-14 11:47:26 +010063 return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000064}
65
66bool CheckOption(const po::variables_map& vm,
67 const char* option)
68{
69 if (option == nullptr)
70 {
71 return false;
72 }
73
74 // Check whether 'option' is provided.
75 return vm.find(option) != vm.end();
76}
77
78void CheckOptionDependency(const po::variables_map& vm,
79 const char* option,
80 const char* required)
81{
82 if (option == nullptr || required == nullptr)
83 {
84 throw po::error("Invalid option to check dependency for");
85 }
86
87 // Check that if 'option' is provided, 'required' is also provided.
88 if (CheckOption(vm, option) && !vm[option].defaulted())
89 {
90 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
91 {
92 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
93 }
94 }
95}
96
97void CheckOptionDependencies(const po::variables_map& vm)
98{
99 CheckOptionDependency(vm, "model-path", "model-format");
100 CheckOptionDependency(vm, "model-path", "input-name");
101 CheckOptionDependency(vm, "model-path", "output-name");
102 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
103}
104
105int ParseCommandLineArgs(int argc, const char* argv[],
106 std::string& modelFormat,
107 std::string& modelPath,
108 std::vector<std::string>& inputNames,
109 std::vector<std::string>& inputTensorShapeStrs,
110 std::vector<std::string>& outputNames,
111 std::string& outputPath, bool& isModelBinary)
112{
113 po::options_description desc("Options");
114
115 desc.add_options()
116 ("help", "Display usage information")
Sadik Armagan232cfc22019-03-13 18:33:10 +0000117 ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
118#if defined(ARMNN_CAFFE_PARSER)
119 ", caffe-binary, caffe-text"
120#endif
121#if defined(ARMNN_ONNX_PARSER)
122 ", onnx-binary, onnx-text"
123#endif
Les Bell10e6be42019-03-28 12:26:46 +0000124#if defined(ARMNN_TF_PARSER)
Sadik Armagan232cfc22019-03-13 18:33:10 +0000125 ", tensorflow-binary, tensorflow-text"
126#endif
127#if defined(ARMNN_TF_LITE_PARSER)
128 ", tflite-binary"
129#endif
130 ".")
Les Bell10e6be42019-03-28 12:26:46 +0000131 ("model-path,m", po::value(&modelPath)->required(), "Path to model file.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000132 ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000133 "Identifier of the input tensors in the network, separated by whitespace.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000134 ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
Les Bell10e6be42019-03-28 12:26:46 +0000135 "The shape of the input tensor in the network as a flat array of integers, separated by comma."
136 " Multiple shapes are separated by whitespace."
137 " This parameter is optional, depending on the network.")
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000138 ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
139 "Identifier of the output tensor in the network.")
140 ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
141
142 po::variables_map vm;
143 try
144 {
145 po::store(po::parse_command_line(argc, argv, desc), vm);
146
147 if (CheckOption(vm, "help") || argc <= 1)
148 {
Les Bell10e6be42019-03-28 12:26:46 +0000149 std::cout << "Convert a neural network model from provided file to ArmNN format." << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000150 std::cout << std::endl;
151 std::cout << desc << std::endl;
Les Bell10e6be42019-03-28 12:26:46 +0000152 exit(EXIT_SUCCESS);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000153 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000154 po::notify(vm);
155 }
156 catch (const po::error& e)
157 {
158 std::cerr << e.what() << std::endl << std::endl;
159 std::cerr << desc << std::endl;
160 return EXIT_FAILURE;
161 }
162
163 try
164 {
165 CheckOptionDependencies(vm);
166 }
167 catch (const po::error& e)
168 {
169 std::cerr << e.what() << std::endl << std::endl;
170 std::cerr << desc << std::endl;
171 return EXIT_FAILURE;
172 }
173
174 if (modelFormat.find("bin") != std::string::npos)
175 {
176 isModelBinary = true;
177 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000178 else if (modelFormat.find("text") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000179 {
180 isModelBinary = false;
181 }
182 else
183 {
Derek Lamberti08446972019-11-26 16:38:31 +0000184 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000185 return EXIT_FAILURE;
186 }
187
Matthew Benthamc01b3912019-04-26 16:57:29 +0100188 if (!vm["input-tensor-shape"].empty())
189 {
190 inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
191 }
192
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000193 inputNames = vm["input-name"].as<std::vector<std::string>>();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000194 outputNames = vm["output-name"].as<std::vector<std::string>>();
195
196 return EXIT_SUCCESS;
197}
198
Sadik Armagan232cfc22019-03-13 18:33:10 +0000199template<typename T>
200struct ParserType
201{
202 typedef T parserType;
203};
204
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000205class ArmnnConverter
206{
207public:
208 ArmnnConverter(const std::string& modelPath,
209 const std::vector<std::string>& inputNames,
210 const std::vector<armnn::TensorShape>& inputShapes,
211 const std::vector<std::string>& outputNames,
212 const std::string& outputPath,
213 bool isModelBinary)
214 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
215 m_ModelPath(modelPath),
216 m_InputNames(inputNames),
217 m_InputShapes(inputShapes),
218 m_OutputNames(outputNames),
219 m_OutputPath(outputPath),
220 m_IsModelBinary(isModelBinary) {}
221
222 bool Serialize()
223 {
224 if (m_NetworkPtr.get() == nullptr)
225 {
226 return false;
227 }
228
229 auto serializer(armnnSerializer::ISerializer::Create());
230
231 serializer->Serialize(*m_NetworkPtr);
232
233 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
234
235 bool retVal = serializer->SaveSerializedToStream(file);
236
237 return retVal;
238 }
239
240 template <typename IParser>
241 bool CreateNetwork ()
242 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000243 return CreateNetwork (ParserType<IParser>());
244 }
245
246private:
247 armnn::INetworkPtr m_NetworkPtr;
248 std::string m_ModelPath;
249 std::vector<std::string> m_InputNames;
250 std::vector<armnn::TensorShape> m_InputShapes;
251 std::vector<std::string> m_OutputNames;
252 std::string m_OutputPath;
253 bool m_IsModelBinary;
254
255 template <typename IParser>
256 bool CreateNetwork (ParserType<IParser>)
257 {
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000258 // Create a network from a file on disk
259 auto parser(IParser::Create());
260
261 std::map<std::string, armnn::TensorShape> inputShapes;
262 if (!m_InputShapes.empty())
263 {
264 const size_t numInputShapes = m_InputShapes.size();
265 const size_t numInputBindings = m_InputNames.size();
266 if (numInputShapes < numInputBindings)
267 {
268 throw armnn::Exception(boost::str(boost::format(
269 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
270 % numInputBindings % numInputShapes));
271 }
272
273 for (size_t i = 0; i < numInputShapes; i++)
274 {
275 inputShapes[m_InputNames[i]] = m_InputShapes[i];
276 }
277 }
278
279 {
280 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
281 m_NetworkPtr = (m_IsModelBinary ?
282 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
283 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
284 }
285
286 return m_NetworkPtr.get() != nullptr;
287 }
288
Sadik Armagan232cfc22019-03-13 18:33:10 +0000289#if defined(ARMNN_TF_LITE_PARSER)
290 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
291 {
292 // Create a network from a file on disk
293 auto parser(armnnTfLiteParser::ITfLiteParser::Create());
294
295 if (!m_InputShapes.empty())
296 {
297 const size_t numInputShapes = m_InputShapes.size();
298 const size_t numInputBindings = m_InputNames.size();
299 if (numInputShapes < numInputBindings)
300 {
301 throw armnn::Exception(boost::str(boost::format(
302 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
303 % numInputBindings % numInputShapes));
304 }
305 }
306
307 {
308 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
309 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
310 }
311
312 return m_NetworkPtr.get() != nullptr;
313 }
314#endif
315
316#if defined(ARMNN_ONNX_PARSER)
317 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
318 {
319 // Create a network from a file on disk
320 auto parser(armnnOnnxParser::IOnnxParser::Create());
321
322 if (!m_InputShapes.empty())
323 {
324 const size_t numInputShapes = m_InputShapes.size();
325 const size_t numInputBindings = m_InputNames.size();
326 if (numInputShapes < numInputBindings)
327 {
328 throw armnn::Exception(boost::str(boost::format(
329 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
330 % numInputBindings % numInputShapes));
331 }
332 }
333
334 {
335 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
336 m_NetworkPtr = (m_IsModelBinary ?
337 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
338 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
339 }
340
341 return m_NetworkPtr.get() != nullptr;
342 }
343#endif
344
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000345};
346
347} // anonymous namespace
348
349int main(int argc, const char* argv[])
350{
351
Sadik Armagan232cfc22019-03-13 18:33:10 +0000352#if (!defined(ARMNN_CAFFE_PARSER) \
353 && !defined(ARMNN_ONNX_PARSER) \
354 && !defined(ARMNN_TF_PARSER) \
355 && !defined(ARMNN_TF_LITE_PARSER))
Derek Lamberti08446972019-11-26 16:38:31 +0000356 ARMNN_LOG(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000357 return EXIT_FAILURE;
358#endif
359
360#if !defined(ARMNN_SERIALIZER)
Derek Lamberti08446972019-11-26 16:38:31 +0000361 ARMNN_LOG(fatal) << "Not built with Serializer support.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000362 return EXIT_FAILURE;
363#endif
364
365#ifdef NDEBUG
366 armnn::LogSeverity level = armnn::LogSeverity::Info;
367#else
368 armnn::LogSeverity level = armnn::LogSeverity::Debug;
369#endif
370
371 armnn::ConfigureLogging(true, true, level);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000372
373 std::string modelFormat;
374 std::string modelPath;
375
376 std::vector<std::string> inputNames;
377 std::vector<std::string> inputTensorShapeStrs;
378 std::vector<armnn::TensorShape> inputTensorShapes;
379
380 std::vector<std::string> outputNames;
381 std::string outputPath;
382
383 bool isModelBinary = true;
384
385 if (ParseCommandLineArgs(
386 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
387 != EXIT_SUCCESS)
388 {
389 return EXIT_FAILURE;
390 }
391
392 for (const std::string& shapeStr : inputTensorShapeStrs)
393 {
394 if (!shapeStr.empty())
395 {
396 std::stringstream ss(shapeStr);
397
398 try
399 {
400 armnn::TensorShape shape = ParseTensorShape(ss);
401 inputTensorShapes.push_back(shape);
402 }
403 catch (const armnn::InvalidArgumentException& e)
404 {
Derek Lamberti08446972019-11-26 16:38:31 +0000405 ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000406 return EXIT_FAILURE;
407 }
408 }
409 }
410
411 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
412
Derek Lambertic9e52792020-03-11 11:42:26 +0000413 try
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000414 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000415 if (modelFormat.find("caffe") != std::string::npos)
416 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000417#if defined(ARMNN_CAFFE_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000418 if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
419 {
420 ARMNN_LOG(fatal) << "Failed to load model from file";
421 return EXIT_FAILURE;
422 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000423#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000424 ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
425 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000426#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000427 }
428 else if (modelFormat.find("onnx") != std::string::npos)
429 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000430#if defined(ARMNN_ONNX_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000431 if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
432 {
433 ARMNN_LOG(fatal) << "Failed to load model from file";
434 return EXIT_FAILURE;
435 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000436#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000437 ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
438 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000439#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000440 }
441 else if (modelFormat.find("tensorflow") != std::string::npos)
442 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000443#if defined(ARMNN_TF_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000444 if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
445 {
446 ARMNN_LOG(fatal) << "Failed to load model from file";
447 return EXIT_FAILURE;
448 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000449#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000450 ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
451 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000452#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000453 }
454 else if (modelFormat.find("tflite") != std::string::npos)
455 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000456#if defined(ARMNN_TF_LITE_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000457 if (!isModelBinary)
458 {
459 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
460 for tflite files";
461 return EXIT_FAILURE;
462 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000463
Derek Lambertic9e52792020-03-11 11:42:26 +0000464 if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
465 {
466 ARMNN_LOG(fatal) << "Failed to load model from file";
467 return EXIT_FAILURE;
468 }
469#else
470 ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
471 return EXIT_FAILURE;
472#endif
473 }
474 else
Sadik Armagan232cfc22019-03-13 18:33:10 +0000475 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000476 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000477 return EXIT_FAILURE;
478 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000479 }
Derek Lambertic9e52792020-03-11 11:42:26 +0000480 catch(armnn::Exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000481 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000482 ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000483 return EXIT_FAILURE;
484 }
485
486 if (!converter.Serialize())
487 {
Derek Lamberti08446972019-11-26 16:38:31 +0000488 ARMNN_LOG(fatal) << "Failed to serialize model";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000489 return EXIT_FAILURE;
490 }
491
492 return EXIT_SUCCESS;
493}