blob: f3beb81d8b0da224ef354c36e613578d12f093cf [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelan46dee402024-05-10 16:49:39 +01002// Copyright © 2017, 2023-2024 Arm Ltd and Contributors. All rights reserved.
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00003// SPDX-License-Identifier: MIT
4//
Matthew Benthamf48afc62020-01-15 17:55:08 +00005#include <armnn/Logging.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00006
Sadik Armagan232cfc22019-03-13 18:33:10 +00007#if defined(ARMNN_ONNX_PARSER)
8#include <armnnOnnxParser/IOnnxParser.hpp>
9#endif
10#if defined(ARMNN_SERIALIZER)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000011#include <armnnSerializer/ISerializer.hpp>
Sadik Armagan232cfc22019-03-13 18:33:10 +000012#endif
Sadik Armagan232cfc22019-03-13 18:33:10 +000013#if defined(ARMNN_TF_LITE_PARSER)
14#include <armnnTfLiteParser/ITfLiteParser.hpp>
15#endif
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000016
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000017#include <HeapProfiling.hpp>
Matthew Sloyan0663d662020-09-14 11:47:26 +010018#include <armnn/utility/NumericCast.hpp>
Colm Donelanb524ca02020-10-06 15:15:33 +010019#include <armnn/utility/StringUtils.hpp>
20
21/*
22 * Historically we use the ',' character to separate dimensions in a tensor shape. However, cxxopts will read this
23 * an an array of values which is fine until we have multiple tensors specified. This lumps the values of all shapes
24 * together in a single array and we cannot break it up again. We'll change the vector delimiter to a '.'. We do this
25 * as close as possible to the usage of cxxopts to avoid polluting other possible uses.
26 */
27#define CXXOPTS_VECTOR_DELIMITER '.'
28#include <cxxopts/cxxopts.hpp>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000029
Colm Donelan5b5c2222020-09-09 12:48:16 +010030#include <fmt/format.h>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000031
Les Bell10e6be42019-03-28 12:26:46 +000032#include <cstdlib>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000033#include <fstream>
Les Bell10e6be42019-03-28 12:26:46 +000034#include <iostream>
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000035
36namespace
37{
38
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000039armnn::TensorShape ParseTensorShape(std::istream& stream)
40{
41 std::vector<unsigned int> result;
42 std::string line;
43
44 while (std::getline(stream, line))
45 {
David Monahana8837bf2020-04-16 10:01:56 +010046 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000047 for (const std::string& token : tokens)
48 {
49 if (!token.empty())
50 {
51 try
52 {
Matthew Sloyan0663d662020-09-14 11:47:26 +010053 result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000054 }
55 catch (const std::exception&)
56 {
Derek Lamberti08446972019-11-26 16:38:31 +000057 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000058 }
59 }
60 }
61 }
62
Matthew Sloyan0663d662020-09-14 11:47:26 +010063 return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000064}
65
Colm Donelanb524ca02020-10-06 15:15:33 +010066int ParseCommandLineArgs(int argc, char* argv[],
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000067 std::string& modelFormat,
68 std::string& modelPath,
69 std::vector<std::string>& inputNames,
70 std::vector<std::string>& inputTensorShapeStrs,
71 std::vector<std::string>& outputNames,
72 std::string& outputPath, bool& isModelBinary)
73{
Colm Donelanb524ca02020-10-06 15:15:33 +010074 cxxopts::Options options("ArmNNConverter", "Convert a neural network model from provided file to ArmNN format.");
75 try
76 {
77 std::string modelFormatDescription("Format of the model file");
Sadik Armagan232cfc22019-03-13 18:33:10 +000078#if defined(ARMNN_ONNX_PARSER)
Colm Donelanb524ca02020-10-06 15:15:33 +010079 modelFormatDescription += ", onnx-binary, onnx-text";
Sadik Armagan232cfc22019-03-13 18:33:10 +000080#endif
Les Bell10e6be42019-03-28 12:26:46 +000081#if defined(ARMNN_TF_PARSER)
Colm Donelanb524ca02020-10-06 15:15:33 +010082 modelFormatDescription += ", tensorflow-binary, tensorflow-text";
Sadik Armagan232cfc22019-03-13 18:33:10 +000083#endif
84#if defined(ARMNN_TF_LITE_PARSER)
Colm Donelanb524ca02020-10-06 15:15:33 +010085 modelFormatDescription += ", tflite-binary";
Sadik Armagan232cfc22019-03-13 18:33:10 +000086#endif
Colm Donelanb524ca02020-10-06 15:15:33 +010087 modelFormatDescription += ".";
88 options.add_options()
89 ("help", "Display usage information")
90 ("f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
91 ("m,model-path", "Path to model file.", cxxopts::value<std::string>(modelPath))
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +000092
Colm Donelanb524ca02020-10-06 15:15:33 +010093 ("i,input-name", "Identifier of the input tensors in the network. "
94 "Each input must be specified separately.",
95 cxxopts::value<std::vector<std::string>>(inputNames))
96 ("s,input-tensor-shape",
97 "The shape of the input tensor in the network as a flat array of integers, "
98 "separated by comma. Each input shape must be specified separately after the input name. "
99 "This parameter is optional, depending on the network.",
100 cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000101
Colm Donelanb524ca02020-10-06 15:15:33 +0100102 ("o,output-name", "Identifier of the output tensor in the network.",
103 cxxopts::value<std::vector<std::string>>(outputNames))
104 ("p,output-path",
105 "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000106 }
Colm Donelanb524ca02020-10-06 15:15:33 +0100107 catch (const std::exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000108 {
Colm Donelanb524ca02020-10-06 15:15:33 +0100109 std::cerr << e.what() << std::endl << options.help() << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000110 return EXIT_FAILURE;
111 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000112 try
113 {
Colm Donelanb524ca02020-10-06 15:15:33 +0100114 cxxopts::ParseResult result = options.parse(argc, argv);
115 if (result.count("help"))
116 {
117 std::cerr << options.help() << std::endl;
118 return EXIT_SUCCESS;
119 }
120 // Check for mandatory single options.
121 std::string mandatorySingleParameters[] = { "model-format", "model-path", "output-name", "output-path" };
122 bool somethingsMissing = false;
123 for (auto param : mandatorySingleParameters)
124 {
125 if (result.count(param) != 1)
126 {
127 std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
128 somethingsMissing = true;
129 }
130 }
131 // Check at least one "input-name" option.
132 if (result.count("input-name") == 0)
133 {
134 std::cerr << "Parameter \'--" << "input-name" << "\' must be specified at least once." << std::endl;
135 somethingsMissing = true;
136 }
137 // If input-tensor-shape is specified then there must be a 1:1 match with input-name.
138 if (result.count("input-tensor-shape") > 0)
139 {
140 if (result.count("input-tensor-shape") != result.count("input-name"))
141 {
142 std::cerr << "When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters "
143 "must be specified." << std::endl;
144 somethingsMissing = true;
145 }
146 }
147
148 if (somethingsMissing)
149 {
150 std::cerr << options.help() << std::endl;
151 return EXIT_FAILURE;
152 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000153 }
Jim Flynn357add22023-04-10 23:26:40 +0100154 catch (const cxxopts::exceptions::exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000155 {
156 std::cerr << e.what() << std::endl << std::endl;
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000157 return EXIT_FAILURE;
158 }
159
160 if (modelFormat.find("bin") != std::string::npos)
161 {
162 isModelBinary = true;
163 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000164 else if (modelFormat.find("text") != std::string::npos)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000165 {
166 isModelBinary = false;
167 }
168 else
169 {
Derek Lamberti08446972019-11-26 16:38:31 +0000170 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000171 return EXIT_FAILURE;
172 }
173
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000174 return EXIT_SUCCESS;
175}
176
Sadik Armagan232cfc22019-03-13 18:33:10 +0000177template<typename T>
178struct ParserType
179{
180 typedef T parserType;
181};
182
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000183class ArmnnConverter
184{
185public:
186 ArmnnConverter(const std::string& modelPath,
187 const std::vector<std::string>& inputNames,
188 const std::vector<armnn::TensorShape>& inputShapes,
189 const std::vector<std::string>& outputNames,
190 const std::string& outputPath,
191 bool isModelBinary)
192 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
193 m_ModelPath(modelPath),
194 m_InputNames(inputNames),
195 m_InputShapes(inputShapes),
196 m_OutputNames(outputNames),
197 m_OutputPath(outputPath),
198 m_IsModelBinary(isModelBinary) {}
199
200 bool Serialize()
201 {
202 if (m_NetworkPtr.get() == nullptr)
203 {
204 return false;
205 }
206
207 auto serializer(armnnSerializer::ISerializer::Create());
208
209 serializer->Serialize(*m_NetworkPtr);
210
211 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
212
213 bool retVal = serializer->SaveSerializedToStream(file);
214
215 return retVal;
216 }
217
218 template <typename IParser>
219 bool CreateNetwork ()
220 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000221 return CreateNetwork (ParserType<IParser>());
222 }
223
224private:
225 armnn::INetworkPtr m_NetworkPtr;
226 std::string m_ModelPath;
227 std::vector<std::string> m_InputNames;
228 std::vector<armnn::TensorShape> m_InputShapes;
229 std::vector<std::string> m_OutputNames;
230 std::string m_OutputPath;
231 bool m_IsModelBinary;
232
233 template <typename IParser>
234 bool CreateNetwork (ParserType<IParser>)
235 {
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000236 // Create a network from a file on disk
237 auto parser(IParser::Create());
238
239 std::map<std::string, armnn::TensorShape> inputShapes;
240 if (!m_InputShapes.empty())
241 {
242 const size_t numInputShapes = m_InputShapes.size();
243 const size_t numInputBindings = m_InputNames.size();
244 if (numInputShapes < numInputBindings)
245 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100246 throw armnn::Exception(fmt::format(
247 "Not every input has its tensor shape specified: expected={0}, got={1}",
248 numInputBindings, numInputShapes));
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000249 }
250
251 for (size_t i = 0; i < numInputShapes; i++)
252 {
253 inputShapes[m_InputNames[i]] = m_InputShapes[i];
254 }
255 }
256
257 {
258 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
259 m_NetworkPtr = (m_IsModelBinary ?
260 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
261 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
262 }
263
264 return m_NetworkPtr.get() != nullptr;
265 }
266
Sadik Armagan232cfc22019-03-13 18:33:10 +0000267#if defined(ARMNN_TF_LITE_PARSER)
268 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
269 {
270 // Create a network from a file on disk
271 auto parser(armnnTfLiteParser::ITfLiteParser::Create());
272
273 if (!m_InputShapes.empty())
274 {
275 const size_t numInputShapes = m_InputShapes.size();
276 const size_t numInputBindings = m_InputNames.size();
277 if (numInputShapes < numInputBindings)
278 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100279 throw armnn::Exception(fmt::format(
280 "Not every input has its tensor shape specified: expected={0}, got={1}",
281 numInputBindings, numInputShapes));
Sadik Armagan232cfc22019-03-13 18:33:10 +0000282 }
283 }
284
285 {
286 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
287 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
288 }
289
290 return m_NetworkPtr.get() != nullptr;
291 }
292#endif
293
294#if defined(ARMNN_ONNX_PARSER)
Colm Donelan46dee402024-05-10 16:49:39 +0100295ARMNN_NO_DEPRECATE_WARN_BEGIN
Sadik Armagan232cfc22019-03-13 18:33:10 +0000296 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
297 {
298 // Create a network from a file on disk
299 auto parser(armnnOnnxParser::IOnnxParser::Create());
300
301 if (!m_InputShapes.empty())
302 {
303 const size_t numInputShapes = m_InputShapes.size();
304 const size_t numInputBindings = m_InputNames.size();
305 if (numInputShapes < numInputBindings)
306 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100307 throw armnn::Exception(fmt::format(
308 "Not every input has its tensor shape specified: expected={0}, got={1}",
309 numInputBindings, numInputShapes));
Sadik Armagan232cfc22019-03-13 18:33:10 +0000310 }
311 }
312
313 {
314 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
315 m_NetworkPtr = (m_IsModelBinary ?
316 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
317 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
318 }
319
320 return m_NetworkPtr.get() != nullptr;
321 }
Colm Donelan46dee402024-05-10 16:49:39 +0100322ARMNN_NO_DEPRECATE_WARN_END
Sadik Armagan232cfc22019-03-13 18:33:10 +0000323#endif
324
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000325};
326
327} // anonymous namespace
328
Colm Donelanb524ca02020-10-06 15:15:33 +0100329int main(int argc, char* argv[])
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000330{
331
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100332#if (!defined(ARMNN_ONNX_PARSER) \
Sadik Armagan232cfc22019-03-13 18:33:10 +0000333 && !defined(ARMNN_TF_PARSER) \
334 && !defined(ARMNN_TF_LITE_PARSER))
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100335 ARMNN_LOG(fatal) << "Not built with any of the supported parsers Onnx, Tensorflow, or TfLite.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000336 return EXIT_FAILURE;
337#endif
338
339#if !defined(ARMNN_SERIALIZER)
Derek Lamberti08446972019-11-26 16:38:31 +0000340 ARMNN_LOG(fatal) << "Not built with Serializer support.";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000341 return EXIT_FAILURE;
342#endif
343
344#ifdef NDEBUG
345 armnn::LogSeverity level = armnn::LogSeverity::Info;
346#else
347 armnn::LogSeverity level = armnn::LogSeverity::Debug;
348#endif
349
350 armnn::ConfigureLogging(true, true, level);
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000351
352 std::string modelFormat;
353 std::string modelPath;
354
355 std::vector<std::string> inputNames;
356 std::vector<std::string> inputTensorShapeStrs;
357 std::vector<armnn::TensorShape> inputTensorShapes;
358
359 std::vector<std::string> outputNames;
360 std::string outputPath;
361
362 bool isModelBinary = true;
363
364 if (ParseCommandLineArgs(
365 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
366 != EXIT_SUCCESS)
367 {
368 return EXIT_FAILURE;
369 }
370
371 for (const std::string& shapeStr : inputTensorShapeStrs)
372 {
373 if (!shapeStr.empty())
374 {
375 std::stringstream ss(shapeStr);
376
377 try
378 {
379 armnn::TensorShape shape = ParseTensorShape(ss);
380 inputTensorShapes.push_back(shape);
381 }
382 catch (const armnn::InvalidArgumentException& e)
383 {
Derek Lamberti08446972019-11-26 16:38:31 +0000384 ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000385 return EXIT_FAILURE;
386 }
387 }
388 }
389
390 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
391
Derek Lambertic9e52792020-03-11 11:42:26 +0000392 try
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000393 {
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100394 if (modelFormat.find("onnx") != std::string::npos)
Derek Lambertic9e52792020-03-11 11:42:26 +0000395 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000396#if defined(ARMNN_ONNX_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000397 if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
398 {
399 ARMNN_LOG(fatal) << "Failed to load model from file";
400 return EXIT_FAILURE;
401 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000402#else
Derek Lambertic9e52792020-03-11 11:42:26 +0000403 ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
404 return EXIT_FAILURE;
Sadik Armagan232cfc22019-03-13 18:33:10 +0000405#endif
Derek Lambertic9e52792020-03-11 11:42:26 +0000406 }
Derek Lambertic9e52792020-03-11 11:42:26 +0000407 else if (modelFormat.find("tflite") != std::string::npos)
408 {
Sadik Armagan232cfc22019-03-13 18:33:10 +0000409#if defined(ARMNN_TF_LITE_PARSER)
Derek Lambertic9e52792020-03-11 11:42:26 +0000410 if (!isModelBinary)
411 {
412 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
413 for tflite files";
414 return EXIT_FAILURE;
415 }
Sadik Armagan232cfc22019-03-13 18:33:10 +0000416
Derek Lambertic9e52792020-03-11 11:42:26 +0000417 if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
418 {
419 ARMNN_LOG(fatal) << "Failed to load model from file";
420 return EXIT_FAILURE;
421 }
422#else
423 ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
424 return EXIT_FAILURE;
425#endif
426 }
427 else
Sadik Armagan232cfc22019-03-13 18:33:10 +0000428 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000429 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
Sadik Armagan232cfc22019-03-13 18:33:10 +0000430 return EXIT_FAILURE;
431 }
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000432 }
Derek Lambertic9e52792020-03-11 11:42:26 +0000433 catch(armnn::Exception& e)
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000434 {
Derek Lambertic9e52792020-03-11 11:42:26 +0000435 ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000436 return EXIT_FAILURE;
437 }
438
439 if (!converter.Serialize())
440 {
Derek Lamberti08446972019-11-26 16:38:31 +0000441 ARMNN_LOG(fatal) << "Failed to serialize model";
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +0000442 return EXIT_FAILURE;
443 }
444
445 return EXIT_SUCCESS;
446}