blob: e554111a46e060e7fd29772f039134d7b38e447c [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Jim Flynn357add22023-04-10 23:26:40 +01002// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "../InferenceTest.hpp"
6#include "../ImagePreprocessor.hpp"
7#include "armnnTfLiteParser/ITfLiteParser.hpp"
8
Matthew Sloyan0029cd62020-09-28 12:58:14 +01009#include <cxxopts/cxxopts.hpp>
Pablo Tello507f39d2019-04-15 15:44:39 +010010#include <fstream>
11
telsoa01c577f2c2018-08-31 09:22:23 +010012using namespace armnnTfLiteParser;
13
Pablo Tello507f39d2019-04-15 15:44:39 +010014std::vector<ImageSet> ParseDataset(const std::string& filename)
15{
16 std::ifstream read(filename);
17 std::vector<ImageSet> imageSet;
18 if (read.is_open())
19 {
20 // Get the images and the correct corresponding label from the given file
21 for (std::string line; std::getline(read, line);)
22 {
23 stringstream ss(line);
24 std::string image_name;
25 std::string label;
26 getline(ss, image_name, ' ');
27 getline(ss, label, ' ');
28 imageSet.push_back(ImageSet(image_name, std::stoi(label)));
29 }
30 }
31 else
32 {
33 // Use the default images
34 imageSet.push_back(ImageSet("Dog.jpg", 209));
35 // top five predictions in tensorflow:
36 // -----------------------------------
37 // 209:Labrador retriever 0.949995
38 // 160:Rhodesian ridgeback 0.0270182
39 // 208:golden retriever 0.0192866
40 // 853:tennis ball 0.000470382
41 // 239:Greater Swiss Mountain dog 0.000464451
42 imageSet.push_back(ImageSet("Cat.jpg", 283));
43 // top five predictions in tensorflow:
44 // -----------------------------------
45 // 283:tiger cat 0.579016
46 // 286:Egyptian cat 0.319676
47 // 282:tabby, tabby cat 0.0873346
48 // 288:lynx, catamount 0.011163
49 // 289:leopard, Panthera pardus 0.000856755
50 imageSet.push_back(ImageSet("shark.jpg", 3));
51 // top five predictions in tensorflow:
52 // -----------------------------------
53 // 3:great white shark, white shark, ... 0.996926
54 // 4:tiger shark, Galeocerdo cuvieri 0.00270528
55 // 149:killer whale, killer, orca, ... 0.000121848
56 // 395:sturgeon 7.78977e-05
57 // 5:hammerhead, hammerhead shark 6.44127e-055
58 };
59 return imageSet;
60}
61
62std::string GetLabelsFilenameFromOptions(int argc, char* argv[])
63{
Matthew Sloyan0029cd62020-09-28 12:58:14 +010064 cxxopts::Options options("TfLiteMobilenetQuantized-Armnn","Validation Options");
65
66 std::string fileName;
67 try
Pablo Tello507f39d2019-04-15 15:44:39 +010068 {
Matthew Sloyan0029cd62020-09-28 12:58:14 +010069 options
70 .allow_unrecognised_options()
71 .add_options()
72 ("l,labels",
73 "Filename of a text file where in each line contains an image "
74 "filename and the correct label the network should predict when fed that image",
75 cxxopts::value<std::string>(fileName));
76
77 auto result = options.parse(argc, argv);
Pablo Tello507f39d2019-04-15 15:44:39 +010078 }
Jim Flynn357add22023-04-10 23:26:40 +010079 catch (const cxxopts::exceptions::exception& e)
Matthew Sloyan0029cd62020-09-28 12:58:14 +010080 {
81 std::cerr << e.what() << std::endl;
82 exit(EXIT_FAILURE);
83 }
84 catch (const std::exception& e)
85 {
86 std::cerr << "Fatal internal error: [" << e.what() << "]" << std::endl;
87 exit(EXIT_FAILURE);
88 }
89
90 return fileName;
Pablo Tello507f39d2019-04-15 15:44:39 +010091}
92
93
telsoa01c577f2c2018-08-31 09:22:23 +010094int main(int argc, char* argv[])
95{
96 int retVal = EXIT_FAILURE;
97 try
98 {
99 // Coverity fix: The following code may throw an exception of type std::length_error.
Pablo Tello507f39d2019-04-15 15:44:39 +0100100 const std::string labels_file = GetLabelsFilenameFromOptions(argc,argv);
101 std::vector<ImageSet> imageSet = ParseDataset(labels_file);
102 std::vector<unsigned int> indices(imageSet.size());
103 std::generate(indices.begin(), indices.end(), [n = 0] () mutable { return n++; });
telsoa01c577f2c2018-08-31 09:22:23 +0100104
105 armnn::TensorShape inputTensorShape({ 1, 224, 224, 3 });
106
107 using DataType = uint8_t;
108 using DatabaseType = ImagePreprocessor<DataType>;
109 using ParserType = armnnTfLiteParser::ITfLiteParser;
110 using ModelType = InferenceModel<ParserType, DataType>;
111
112 // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions.
113 retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType,
114 ParserType>(
115 argc, argv,
116 "mobilenet_v1_1.0_224_quant.tflite", // model name
117 true, // model is binary
118 "input", // input tensor name
119 "MobilenetV1/Predictions/Reshape_1", // output tensor name
Pablo Tello507f39d2019-04-15 15:44:39 +0100120 indices, // vector of indices to select which images to validate
Derek Lambertieb1fce02019-12-10 21:20:10 +0000121 [&imageSet](const char* dataDir, const ModelType &) {
telsoa01c577f2c2018-08-31 09:22:23 +0100122 // we need to get the input quantization parameters from
123 // the parsed model
telsoa01c577f2c2018-08-31 09:22:23 +0100124 return DatabaseType(
125 dataDir,
126 224,
127 224,
128 imageSet,
FinnWilliamsArmaf8b72d2019-05-22 14:50:55 +0100129 1);
telsoa01c577f2c2018-08-31 09:22:23 +0100130 },
131 &inputTensorShape);
132 }
133 catch (const std::exception& e)
134 {
135 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
136 // exception of type std::length_error.
137 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
138 std::cerr << "WARNING: " << *argv << ": An error has occurred when running "
139 "the classifier inference tests: " << e.what() << std::endl;
140 }
141 return retVal;
142}