blob: 49efbbf9289943a5dab7e0a932d44fb3ab81be9c [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
SiCong Li39f46392019-06-21 12:00:04 +01006#include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7#include "../InferenceTest.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01008#include "ModelAccuracyChecker.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01009#include "armnnDeserializer/IDeserializer.hpp"
David Monahan6bb47a72021-10-22 12:57:28 +010010
Rob Hughes9542f902021-07-14 09:48:54 +010011#include <armnnUtils/Filesystem.hpp>
Francis Murtagh40d27412021-10-28 11:11:35 +010012#include <armnnUtils/TContainer.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010013
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010014#include <cxxopts/cxxopts.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010015#include <map>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010016
17using namespace armnn::test;
18
SiCong Li898a3242019-06-24 16:03:33 +010019/** Load image names and ground-truth labels from the image directory and the ground truth label file
20 *
21 * @pre \p validationLabelPath exists and is valid regular file
22 * @pre \p imageDirectoryPath exists and is valid directory
23 * @pre labels in validation file correspond to images which are in lexicographical order with the image name
24 * @pre image index starts at 1
25 * @pre \p begIndex and \p endIndex are end-inclusive
26 *
27 * @param[in] validationLabelPath Path to validation label file
28 * @param[in] imageDirectoryPath Path to directory containing validation images
29 * @param[in] begIndex Begin index of images to be loaded. Inclusive
30 * @param[in] endIndex End index of images to be loaded. Inclusive
Teresa Charlin2b30f162021-11-17 11:46:25 +000031 * @param[in] excludelistPath Path to excludelist file
SiCong Li898a3242019-06-24 16:03:33 +010032 * @return A map mapping image file names to their corresponding ground-truth labels
33 */
34map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
35 const string& imageDirectoryPath,
36 size_t begIndex = 0,
37 size_t endIndex = 0,
Teresa Charlin2b30f162021-11-17 11:46:25 +000038 const string& excludelistPath = "");
SiCong Li898a3242019-06-24 16:03:33 +010039
40/** Load model output labels from file
Teresa Charlin2b30f162021-11-17 11:46:25 +000041 *
SiCong Li898a3242019-06-24 16:03:33 +010042 * @pre \p modelOutputLabelsPath exists and is a regular file
43 *
44 * @param[in] modelOutputLabelsPath path to model output labels file
45 * @return A vector of labels, which in turn is described by a list of category names
46 */
47std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010048
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010049int main(int argc, char* argv[])
50{
51 try
52 {
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010053 armnn::LogSeverity level = armnn::LogSeverity::Debug;
54 armnn::ConfigureLogging(true, true, level);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010055
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010056 std::string modelPath;
SiCong Li39f46392019-06-21 12:00:04 +010057 std::string modelFormat;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010058 std::vector<std::string> inputNames;
59 std::vector<std::string> outputNames;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010060 std::string dataDir;
SiCong Li898a3242019-06-24 16:03:33 +010061 std::string modelOutputLabelsPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010062 std::string validationLabelPath;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010063 std::string inputLayout;
64 std::vector<armnn::BackendId> computeDevice;
SiCong Li898a3242019-06-24 16:03:33 +010065 std::string validationRange;
Teresa Charlin2b30f162021-11-17 11:46:25 +000066 std::string excludelistPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010067
68 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
69 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
70
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010071 try
72 {
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010073 cxxopts::Options options("ModeAccuracyTool-Armnn","Options");
74
75 options.add_options()
76 ("h,help", "Display help messages")
77 ("m,model-path",
78 "Path to armnn format model file",
79 cxxopts::value<std::string>(modelPath))
80 ("f,model-format",
Nikhil Raj5d955cf2021-04-19 16:59:48 +010081 "The model format. Supported values: tflite",
Matthew Sloyane7ba17e2020-10-06 10:03:21 +010082 cxxopts::value<std::string>(modelFormat))
83 ("i,input-name",
84 "Identifier of the input tensors in the network separated by comma with no space.",
85 cxxopts::value<std::vector<std::string>>(inputNames))
86 ("o,output-name",
87 "Identifier of the output tensors in the network separated by comma with no space.",
88 cxxopts::value<std::vector<std::string>>(outputNames))
89 ("d,data-dir",
90 "Path to directory containing the ImageNet test data",
91 cxxopts::value<std::string>(dataDir))
92 ("p,model-output-labels",
93 "Path to model output labels file.",
94 cxxopts::value<std::string>(modelOutputLabelsPath))
95 ("v,validation-labels-path",
96 "Path to ImageNet Validation Label file",
97 cxxopts::value<std::string>(validationLabelPath))
98 ("l,data-layout",
99 "Data layout. Supported value: NHWC, NCHW. Default: NHWC",
100 cxxopts::value<std::string>(inputLayout)->default_value("NHWC"))
101 ("c,compute",
102 backendsMessage.c_str(),
103 cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
104 ("r,validation-range",
105 "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
106 "The index starts at 1 and the range is inclusive."
107 "By default the evaluation will be performed on all images.",
108 cxxopts::value<std::string>(validationRange)->default_value("1:0"))
Teresa Charlin2b30f162021-11-17 11:46:25 +0000109 ("e,excludelist-path",
110 "Path to a excludelist file where each line denotes the index of an image to be "
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100111 "excluded from evaluation.",
Teresa Charlin2b30f162021-11-17 11:46:25 +0000112 cxxopts::value<std::string>(excludelistPath)->default_value(""));
113 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This b,blacklist-path command is deprecated", "22.08")
114 ("b,blacklist-path",
115 "Path to a blacklist file where each line denotes the index of an image to be "
116 "excluded from evaluation. This command will be deprecated in favor of: --excludelist-path ",
117 cxxopts::value<std::string>(excludelistPath)->default_value(""));
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100118
119 auto result = options.parse(argc, argv);
120
121 if (result.count("help") > 0)
122 {
123 std::cout << options.help() << std::endl;
124 return EXIT_FAILURE;
125 }
126
127 // Check for mandatory single options.
128 std::string mandatorySingleParameters[] = { "model-path", "model-format", "input-name", "output-name",
129 "data-dir", "model-output-labels", "validation-labels-path" };
130 for (auto param : mandatorySingleParameters)
131 {
132 if (result.count(param) != 1)
133 {
134 std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
135 return EXIT_FAILURE;
136 }
137 }
138 }
139 catch (const cxxopts::OptionException& e)
140 {
141 std::cerr << e.what() << std::endl << std::endl;
142 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100143 }
144 catch (const std::exception& e)
145 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100146 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100147 std::cerr << "Fatal internal error: " << e.what() << std::endl;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100148 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100149 }
150
151 // Check if the requested backend are all valid
152 std::string invalidBackends;
153 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
154 {
Derek Lamberti08446972019-11-26 16:38:31 +0000155 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
156 << invalidBackends;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100157 return EXIT_FAILURE;
158 }
159 armnn::Status status;
160
161 // Create runtime
162 armnn::IRuntime::CreationOptions options;
163 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
164 std::ifstream file(modelPath);
165
166 // Create Parser
167 using IParser = armnnDeserializer::IDeserializer;
168 auto armnnparser(IParser::Create());
169
170 // Create a network
171 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
172
173 // Optimizes the network.
174 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
175 try
176 {
177 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
178 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000179 catch (const armnn::Exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100180 {
181 std::stringstream message;
182 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
Derek Lamberti08446972019-11-26 16:38:31 +0000183 ARMNN_LOG(fatal) << message.str();
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100184 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100185 }
186
187 // Loads the network into the runtime.
188 armnn::NetworkId networkId;
189 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
190 if (status == armnn::Status::Failure)
191 {
Derek Lamberti08446972019-11-26 16:38:31 +0000192 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100193 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100194 }
195
196 // Set up Network
197 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
198
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100199 // Handle inputNames and outputNames, there can be multiple.
200 std::vector<BindingPointInfo> inputBindings;
201 for(auto& input: inputNames)
202 {
203 const armnnDeserializer::BindingPointInfo&
204 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, input);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100205
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100206 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
207 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
208 inputBindings.push_back(m_InputBindingInfo);
209 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100210
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100211 std::vector<BindingPointInfo> outputBindings;
212 for(auto& output: outputNames)
213 {
214 const armnnDeserializer::BindingPointInfo&
215 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, output);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100216
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100217 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
218 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
219 outputBindings.push_back(m_OutputBindingInfo);
220 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100221
SiCong Li898a3242019-06-24 16:03:33 +0100222 // Load model output labels
Francis Murtagh532a29d2020-06-29 11:50:01 +0100223 if (modelOutputLabelsPath.empty() || !fs::exists(modelOutputLabelsPath) ||
224 !fs::is_regular_file(modelOutputLabelsPath))
SiCong Li898a3242019-06-24 16:03:33 +0100225 {
Derek Lamberti08446972019-11-26 16:38:31 +0000226 ARMNN_LOG(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
SiCong Li898a3242019-06-24 16:03:33 +0100227 }
228 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
229 LoadModelOutputLabels(modelOutputLabelsPath);
230
231 // Parse begin and end image indices
232 std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
233 size_t imageBegIndex;
234 size_t imageEndIndex;
235 if (imageIndexStrs.size() != 2)
236 {
Derek Lamberti08446972019-11-26 16:38:31 +0000237 ARMNN_LOG(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100238 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100239 }
240 try
241 {
242 imageBegIndex = std::stoul(imageIndexStrs[0]);
243 imageEndIndex = std::stoul(imageIndexStrs[1]);
244 }
245 catch (const std::exception& e)
246 {
Derek Lamberti08446972019-11-26 16:38:31 +0000247 ARMNN_LOG(fatal) << "Invalid validation range specification: " << validationRange;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100248 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100249 }
250
Teresa Charlin2b30f162021-11-17 11:46:25 +0000251 // Validate excludelist file if it's specified
252 if (!excludelistPath.empty() &&
253 !(fs::exists(excludelistPath) && fs::is_regular_file(excludelistPath)))
SiCong Li898a3242019-06-24 16:03:33 +0100254 {
Teresa Charlin2b30f162021-11-17 11:46:25 +0000255 ARMNN_LOG(fatal) << "Invalid path to excludelist file at " << excludelistPath;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100256 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100257 }
258
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100259 fs::path pathToDataDir(dataDir);
SiCong Li898a3242019-06-24 16:03:33 +0100260 const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
Teresa Charlin2b30f162021-11-17 11:46:25 +0000261 validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, excludelistPath);
SiCong Li898a3242019-06-24 16:03:33 +0100262 armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100263
SiCong Li39f46392019-06-21 12:00:04 +0100264 if (ValidateDirectory(dataDir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100265 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100266 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100267
SiCong Li39f46392019-06-21 12:00:04 +0100268 params.m_ModelPath = modelPath;
269 params.m_IsModelBinary = true;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100270 params.m_ComputeDevices = computeDevice;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100271 // Insert inputNames and outputNames into params vector
272 params.m_InputBindings.insert(std::end(params.m_InputBindings),
273 std::begin(inputNames),
274 std::end(inputNames));
275 params.m_OutputBindings.insert(std::end(params.m_OutputBindings),
276 std::begin(outputNames),
277 std::end(outputNames));
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100278
279 using TParser = armnnDeserializer::IDeserializer;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100280 // If dynamicBackends is empty it will be disabled by default.
281 InferenceModel<TParser, float> model(params, false, "");
282
SiCong Li39f46392019-06-21 12:00:04 +0100283 // Get input tensor information
284 const armnn::TensorInfo& inputTensorInfo = model.GetInputBindingInfo().second;
285 const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
286 const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
287 armnn::DataLayout inputTensorDataLayout;
288 if (inputLayout == "NCHW")
289 {
290 inputTensorDataLayout = armnn::DataLayout::NCHW;
291 }
292 else if (inputLayout == "NHWC")
293 {
294 inputTensorDataLayout = armnn::DataLayout::NHWC;
295 }
296 else
297 {
Derek Lamberti08446972019-11-26 16:38:31 +0000298 ARMNN_LOG(fatal) << "Invalid Data layout: " << inputLayout;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100299 return EXIT_FAILURE;
SiCong Li39f46392019-06-21 12:00:04 +0100300 }
301 const unsigned int inputTensorWidth =
302 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
303 const unsigned int inputTensorHeight =
304 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100305 // Get output tensor info
306 const unsigned int outputNumElements = model.GetOutputSize();
SiCong Li898a3242019-06-24 16:03:33 +0100307 // Check output tensor shape is valid
308 if (modelOutputLabels.size() != outputNumElements)
309 {
Derek Lamberti08446972019-11-26 16:38:31 +0000310 ARMNN_LOG(fatal) << "Number of output elements: " << outputNumElements
SiCong Li898a3242019-06-24 16:03:33 +0100311 << " , mismatches the number of output labels: " << modelOutputLabels.size();
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100312 return EXIT_FAILURE;
SiCong Li898a3242019-06-24 16:03:33 +0100313 }
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100314
SiCong Li39f46392019-06-21 12:00:04 +0100315 const unsigned int batchSize = 1;
316 // Get normalisation parameters
317 SupportedFrontend modelFrontend;
Nikhil Raj5d955cf2021-04-19 16:59:48 +0100318 if (modelFormat == "tflite")
SiCong Li39f46392019-06-21 12:00:04 +0100319 {
320 modelFrontend = SupportedFrontend::TFLite;
321 }
322 else
323 {
Derek Lamberti08446972019-11-26 16:38:31 +0000324 ARMNN_LOG(fatal) << "Unsupported frontend: " << modelFormat;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100325 return EXIT_FAILURE;
SiCong Li39f46392019-06-21 12:00:04 +0100326 }
327 const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
SiCong Li898a3242019-06-24 16:03:33 +0100328 for (const auto& imageEntry : imageNameToLabel)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100329 {
SiCong Li898a3242019-06-24 16:03:33 +0100330 const std::string imageName = imageEntry.first;
331 std::cout << "Processing image: " << imageName << "\n";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100332
Francis Murtagh40d27412021-10-28 11:11:35 +0100333 vector<armnnUtils::TContainer> inputDataContainers;
334 vector<armnnUtils::TContainer> outputDataContainers;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100335
Francis Murtagh532a29d2020-06-29 11:50:01 +0100336 auto imagePath = pathToDataDir / fs::path(imageName);
SiCong Li39f46392019-06-21 12:00:04 +0100337 switch (inputTensorDataType)
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100338 {
SiCong Li39f46392019-06-21 12:00:04 +0100339 case armnn::DataType::Signed32:
340 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100341 PrepareImageTensor<int>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100342 inputTensorWidth, inputTensorHeight,
343 normParams,
344 batchSize,
345 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100346 outputDataContainers = { vector<int>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100347 break;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000348 case armnn::DataType::QAsymmU8:
SiCong Li39f46392019-06-21 12:00:04 +0100349 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100350 PrepareImageTensor<uint8_t>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100351 inputTensorWidth, inputTensorHeight,
352 normParams,
353 batchSize,
354 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100355 outputDataContainers = { vector<uint8_t>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100356 break;
357 case armnn::DataType::Float32:
358 default:
359 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100360 PrepareImageTensor<float>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100361 inputTensorWidth, inputTensorHeight,
362 normParams,
363 batchSize,
364 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100365 outputDataContainers = { vector<float>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100366 break;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100367 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100368
369 status = runtime->EnqueueWorkload(networkId,
370 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
371 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
372
373 if (status == armnn::Status::Failure)
374 {
Derek Lamberti08446972019-11-26 16:38:31 +0000375 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100376 }
377
Francis Murtagh40d27412021-10-28 11:11:35 +0100378 checker.AddImageResult<armnnUtils::TContainer>(imageName, outputDataContainers);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100379 }
380 }
381 else
382 {
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100383 return EXIT_SUCCESS;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100384 }
385
386 for(unsigned int i = 1; i <= 5; ++i)
387 {
388 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
389 }
390
Derek Lamberti08446972019-11-26 16:38:31 +0000391 ARMNN_LOG(info) << "Accuracy Tool ran successfully!";
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100392 return EXIT_SUCCESS;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100393 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000394 catch (const armnn::Exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100395 {
396 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
397 // exception of type std::length_error.
398 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
399 std::cerr << "Armnn Error: " << e.what() << std::endl;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100400 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100401 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000402 catch (const std::exception& e)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100403 {
404 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
405 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
406 "Accuracy Tool: " << e.what() << std::endl;
Matthew Sloyane7ba17e2020-10-06 10:03:21 +0100407 return EXIT_FAILURE;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100408 }
409}
410
SiCong Li898a3242019-06-24 16:03:33 +0100411map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
412 const string& imageDirectoryPath,
413 size_t begIndex,
414 size_t endIndex,
Teresa Charlin2b30f162021-11-17 11:46:25 +0000415 const string& excludelistPath)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100416{
SiCong Li898a3242019-06-24 16:03:33 +0100417 // Populate imageFilenames with names of all .JPEG, .PNG images
418 std::vector<std::string> imageFilenames;
Matthew Sloyan2b428032020-10-06 10:45:32 +0100419 for (const auto& imageEntry : fs::directory_iterator(fs::path(imageDirectoryPath)))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100420 {
Francis Murtagh532a29d2020-06-29 11:50:01 +0100421 fs::path imagePath = imageEntry.path();
Matthew Sloyan2b428032020-10-06 10:45:32 +0100422
423 // Get extension and convert to uppercase
424 std::string imageExtension = imagePath.extension().string();
425 std::transform(imageExtension.begin(), imageExtension.end(), imageExtension.begin(), ::toupper);
426
Francis Murtagh532a29d2020-06-29 11:50:01 +0100427 if (fs::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100428 {
SiCong Li898a3242019-06-24 16:03:33 +0100429 imageFilenames.push_back(imagePath.filename().string());
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100430 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100431 }
SiCong Li898a3242019-06-24 16:03:33 +0100432 if (imageFilenames.empty())
433 {
434 throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
435 }
436
437 // Sort the image filenames lexicographically
438 std::sort(imageFilenames.begin(), imageFilenames.end());
439
440 std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
441
442 // Get default end index
443 if (begIndex < 1 || endIndex > imageFilenames.size())
444 {
445 throw armnn::Exception("Invalid image index range");
446 }
447 endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
448 if (begIndex > endIndex)
449 {
450 throw armnn::Exception("Invalid image index range");
451 }
452
Teresa Charlin2b30f162021-11-17 11:46:25 +0000453 // Load excludelist if there is one
454 std::vector<unsigned int> excludelist;
455 if (!excludelistPath.empty())
SiCong Li898a3242019-06-24 16:03:33 +0100456 {
Teresa Charlin2b30f162021-11-17 11:46:25 +0000457 std::ifstream excludelistFile(excludelistPath);
SiCong Li898a3242019-06-24 16:03:33 +0100458 unsigned int index;
Teresa Charlin2b30f162021-11-17 11:46:25 +0000459 while (excludelistFile >> index)
SiCong Li898a3242019-06-24 16:03:33 +0100460 {
Teresa Charlin2b30f162021-11-17 11:46:25 +0000461 excludelist.push_back(index);
SiCong Li898a3242019-06-24 16:03:33 +0100462 }
463 }
464
465 // Load ground truth labels and pair them with corresponding image names
466 std::string classification;
467 map<std::string, std::string> imageNameToLabel;
468 ifstream infile(validationLabelPath);
469 size_t imageIndex = begIndex;
Teresa Charlin2b30f162021-11-17 11:46:25 +0000470 size_t excludelistIndexCount = 0;
SiCong Li898a3242019-06-24 16:03:33 +0100471 while (std::getline(infile, classification))
472 {
473 if (imageIndex > endIndex)
474 {
475 break;
476 }
Teresa Charlin2b30f162021-11-17 11:46:25 +0000477 // If current imageIndex is included in excludelist, skip the current image
478 if (excludelistIndexCount < excludelist.size() && imageIndex == excludelist[excludelistIndexCount])
SiCong Li898a3242019-06-24 16:03:33 +0100479 {
480 ++imageIndex;
Teresa Charlin2b30f162021-11-17 11:46:25 +0000481 ++excludelistIndexCount;
SiCong Li898a3242019-06-24 16:03:33 +0100482 continue;
483 }
484 imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
485 ++imageIndex;
486 }
Teresa Charlin2b30f162021-11-17 11:46:25 +0000487 std::cout << excludelistIndexCount << " images in excludelist" << std::endl;
488 std::cout << imageIndex - begIndex - excludelistIndexCount << " images to be loaded" << std::endl;
SiCong Li898a3242019-06-24 16:03:33 +0100489 return imageNameToLabel;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100490}
SiCong Li898a3242019-06-24 16:03:33 +0100491
492std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
493{
494 std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
495 ifstream modelOutputLablesFile(modelOutputLabelsPath);
496 std::string line;
497 while (std::getline(modelOutputLablesFile, line))
498 {
499 armnnUtils::LabelCategoryNames tokens = armnnUtils::SplitBy(line, ":");
500 armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
501 std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
502 [](const std::string& category) { return armnnUtils::Strip(category); });
503 modelOutputLabels.push_back(predictionCategoryNames);
504 }
505 return modelOutputLabels;
506}