blob: 345a0fed98652a4cbb06b1d51e36da40ef59c525 [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
SiCong Li39f46392019-06-21 12:00:04 +01006#include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7#include "../InferenceTest.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01008#include "ModelAccuracyChecker.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01009#include "armnnDeserializer/IDeserializer.hpp"
Francis Murtagh532a29d2020-06-29 11:50:01 +010010#include <Filesystem.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010011
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010012#include <cxxopts/cxxopts.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010013#include <map>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010014
15using namespace armnn::test;
16
SiCong Li898a3242019-06-24 16:03:33 +010017/** Load image names and ground-truth labels from the image directory and the ground truth label file
18 *
19 * @pre \p validationLabelPath exists and is valid regular file
20 * @pre \p imageDirectoryPath exists and is valid directory
21 * @pre labels in validation file correspond to images which are in lexicographical order with the image name
22 * @pre image index starts at 1
23 * @pre \p begIndex and \p endIndex are end-inclusive
24 *
25 * @param[in] validationLabelPath Path to validation label file
26 * @param[in] imageDirectoryPath Path to directory containing validation images
27 * @param[in] begIndex Begin index of images to be loaded. Inclusive
28 * @param[in] endIndex End index of images to be loaded. Inclusive
29 * @param[in] blacklistPath Path to blacklist file
30 * @return A map mapping image file names to their corresponding ground-truth labels
31 */
32map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
33 const string& imageDirectoryPath,
34 size_t begIndex = 0,
35 size_t endIndex = 0,
36 const string& blacklistPath = "");
37
38/** Load model output labels from file
39 *
40 * @pre \p modelOutputLabelsPath exists and is a regular file
41 *
42 * @param[in] modelOutputLabelsPath path to model output labels file
43 * @return A vector of labels, which in turn is described by a list of category names
44 */
45std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010046
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010047int main(int argc, char* argv[])
48{
49 try
50 {
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010051 armnn::LogSeverity level = armnn::LogSeverity::Debug;
52 armnn::ConfigureLogging(true, true, level);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010053
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010054 std::string modelPath;
SiCong Li39f46392019-06-21 12:00:04 +010055 std::string modelFormat;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010056 std::vector<std::string> inputNames;
57 std::vector<std::string> outputNames;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010058 std::string dataDir;
SiCong Li898a3242019-06-24 16:03:33 +010059 std::string modelOutputLabelsPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010060 std::string validationLabelPath;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010061 std::string inputLayout;
62 std::vector<armnn::BackendId> computeDevice;
SiCong Li898a3242019-06-24 16:03:33 +010063 std::string validationRange;
64 std::string blacklistPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010065
66 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
67 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
68
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010069 try
70 {
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010071 cxxopts::Options options("ModeAccuracyTool-Armnn","Options");
72
73 options.add_options()
74 ("h,help", "Display help messages")
75 ("m,model-path",
76 "Path to armnn format model file",
77 cxxopts::value<std::string>(modelPath))
78 ("f,model-format",
Nikhil Raj6dd178f2021-04-02 22:04:39 +010079 "The model format. Supported values: tensorflow, tflite",
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010080 cxxopts::value<std::string>(modelFormat))
81 ("i,input-name",
82 "Identifier of the input tensors in the network separated by comma with no space.",
83 cxxopts::value<std::vector<std::string>>(inputNames))
84 ("o,output-name",
85 "Identifier of the output tensors in the network separated by comma with no space.",
86 cxxopts::value<std::vector<std::string>>(outputNames))
87 ("d,data-dir",
88 "Path to directory containing the ImageNet test data",
89 cxxopts::value<std::string>(dataDir))
90 ("p,model-output-labels",
91 "Path to model output labels file.",
92 cxxopts::value<std::string>(modelOutputLabelsPath))
93 ("v,validation-labels-path",
94 "Path to ImageNet Validation Label file",
95 cxxopts::value<std::string>(validationLabelPath))
96 ("l,data-layout",
97 "Data layout. Supported value: NHWC, NCHW. Default: NHWC",
98 cxxopts::value<std::string>(inputLayout)->default_value("NHWC"))
99 ("c,compute",
100 backendsMessage.c_str(),
101 cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
102 ("r,validation-range",
103 "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
104 "The index starts at 1 and the range is inclusive."
105 "By default the evaluation will be performed on all images.",
106 cxxopts::value<std::string>(validationRange)->default_value("1:0"))
107 ("b,blacklist-path",
108 "Path to a blacklist file where each line denotes the index of an image to be "
109 "excluded from evaluation.",
110 cxxopts::value<std::string>(blacklistPath)->default_value(""));
111
112 auto result = options.parse(argc, argv);
113
114 if (result.count("help") > 0)
115 {
116 std::cout << options.help() << std::endl;
117 return EXIT_FAILURE;
118 }
119
120 // Check for mandatory single options.
121 std::string mandatorySingleParameters[] = { "model-path", "model-format", "input-name", "output-name",
122 "data-dir", "model-output-labels", "validation-labels-path" };
123 for (auto param : mandatorySingleParameters)
124 {
125 if (result.count(param) != 1)
126 {
127 std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
128 return EXIT_FAILURE;
129 }
130 }
131 }
132 catch (const cxxopts::OptionException& e)
133 {
134 std::cerr << e.what() << std::endl << std::endl;
135 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100136 }
137 catch (const std::exception& e)
138 {
139 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
140 // and that desc.add_options() can throw boost::io::too_few_args.
141 // They really won't in any of these cases.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100142 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100143 std::cerr << "Fatal internal error: " << e.what() << std::endl;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100144 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100145 }
146
147 // Check if the requested backend are all valid
148 std::string invalidBackends;
149 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
150 {
Derek Lamberti08446972019-11-26 16:38:31 +0000151 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
152 << invalidBackends;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100153 return EXIT_FAILURE;
154 }
155 armnn::Status status;
156
157 // Create runtime
158 armnn::IRuntime::CreationOptions options;
159 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
160 std::ifstream file(modelPath);
161
162 // Create Parser
163 using IParser = armnnDeserializer::IDeserializer;
164 auto armnnparser(IParser::Create());
165
166 // Create a network
167 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
168
169 // Optimizes the network.
170 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
171 try
172 {
173 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
174 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000175 catch (const armnn::Exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100176 {
177 std::stringstream message;
178 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
Derek Lamberti08446972019-11-26 16:38:31 +0000179 ARMNN_LOG(fatal) << message.str();
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100180 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100181 }
182
183 // Loads the network into the runtime.
184 armnn::NetworkId networkId;
185 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
186 if (status == armnn::Status::Failure)
187 {
Derek Lamberti08446972019-11-26 16:38:31 +0000188 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100189 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100190 }
191
192 // Set up Network
193 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
194
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100195 // Handle inputNames and outputNames, there can be multiple.
196 std::vector<BindingPointInfo> inputBindings;
197 for(auto& input: inputNames)
198 {
199 const armnnDeserializer::BindingPointInfo&
200 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, input);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100201
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100202 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
203 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
204 inputBindings.push_back(m_InputBindingInfo);
205 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100206
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100207 std::vector<BindingPointInfo> outputBindings;
208 for(auto& output: outputNames)
209 {
210 const armnnDeserializer::BindingPointInfo&
211 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, output);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100212
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100213 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
214 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
215 outputBindings.push_back(m_OutputBindingInfo);
216 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100217
SiCong Li898a3242019-06-24 16:03:33 +0100218 // Load model output labels
Francis Murtagh532a29d2020-06-29 11:50:01 +0100219 if (modelOutputLabelsPath.empty() || !fs::exists(modelOutputLabelsPath) ||
220 !fs::is_regular_file(modelOutputLabelsPath))
SiCong Li898a3242019-06-24 16:03:33 +0100221 {
Derek Lamberti08446972019-11-26 16:38:31 +0000222 ARMNN_LOG(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
SiCong Li898a3242019-06-24 16:03:33 +0100223 }
224 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
225 LoadModelOutputLabels(modelOutputLabelsPath);
226
227 // Parse begin and end image indices
228 std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
229 size_t imageBegIndex;
230 size_t imageEndIndex;
231 if (imageIndexStrs.size() != 2)
232 {
Derek Lamberti08446972019-11-26 16:38:31 +0000233 ARMNN_LOG(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100234 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100235 }
236 try
237 {
238 imageBegIndex = std::stoul(imageIndexStrs[0]);
239 imageEndIndex = std::stoul(imageIndexStrs[1]);
240 }
241 catch (const std::exception& e)
242 {
Derek Lamberti08446972019-11-26 16:38:31 +0000243 ARMNN_LOG(fatal) << "Invalid validation range specification: " << validationRange;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100244 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100245 }
246
247 // Validate blacklist file if it's specified
248 if (!blacklistPath.empty() &&
Francis Murtagh532a29d2020-06-29 11:50:01 +0100249 !(fs::exists(blacklistPath) && fs::is_regular_file(blacklistPath)))
SiCong Li898a3242019-06-24 16:03:33 +0100250 {
Derek Lamberti08446972019-11-26 16:38:31 +0000251 ARMNN_LOG(fatal) << "Invalid path to blacklist file at " << blacklistPath;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100252 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100253 }
254
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100255 fs::path pathToDataDir(dataDir);
SiCong Li898a3242019-06-24 16:03:33 +0100256 const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
257 validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, blacklistPath);
258 armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
James Ward6d9f5c52020-09-28 11:56:35 +0100259 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100260
SiCong Li39f46392019-06-21 12:00:04 +0100261 if (ValidateDirectory(dataDir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100262 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100263 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100264
SiCong Li39f46392019-06-21 12:00:04 +0100265 params.m_ModelPath = modelPath;
266 params.m_IsModelBinary = true;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100267 params.m_ComputeDevices = computeDevice;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100268 // Insert inputNames and outputNames into params vector
269 params.m_InputBindings.insert(std::end(params.m_InputBindings),
270 std::begin(inputNames),
271 std::end(inputNames));
272 params.m_OutputBindings.insert(std::end(params.m_OutputBindings),
273 std::begin(outputNames),
274 std::end(outputNames));
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100275
276 using TParser = armnnDeserializer::IDeserializer;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100277 // If dynamicBackends is empty it will be disabled by default.
278 InferenceModel<TParser, float> model(params, false, "");
279
SiCong Li39f46392019-06-21 12:00:04 +0100280 // Get input tensor information
281 const armnn::TensorInfo& inputTensorInfo = model.GetInputBindingInfo().second;
282 const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
283 const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
284 armnn::DataLayout inputTensorDataLayout;
285 if (inputLayout == "NCHW")
286 {
287 inputTensorDataLayout = armnn::DataLayout::NCHW;
288 }
289 else if (inputLayout == "NHWC")
290 {
291 inputTensorDataLayout = armnn::DataLayout::NHWC;
292 }
293 else
294 {
Derek Lamberti08446972019-11-26 16:38:31 +0000295 ARMNN_LOG(fatal) << "Invalid Data layout: " << inputLayout;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100296 return EXIT_FAILURE;
SiCong Li39f46392019-06-21 12:00:04 +0100297 }
298 const unsigned int inputTensorWidth =
299 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
300 const unsigned int inputTensorHeight =
301 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100302 // Get output tensor info
303 const unsigned int outputNumElements = model.GetOutputSize();
SiCong Li898a3242019-06-24 16:03:33 +0100304 // Check output tensor shape is valid
305 if (modelOutputLabels.size() != outputNumElements)
306 {
Derek Lamberti08446972019-11-26 16:38:31 +0000307 ARMNN_LOG(fatal) << "Number of output elements: " << outputNumElements
SiCong Li898a3242019-06-24 16:03:33 +0100308 << " , mismatches the number of output labels: " << modelOutputLabels.size();
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100309 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100310 }
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100311
SiCong Li39f46392019-06-21 12:00:04 +0100312 const unsigned int batchSize = 1;
313 // Get normalisation parameters
314 SupportedFrontend modelFrontend;
Nikhil Raj6dd178f2021-04-02 22:04:39 +0100315 if (modelFormat == "tensorflow")
SiCong Li39f46392019-06-21 12:00:04 +0100316 {
317 modelFrontend = SupportedFrontend::TensorFlow;
318 }
319 else if (modelFormat == "tflite")
320 {
321 modelFrontend = SupportedFrontend::TFLite;
322 }
323 else
324 {
Derek Lamberti08446972019-11-26 16:38:31 +0000325 ARMNN_LOG(fatal) << "Unsupported frontend: " << modelFormat;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100326 return EXIT_FAILURE;
SiCong Li39f46392019-06-21 12:00:04 +0100327 }
328 const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
SiCong Li898a3242019-06-24 16:03:33 +0100329 for (const auto& imageEntry : imageNameToLabel)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100330 {
SiCong Li898a3242019-06-24 16:03:33 +0100331 const std::string imageName = imageEntry.first;
332 std::cout << "Processing image: " << imageName << "\n";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100333
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100334 vector<TContainer> inputDataContainers;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100335 vector<TContainer> outputDataContainers;
336
Francis Murtagh532a29d2020-06-29 11:50:01 +0100337 auto imagePath = pathToDataDir / fs::path(imageName);
SiCong Li39f46392019-06-21 12:00:04 +0100338 switch (inputTensorDataType)
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100339 {
SiCong Li39f46392019-06-21 12:00:04 +0100340 case armnn::DataType::Signed32:
341 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100342 PrepareImageTensor<int>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100343 inputTensorWidth, inputTensorHeight,
344 normParams,
345 batchSize,
346 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100347 outputDataContainers = { vector<int>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100348 break;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000349 case armnn::DataType::QAsymmU8:
SiCong Li39f46392019-06-21 12:00:04 +0100350 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100351 PrepareImageTensor<uint8_t>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100352 inputTensorWidth, inputTensorHeight,
353 normParams,
354 batchSize,
355 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100356 outputDataContainers = { vector<uint8_t>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100357 break;
358 case armnn::DataType::Float32:
359 default:
360 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100361 PrepareImageTensor<float>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100362 inputTensorWidth, inputTensorHeight,
363 normParams,
364 batchSize,
365 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100366 outputDataContainers = { vector<float>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100367 break;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100368 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100369
370 status = runtime->EnqueueWorkload(networkId,
371 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
372 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
373
374 if (status == armnn::Status::Failure)
375 {
Derek Lamberti08446972019-11-26 16:38:31 +0000376 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100377 }
378
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100379 checker.AddImageResult<TContainer>(imageName, outputDataContainers);
380 }
381 }
382 else
383 {
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100384 return EXIT_SUCCESS;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100385 }
386
387 for(unsigned int i = 1; i <= 5; ++i)
388 {
389 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
390 }
391
Derek Lamberti08446972019-11-26 16:38:31 +0000392 ARMNN_LOG(info) << "Accuracy Tool ran successfully!";
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100393 return EXIT_SUCCESS;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100394 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000395 catch (const armnn::Exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100396 {
397 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
398 // exception of type std::length_error.
399 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
400 std::cerr << "Armnn Error: " << e.what() << std::endl;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100401 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100402 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000403 catch (const std::exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100404 {
405 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
406 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
407 "Accuracy Tool: " << e.what() << std::endl;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100408 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100409 }
410}
411
SiCong Li898a3242019-06-24 16:03:33 +0100412map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
413 const string& imageDirectoryPath,
414 size_t begIndex,
415 size_t endIndex,
416 const string& blacklistPath)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100417{
SiCong Li898a3242019-06-24 16:03:33 +0100418 // Populate imageFilenames with names of all .JPEG, .PNG images
419 std::vector<std::string> imageFilenames;
Matthew Sloyan2b428032020-10-06 10:45:32 +0100420 for (const auto& imageEntry : fs::directory_iterator(fs::path(imageDirectoryPath)))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100421 {
Francis Murtagh532a29d2020-06-29 11:50:01 +0100422 fs::path imagePath = imageEntry.path();
Matthew Sloyan2b428032020-10-06 10:45:32 +0100423
424 // Get extension and convert to uppercase
425 std::string imageExtension = imagePath.extension().string();
426 std::transform(imageExtension.begin(), imageExtension.end(), imageExtension.begin(), ::toupper);
427
Francis Murtagh532a29d2020-06-29 11:50:01 +0100428 if (fs::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100429 {
SiCong Li898a3242019-06-24 16:03:33 +0100430 imageFilenames.push_back(imagePath.filename().string());
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100431 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100432 }
SiCong Li898a3242019-06-24 16:03:33 +0100433 if (imageFilenames.empty())
434 {
435 throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
436 }
437
438 // Sort the image filenames lexicographically
439 std::sort(imageFilenames.begin(), imageFilenames.end());
440
441 std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
442
443 // Get default end index
444 if (begIndex < 1 || endIndex > imageFilenames.size())
445 {
446 throw armnn::Exception("Invalid image index range");
447 }
448 endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
449 if (begIndex > endIndex)
450 {
451 throw armnn::Exception("Invalid image index range");
452 }
453
454 // Load blacklist if there is one
455 std::vector<unsigned int> blacklist;
456 if (!blacklistPath.empty())
457 {
458 std::ifstream blacklistFile(blacklistPath);
459 unsigned int index;
460 while (blacklistFile >> index)
461 {
462 blacklist.push_back(index);
463 }
464 }
465
466 // Load ground truth labels and pair them with corresponding image names
467 std::string classification;
468 map<std::string, std::string> imageNameToLabel;
469 ifstream infile(validationLabelPath);
470 size_t imageIndex = begIndex;
471 size_t blacklistIndexCount = 0;
472 while (std::getline(infile, classification))
473 {
474 if (imageIndex > endIndex)
475 {
476 break;
477 }
478 // If current imageIndex is included in blacklist, skip the current image
479 if (blacklistIndexCount < blacklist.size() && imageIndex == blacklist[blacklistIndexCount])
480 {
481 ++imageIndex;
482 ++blacklistIndexCount;
483 continue;
484 }
485 imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
486 ++imageIndex;
487 }
488 std::cout << blacklistIndexCount << " images blacklisted" << std::endl;
489 std::cout << imageIndex - begIndex - blacklistIndexCount << " images to be loaded" << std::endl;
490 return imageNameToLabel;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100491}
SiCong Li898a3242019-06-24 16:03:33 +0100492
493std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
494{
495 std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
496 ifstream modelOutputLablesFile(modelOutputLabelsPath);
497 std::string line;
498 while (std::getline(modelOutputLablesFile, line))
499 {
500 armnnUtils::LabelCategoryNames tokens = armnnUtils::SplitBy(line, ":");
501 armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
502 std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
503 [](const std::string& category) { return armnnUtils::Strip(category); });
504 modelOutputLabels.push_back(predictionCategoryNames);
505 }
506 return modelOutputLabels;
507}