blob: 5e9e6532cf82043fc333be1f439183ff27412a94 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "armnn/ArmNN.hpp"
6#if defined(ARMNN_CAFFE_PARSER)
7#include "armnnCaffeParser/ICaffeParser.hpp"
8#endif
9#include "Logging.hpp"
10#include "../InferenceTest.hpp"
11
12#include <boost/program_options.hpp>
13#include <boost/algorithm/string/split.hpp>
14#include <boost/algorithm/string/classification.hpp>
15
16#include <iostream>
17#include <fstream>
18
19namespace
20{
21
22template<typename T, typename TParseElementFunc>
23std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
24{
25 std::vector<T> result;
26 // Process line-by-line
27 std::string line;
28 while (std::getline(stream, line))
29 {
30 std::vector<std::string> tokens;
31 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
32 for (const std::string& token : tokens)
33 {
34 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
35 {
36 try
37 {
38 result.push_back(parseElementFunc(token));
39 }
40 catch (const std::exception&)
41 {
42 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
43 }
44 }
45 }
46 }
47
48 return result;
49}
50
51}
52
53template<typename T>
54std::vector<T> ParseArray(std::istream& stream);
55
56template<>
57std::vector<float> ParseArray(std::istream& stream)
58{
59 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
60}
61
62template<>
63std::vector<unsigned int> ParseArray(std::istream& stream)
64{
65 return ParseArrayImpl<unsigned int>(stream,
66 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
67}
68
69void PrintArray(const std::vector<float>& v)
70{
71 for (size_t i = 0; i < v.size(); i++)
72 {
73 printf("%f ", v[i]);
74 }
75 printf("\n");
76}
77
78template<typename TParser, typename TDataType>
79int MainImpl(const char* modelPath, bool isModelBinary, armnn::Compute computeDevice,
80 const char* inputName, const armnn::TensorShape* inputTensorShape, const char* inputTensorDataFilePath,
81 const char* outputName)
82{
83 // Load input tensor
84 std::vector<TDataType> input;
85 {
86 std::ifstream inputTensorFile(inputTensorDataFilePath);
87 if (!inputTensorFile.good())
88 {
89 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
90 return 1;
91 }
92 input = ParseArray<TDataType>(inputTensorFile);
93 }
94
95 try
96 {
97 // Create an InferenceModel, which will parse the model and load it into an IRuntime
98 typename InferenceModel<TParser, TDataType>::Params params;
99 params.m_ModelPath = modelPath;
100 params.m_IsModelBinary = isModelBinary;
101 params.m_ComputeDevice = computeDevice;
102 params.m_InputBinding = inputName;
103 params.m_InputTensorShape = inputTensorShape;
104 params.m_OutputBinding = outputName;
105 InferenceModel<TParser, TDataType> model(params);
106
107 // Execute the model
108 std::vector<TDataType> output(model.GetOutputSize());
109 model.Run(input, output);
110
111 // Print the output tensor
112 PrintArray(output);
113 }
114 catch (armnn::Exception const& e)
115 {
116 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
117 return 1;
118 }
119
120 return 0;
121}
122
123int main(int argc, char* argv[])
124{
125 // Configure logging for both the ARMNN library and this test program
126#ifdef NDEBUG
127 armnn::LogSeverity level = armnn::LogSeverity::Info;
128#else
129 armnn::LogSeverity level = armnn::LogSeverity::Debug;
130#endif
131 armnn::ConfigureLogging(true, true, level);
132 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
133
134 // Configure boost::program_options for command-line parsing
135 namespace po = boost::program_options;
136
137 std::string modelFormat;
138 std::string modelPath;
139 std::string inputName;
140 std::string inputTensorShapeStr;
141 std::string inputTensorDataFilePath;
142 std::string outputName;
143 armnn::Compute computeDevice;
144
145 po::options_description desc("Options");
146 try
147 {
148 desc.add_options()
149 ("help", "Display usage information")
150 ("model-format,f", po::value(&modelFormat)->required(),
151 "caffe-binary, caffe-text, tensorflow-binary or tensorflow-text.")
152 ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt")
153 ("compute,c", po::value<armnn::Compute>(&computeDevice)->required(),
154 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
155 ("input-name,i", po::value(&inputName)->required(), "Identifier of the input tensor in the network.")
156 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
157 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
158 "This parameter is optional, depending on the network.")
159 ("input-tensor-data,d", po::value(&inputTensorDataFilePath)->required(),
160 "Path to a file containing the input data as a flat array separated by whitespace.")
161 ("output-name,o", po::value(&outputName)->required(), "Identifier of the output tensor in the network.");
162 }
163 catch (const std::exception& e)
164 {
165 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
166 // and that desc.add_options() can throw boost::io::too_few_args.
167 // They really won't in any of these cases.
168 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
169 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
170 return 1;
171 }
172
173 // Parse the command-line
174 po::variables_map vm;
175 try
176 {
177 po::store(po::parse_command_line(argc, argv, desc), vm);
178
179 if (vm.count("help") || argc <= 1)
180 {
181 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
182 std::cout << "Prints the resulting output tensor." << std::endl;
183 std::cout << std::endl;
184 std::cout << desc << std::endl;
185 return 1;
186 }
187
188 po::notify(vm);
189 }
190 catch (po::error& e)
191 {
192 std::cerr << e.what() << std::endl << std::endl;
193 std::cerr << desc << std::endl;
194 return 1;
195 }
196
197 // Parse model binary flag from the model-format string we got from the command-line
198 bool isModelBinary;
199 if (modelFormat.find("bin") != std::string::npos)
200 {
201 isModelBinary = true;
202 }
203 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
204 {
205 isModelBinary = false;
206 }
207 else
208 {
209 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
210 return 1;
211 }
212
213 // Parse input tensor shape from the string we got from the command-line.
214 std::unique_ptr<armnn::TensorShape> inputTensorShape;
215 if (!inputTensorShapeStr.empty())
216 {
217 std::stringstream ss(inputTensorShapeStr);
218 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
219 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
220 }
221
222 // Forward to implementation based on the parser type
223 if (modelFormat.find("caffe") != std::string::npos)
224 {
225#if defined(ARMNN_CAFFE_PARSER)
226 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
227 inputName.c_str(), inputTensorShape.get(), inputTensorDataFilePath.c_str(), outputName.c_str());
228#else
229 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
230 return 1;
231#endif
232 }
233 else if (modelFormat.find("tensorflow") != std::string::npos)
234 {
235 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
236 return 1;
237 }
238 else
239 {
240 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
241 "'. Please include 'caffe' or 'tensorflow'";
242 return 1;
243 }
244}