blob: 1f537745b486956171f8fd8f46177062383b7e69 [file] [log] [blame]
Sadik Armagan8271f812019-04-19 09:55:06 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "../InferenceTestImage.hpp"
7
8#include <boost/filesystem.hpp>
9#include <boost/filesystem/operations.hpp>
10#include <boost/filesystem/path.hpp>
11#include <boost/log/trivial.hpp>
12#include <boost/program_options.hpp>
13
14#include <algorithm>
15#include <fstream>
16#include <iostream>
17#include <string>
18
19namespace
20{
21
22// parses the command line to extract
23// * the input image file -i the input image file path (must exist)
24// * the layout -l the data layout output generated with (optional - default value is NHWC)
25// * the output file -o the output raw tensor file path (must not already exist)
26class CommandLineProcessor
27{
28public:
29 bool ValidateInputFile(const std::string& inputFileName)
30 {
31 if (inputFileName.empty())
32 {
33 std::cerr << "No input file name specified" << std::endl;
34 return false;
35 }
36
37 if (!boost::filesystem::exists(inputFileName))
38 {
39 std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl;
40 return false;
41 }
42
43 if (boost::filesystem::is_directory(inputFileName))
44 {
45 std::cerr << "Input file [" << inputFileName << "] is a directory" << std::endl;
46 return false;
47 }
48
49 return true;
50 }
51
52 bool ValidateLayout(const std::string& layout)
53 {
54 if (layout.empty())
55 {
56 std::cerr << "No layout specified" << std::endl;
57 return false;
58 }
59
60 std::vector<std::string> supportedLayouts = {
61 "NHWC",
62 "NCHW"
63 };
64
65 auto iterator = std::find(supportedLayouts.begin(), supportedLayouts.end(), layout);
66 if (iterator == supportedLayouts.end())
67 {
68 std::cerr << "Layout [" << layout << "] is not supported" << std::endl;
69 return false;
70 }
71
72 return true;
73 }
74
75 bool ValidateOutputFile(std::string& outputFileName)
76 {
77 if (outputFileName.empty())
78 {
79 std::cerr << "No output file name specified" << std::endl;
80 return false;
81 }
82
83 if (boost::filesystem::exists(outputFileName))
84 {
85 std::cerr << "Output file [" << outputFileName << "] already exists" << std::endl;
86 return false;
87 }
88
89 if (boost::filesystem::is_directory(outputFileName))
90 {
91 std::cerr << "Output file [" << outputFileName << "] is a directory" << std::endl;
92 return false;
93 }
94
95 boost::filesystem::path outputPath(outputFileName);
96 if (!boost::filesystem::exists(outputPath.parent_path()))
97 {
98 std::cerr << "Output directory [" << outputPath.parent_path().c_str() << "] does not exist" << std::endl;
99 return false;
100 }
101
102 return true;
103 }
104
105 bool ProcessCommandLine(int argc, char* argv[])
106 {
107 namespace po = boost::program_options;
108
109 po::options_description desc("Options");
110 try
111 {
112 desc.add_options()
113 ("help,h", "Display help messages")
114 ("infile,i", po::value<std::string>(&m_InputFileName)->required(),
115 "Input image file to generate tensor from")
116 ("layout,l", po::value<std::string>(&m_Layout)->default_value("NHWC"),
117 "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC")
118 ("outfile,o", po::value<std::string>(&m_OutputFileName)->required(),
119 "Output raw tensor file path");
120 }
121 catch (const std::exception& e)
122 {
123 std::cerr << "Fatal internal error: [" << e.what() << "]" << std::endl;
124 return false;
125 }
126
127 po::variables_map vm;
128
129 try
130 {
131 po::store(po::parse_command_line(argc, argv, desc), vm);
132
133 if (vm.count("help"))
134 {
135 std::cout << desc << std::endl;
136 return false;
137 }
138
139 po::notify(vm);
140 }
141 catch (const po::error& e)
142 {
143 std::cerr << e.what() << std::endl << std::endl;
144 std::cerr << desc << std::endl;
145 return false;
146 }
147
148 if (!ValidateInputFile(m_InputFileName))
149 {
150 return false;
151 }
152
153 if (!ValidateLayout(m_Layout))
154 {
155 return false;
156 }
157
158 if (!ValidateOutputFile(m_OutputFileName))
159 {
160 return false;
161 }
162
163 return true;
164 }
165
166 std::string GetInputFileName() {return m_InputFileName;}
167 std::string GetLayout() {return m_Layout;}
168 std::string GetOutputFileName() {return m_OutputFileName;}
169
170private:
171 std::string m_InputFileName;
172 std::string m_Layout;
173 std::string m_OutputFileName;
174};
175
176} // namespace anonymous
177
178int main(int argc, char* argv[])
179{
180 CommandLineProcessor cmdline;
181 if (!cmdline.ProcessCommandLine(argc, argv))
182 {
183 return -1;
184 }
185
186 const std::string imagePath(cmdline.GetInputFileName());
187 const std::string outputPath(cmdline.GetOutputFileName());
188
189 // generate image tensor
190 std::vector<float> imageData;
191 try
192 {
193 InferenceTestImage testImage(imagePath.c_str());
194 imageData = cmdline.GetLayout() == "NHWC"
195 ? GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, testImage)
196 : GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout::Rgb, testImage);
197 }
198 catch (const InferenceTestImageException& e)
199 {
200 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image file " << imagePath << " with error: " << e.what();
201 return -1;
202 }
203
204 std::ofstream imageTensorFile;
205 imageTensorFile.open(outputPath, std::ofstream::out);
206 if (imageTensorFile.is_open())
207 {
208 std::copy(imageData.begin(), imageData.end(), std::ostream_iterator<float>(imageTensorFile, " "));
209 if (!imageTensorFile)
210 {
211 BOOST_LOG_TRIVIAL(fatal) << "Failed to write to output file" << outputPath;
212 imageTensorFile.close();
213 return -1;
214 }
215 imageTensorFile.close();
216 }
217 else
218 {
219 BOOST_LOG_TRIVIAL(fatal) << "Failed to open output file" << outputPath;
220 return -1;
221 }
222
223 return 0;
224}