blob: 012920425186a44447700bed9b330bf2993cd56e [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Jim Flynn357add22023-04-10 23:26:40 +01002// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00003// SPDX-License-Identifier: MIT
4//
Matthew Benthamf48afc62020-01-15 17:55:08 +00005#include <armnn/Logging.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00006
Sadik Armagan232cfc22019-03-13 18:33:10 +00007#if defined(ARMNN_ONNX_PARSER)
8#include <armnnOnnxParser/IOnnxParser.hpp>
9#endif
10#if defined(ARMNN_SERIALIZER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000011#include <armnnSerializer/ISerializer.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000012#endif
Sadik Armagan232cfc22019-03-13 18:33:10 +000013#if defined(ARMNN_TF_LITE_PARSER)
14#include <armnnTfLiteParser/ITfLiteParser.hpp>
15#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000016
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000017#include <HeapProfiling.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +010018#include <armnn/utility/NumericCast.hpp>
Colm Donelanb524ca02020-10-06 15:15:33 +010019#include <armnn/utility/StringUtils.hpp>
20
21/*
22 * Historically we use the ',' character to separate dimensions in a tensor shape. However, cxxopts will read this
23 * an an array of values which is fine until we have multiple tensors specified. This lumps the values of all shapes
24 * together in a single array and we cannot break it up again. We'll change the vector delimiter to a '.'. We do this
25 * as close as possible to the usage of cxxopts to avoid polluting other possible uses.
26 */
27#define CXXOPTS_VECTOR_DELIMITER '.'
28#include <cxxopts/cxxopts.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000029
Colm Donelan5b5c2222020-09-09 12:48:16 +010030#include <fmt/format.h>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000031
Les Bell10e6be42019-03-28 12:26:46 +000032#include <cstdlib>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000033#include <fstream>
Les Bell10e6be42019-03-28 12:26:46 +000034#include <iostream>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000035
36namespace
37{
38
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000039armnn::TensorShape ParseTensorShape(std::istream& stream)
40{
41 std::vector<unsigned int> result;
42 std::string line;
43
44 while (std::getline(stream, line))
45 {
David Monahana8837bf2020-04-16 10:01:56 +010046 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000047 for (const std::string& token : tokens)
48 {
49 if (!token.empty())
50 {
51 try
52 {
Matthew Sloyan0663d662020-09-14 11:47:26 +010053 result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000054 }
55 catch (const std::exception&)
56 {
Derek Lamberti08446972019-11-26 16:38:31 +000057 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000058 }
59 }
60 }
61 }
62
Matthew Sloyan0663d662020-09-14 11:47:26 +010063 return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000064}
65
Colm Donelanb524ca02020-10-06 15:15:33 +010066int ParseCommandLineArgs(int argc, char* argv[],
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000067 std::string& modelFormat,
68 std::string& modelPath,
69 std::vector<std::string>& inputNames,
70 std::vector<std::string>& inputTensorShapeStrs,
71 std::vector<std::string>& outputNames,
72 std::string& outputPath, bool& isModelBinary)
73{
Colm Donelanb524ca02020-10-06 15:15:33 +010074 cxxopts::Options options("ArmNNConverter", "Convert a neural network model from provided file to ArmNN format.");
75 try
76 {
77 std::string modelFormatDescription("Format of the model file");
Sadik Armagan232cfc22019-03-13 18:33:10 +000078#if defined(ARMNN_ONNX_PARSER)
Colm Donelanb524ca02020-10-06 15:15:33 +010079 modelFormatDescription += ", onnx-binary, onnx-text";
Sadik Armagan232cfc22019-03-13 18:33:10 +000080#endif
Les Bell10e6be42019-03-28 12:26:46 +000081#if defined(ARMNN_TF_PARSER)
Colm Donelanb524ca02020-10-06 15:15:33 +010082 modelFormatDescription += ", tensorflow-binary, tensorflow-text";
Sadik Armagan232cfc22019-03-13 18:33:10 +000083#endif
84#if defined(ARMNN_TF_LITE_PARSER)
Colm Donelanb524ca02020-10-06 15:15:33 +010085 modelFormatDescription += ", tflite-binary";
Sadik Armagan232cfc22019-03-13 18:33:10 +000086#endif
Colm Donelanb524ca02020-10-06 15:15:33 +010087 modelFormatDescription += ".";
88 options.add_options()
89 ("help", "Display usage information")
90 ("f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
91 ("m,model-path", "Path to model file.", cxxopts::value<std::string>(modelPath))
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000092
Colm Donelanb524ca02020-10-06 15:15:33 +010093 ("i,input-name", "Identifier of the input tensors in the network. "
94 "Each input must be specified separately.",
95 cxxopts::value<std::vector<std::string>>(inputNames))
96 ("s,input-tensor-shape",
97 "The shape of the input tensor in the network as a flat array of integers, "
98 "separated by comma. Each input shape must be specified separately after the input name. "
99 "This parameter is optional, depending on the network.",
100 cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000101
Colm Donelanb524ca02020-10-06 15:15:33 +0100102 ("o,output-name", "Identifier of the output tensor in the network.",
103 cxxopts::value<std::vector<std::string>>(outputNames))
104 ("p,output-path",
105 "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000106 }
Colm Donelanb524ca02020-10-06 15:15:33 +0100107 catch (const std::exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000108 {
Colm Donelanb524ca02020-10-06 15:15:33 +0100109 std::cerr << e.what() << std::endl << options.help() << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000110 return EXIT_FAILURE;
111 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000112 try
113 {
Colm Donelanb524ca02020-10-06 15:15:33 +0100114 cxxopts::ParseResult result = options.parse(argc, argv);
115 if (result.count("help"))
116 {
117 std::cerr << options.help() << std::endl;
118 return EXIT_SUCCESS;
119 }
120 // Check for mandatory single options.
121 std::string mandatorySingleParameters[] = { "model-format", "model-path", "output-name", "output-path" };
122 bool somethingsMissing = false;
123 for (auto param : mandatorySingleParameters)
124 {
125 if (result.count(param) != 1)
126 {
127 std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
128 somethingsMissing = true;
129 }
130 }
131 // Check at least one "input-name" option.
132 if (result.count("input-name") == 0)
133 {
134 std::cerr << "Parameter \'--" << "input-name" << "\' must be specified at least once." << std::endl;
135 somethingsMissing = true;
136 }
137 // If input-tensor-shape is specified then there must be a 1:1 match with input-name.
138 if (result.count("input-tensor-shape") > 0)
139 {
140 if (result.count("input-tensor-shape") != result.count("input-name"))
141 {
142 std::cerr << "When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters "
143 "must be specified." << std::endl;
144 somethingsMissing = true;
145 }
146 }
147
148 if (somethingsMissing)
149 {
150 std::cerr << options.help() << std::endl;
151 return EXIT_FAILURE;
152 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000153 }
Jim Flynn357add22023-04-10 23:26:40 +0100154 catch (const cxxopts::exceptions::exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000155 {
156 std::cerr << e.what() << std::endl << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000157 return EXIT_FAILURE;
158 }
159
160 if (modelFormat.find("bin") != std::string::npos)
161 {
162 isModelBinary = true;
163 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000164 else if (modelFormat.find("text") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000165 {
166 isModelBinary = false;
167 }
168 else
169 {
Derek Lamberti08446972019-11-26 16:38:31 +0000170 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000171 return EXIT_FAILURE;
172 }
173
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000174 return EXIT_SUCCESS;
175}
176
Sadik Armagan232cfc22019-03-13 18:33:10 +0000177template<typename T>
178struct ParserType
179{
180 typedef T parserType;
181};
182
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000183class ArmnnConverter
184{
185public:
186 ArmnnConverter(const std::string& modelPath,
187 const std::vector<std::string>& inputNames,
188 const std::vector<armnn::TensorShape>& inputShapes,
189 const std::vector<std::string>& outputNames,
190 const std::string& outputPath,
191 bool isModelBinary)
192 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
193 m_ModelPath(modelPath),
194 m_InputNames(inputNames),
195 m_InputShapes(inputShapes),
196 m_OutputNames(outputNames),
197 m_OutputPath(outputPath),
198 m_IsModelBinary(isModelBinary) {}
199
200 bool Serialize()
201 {
202 if (m_NetworkPtr.get() == nullptr)
203 {
204 return false;
205 }
206
207 auto serializer(armnnSerializer::ISerializer::Create());
208
209 serializer->Serialize(*m_NetworkPtr);
210
211 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
212
213 bool retVal = serializer->SaveSerializedToStream(file);
214
215 return retVal;
216 }
217
218 template <typename IParser>
219 bool CreateNetwork ()
220 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000221 return CreateNetwork (ParserType<IParser>());
222 }
223
224private:
225 armnn::INetworkPtr m_NetworkPtr;
226 std::string m_ModelPath;
227 std::vector<std::string> m_InputNames;
228 std::vector<armnn::TensorShape> m_InputShapes;
229 std::vector<std::string> m_OutputNames;
230 std::string m_OutputPath;
231 bool m_IsModelBinary;
232
233 template <typename IParser>
234 bool CreateNetwork (ParserType<IParser>)
235 {
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000236 // Create a network from a file on disk
237 auto parser(IParser::Create());
238
239 std::map<std::string, armnn::TensorShape> inputShapes;
240 if (!m_InputShapes.empty())
241 {
242 const size_t numInputShapes = m_InputShapes.size();
243 const size_t numInputBindings = m_InputNames.size();
244 if (numInputShapes < numInputBindings)
245 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100246 throw armnn::Exception(fmt::format(
247 "Not every input has its tensor shape specified: expected={0}, got={1}",
248 numInputBindings, numInputShapes));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000249 }
250
251 for (size_t i = 0; i < numInputShapes; i++)
252 {
253 inputShapes[m_InputNames[i]] = m_InputShapes[i];
254 }
255 }
256
257 {
258 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
259 m_NetworkPtr = (m_IsModelBinary ?
260 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
261 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
262 }
263
264 return m_NetworkPtr.get() != nullptr;
265 }
266
Sadik Armagan232cfc22019-03-13 18:33:10 +0000267#if defined(ARMNN_TF_LITE_PARSER)
268 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
269 {
270 // Create a network from a file on disk
271 auto parser(armnnTfLiteParser::ITfLiteParser::Create());
272
273 if (!m_InputShapes.empty())
274 {
275 const size_t numInputShapes = m_InputShapes.size();
276 const size_t numInputBindings = m_InputNames.size();
277 if (numInputShapes < numInputBindings)
278 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100279 throw armnn::Exception(fmt::format(
280 "Not every input has its tensor shape specified: expected={0}, got={1}",
281 numInputBindings, numInputShapes));
Sadik Armagan232cfc22019-03-13 18:33:10 +0000282 }
283 }
284
285 {
286 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
287 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
288 }
289
290 return m_NetworkPtr.get() != nullptr;
291 }
292#endif
293
294#if defined(ARMNN_ONNX_PARSER)
295 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
296 {
297 // Create a network from a file on disk
298 auto parser(armnnOnnxParser::IOnnxParser::Create());
299
300 if (!m_InputShapes.empty())
301 {
302 const size_t numInputShapes = m_InputShapes.size();
303 const size_t numInputBindings = m_InputNames.size();
304 if (numInputShapes < numInputBindings)
305 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100306 throw armnn::Exception(fmt::format(
307 "Not every input has its tensor shape specified: expected={0}, got={1}",
308 numInputBindings, numInputShapes));
Sadik Armagan232cfc22019-03-13 18:33:10 +0000309 }
310 }
311
312 {
313 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
314 m_NetworkPtr = (m_IsModelBinary ?
315 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
316 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
317 }
318
319 return m_NetworkPtr.get() != nullptr;
320 }
321#endif
322
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000323};
324
325} // anonymous namespace
326
Colm Donelanb524ca02020-10-06 15:15:33 +0100327int main(int argc, char* argv[])
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000328{
329
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100330#if (!defined(ARMNN_ONNX_PARSER) \
Sadik Armagan232cfc22019-03-13 18:33:10 +0000331 && !defined(ARMNN_TF_PARSER) \
332 && !defined(ARMNN_TF_LITE_PARSER))
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100333 ARMNN_LOG(fatal) << "Not built with any of the supported parsers Onnx, Tensorflow, or TfLite.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000334 return EXIT_FAILURE;
335#endif
336
337#if !defined(ARMNN_SERIALIZER)
Derek Lamberti08446972019-11-26 16:38:31 +0000338 ARMNN_LOG(fatal) << "Not built with Serializer support.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000339 return EXIT_FAILURE;
340#endif
341
342#ifdef NDEBUG
343 armnn::LogSeverity level = armnn::LogSeverity::Info;
344#else
345 armnn::LogSeverity level = armnn::LogSeverity::Debug;
346#endif
347
348 armnn::ConfigureLogging(true, true, level);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000349
350 std::string modelFormat;
351 std::string modelPath;
352
353 std::vector<std::string> inputNames;
354 std::vector<std::string> inputTensorShapeStrs;
355 std::vector<armnn::TensorShape> inputTensorShapes;
356
357 std::vector<std::string> outputNames;
358 std::string outputPath;
359
360 bool isModelBinary = true;
361
362 if (ParseCommandLineArgs(
363 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
364 != EXIT_SUCCESS)
365 {
366 return EXIT_FAILURE;
367 }
368
369 for (const std::string& shapeStr : inputTensorShapeStrs)
370 {
371 if (!shapeStr.empty())
372 {
373 std::stringstream ss(shapeStr);
374
375 try
376 {
377 armnn::TensorShape shape = ParseTensorShape(ss);
378 inputTensorShapes.push_back(shape);
379 }
380 catch (const armnn::InvalidArgumentException& e)
381 {
Derek Lamberti08446972019-11-26 16:38:31 +0000382 ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000383 return EXIT_FAILURE;
384 }
385 }
386 }
387
388 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
389
Derek Lambertic9e52792020-03-11 11:42:26 +0000390 try
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000391 {
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100392 if (modelFormat.find("onnx") != std::string::npos)
Derek Lambertic9e52792020-03-11 11:42:26 +0000393 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000394#if defined(ARMNN_ONNX_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000395 if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
396 {
397 ARMNN_LOG(fatal) << "Failed to load model from file";
398 return EXIT_FAILURE;
399 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000400#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000401 ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
402 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000403#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000404 }
Derek Lambertic9e52792020-03-11 11:42:26 +0000405 else if (modelFormat.find("tflite") != std::string::npos)
406 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000407#if defined(ARMNN_TF_LITE_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000408 if (!isModelBinary)
409 {
410 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
411 for tflite files";
412 return EXIT_FAILURE;
413 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000414
Derek Lambertic9e52792020-03-11 11:42:26 +0000415 if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
416 {
417 ARMNN_LOG(fatal) << "Failed to load model from file";
418 return EXIT_FAILURE;
419 }
420#else
421 ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
422 return EXIT_FAILURE;
423#endif
424 }
425 else
Sadik Armagan232cfc22019-03-13 18:33:10 +0000426 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000427 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000428 return EXIT_FAILURE;
429 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000430 }
Derek Lambertic9e52792020-03-11 11:42:26 +0000431 catch(armnn::Exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000432 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000433 ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000434 return EXIT_FAILURE;
435 }
436
437 if (!converter.Serialize())
438 {
Derek Lamberti08446972019-11-26 16:38:31 +0000439 ARMNN_LOG(fatal) << "Failed to serialize model";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000440 return EXIT_FAILURE;
441 }
442
443 return EXIT_SUCCESS;
444}