blob: aec4d702714bb6a4ba962993fda0599daff429f3 [file] [log] [blame]
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ModelAccuracyChecker.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +01007#include "../ImagePreprocessor.hpp"
8#include "armnnDeserializer/IDeserializer.hpp"
Francis Murtaghbee4bc92019-06-18 12:30:37 +01009#include "../NetworkExecutionUtils/NetworkExecutionUtils.hpp"
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010010
11#include <boost/filesystem.hpp>
12#include <boost/range/iterator_range.hpp>
13#include <boost/program_options/variables_map.hpp>
14
15using namespace armnn::test;
16
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010017map<std::string, int> LoadValidationLabels(const string & validationLabelPath);
18
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010019int main(int argc, char* argv[])
20{
21 try
22 {
23 using namespace boost::filesystem;
24 armnn::LogSeverity level = armnn::LogSeverity::Debug;
25 armnn::ConfigureLogging(true, true, level);
26 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
27
28 // Set-up program Options
29 namespace po = boost::program_options;
30
31 std::vector<armnn::BackendId> computeDevice;
32 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
33 std::string modelPath;
34 std::string dataDir;
Francis Murtaghbee4bc92019-06-18 12:30:37 +010035 std::string inputType = "float";
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010036 std::string inputName;
37 std::string outputName;
38 std::string validationLabelPath;
39
40 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
41 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
42
43 po::options_description desc("Options");
44 try
45 {
46 // Adds generic options needed to run Accuracy Tool.
47 desc.add_options()
Conor Kennedy30562022019-05-13 14:48:58 +010048 ("help,h", "Display help messages")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010049 ("model-path,m", po::value<std::string>(&modelPath)->required(), "Path to armnn format model file")
50 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
51 backendsMessage.c_str())
52 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
53 "Path to directory containing the ImageNet test data")
Francis Murtaghbee4bc92019-06-18 12:30:37 +010054 ("input-type,y", po::value(&inputType), "The data type of the input tensors."
55 "If unset, defaults to \"float\" for all defined inputs. "
56 "Accepted values (float, int or qasymm8)")
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010057 ("input-name,i", po::value<std::string>(&inputName)->required(),
58 "Identifier of the input tensors in the network separated by comma.")
59 ("output-name,o", po::value<std::string>(&outputName)->required(),
60 "Identifier of the output tensors in the network separated by comma.")
61 ("validation-labels-path,v", po::value<std::string>(&validationLabelPath)->required(),
62 "Path to ImageNet Validation Label file");
63 }
64 catch (const std::exception& e)
65 {
66 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
67 // and that desc.add_options() can throw boost::io::too_few_args.
68 // They really won't in any of these cases.
69 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
70 std::cerr << "Fatal internal error: " << e.what() << std::endl;
71 return 1;
72 }
73
74 po::variables_map vm;
75 try
76 {
77 po::store(po::parse_command_line(argc, argv, desc), vm);
78
79 if (vm.count("help"))
80 {
81 std::cout << desc << std::endl;
82 return 1;
83 }
84 po::notify(vm);
85 }
86 catch (po::error& e)
87 {
88 std::cerr << e.what() << std::endl << std::endl;
89 std::cerr << desc << std::endl;
90 return 1;
91 }
92
93 // Check if the requested backend are all valid
94 std::string invalidBackends;
95 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
96 {
97 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
98 << invalidBackends;
99 return EXIT_FAILURE;
100 }
101 armnn::Status status;
102
103 // Create runtime
104 armnn::IRuntime::CreationOptions options;
105 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
106 std::ifstream file(modelPath);
107
108 // Create Parser
109 using IParser = armnnDeserializer::IDeserializer;
110 auto armnnparser(IParser::Create());
111
112 // Create a network
113 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
114
115 // Optimizes the network.
116 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
117 try
118 {
119 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
120 }
121 catch (armnn::Exception& e)
122 {
123 std::stringstream message;
124 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
125 BOOST_LOG_TRIVIAL(fatal) << message.str();
126 return 1;
127 }
128
129 // Loads the network into the runtime.
130 armnn::NetworkId networkId;
131 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
132 if (status == armnn::Status::Failure)
133 {
134 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
135 return 1;
136 }
137
138 // Set up Network
139 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
140
141 const armnnDeserializer::BindingPointInfo&
142 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, inputName);
143
144 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
145 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
146 std::vector<BindingPointInfo> inputBindings = { m_InputBindingInfo };
147
148 const armnnDeserializer::BindingPointInfo&
149 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, outputName);
150
151 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
152 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
153 std::vector<BindingPointInfo> outputBindings = { m_OutputBindingInfo };
154
155 path pathToDataDir(dataDir);
156 map<string, int> validationLabels = LoadValidationLabels(validationLabelPath);
157 armnnUtils::ModelAccuracyChecker checker(validationLabels);
158 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
159
160 if(ValidateDirectory(dataDir))
161 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100162 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
163 params.m_ModelPath = modelPath;
164 params.m_IsModelBinary = true;
165 params.m_ComputeDevices = computeDevice;
166 params.m_InputBindings.push_back(inputName);
167 params.m_OutputBindings.push_back(outputName);
168
169 using TParser = armnnDeserializer::IDeserializer;
170 InferenceModel<TParser, float> model(params, false);
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100171 for (auto & imageEntry : boost::make_iterator_range(directory_iterator(pathToDataDir), {}))
172 {
173 cout << "Processing image: " << imageEntry << "\n";
174
175 std::ifstream inputTensorFile(imageEntry.path().string());
176 vector<TContainer> inputDataContainers;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100177 vector<TContainer> outputDataContainers;
178
179 if (inputType.compare("float") == 0)
180 {
181 inputDataContainers.push_back(
182 ParseDataArray<armnn::DataType::Float32>(inputTensorFile));
183 outputDataContainers = {vector<float>(1001)};
184 }
185 else if (inputType.compare("int") == 0)
186 {
187 inputDataContainers.push_back(
188 ParseDataArray<armnn::DataType::Signed32>(inputTensorFile));
189 outputDataContainers = {vector<int>(1001)};
190 }
191 else if (inputType.compare("qasymm8") == 0)
192 {
193 auto inputBinding = model.GetInputBindingInfo();
194 inputDataContainers.push_back(
195 ParseDataArray<armnn::DataType::QuantisedAsymm8>(
196 inputTensorFile,
197 inputBinding.second.GetQuantizationScale(),
198 inputBinding.second.GetQuantizationOffset()));
199 outputDataContainers = {vector<uint8_t >(1001)};
200 }
201 else
202 {
203 BOOST_LOG_TRIVIAL(fatal) << "Unsupported tensor data type \"" << inputType << "\". ";
204 return EXIT_FAILURE;
205 }
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100206
207 status = runtime->EnqueueWorkload(networkId,
208 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
209 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
210
211 if (status == armnn::Status::Failure)
212 {
213 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageEntry;
214 }
215
216 const std::string imageName = imageEntry.path().filename().string();
217 checker.AddImageResult<TContainer>(imageName, outputDataContainers);
218 }
219 }
220 else
221 {
222 return 1;
223 }
224
225 for(unsigned int i = 1; i <= 5; ++i)
226 {
227 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
228 }
229
230 BOOST_LOG_TRIVIAL(info) << "Accuracy Tool ran successfully!";
231 return 0;
232 }
233 catch (armnn::Exception const & e)
234 {
235 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
236 // exception of type std::length_error.
237 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
238 std::cerr << "Armnn Error: " << e.what() << std::endl;
239 return 1;
240 }
241 catch (const std::exception & e)
242 {
243 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
244 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
245 "Accuracy Tool: " << e.what() << std::endl;
246 return 1;
247 }
248}
249
250map<std::string, int> LoadValidationLabels(const string & validationLabelPath)
251{
252 std::string imageName;
253 int classification;
254 map<std::string, int> validationLabel;
255 ifstream infile(validationLabelPath);
256 while (infile >> imageName >> classification)
257 {
258 std::string trimmedName;
259 size_t lastindex = imageName.find_last_of(".");
260 if(lastindex != std::string::npos)
261 {
262 trimmedName = imageName.substr(0, lastindex);
263 }
264 validationLabel.insert(pair<string, int>(trimmedName, classification));
265 }
266 return validationLabel;
267}