blob: fbec1449a827184ec77f70e6ded9e5ef45feda2f [file] [log] [blame]
Nattapat Chaimanowong4fbae332019-02-14 15:28:02 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <armnn/ArmNN.hpp>
6
7#include <armnnSerializer/ISerializer.hpp>
8#include <armnnTfParser/ITfParser.hpp>
9
10#include <Logging.hpp>
11#include <HeapProfiling.hpp>
12
13#include <boost/format.hpp>
14#include <boost/algorithm/string/split.hpp>
15#include <boost/algorithm/string/classification.hpp>
16#include <boost/program_options.hpp>
17
18#include <iostream>
19#include <fstream>
20
21namespace
22{
23
24namespace po = boost::program_options;
25
26armnn::TensorShape ParseTensorShape(std::istream& stream)
27{
28 std::vector<unsigned int> result;
29 std::string line;
30
31 while (std::getline(stream, line))
32 {
33 std::vector<std::string> tokens;
34 try
35 {
36 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
37 boost::split(tokens, line, boost::algorithm::is_any_of(","), boost::token_compress_on);
38 }
39 catch (const std::exception& e)
40 {
41 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
42 continue;
43 }
44 for (const std::string& token : tokens)
45 {
46 if (!token.empty())
47 {
48 try
49 {
50 result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
51 }
52 catch (const std::exception&)
53 {
54 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
55 }
56 }
57 }
58 }
59
60 return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
61}
62
63bool CheckOption(const po::variables_map& vm,
64 const char* option)
65{
66 if (option == nullptr)
67 {
68 return false;
69 }
70
71 // Check whether 'option' is provided.
72 return vm.find(option) != vm.end();
73}
74
75void CheckOptionDependency(const po::variables_map& vm,
76 const char* option,
77 const char* required)
78{
79 if (option == nullptr || required == nullptr)
80 {
81 throw po::error("Invalid option to check dependency for");
82 }
83
84 // Check that if 'option' is provided, 'required' is also provided.
85 if (CheckOption(vm, option) && !vm[option].defaulted())
86 {
87 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
88 {
89 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
90 }
91 }
92}
93
94void CheckOptionDependencies(const po::variables_map& vm)
95{
96 CheckOptionDependency(vm, "model-path", "model-format");
97 CheckOptionDependency(vm, "model-path", "input-name");
98 CheckOptionDependency(vm, "model-path", "output-name");
99 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
100}
101
102int ParseCommandLineArgs(int argc, const char* argv[],
103 std::string& modelFormat,
104 std::string& modelPath,
105 std::vector<std::string>& inputNames,
106 std::vector<std::string>& inputTensorShapeStrs,
107 std::vector<std::string>& outputNames,
108 std::string& outputPath, bool& isModelBinary)
109{
110 po::options_description desc("Options");
111
112 desc.add_options()
113 ("help", "Display usage information")
114 ("model-format,f", po::value(&modelFormat)->required(),"tensorflow-binary or tensorflow-text.")
115 ("model-path,m", po::value(&modelPath)->required(), "Path to model file")
116 ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
117 "Identifier of the input tensors in the network separated by whitespace")
118 ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
119 "The shape of the input tensor in the network as a flat array of integers separated by comma"
120 "Multiple shapes are separated by whitespace"
121 "This parameter is optional, depending on the network.")
122 ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
123 "Identifier of the output tensor in the network.")
124 ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
125
126 po::variables_map vm;
127 try
128 {
129 po::store(po::parse_command_line(argc, argv, desc), vm);
130
131 if (CheckOption(vm, "help") || argc <= 1)
132 {
133 std::cout << "Convert a neural network model from provided file to ArmNN format " << std::endl;
134 std::cout << std::endl;
135 std::cout << desc << std::endl;
136 return EXIT_SUCCESS;
137 }
138
139 po::notify(vm);
140 }
141 catch (const po::error& e)
142 {
143 std::cerr << e.what() << std::endl << std::endl;
144 std::cerr << desc << std::endl;
145 return EXIT_FAILURE;
146 }
147
148 try
149 {
150 CheckOptionDependencies(vm);
151 }
152 catch (const po::error& e)
153 {
154 std::cerr << e.what() << std::endl << std::endl;
155 std::cerr << desc << std::endl;
156 return EXIT_FAILURE;
157 }
158
159 if (modelFormat.find("bin") != std::string::npos)
160 {
161 isModelBinary = true;
162 }
163 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
164 {
165 isModelBinary = false;
166 }
167 else
168 {
169 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
170 return EXIT_FAILURE;
171 }
172
173 inputNames = vm["input-name"].as<std::vector<std::string>>();
174 inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
175 outputNames = vm["output-name"].as<std::vector<std::string>>();
176
177 return EXIT_SUCCESS;
178}
179
180class ArmnnConverter
181{
182public:
183 ArmnnConverter(const std::string& modelPath,
184 const std::vector<std::string>& inputNames,
185 const std::vector<armnn::TensorShape>& inputShapes,
186 const std::vector<std::string>& outputNames,
187 const std::string& outputPath,
188 bool isModelBinary)
189 : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
190 m_ModelPath(modelPath),
191 m_InputNames(inputNames),
192 m_InputShapes(inputShapes),
193 m_OutputNames(outputNames),
194 m_OutputPath(outputPath),
195 m_IsModelBinary(isModelBinary) {}
196
197 bool Serialize()
198 {
199 if (m_NetworkPtr.get() == nullptr)
200 {
201 return false;
202 }
203
204 auto serializer(armnnSerializer::ISerializer::Create());
205
206 serializer->Serialize(*m_NetworkPtr);
207
208 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
209
210 bool retVal = serializer->SaveSerializedToStream(file);
211
212 return retVal;
213 }
214
215 template <typename IParser>
216 bool CreateNetwork ()
217 {
218 // Create a network from a file on disk
219 auto parser(IParser::Create());
220
221 std::map<std::string, armnn::TensorShape> inputShapes;
222 if (!m_InputShapes.empty())
223 {
224 const size_t numInputShapes = m_InputShapes.size();
225 const size_t numInputBindings = m_InputNames.size();
226 if (numInputShapes < numInputBindings)
227 {
228 throw armnn::Exception(boost::str(boost::format(
229 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
230 % numInputBindings % numInputShapes));
231 }
232
233 for (size_t i = 0; i < numInputShapes; i++)
234 {
235 inputShapes[m_InputNames[i]] = m_InputShapes[i];
236 }
237 }
238
239 {
240 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
241 m_NetworkPtr = (m_IsModelBinary ?
242 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
243 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
244 }
245
246 return m_NetworkPtr.get() != nullptr;
247 }
248
249private:
250 armnn::INetworkPtr m_NetworkPtr;
251 std::string m_ModelPath;
252 std::vector<std::string> m_InputNames;
253 std::vector<armnn::TensorShape> m_InputShapes;
254 std::vector<std::string> m_OutputNames;
255 std::string m_OutputPath;
256 bool m_IsModelBinary;
257};
258
259} // anonymous namespace
260
261int main(int argc, const char* argv[])
262{
263
264#if !defined(ARMNN_TF_PARSER)
265 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
266 return EXIT_FAILURE;
267#endif
268
269#if !defined(ARMNN_SERIALIZER)
270 BOOST_LOG_TRIVIAL(fatal) << "Not built with Serializer support.";
271 return EXIT_FAILURE;
272#endif
273
274#ifdef NDEBUG
275 armnn::LogSeverity level = armnn::LogSeverity::Info;
276#else
277 armnn::LogSeverity level = armnn::LogSeverity::Debug;
278#endif
279
280 armnn::ConfigureLogging(true, true, level);
281 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
282
283 std::string modelFormat;
284 std::string modelPath;
285
286 std::vector<std::string> inputNames;
287 std::vector<std::string> inputTensorShapeStrs;
288 std::vector<armnn::TensorShape> inputTensorShapes;
289
290 std::vector<std::string> outputNames;
291 std::string outputPath;
292
293 bool isModelBinary = true;
294
295 if (ParseCommandLineArgs(
296 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
297 != EXIT_SUCCESS)
298 {
299 return EXIT_FAILURE;
300 }
301
302 for (const std::string& shapeStr : inputTensorShapeStrs)
303 {
304 if (!shapeStr.empty())
305 {
306 std::stringstream ss(shapeStr);
307
308 try
309 {
310 armnn::TensorShape shape = ParseTensorShape(ss);
311 inputTensorShapes.push_back(shape);
312 }
313 catch (const armnn::InvalidArgumentException& e)
314 {
315 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
316 return EXIT_FAILURE;
317 }
318 }
319 }
320
321 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
322
323 if (modelFormat.find("tensorflow") != std::string::npos)
324 {
325 if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
326 {
327 BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
328 return EXIT_FAILURE;
329 }
330 }
331 else
332 {
333 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat;
334 return EXIT_FAILURE;
335 }
336
337 if (!converter.Serialize())
338 {
339 BOOST_LOG_TRIVIAL(fatal) << "Failed to serialize model";
340 return EXIT_FAILURE;
341 }
342
343 return EXIT_SUCCESS;
344}