blob: ee8e8e4d35f17af2b5860f9ab0fa8dc1a3fec45a [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
SiCong Li39f46392019-06-21 12:00:04 +01006#include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7#include "../InferenceTest.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01008#include "ModelAccuracyChecker.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01009#include "armnnDeserializer/IDeserializer.hpp"
10
SiCong Li898a3242019-06-24 16:03:33 +010011#include <boost/algorithm/string.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010012#include <boost/filesystem.hpp>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010013#include <boost/program_options/variables_map.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010014#include <boost/range/iterator_range.hpp>
SiCong Li39f46392019-06-21 12:00:04 +010015#include <map>
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010016
17using namespace armnn::test;
18
SiCong Li898a3242019-06-24 16:03:33 +010019/** Load image names and ground-truth labels from the image directory and the ground truth label file
20 *
21 * @pre \p validationLabelPath exists and is valid regular file
22 * @pre \p imageDirectoryPath exists and is valid directory
23 * @pre labels in validation file correspond to images which are in lexicographical order with the image name
24 * @pre image index starts at 1
25 * @pre \p begIndex and \p endIndex are end-inclusive
26 *
27 * @param[in] validationLabelPath Path to validation label file
28 * @param[in] imageDirectoryPath Path to directory containing validation images
29 * @param[in] begIndex Begin index of images to be loaded. Inclusive
30 * @param[in] endIndex End index of images to be loaded. Inclusive
31 * @param[in] blacklistPath Path to blacklist file
32 * @return A map mapping image file names to their corresponding ground-truth labels
33 */
34map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
35 const string& imageDirectoryPath,
36 size_t begIndex = 0,
37 size_t endIndex = 0,
38 const string& blacklistPath = "");
39
40/** Load model output labels from file
41 *
42 * @pre \p modelOutputLabelsPath exists and is a regular file
43 *
44 * @param[in] modelOutputLabelsPath path to model output labels file
45 * @return A vector of labels, which in turn is described by a list of category names
46 */
47std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010048
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010049int main(int argc, char* argv[])
50{
51 try
52 {
53 using namespace boost::filesystem;
54 armnn::LogSeverity level = armnn::LogSeverity::Debug;
55 armnn::ConfigureLogging(true, true, level);
56 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
57
58 // Set-up program Options
59 namespace po = boost::program_options;
60
61 std::vector<armnn::BackendId> computeDevice;
62 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
63 std::string modelPath;
SiCong Li39f46392019-06-21 12:00:04 +010064 std::string modelFormat;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010065 std::string dataDir;
66 std::string inputName;
SiCong Li39f46392019-06-21 12:00:04 +010067 std::string inputLayout;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010068 std::string outputName;
SiCong Li898a3242019-06-24 16:03:33 +010069 std::string modelOutputLabelsPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010070 std::string validationLabelPath;
SiCong Li898a3242019-06-24 16:03:33 +010071 std::string validationRange;
72 std::string blacklistPath;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010073
74 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
75 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
76
77 po::options_description desc("Options");
78 try
79 {
80 // Adds generic options needed to run Accuracy Tool.
81 desc.add_options()
Conor Kennedy30562022019-05-13 14:48:58 +010082 ("help,h", "Display help messages")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010083 ("model-path,m", po::value<std::string>(&modelPath)->required(), "Path to armnn format model file")
SiCong Li39f46392019-06-21 12:00:04 +010084 ("model-format,f", po::value<std::string>(&modelFormat)->required(),
85 "The model format. Supported values: caffe, tensorflow, tflite")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010086 ("input-name,i", po::value<std::string>(&inputName)->required(),
87 "Identifier of the input tensors in the network separated by comma.")
88 ("output-name,o", po::value<std::string>(&outputName)->required(),
89 "Identifier of the output tensors in the network separated by comma.")
SiCong Li39f46392019-06-21 12:00:04 +010090 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
91 "Path to directory containing the ImageNet test data")
SiCong Li898a3242019-06-24 16:03:33 +010092 ("model-output-labels,p", po::value<std::string>(&modelOutputLabelsPath)->required(),
93 "Path to model output labels file.")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010094 ("validation-labels-path,v", po::value<std::string>(&validationLabelPath)->required(),
SiCong Li39f46392019-06-21 12:00:04 +010095 "Path to ImageNet Validation Label file")
96 ("data-layout,l", po::value<std::string>(&inputLayout)->default_value("NHWC"),
SiCong Li23700bb2019-07-25 14:54:39 +010097 "Data layout. Supported value: NHWC, NCHW. Default: NHWC")
SiCong Li39f46392019-06-21 12:00:04 +010098 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
SiCong Li898a3242019-06-24 16:03:33 +010099 backendsMessage.c_str())
100 ("validation-range,r", po::value<std::string>(&validationRange)->default_value("1:0"),
101 "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
102 "The index starts at 1 and the range is inclusive."
103 "By default the evaluation will be performed on all images.")
104 ("blacklist-path,b", po::value<std::string>(&blacklistPath)->default_value(""),
105 "Path to a blacklist file where each line denotes the index of an image to be "
106 "excluded from evaluation.");
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100107 }
108 catch (const std::exception& e)
109 {
110 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
111 // and that desc.add_options() can throw boost::io::too_few_args.
112 // They really won't in any of these cases.
113 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
114 std::cerr << "Fatal internal error: " << e.what() << std::endl;
115 return 1;
116 }
117
118 po::variables_map vm;
119 try
120 {
121 po::store(po::parse_command_line(argc, argv, desc), vm);
122
123 if (vm.count("help"))
124 {
125 std::cout << desc << std::endl;
126 return 1;
127 }
128 po::notify(vm);
129 }
130 catch (po::error& e)
131 {
132 std::cerr << e.what() << std::endl << std::endl;
133 std::cerr << desc << std::endl;
134 return 1;
135 }
136
137 // Check if the requested backend are all valid
138 std::string invalidBackends;
139 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
140 {
141 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
142 << invalidBackends;
143 return EXIT_FAILURE;
144 }
145 armnn::Status status;
146
147 // Create runtime
148 armnn::IRuntime::CreationOptions options;
149 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
150 std::ifstream file(modelPath);
151
152 // Create Parser
153 using IParser = armnnDeserializer::IDeserializer;
154 auto armnnparser(IParser::Create());
155
156 // Create a network
157 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
158
159 // Optimizes the network.
160 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
161 try
162 {
163 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
164 }
165 catch (armnn::Exception& e)
166 {
167 std::stringstream message;
168 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
169 BOOST_LOG_TRIVIAL(fatal) << message.str();
170 return 1;
171 }
172
173 // Loads the network into the runtime.
174 armnn::NetworkId networkId;
175 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
176 if (status == armnn::Status::Failure)
177 {
178 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
179 return 1;
180 }
181
182 // Set up Network
183 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
184
185 const armnnDeserializer::BindingPointInfo&
186 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, inputName);
187
188 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
189 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
190 std::vector<BindingPointInfo> inputBindings = { m_InputBindingInfo };
191
192 const armnnDeserializer::BindingPointInfo&
193 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, outputName);
194
195 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
196 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
197 std::vector<BindingPointInfo> outputBindings = { m_OutputBindingInfo };
198
SiCong Li898a3242019-06-24 16:03:33 +0100199 // Load model output labels
200 if (modelOutputLabelsPath.empty() || !boost::filesystem::exists(modelOutputLabelsPath) ||
201 !boost::filesystem::is_regular_file(modelOutputLabelsPath))
202 {
203 BOOST_LOG_TRIVIAL(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
204 }
205 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
206 LoadModelOutputLabels(modelOutputLabelsPath);
207
208 // Parse begin and end image indices
209 std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
210 size_t imageBegIndex;
211 size_t imageEndIndex;
212 if (imageIndexStrs.size() != 2)
213 {
214 BOOST_LOG_TRIVIAL(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
215 return 1;
216 }
217 try
218 {
219 imageBegIndex = std::stoul(imageIndexStrs[0]);
220 imageEndIndex = std::stoul(imageIndexStrs[1]);
221 }
222 catch (const std::exception& e)
223 {
224 BOOST_LOG_TRIVIAL(fatal) << "Invalid validation range specification: " << validationRange;
225 return 1;
226 }
227
228 // Validate blacklist file if it's specified
229 if (!blacklistPath.empty() &&
230 !(boost::filesystem::exists(blacklistPath) && boost::filesystem::is_regular_file(blacklistPath)))
231 {
232 BOOST_LOG_TRIVIAL(fatal) << "Invalid path to blacklist file at " << blacklistPath;
233 return 1;
234 }
235
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100236 path pathToDataDir(dataDir);
SiCong Li898a3242019-06-24 16:03:33 +0100237 const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
238 validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, blacklistPath);
239 armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100240 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
241
SiCong Li39f46392019-06-21 12:00:04 +0100242 if (ValidateDirectory(dataDir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100243 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100244 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
SiCong Li39f46392019-06-21 12:00:04 +0100245 params.m_ModelPath = modelPath;
246 params.m_IsModelBinary = true;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100247 params.m_ComputeDevices = computeDevice;
248 params.m_InputBindings.push_back(inputName);
249 params.m_OutputBindings.push_back(outputName);
250
251 using TParser = armnnDeserializer::IDeserializer;
252 InferenceModel<TParser, float> model(params, false);
SiCong Li39f46392019-06-21 12:00:04 +0100253 // Get input tensor information
254 const armnn::TensorInfo& inputTensorInfo = model.GetInputBindingInfo().second;
255 const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
256 const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
257 armnn::DataLayout inputTensorDataLayout;
258 if (inputLayout == "NCHW")
259 {
260 inputTensorDataLayout = armnn::DataLayout::NCHW;
261 }
262 else if (inputLayout == "NHWC")
263 {
264 inputTensorDataLayout = armnn::DataLayout::NHWC;
265 }
266 else
267 {
268 BOOST_LOG_TRIVIAL(fatal) << "Invalid Data layout: " << inputLayout;
269 return 1;
270 }
271 const unsigned int inputTensorWidth =
272 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
273 const unsigned int inputTensorHeight =
274 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100275 // Get output tensor info
276 const unsigned int outputNumElements = model.GetOutputSize();
SiCong Li898a3242019-06-24 16:03:33 +0100277 // Check output tensor shape is valid
278 if (modelOutputLabels.size() != outputNumElements)
279 {
280 BOOST_LOG_TRIVIAL(fatal) << "Number of output elements: " << outputNumElements
281 << " , mismatches the number of output labels: " << modelOutputLabels.size();
282 return 1;
283 }
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100284
SiCong Li39f46392019-06-21 12:00:04 +0100285 const unsigned int batchSize = 1;
286 // Get normalisation parameters
287 SupportedFrontend modelFrontend;
288 if (modelFormat == "caffe")
289 {
290 modelFrontend = SupportedFrontend::Caffe;
291 }
292 else if (modelFormat == "tensorflow")
293 {
294 modelFrontend = SupportedFrontend::TensorFlow;
295 }
296 else if (modelFormat == "tflite")
297 {
298 modelFrontend = SupportedFrontend::TFLite;
299 }
300 else
301 {
302 BOOST_LOG_TRIVIAL(fatal) << "Unsupported frontend: " << modelFormat;
303 return 1;
304 }
305 const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
SiCong Li898a3242019-06-24 16:03:33 +0100306 for (const auto& imageEntry : imageNameToLabel)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100307 {
SiCong Li898a3242019-06-24 16:03:33 +0100308 const std::string imageName = imageEntry.first;
309 std::cout << "Processing image: " << imageName << "\n";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100310
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100311 vector<TContainer> inputDataContainers;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100312 vector<TContainer> outputDataContainers;
313
SiCong Li898a3242019-06-24 16:03:33 +0100314 auto imagePath = pathToDataDir / boost::filesystem::path(imageName);
SiCong Li39f46392019-06-21 12:00:04 +0100315 switch (inputTensorDataType)
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100316 {
SiCong Li39f46392019-06-21 12:00:04 +0100317 case armnn::DataType::Signed32:
318 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100319 PrepareImageTensor<int>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100320 inputTensorWidth, inputTensorHeight,
321 normParams,
322 batchSize,
323 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100324 outputDataContainers = { vector<int>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100325 break;
326 case armnn::DataType::QuantisedAsymm8:
327 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100328 PrepareImageTensor<uint8_t>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100329 inputTensorWidth, inputTensorHeight,
330 normParams,
331 batchSize,
332 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100333 outputDataContainers = { vector<uint8_t>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100334 break;
335 case armnn::DataType::Float32:
336 default:
337 inputDataContainers.push_back(
SiCong Li898a3242019-06-24 16:03:33 +0100338 PrepareImageTensor<float>(imagePath.string(),
SiCong Li39f46392019-06-21 12:00:04 +0100339 inputTensorWidth, inputTensorHeight,
340 normParams,
341 batchSize,
342 inputTensorDataLayout));
SiCong Lic0ed7ba2019-06-21 16:02:40 +0100343 outputDataContainers = { vector<float>(outputNumElements) };
SiCong Li39f46392019-06-21 12:00:04 +0100344 break;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100345 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100346
347 status = runtime->EnqueueWorkload(networkId,
348 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
349 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
350
351 if (status == armnn::Status::Failure)
352 {
SiCong Li898a3242019-06-24 16:03:33 +0100353 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100354 }
355
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100356 checker.AddImageResult<TContainer>(imageName, outputDataContainers);
357 }
358 }
359 else
360 {
361 return 1;
362 }
363
364 for(unsigned int i = 1; i <= 5; ++i)
365 {
366 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
367 }
368
369 BOOST_LOG_TRIVIAL(info) << "Accuracy Tool ran successfully!";
370 return 0;
371 }
372 catch (armnn::Exception const & e)
373 {
374 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
375 // exception of type std::length_error.
376 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
377 std::cerr << "Armnn Error: " << e.what() << std::endl;
378 return 1;
379 }
380 catch (const std::exception & e)
381 {
382 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
383 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
384 "Accuracy Tool: " << e.what() << std::endl;
385 return 1;
386 }
387}
388
SiCong Li898a3242019-06-24 16:03:33 +0100389map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
390 const string& imageDirectoryPath,
391 size_t begIndex,
392 size_t endIndex,
393 const string& blacklistPath)
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100394{
SiCong Li898a3242019-06-24 16:03:33 +0100395 // Populate imageFilenames with names of all .JPEG, .PNG images
396 std::vector<std::string> imageFilenames;
397 for (const auto& imageEntry :
398 boost::make_iterator_range(boost::filesystem::directory_iterator(boost::filesystem::path(imageDirectoryPath))))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100399 {
SiCong Li898a3242019-06-24 16:03:33 +0100400 boost::filesystem::path imagePath = imageEntry.path();
401 std::string imageExtension = boost::to_upper_copy<std::string>(imagePath.extension().string());
402 if (boost::filesystem::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100403 {
SiCong Li898a3242019-06-24 16:03:33 +0100404 imageFilenames.push_back(imagePath.filename().string());
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100405 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100406 }
SiCong Li898a3242019-06-24 16:03:33 +0100407 if (imageFilenames.empty())
408 {
409 throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
410 }
411
412 // Sort the image filenames lexicographically
413 std::sort(imageFilenames.begin(), imageFilenames.end());
414
415 std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
416
417 // Get default end index
418 if (begIndex < 1 || endIndex > imageFilenames.size())
419 {
420 throw armnn::Exception("Invalid image index range");
421 }
422 endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
423 if (begIndex > endIndex)
424 {
425 throw armnn::Exception("Invalid image index range");
426 }
427
428 // Load blacklist if there is one
429 std::vector<unsigned int> blacklist;
430 if (!blacklistPath.empty())
431 {
432 std::ifstream blacklistFile(blacklistPath);
433 unsigned int index;
434 while (blacklistFile >> index)
435 {
436 blacklist.push_back(index);
437 }
438 }
439
440 // Load ground truth labels and pair them with corresponding image names
441 std::string classification;
442 map<std::string, std::string> imageNameToLabel;
443 ifstream infile(validationLabelPath);
444 size_t imageIndex = begIndex;
445 size_t blacklistIndexCount = 0;
446 while (std::getline(infile, classification))
447 {
448 if (imageIndex > endIndex)
449 {
450 break;
451 }
452 // If current imageIndex is included in blacklist, skip the current image
453 if (blacklistIndexCount < blacklist.size() && imageIndex == blacklist[blacklistIndexCount])
454 {
455 ++imageIndex;
456 ++blacklistIndexCount;
457 continue;
458 }
459 imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
460 ++imageIndex;
461 }
462 std::cout << blacklistIndexCount << " images blacklisted" << std::endl;
463 std::cout << imageIndex - begIndex - blacklistIndexCount << " images to be loaded" << std::endl;
464 return imageNameToLabel;
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100465}
SiCong Li898a3242019-06-24 16:03:33 +0100466
467std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
468{
469 std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
470 ifstream modelOutputLablesFile(modelOutputLabelsPath);
471 std::string line;
472 while (std::getline(modelOutputLablesFile, line))
473 {
474 armnnUtils::LabelCategoryNames tokens = armnnUtils::SplitBy(line, ":");
475 armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
476 std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
477 [](const std::string& category) { return armnnUtils::Strip(category); });
478 modelOutputLabels.push_back(predictionCategoryNames);
479 }
480 return modelOutputLabels;
481}