blob: edc7e1cc3335e176206b0c57bfa50de39c989f4a [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
SiCong Li39f46392019-06-21 12:00:04 +01006#include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7#include "../InferenceTest.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01008#include "ModelAccuracyChecker.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01009#include "armnnDeserializer/IDeserializer.hpp"
Francis Murtagh532a29d2020-06-29 11:50:01 +010010#include <Filesystem.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010011
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010012#include <boost/program_options/variables_map.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010013#include <boost/range/iterator_range.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010014#include <map>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010015
16using namespace armnn::test;
17
SiCong Li898a3242019-06-24 16:03:33 +010018/** Load image names and ground-truth labels from the image directory and the ground truth label file
19 *
20 * @pre \p validationLabelPath exists and is valid regular file
21 * @pre \p imageDirectoryPath exists and is valid directory
22 * @pre labels in validation file correspond to images which are in lexicographical order with the image name
23 * @pre image index starts at 1
24 * @pre \p begIndex and \p endIndex are end-inclusive
25 *
26 * @param[in] validationLabelPath Path to validation label file
27 * @param[in] imageDirectoryPath Path to directory containing validation images
28 * @param[in] begIndex Begin index of images to be loaded. Inclusive
29 * @param[in] endIndex End index of images to be loaded. Inclusive
30 * @param[in] blacklistPath Path to blacklist file
31 * @return A map mapping image file names to their corresponding ground-truth labels
32 */
33map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
34 const string& imageDirectoryPath,
35 size_t begIndex = 0,
36 size_t endIndex = 0,
37 const string& blacklistPath = "");
38
39/** Load model output labels from file
40 *
41 * @pre \p modelOutputLabelsPath exists and is a regular file
42 *
43 * @param[in] modelOutputLabelsPath path to model output labels file
44 * @return A vector of labels, which in turn is described by a list of category names
45 */
46std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010047
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010048int main(int argc, char* argv[])
49{
50 try
51 {
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010052 armnn::LogSeverity level = armnn::LogSeverity::Debug;
53 armnn::ConfigureLogging(true, true, level);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010054
55 // Set-up program Options
56 namespace po = boost::program_options;
57
58 std::vector<armnn::BackendId> computeDevice;
59 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
60 std::string modelPath;
SiCong Li39f46392019-06-21 12:00:04 +010061 std::string modelFormat;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010062 std::string dataDir;
63 std::string inputName;
SiCong Li39f46392019-06-21 12:00:04 +010064 std::string inputLayout;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010065 std::string outputName;
SiCong Li898a3242019-06-24 16:03:33 +010066 std::string modelOutputLabelsPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010067 std::string validationLabelPath;
SiCong Li898a3242019-06-24 16:03:33 +010068 std::string validationRange;
69 std::string blacklistPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010070
71 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
72 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
73
74 po::options_description desc("Options");
75 try
76 {
77 // Adds generic options needed to run Accuracy Tool.
78 desc.add_options()
Conor Kennedy30562022019-05-13 14:48:58 +010079 ("help,h", "Display help messages")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010080 ("model-path,m", po::value<std::string>(&modelPath)->required(), "Path to armnn format model file")
SiCong Li39f46392019-06-21 12:00:04 +010081 ("model-format,f", po::value<std::string>(&modelFormat)->required(),
82 "The model format. Supported values: caffe, tensorflow, tflite")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010083 ("input-name,i", po::value<std::string>(&inputName)->required(),
84 "Identifier of the input tensors in the network separated by comma.")
85 ("output-name,o", po::value<std::string>(&outputName)->required(),
86 "Identifier of the output tensors in the network separated by comma.")
SiCong Li39f46392019-06-21 12:00:04 +010087 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
88 "Path to directory containing the ImageNet test data")
SiCong Li898a3242019-06-24 16:03:33 +010089 ("model-output-labels,p", po::value<std::string>(&modelOutputLabelsPath)->required(),
90 "Path to model output labels file.")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010091 ("validation-labels-path,v", po::value<std::string>(&validationLabelPath)->required(),
SiCong Li39f46392019-06-21 12:00:04 +010092 "Path to ImageNet Validation Label file")
93 ("data-layout,l", po::value<std::string>(&inputLayout)->default_value("NHWC"),
SiCong Li23700bb2019-07-25 14:54:39 +010094 "Data layout. Supported value: NHWC, NCHW. Default: NHWC")
SiCong Li39f46392019-06-21 12:00:04 +010095 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
SiCong Li898a3242019-06-24 16:03:33 +010096 backendsMessage.c_str())
97 ("validation-range,r", po::value<std::string>(&validationRange)->default_value("1:0"),
98 "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
99 "The index starts at 1 and the range is inclusive."
100 "By default the evaluation will be performed on all images.")
101 ("blacklist-path,b", po::value<std::string>(&blacklistPath)->default_value(""),
102 "Path to a blacklist file where each line denotes the index of an image to be "
103 "excluded from evaluation.");
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100104 }
105 catch (const std::exception& e)
106 {
107 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
108 // and that desc.add_options() can throw boost::io::too_few_args.
109 // They really won't in any of these cases.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100110 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100111 std::cerr << "Fatal internal error: " << e.what() << std::endl;
112 return 1;
113 }
114
115 po::variables_map vm;
116 try
117 {
118 po::store(po::parse_command_line(argc, argv, desc), vm);
119
120 if (vm.count("help"))
121 {
122 std::cout << desc << std::endl;
123 return 1;
124 }
125 po::notify(vm);
126 }
127 catch (po::error& e)
128 {
129 std::cerr << e.what() << std::endl << std::endl;
130 std::cerr << desc << std::endl;
131 return 1;
132 }
133
134 // Check if the requested backend are all valid
135 std::string invalidBackends;
136 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
137 {
Derek Lamberti08446972019-11-26 16:38:31 +0000138 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
139 << invalidBackends;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100140 return EXIT_FAILURE;
141 }
142 armnn::Status status;
143
144 // Create runtime
145 armnn::IRuntime::CreationOptions options;
146 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
147 std::ifstream file(modelPath);
148
149 // Create Parser
150 using IParser = armnnDeserializer::IDeserializer;
151 auto armnnparser(IParser::Create());
152
153 // Create a network
154 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
155
156 // Optimizes the network.
157 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
158 try
159 {
160 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
161 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000162 catch (const armnn::Exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100163 {
164 std::stringstream message;
165 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
Derek Lamberti08446972019-11-26 16:38:31 +0000166 ARMNN_LOG(fatal) << message.str();
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100167 return 1;
168 }
169
170 // Loads the network into the runtime.
171 armnn::NetworkId networkId;
172 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
173 if (status == armnn::Status::Failure)
174 {
Derek Lamberti08446972019-11-26 16:38:31 +0000175 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100176 return 1;
177 }
178
179 // Set up Network
180 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
181
182 const armnnDeserializer::BindingPointInfo&
183 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, inputName);
184
185 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
186 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
187 std::vector<BindingPointInfo> inputBindings = { m_InputBindingInfo };
188
189 const armnnDeserializer::BindingPointInfo&
190 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, outputName);
191
192 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
193 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
194 std::vector<BindingPointInfo> outputBindings = { m_OutputBindingInfo };
195
SiCong Li898a3242019-06-24 16:03:33 +0100196 // Load model output labels
Francis Murtagh532a29d2020-06-29 11:50:01 +0100197 if (modelOutputLabelsPath.empty() || !fs::exists(modelOutputLabelsPath) ||
198 !fs::is_regular_file(modelOutputLabelsPath))
SiCong Li898a3242019-06-24 16:03:33 +0100199 {
Derek Lamberti08446972019-11-26 16:38:31 +0000200 ARMNN_LOG(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
SiCong Li898a3242019-06-24 16:03:33 +0100201 }
202 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
203 LoadModelOutputLabels(modelOutputLabelsPath);
204
205 // Parse begin and end image indices
206 std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
207 size_t imageBegIndex;
208 size_t imageEndIndex;
209 if (imageIndexStrs.size() != 2)
210 {
Derek Lamberti08446972019-11-26 16:38:31 +0000211 ARMNN_LOG(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
SiCong Li898a3242019-06-24 16:03:33 +0100212 return 1;
213 }
214 try
215 {
216 imageBegIndex = std::stoul(imageIndexStrs[0]);
217 imageEndIndex = std::stoul(imageIndexStrs[1]);
218 }
219 catch (const std::exception& e)
220 {
Derek Lamberti08446972019-11-26 16:38:31 +0000221 ARMNN_LOG(fatal) << "Invalid validation range specification: " << validationRange;
SiCong Li898a3242019-06-24 16:03:33 +0100222 return 1;
223 }
224
225 // Validate blacklist file if it's specified
226 if (!blacklistPath.empty() &&
Francis Murtagh532a29d2020-06-29 11:50:01 +0100227 !(fs::exists(blacklistPath) && fs::is_regular_file(blacklistPath)))
SiCong Li898a3242019-06-24 16:03:33 +0100228 {
Derek Lamberti08446972019-11-26 16:38:31 +0000229 ARMNN_LOG(fatal) << "Invalid path to blacklist file at " << blacklistPath;
SiCong Li898a3242019-06-24 16:03:33 +0100230 return 1;
231 }
232
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100233 path pathToDataDir(dataDir);
SiCong Li898a3242019-06-24 16:03:33 +0100234 const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
235 validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, blacklistPath);
236 armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
James Ward6d9f5c52020-09-28 11:56:35 +0100237 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100238
SiCong Li39f46392019-06-21 12:00:04 +0100239 if (ValidateDirectory(dataDir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100240 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100241 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
SiCong Li39f46392019-06-21 12:00:04 +0100242 params.m_ModelPath = modelPath;
243 params.m_IsModelBinary = true;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100244 params.m_ComputeDevices = computeDevice;
245 params.m_InputBindings.push_back(inputName);
246 params.m_OutputBindings.push_back(outputName);
247
248 using TParser = armnnDeserializer::IDeserializer;
249 InferenceModel<TParser, float> model(params, false);
SiCong Li39f46392019-06-21 12:00:04 +0100250 // Get input tensor information
251 const armnn::TensorInfo& inputTensorInfo = model.GetInputBindingInfo().second;
252 const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
253 const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
254 armnn::DataLayout inputTensorDataLayout;
255 if (inputLayout == "NCHW")
256 {
257 inputTensorDataLayout = armnn::DataLayout::NCHW;
258 }
259 else if (inputLayout == "NHWC")
260 {
261 inputTensorDataLayout = armnn::DataLayout::NHWC;
262 }
263 else
264 {
Derek Lamberti08446972019-11-26 16:38:31 +0000265 ARMNN_LOG(fatal) << "Invalid Data layout: " << inputLayout;
SiCong Li39f46392019-06-21 12:00:04 +0100266 return 1;
267 }
268 const unsigned int inputTensorWidth =
269 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
270 const unsigned int inputTensorHeight =
271 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100272 // Get output tensor info
273 const unsigned int outputNumElements = model.GetOutputSize();
SiCong Li898a3242019-06-24 16:03:33 +0100274 // Check output tensor shape is valid
275 if (modelOutputLabels.size() != outputNumElements)
276 {
Derek Lamberti08446972019-11-26 16:38:31 +0000277 ARMNN_LOG(fatal) << "Number of output elements: " << outputNumElements
SiCong Li898a3242019-06-24 16:03:33 +0100278 << " , mismatches the number of output labels: " << modelOutputLabels.size();
279 return 1;
280 }
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100281
SiCong Li39f46392019-06-21 12:00:04 +0100282 const unsigned int batchSize = 1;
283 // Get normalisation parameters
284 SupportedFrontend modelFrontend;
285 if (modelFormat == "caffe")
286 {
287 modelFrontend = SupportedFrontend::Caffe;
288 }
289 else if (modelFormat == "tensorflow")
290 {
291 modelFrontend = SupportedFrontend::TensorFlow;
292 }
293 else if (modelFormat == "tflite")
294 {
295 modelFrontend = SupportedFrontend::TFLite;
296 }
297 else
298 {
Derek Lamberti08446972019-11-26 16:38:31 +0000299 ARMNN_LOG(fatal) << "Unsupported frontend: " << modelFormat;
SiCong Li39f46392019-06-21 12:00:04 +0100300 return 1;
301 }
302 const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
SiCong Li898a3242019-06-24 16:03:33 +0100303 for (const auto& imageEntry : imageNameToLabel)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100304 {
SiCong Li898a3242019-06-24 16:03:33 +0100305 const std::string imageName = imageEntry.first;
306 std::cout << "Processing image: " << imageName << "\n";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100307
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100308 vector<TContainer> inputDataContainers;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100309 vector<TContainer> outputDataContainers;
310
Francis Murtagh532a29d2020-06-29 11:50:01 +0100311 auto imagePath = pathToDataDir / fs::path(imageName);
SiCong Li39f46392019-06-21 12:00:04 +0100312 switch (inputTensorDataType)
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100313 {
SiCong Li39f46392019-06-21 12:00:04 +0100314 case armnn::DataType::Signed32:
315 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100316 PrepareImageTensor<int>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100317 inputTensorWidth, inputTensorHeight,
318 normParams,
319 batchSize,
320 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100321 outputDataContainers = { vector<int>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100322 break;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000323 case armnn::DataType::QAsymmU8:
SiCong Li39f46392019-06-21 12:00:04 +0100324 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100325 PrepareImageTensor<uint8_t>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100326 inputTensorWidth, inputTensorHeight,
327 normParams,
328 batchSize,
329 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100330 outputDataContainers = { vector<uint8_t>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100331 break;
332 case armnn::DataType::Float32:
333 default:
334 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100335 PrepareImageTensor<float>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100336 inputTensorWidth, inputTensorHeight,
337 normParams,
338 batchSize,
339 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100340 outputDataContainers = { vector<float>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100341 break;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100342 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100343
344 status = runtime->EnqueueWorkload(networkId,
345 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
346 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
347
348 if (status == armnn::Status::Failure)
349 {
Derek Lamberti08446972019-11-26 16:38:31 +0000350 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100351 }
352
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100353 checker.AddImageResult<TContainer>(imageName, outputDataContainers);
354 }
355 }
356 else
357 {
358 return 1;
359 }
360
361 for(unsigned int i = 1; i <= 5; ++i)
362 {
363 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
364 }
365
Derek Lamberti08446972019-11-26 16:38:31 +0000366 ARMNN_LOG(info) << "Accuracy Tool ran successfully!";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100367 return 0;
368 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000369 catch (const armnn::Exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100370 {
371 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
372 // exception of type std::length_error.
373 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
374 std::cerr << "Armnn Error: " << e.what() << std::endl;
375 return 1;
376 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000377 catch (const std::exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100378 {
379 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
380 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
381 "Accuracy Tool: " << e.what() << std::endl;
382 return 1;
383 }
384}
385
SiCong Li898a3242019-06-24 16:03:33 +0100386map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
387 const string& imageDirectoryPath,
388 size_t begIndex,
389 size_t endIndex,
390 const string& blacklistPath)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100391{
SiCong Li898a3242019-06-24 16:03:33 +0100392 // Populate imageFilenames with names of all .JPEG, .PNG images
393 std::vector<std::string> imageFilenames;
394 for (const auto& imageEntry :
Francis Murtagh532a29d2020-06-29 11:50:01 +0100395 boost::make_iterator_range(fs::directory_iterator(fs::path(imageDirectoryPath))))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100396 {
Francis Murtagh532a29d2020-06-29 11:50:01 +0100397 fs::path imagePath = imageEntry.path();
SiCong Li898a3242019-06-24 16:03:33 +0100398 std::string imageExtension = boost::to_upper_copy<std::string>(imagePath.extension().string());
Francis Murtagh532a29d2020-06-29 11:50:01 +0100399 if (fs::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100400 {
SiCong Li898a3242019-06-24 16:03:33 +0100401 imageFilenames.push_back(imagePath.filename().string());
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100402 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100403 }
SiCong Li898a3242019-06-24 16:03:33 +0100404 if (imageFilenames.empty())
405 {
406 throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
407 }
408
409 // Sort the image filenames lexicographically
410 std::sort(imageFilenames.begin(), imageFilenames.end());
411
412 std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
413
414 // Get default end index
415 if (begIndex < 1 || endIndex > imageFilenames.size())
416 {
417 throw armnn::Exception("Invalid image index range");
418 }
419 endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
420 if (begIndex > endIndex)
421 {
422 throw armnn::Exception("Invalid image index range");
423 }
424
425 // Load blacklist if there is one
426 std::vector<unsigned int> blacklist;
427 if (!blacklistPath.empty())
428 {
429 std::ifstream blacklistFile(blacklistPath);
430 unsigned int index;
431 while (blacklistFile >> index)
432 {
433 blacklist.push_back(index);
434 }
435 }
436
437 // Load ground truth labels and pair them with corresponding image names
438 std::string classification;
439 map<std::string, std::string> imageNameToLabel;
440 ifstream infile(validationLabelPath);
441 size_t imageIndex = begIndex;
442 size_t blacklistIndexCount = 0;
443 while (std::getline(infile, classification))
444 {
445 if (imageIndex > endIndex)
446 {
447 break;
448 }
449 // If current imageIndex is included in blacklist, skip the current image
450 if (blacklistIndexCount < blacklist.size() && imageIndex == blacklist[blacklistIndexCount])
451 {
452 ++imageIndex;
453 ++blacklistIndexCount;
454 continue;
455 }
456 imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
457 ++imageIndex;
458 }
459 std::cout << blacklistIndexCount << " images blacklisted" << std::endl;
460 std::cout << imageIndex - begIndex - blacklistIndexCount << " images to be loaded" << std::endl;
461 return imageNameToLabel;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100462}
SiCong Li898a3242019-06-24 16:03:33 +0100463
464std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
465{
466 std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
467 ifstream modelOutputLablesFile(modelOutputLabelsPath);
468 std::string line;
469 while (std::getline(modelOutputLablesFile, line))
470 {
471 armnnUtils::LabelCategoryNames tokens = armnnUtils::SplitBy(line, ":");
472 armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
473 std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
474 [](const std::string& category) { return armnnUtils::Strip(category); });
475 modelOutputLabels.push_back(predictionCategoryNames);
476 }
477 return modelOutputLabels;
478}