blob: 1b411f9b94f06959c4ea3b295eb58366dfca09a5 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "../InferenceTest.hpp"
6#include "../ImagePreprocessor.hpp"
7#include "armnnTfLiteParser/ITfLiteParser.hpp"
8
Pablo Tello507f39d2019-04-15 15:44:39 +01009#include "boost/program_options.hpp"
10#include <fstream>
11
telsoa01c577f2c2018-08-31 09:22:23 +010012using namespace armnnTfLiteParser;
13
Pablo Tello507f39d2019-04-15 15:44:39 +010014std::vector<ImageSet> ParseDataset(const std::string& filename)
15{
16 std::ifstream read(filename);
17 std::vector<ImageSet> imageSet;
18 if (read.is_open())
19 {
20 // Get the images and the correct corresponding label from the given file
21 for (std::string line; std::getline(read, line);)
22 {
23 stringstream ss(line);
24 std::string image_name;
25 std::string label;
26 getline(ss, image_name, ' ');
27 getline(ss, label, ' ');
28 imageSet.push_back(ImageSet(image_name, std::stoi(label)));
29 }
30 }
31 else
32 {
33 // Use the default images
34 imageSet.push_back(ImageSet("Dog.jpg", 209));
35 // top five predictions in tensorflow:
36 // -----------------------------------
37 // 209:Labrador retriever 0.949995
38 // 160:Rhodesian ridgeback 0.0270182
39 // 208:golden retriever 0.0192866
40 // 853:tennis ball 0.000470382
41 // 239:Greater Swiss Mountain dog 0.000464451
42 imageSet.push_back(ImageSet("Cat.jpg", 283));
43 // top five predictions in tensorflow:
44 // -----------------------------------
45 // 283:tiger cat 0.579016
46 // 286:Egyptian cat 0.319676
47 // 282:tabby, tabby cat 0.0873346
48 // 288:lynx, catamount 0.011163
49 // 289:leopard, Panthera pardus 0.000856755
50 imageSet.push_back(ImageSet("shark.jpg", 3));
51 // top five predictions in tensorflow:
52 // -----------------------------------
53 // 3:great white shark, white shark, ... 0.996926
54 // 4:tiger shark, Galeocerdo cuvieri 0.00270528
55 // 149:killer whale, killer, orca, ... 0.000121848
56 // 395:sturgeon 7.78977e-05
57 // 5:hammerhead, hammerhead shark 6.44127e-055
58 };
59 return imageSet;
60}
61
62std::string GetLabelsFilenameFromOptions(int argc, char* argv[])
63{
64 namespace po = boost::program_options;
65 po::options_description desc("Validation Options");
66 std::string fn("");
67 desc.add_options()
68 ("labels", po::value<std::string>(&fn), "Filename of a text file where in each line contains an image "
69 "filename and the correct label the network should predict when fed that image");
70 po::variables_map vm;
71 po::parsed_options parsed = po::command_line_parser(argc, argv).options(desc).allow_unregistered().run();
72 po::store(parsed, vm);
73 if (vm.count("labels"))
74 {
75 fn = vm["labels"].as<std::string>();
76 }
77 return fn;
78}
79
80
telsoa01c577f2c2018-08-31 09:22:23 +010081int main(int argc, char* argv[])
82{
83 int retVal = EXIT_FAILURE;
84 try
85 {
86 // Coverity fix: The following code may throw an exception of type std::length_error.
Pablo Tello507f39d2019-04-15 15:44:39 +010087 const std::string labels_file = GetLabelsFilenameFromOptions(argc,argv);
88 std::vector<ImageSet> imageSet = ParseDataset(labels_file);
89 std::vector<unsigned int> indices(imageSet.size());
90 std::generate(indices.begin(), indices.end(), [n = 0] () mutable { return n++; });
telsoa01c577f2c2018-08-31 09:22:23 +010091
92 armnn::TensorShape inputTensorShape({ 1, 224, 224, 3 });
93
94 using DataType = uint8_t;
95 using DatabaseType = ImagePreprocessor<DataType>;
96 using ParserType = armnnTfLiteParser::ITfLiteParser;
97 using ModelType = InferenceModel<ParserType, DataType>;
98
99 // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions.
100 retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType,
101 ParserType>(
102 argc, argv,
103 "mobilenet_v1_1.0_224_quant.tflite", // model name
104 true, // model is binary
105 "input", // input tensor name
106 "MobilenetV1/Predictions/Reshape_1", // output tensor name
Pablo Tello507f39d2019-04-15 15:44:39 +0100107 indices, // vector of indices to select which images to validate
telsoa01c577f2c2018-08-31 09:22:23 +0100108 [&imageSet](const char* dataDir, const ModelType & model) {
109 // we need to get the input quantization parameters from
110 // the parsed model
telsoa01c577f2c2018-08-31 09:22:23 +0100111 return DatabaseType(
112 dataDir,
113 224,
114 224,
115 imageSet,
FinnWilliamsArma723ec52019-05-22 14:50:55 +0100116 1);
telsoa01c577f2c2018-08-31 09:22:23 +0100117 },
118 &inputTensorShape);
119 }
120 catch (const std::exception& e)
121 {
122 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
123 // exception of type std::length_error.
124 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
125 std::cerr << "WARNING: " << *argv << ": An error has occurred when running "
126 "the classifier inference tests: " << e.what() << std::endl;
127 }
128 return retVal;
129}