blob: cf97459ddc99a58c2e7b354074c19e46cdc3d51b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa01c577f2c2018-08-31 09:22:23 +01007#include "../src/armnn/Profiling.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008#include <boost/algorithm/string.hpp>
9#include <boost/numeric/conversion/cast.hpp>
10#include <boost/log/trivial.hpp>
11#include <boost/filesystem/path.hpp>
12#include <boost/assert.hpp>
13#include <boost/format.hpp>
14#include <boost/program_options.hpp>
15#include <boost/filesystem/operations.hpp>
16
17#include <fstream>
18#include <iostream>
19#include <iomanip>
20#include <array>
21
22using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26namespace armnn
27{
28namespace test
29{
telsoa014fcda012018-03-09 14:13:49 +000030/// Parse the command line of an ArmNN (or referencetests) inference test program.
31/// \return false if any error occurred during options processing, otherwise true
32bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
33 InferenceTestOptions& outParams)
34{
35 namespace po = boost::program_options;
36
telsoa014fcda012018-03-09 14:13:49 +000037 po::options_description desc("Options");
38
39 try
40 {
telsoa01c577f2c2018-08-31 09:22:23 +010041 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000042 desc.add_options()
43 ("help", "Display help messages")
44 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
45 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
46 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
telsoa01c577f2c2018-08-31 09:22:23 +010047 "If non-empty, each individual inference time will be recorded and output to this file")
48 ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
49 "Enables built in profiler. If unset, defaults to off.");
telsoa014fcda012018-03-09 14:13:49 +000050
telsoa01c577f2c2018-08-31 09:22:23 +010051 // Adds options specific to the ITestCaseProvider.
telsoa014fcda012018-03-09 14:13:49 +000052 testCaseProvider.AddCommandLineOptions(desc);
53 }
54 catch (const std::exception& e)
55 {
56 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
57 // and that desc.add_options() can throw boost::io::too_few_args.
58 // They really won't in any of these cases.
59 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
60 std::cerr << "Fatal internal error: " << e.what() << std::endl;
61 return false;
62 }
63
64 po::variables_map vm;
65
66 try
67 {
68 po::store(po::parse_command_line(argc, argv, desc), vm);
69
70 if (vm.count("help"))
71 {
72 std::cout << desc << std::endl;
73 return false;
74 }
75
76 po::notify(vm);
77 }
78 catch (po::error& e)
79 {
80 std::cerr << e.what() << std::endl << std::endl;
81 std::cerr << desc << std::endl;
82 return false;
83 }
84
Matthew Bentham3e68b972019-04-09 13:10:46 +010085 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
telsoa014fcda012018-03-09 14:13:49 +000086 {
87 return false;
88 }
89
90 return true;
91}
92
93bool ValidateDirectory(std::string& dir)
94{
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010095 if (dir.empty())
96 {
97 std::cerr << "No directory specified" << std::endl;
98 return false;
99 }
100
telsoa014fcda012018-03-09 14:13:49 +0000101 if (dir[dir.length() - 1] != '/')
102 {
103 dir += "/";
104 }
105
106 if (!boost::filesystem::exists(dir))
107 {
108 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
109 return false;
110 }
111
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100112 if (!boost::filesystem::is_directory(dir))
113 {
114 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
115 return false;
116 }
117
telsoa014fcda012018-03-09 14:13:49 +0000118 return true;
119}
120
121bool InferenceTest(const InferenceTestOptions& params,
122 const std::vector<unsigned int>& defaultTestCaseIds,
123 IInferenceTestCaseProvider& testCaseProvider)
124{
125#if !defined (NDEBUG)
telsoa01c577f2c2018-08-31 09:22:23 +0100126 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
telsoa014fcda012018-03-09 14:13:49 +0000127 {
128 BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
129 }
130#endif
131
132 double totalTime = 0;
133 unsigned int nbProcessed = 0;
134 bool success = true;
135
telsoa01c577f2c2018-08-31 09:22:23 +0100136 // Opens the file to write inference times too, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000137 ofstream inferenceTimesFile;
138 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
139 if (recordInferenceTimes)
140 {
141 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
142 if (!inferenceTimesFile.good())
143 {
144 BOOST_LOG_TRIVIAL(error) << "Failed to open inference times file for writing: "
145 << params.m_InferenceTimesFile;
146 return false;
147 }
148 }
149
telsoa01c577f2c2018-08-31 09:22:23 +0100150 // Create a profiler and register it for the current thread.
151 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
152 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
153
154 // Enable profiling if requested.
155 profiler->EnableProfiling(params.m_EnableProfiling);
156
telsoa014fcda012018-03-09 14:13:49 +0000157 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
158 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
159 if (warmupTestCase == nullptr)
160 {
161 BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
162 return false;
163 }
164
165 try
166 {
167 warmupTestCase->Run();
168 }
169 catch (const TestFrameworkException& testError)
170 {
171 BOOST_LOG_TRIVIAL(error) << testError.what();
172 return false;
173 }
174
175 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100176 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000177
178 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
179 {
180 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
181 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
182
183 if (testCase == nullptr)
184 {
185 BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
186 return false;
187 }
188
189 time_point<high_resolution_clock> predictStart;
190 time_point<high_resolution_clock> predictEnd;
191
192 TestCaseResult result = TestCaseResult::Ok;
193
194 try
195 {
196 predictStart = high_resolution_clock::now();
197
198 testCase->Run();
199
200 predictEnd = high_resolution_clock::now();
201
202 // duration<double> will convert the time difference into seconds as a double by default.
203 double timeTakenS = duration<double>(predictEnd - predictStart).count();
204 totalTime += timeTakenS;
205
telsoa01c577f2c2018-08-31 09:22:23 +0100206 // Outputss inference times, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000207 if (recordInferenceTimes)
208 {
209 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
210 }
211
212 result = testCase->ProcessResult(params);
213
214 }
215 catch (const TestFrameworkException& testError)
216 {
217 BOOST_LOG_TRIVIAL(error) << testError.what();
218 result = TestCaseResult::Abort;
219 }
220
221 switch (result)
222 {
223 case TestCaseResult::Ok:
224 break;
225 case TestCaseResult::Abort:
226 return false;
227 case TestCaseResult::Failed:
228 // This test failed so we will fail the entire program eventually, but keep going for now.
229 success = false;
230 break;
231 default:
232 BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
233 return false;
234 }
235 }
236
237 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
238
239 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
240 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
241 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
242 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
243
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100244 // if profiling is enabled print out the results
245 if (profiler && profiler->IsProfilingEnabled())
246 {
247 profiler->Print(std::cout);
248 }
249
telsoa014fcda012018-03-09 14:13:49 +0000250 if (!success)
251 {
252 BOOST_LOG_TRIVIAL(error) << "One or more test cases failed";
253 return false;
254 }
255
256 return testCaseProvider.OnInferenceTestFinished();
257}
258
259} // namespace test
260
261} // namespace armnn