blob: 0d7d7689e31762f8fa16103fd562067921d39c70 [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
SiCong Li39f46392019-06-21 12:00:04 +01006#include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7#include "../InferenceTest.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01008#include "ModelAccuracyChecker.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01009#include "armnnDeserializer/IDeserializer.hpp"
10
SiCong Li898a3242019-06-24 16:03:33 +010011#include <boost/algorithm/string.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010012#include <boost/filesystem.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010013#include <boost/program_options/variables_map.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010014#include <boost/range/iterator_range.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010015#include <map>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010016
17using namespace armnn::test;
18
SiCong Li898a3242019-06-24 16:03:33 +010019/** Load image names and ground-truth labels from the image directory and the ground truth label file
20 *
21 * @pre \p validationLabelPath exists and is valid regular file
22 * @pre \p imageDirectoryPath exists and is valid directory
23 * @pre labels in validation file correspond to images which are in lexicographical order with the image name
24 * @pre image index starts at 1
25 * @pre \p begIndex and \p endIndex are end-inclusive
26 *
27 * @param[in] validationLabelPath Path to validation label file
28 * @param[in] imageDirectoryPath Path to directory containing validation images
29 * @param[in] begIndex Begin index of images to be loaded. Inclusive
30 * @param[in] endIndex End index of images to be loaded. Inclusive
31 * @param[in] blacklistPath Path to blacklist file
32 * @return A map mapping image file names to their corresponding ground-truth labels
33 */
34map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
35 const string& imageDirectoryPath,
36 size_t begIndex = 0,
37 size_t endIndex = 0,
38 const string& blacklistPath = "");
39
40/** Load model output labels from file
41 *
42 * @pre \p modelOutputLabelsPath exists and is a regular file
43 *
44 * @param[in] modelOutputLabelsPath path to model output labels file
45 * @return A vector of labels, which in turn is described by a list of category names
46 */
47std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010048
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010049int main(int argc, char* argv[])
50{
51 try
52 {
53 using namespace boost::filesystem;
54 armnn::LogSeverity level = armnn::LogSeverity::Debug;
55 armnn::ConfigureLogging(true, true, level);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010056
57 // Set-up program Options
58 namespace po = boost::program_options;
59
60 std::vector<armnn::BackendId> computeDevice;
61 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
62 std::string modelPath;
SiCong Li39f46392019-06-21 12:00:04 +010063 std::string modelFormat;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010064 std::string dataDir;
65 std::string inputName;
SiCong Li39f46392019-06-21 12:00:04 +010066 std::string inputLayout;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010067 std::string outputName;
SiCong Li898a3242019-06-24 16:03:33 +010068 std::string modelOutputLabelsPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010069 std::string validationLabelPath;
SiCong Li898a3242019-06-24 16:03:33 +010070 std::string validationRange;
71 std::string blacklistPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010072
73 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
74 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
75
76 po::options_description desc("Options");
77 try
78 {
79 // Adds generic options needed to run Accuracy Tool.
80 desc.add_options()
Conor Kennedy30562022019-05-13 14:48:58 +010081 ("help,h", "Display help messages")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010082 ("model-path,m", po::value<std::string>(&modelPath)->required(), "Path to armnn format model file")
SiCong Li39f46392019-06-21 12:00:04 +010083 ("model-format,f", po::value<std::string>(&modelFormat)->required(),
84 "The model format. Supported values: caffe, tensorflow, tflite")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010085 ("input-name,i", po::value<std::string>(&inputName)->required(),
86 "Identifier of the input tensors in the network separated by comma.")
87 ("output-name,o", po::value<std::string>(&outputName)->required(),
88 "Identifier of the output tensors in the network separated by comma.")
SiCong Li39f46392019-06-21 12:00:04 +010089 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
90 "Path to directory containing the ImageNet test data")
SiCong Li898a3242019-06-24 16:03:33 +010091 ("model-output-labels,p", po::value<std::string>(&modelOutputLabelsPath)->required(),
92 "Path to model output labels file.")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010093 ("validation-labels-path,v", po::value<std::string>(&validationLabelPath)->required(),
SiCong Li39f46392019-06-21 12:00:04 +010094 "Path to ImageNet Validation Label file")
95 ("data-layout,l", po::value<std::string>(&inputLayout)->default_value("NHWC"),
SiCong Li23700bb2019-07-25 14:54:39 +010096 "Data layout. Supported value: NHWC, NCHW. Default: NHWC")
SiCong Li39f46392019-06-21 12:00:04 +010097 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
SiCong Li898a3242019-06-24 16:03:33 +010098 backendsMessage.c_str())
99 ("validation-range,r", po::value<std::string>(&validationRange)->default_value("1:0"),
100 "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
101 "The index starts at 1 and the range is inclusive."
102 "By default the evaluation will be performed on all images.")
103 ("blacklist-path,b", po::value<std::string>(&blacklistPath)->default_value(""),
104 "Path to a blacklist file where each line denotes the index of an image to be "
105 "excluded from evaluation.");
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100106 }
107 catch (const std::exception& e)
108 {
109 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
110 // and that desc.add_options() can throw boost::io::too_few_args.
111 // They really won't in any of these cases.
112 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
113 std::cerr << "Fatal internal error: " << e.what() << std::endl;
114 return 1;
115 }
116
117 po::variables_map vm;
118 try
119 {
120 po::store(po::parse_command_line(argc, argv, desc), vm);
121
122 if (vm.count("help"))
123 {
124 std::cout << desc << std::endl;
125 return 1;
126 }
127 po::notify(vm);
128 }
129 catch (po::error& e)
130 {
131 std::cerr << e.what() << std::endl << std::endl;
132 std::cerr << desc << std::endl;
133 return 1;
134 }
135
136 // Check if the requested backend are all valid
137 std::string invalidBackends;
138 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
139 {
Derek Lamberti08446972019-11-26 16:38:31 +0000140 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
141 << invalidBackends;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100142 return EXIT_FAILURE;
143 }
144 armnn::Status status;
145
146 // Create runtime
147 armnn::IRuntime::CreationOptions options;
148 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
149 std::ifstream file(modelPath);
150
151 // Create Parser
152 using IParser = armnnDeserializer::IDeserializer;
153 auto armnnparser(IParser::Create());
154
155 // Create a network
156 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
157
158 // Optimizes the network.
159 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
160 try
161 {
162 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
163 }
164 catch (armnn::Exception& e)
165 {
166 std::stringstream message;
167 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
Derek Lamberti08446972019-11-26 16:38:31 +0000168 ARMNN_LOG(fatal) << message.str();
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100169 return 1;
170 }
171
172 // Loads the network into the runtime.
173 armnn::NetworkId networkId;
174 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
175 if (status == armnn::Status::Failure)
176 {
Derek Lamberti08446972019-11-26 16:38:31 +0000177 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100178 return 1;
179 }
180
181 // Set up Network
182 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
183
184 const armnnDeserializer::BindingPointInfo&
185 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, inputName);
186
187 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
188 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
189 std::vector<BindingPointInfo> inputBindings = { m_InputBindingInfo };
190
191 const armnnDeserializer::BindingPointInfo&
192 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, outputName);
193
194 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
195 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
196 std::vector<BindingPointInfo> outputBindings = { m_OutputBindingInfo };
197
SiCong Li898a3242019-06-24 16:03:33 +0100198 // Load model output labels
199 if (modelOutputLabelsPath.empty() || !boost::filesystem::exists(modelOutputLabelsPath) ||
200 !boost::filesystem::is_regular_file(modelOutputLabelsPath))
201 {
Derek Lamberti08446972019-11-26 16:38:31 +0000202 ARMNN_LOG(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
SiCong Li898a3242019-06-24 16:03:33 +0100203 }
204 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
205 LoadModelOutputLabels(modelOutputLabelsPath);
206
207 // Parse begin and end image indices
208 std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
209 size_t imageBegIndex;
210 size_t imageEndIndex;
211 if (imageIndexStrs.size() != 2)
212 {
Derek Lamberti08446972019-11-26 16:38:31 +0000213 ARMNN_LOG(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
SiCong Li898a3242019-06-24 16:03:33 +0100214 return 1;
215 }
216 try
217 {
218 imageBegIndex = std::stoul(imageIndexStrs[0]);
219 imageEndIndex = std::stoul(imageIndexStrs[1]);
220 }
221 catch (const std::exception& e)
222 {
Derek Lamberti08446972019-11-26 16:38:31 +0000223 ARMNN_LOG(fatal) << "Invalid validation range specification: " << validationRange;
SiCong Li898a3242019-06-24 16:03:33 +0100224 return 1;
225 }
226
227 // Validate blacklist file if it's specified
228 if (!blacklistPath.empty() &&
229 !(boost::filesystem::exists(blacklistPath) && boost::filesystem::is_regular_file(blacklistPath)))
230 {
Derek Lamberti08446972019-11-26 16:38:31 +0000231 ARMNN_LOG(fatal) << "Invalid path to blacklist file at " << blacklistPath;
SiCong Li898a3242019-06-24 16:03:33 +0100232 return 1;
233 }
234
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100235 path pathToDataDir(dataDir);
SiCong Li898a3242019-06-24 16:03:33 +0100236 const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
237 validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, blacklistPath);
238 armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100239 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
240
SiCong Li39f46392019-06-21 12:00:04 +0100241 if (ValidateDirectory(dataDir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100242 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100243 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
SiCong Li39f46392019-06-21 12:00:04 +0100244 params.m_ModelPath = modelPath;
245 params.m_IsModelBinary = true;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100246 params.m_ComputeDevices = computeDevice;
247 params.m_InputBindings.push_back(inputName);
248 params.m_OutputBindings.push_back(outputName);
249
250 using TParser = armnnDeserializer::IDeserializer;
251 InferenceModel<TParser, float> model(params, false);
SiCong Li39f46392019-06-21 12:00:04 +0100252 // Get input tensor information
253 const armnn::TensorInfo& inputTensorInfo = model.GetInputBindingInfo().second;
254 const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
255 const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
256 armnn::DataLayout inputTensorDataLayout;
257 if (inputLayout == "NCHW")
258 {
259 inputTensorDataLayout = armnn::DataLayout::NCHW;
260 }
261 else if (inputLayout == "NHWC")
262 {
263 inputTensorDataLayout = armnn::DataLayout::NHWC;
264 }
265 else
266 {
Derek Lamberti08446972019-11-26 16:38:31 +0000267 ARMNN_LOG(fatal) << "Invalid Data layout: " << inputLayout;
SiCong Li39f46392019-06-21 12:00:04 +0100268 return 1;
269 }
270 const unsigned int inputTensorWidth =
271 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
272 const unsigned int inputTensorHeight =
273 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100274 // Get output tensor info
275 const unsigned int outputNumElements = model.GetOutputSize();
SiCong Li898a3242019-06-24 16:03:33 +0100276 // Check output tensor shape is valid
277 if (modelOutputLabels.size() != outputNumElements)
278 {
Derek Lamberti08446972019-11-26 16:38:31 +0000279 ARMNN_LOG(fatal) << "Number of output elements: " << outputNumElements
SiCong Li898a3242019-06-24 16:03:33 +0100280 << " , mismatches the number of output labels: " << modelOutputLabels.size();
281 return 1;
282 }
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100283
SiCong Li39f46392019-06-21 12:00:04 +0100284 const unsigned int batchSize = 1;
285 // Get normalisation parameters
286 SupportedFrontend modelFrontend;
287 if (modelFormat == "caffe")
288 {
289 modelFrontend = SupportedFrontend::Caffe;
290 }
291 else if (modelFormat == "tensorflow")
292 {
293 modelFrontend = SupportedFrontend::TensorFlow;
294 }
295 else if (modelFormat == "tflite")
296 {
297 modelFrontend = SupportedFrontend::TFLite;
298 }
299 else
300 {
Derek Lamberti08446972019-11-26 16:38:31 +0000301 ARMNN_LOG(fatal) << "Unsupported frontend: " << modelFormat;
SiCong Li39f46392019-06-21 12:00:04 +0100302 return 1;
303 }
304 const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
SiCong Li898a3242019-06-24 16:03:33 +0100305 for (const auto& imageEntry : imageNameToLabel)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100306 {
SiCong Li898a3242019-06-24 16:03:33 +0100307 const std::string imageName = imageEntry.first;
308 std::cout << "Processing image: " << imageName << "\n";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100309
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100310 vector<TContainer> inputDataContainers;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100311 vector<TContainer> outputDataContainers;
312
SiCong Li898a3242019-06-24 16:03:33 +0100313 auto imagePath = pathToDataDir / boost::filesystem::path(imageName);
SiCong Li39f46392019-06-21 12:00:04 +0100314 switch (inputTensorDataType)
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100315 {
SiCong Li39f46392019-06-21 12:00:04 +0100316 case armnn::DataType::Signed32:
317 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100318 PrepareImageTensor<int>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100319 inputTensorWidth, inputTensorHeight,
320 normParams,
321 batchSize,
322 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100323 outputDataContainers = { vector<int>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100324 break;
325 case armnn::DataType::QuantisedAsymm8:
326 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100327 PrepareImageTensor<uint8_t>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100328 inputTensorWidth, inputTensorHeight,
329 normParams,
330 batchSize,
331 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100332 outputDataContainers = { vector<uint8_t>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100333 break;
334 case armnn::DataType::Float32:
335 default:
336 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100337 PrepareImageTensor<float>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100338 inputTensorWidth, inputTensorHeight,
339 normParams,
340 batchSize,
341 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100342 outputDataContainers = { vector<float>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100343 break;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100344 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100345
346 status = runtime->EnqueueWorkload(networkId,
347 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
348 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
349
350 if (status == armnn::Status::Failure)
351 {
Derek Lamberti08446972019-11-26 16:38:31 +0000352 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100353 }
354
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100355 checker.AddImageResult<TContainer>(imageName, outputDataContainers);
356 }
357 }
358 else
359 {
360 return 1;
361 }
362
363 for(unsigned int i = 1; i <= 5; ++i)
364 {
365 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
366 }
367
Derek Lamberti08446972019-11-26 16:38:31 +0000368 ARMNN_LOG(info) << "Accuracy Tool ran successfully!";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100369 return 0;
370 }
371 catch (armnn::Exception const & e)
372 {
373 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
374 // exception of type std::length_error.
375 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
376 std::cerr << "Armnn Error: " << e.what() << std::endl;
377 return 1;
378 }
379 catch (const std::exception & e)
380 {
381 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
382 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
383 "Accuracy Tool: " << e.what() << std::endl;
384 return 1;
385 }
386}
387
SiCong Li898a3242019-06-24 16:03:33 +0100388map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
389 const string& imageDirectoryPath,
390 size_t begIndex,
391 size_t endIndex,
392 const string& blacklistPath)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100393{
SiCong Li898a3242019-06-24 16:03:33 +0100394 // Populate imageFilenames with names of all .JPEG, .PNG images
395 std::vector<std::string> imageFilenames;
396 for (const auto& imageEntry :
397 boost::make_iterator_range(boost::filesystem::directory_iterator(boost::filesystem::path(imageDirectoryPath))))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100398 {
SiCong Li898a3242019-06-24 16:03:33 +0100399 boost::filesystem::path imagePath = imageEntry.path();
400 std::string imageExtension = boost::to_upper_copy<std::string>(imagePath.extension().string());
401 if (boost::filesystem::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100402 {
SiCong Li898a3242019-06-24 16:03:33 +0100403 imageFilenames.push_back(imagePath.filename().string());
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100404 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100405 }
SiCong Li898a3242019-06-24 16:03:33 +0100406 if (imageFilenames.empty())
407 {
408 throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
409 }
410
411 // Sort the image filenames lexicographically
412 std::sort(imageFilenames.begin(), imageFilenames.end());
413
414 std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
415
416 // Get default end index
417 if (begIndex < 1 || endIndex > imageFilenames.size())
418 {
419 throw armnn::Exception("Invalid image index range");
420 }
421 endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
422 if (begIndex > endIndex)
423 {
424 throw armnn::Exception("Invalid image index range");
425 }
426
427 // Load blacklist if there is one
428 std::vector<unsigned int> blacklist;
429 if (!blacklistPath.empty())
430 {
431 std::ifstream blacklistFile(blacklistPath);
432 unsigned int index;
433 while (blacklistFile >> index)
434 {
435 blacklist.push_back(index);
436 }
437 }
438
439 // Load ground truth labels and pair them with corresponding image names
440 std::string classification;
441 map<std::string, std::string> imageNameToLabel;
442 ifstream infile(validationLabelPath);
443 size_t imageIndex = begIndex;
444 size_t blacklistIndexCount = 0;
445 while (std::getline(infile, classification))
446 {
447 if (imageIndex > endIndex)
448 {
449 break;
450 }
451 // If current imageIndex is included in blacklist, skip the current image
452 if (blacklistIndexCount < blacklist.size() && imageIndex == blacklist[blacklistIndexCount])
453 {
454 ++imageIndex;
455 ++blacklistIndexCount;
456 continue;
457 }
458 imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
459 ++imageIndex;
460 }
461 std::cout << blacklistIndexCount << " images blacklisted" << std::endl;
462 std::cout << imageIndex - begIndex - blacklistIndexCount << " images to be loaded" << std::endl;
463 return imageNameToLabel;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100464}
SiCong Li898a3242019-06-24 16:03:33 +0100465
466std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
467{
468 std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
469 ifstream modelOutputLablesFile(modelOutputLabelsPath);
470 std::string line;
471 while (std::getline(modelOutputLablesFile, line))
472 {
473 armnnUtils::LabelCategoryNames tokens = armnnUtils::SplitBy(line, ":");
474 armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
475 std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
476 [](const std::string& category) { return armnnUtils::Strip(category); });
477 modelOutputLabels.push_back(predictionCategoryNames);
478 }
479 return modelOutputLabels;
480}