blob: c6e5011ae435d4d4e0818d7096012c447ecaa4c1 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa01c577f2c2018-08-31 09:22:23 +01007#include "../src/armnn/Profiling.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008#include <boost/algorithm/string.hpp>
9#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010#include <boost/filesystem/path.hpp>
11#include <boost/assert.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
14#include <boost/filesystem/operations.hpp>
15
16#include <fstream>
17#include <iostream>
18#include <iomanip>
19#include <array>
20
21using namespace std;
22using namespace std::chrono;
23using namespace armnn::test;
24
25namespace armnn
26{
27namespace test
28{
telsoa014fcda012018-03-09 14:13:49 +000029/// Parse the command line of an ArmNN (or referencetests) inference test program.
30/// \return false if any error occurred during options processing, otherwise true
31bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
32 InferenceTestOptions& outParams)
33{
34 namespace po = boost::program_options;
35
telsoa014fcda012018-03-09 14:13:49 +000036 po::options_description desc("Options");
37
38 try
39 {
telsoa01c577f2c2018-08-31 09:22:23 +010040 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000041 desc.add_options()
42 ("help", "Display help messages")
43 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
44 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
45 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
telsoa01c577f2c2018-08-31 09:22:23 +010046 "If non-empty, each individual inference time will be recorded and output to this file")
47 ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
48 "Enables built in profiler. If unset, defaults to off.");
telsoa014fcda012018-03-09 14:13:49 +000049
telsoa01c577f2c2018-08-31 09:22:23 +010050 // Adds options specific to the ITestCaseProvider.
telsoa014fcda012018-03-09 14:13:49 +000051 testCaseProvider.AddCommandLineOptions(desc);
52 }
53 catch (const std::exception& e)
54 {
55 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
56 // and that desc.add_options() can throw boost::io::too_few_args.
57 // They really won't in any of these cases.
58 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
59 std::cerr << "Fatal internal error: " << e.what() << std::endl;
60 return false;
61 }
62
63 po::variables_map vm;
64
65 try
66 {
67 po::store(po::parse_command_line(argc, argv, desc), vm);
68
69 if (vm.count("help"))
70 {
71 std::cout << desc << std::endl;
72 return false;
73 }
74
75 po::notify(vm);
76 }
77 catch (po::error& e)
78 {
79 std::cerr << e.what() << std::endl << std::endl;
80 std::cerr << desc << std::endl;
81 return false;
82 }
83
Matthew Bentham3e68b972019-04-09 13:10:46 +010084 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
telsoa014fcda012018-03-09 14:13:49 +000085 {
86 return false;
87 }
88
89 return true;
90}
91
92bool ValidateDirectory(std::string& dir)
93{
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010094 if (dir.empty())
95 {
96 std::cerr << "No directory specified" << std::endl;
97 return false;
98 }
99
telsoa014fcda012018-03-09 14:13:49 +0000100 if (dir[dir.length() - 1] != '/')
101 {
102 dir += "/";
103 }
104
105 if (!boost::filesystem::exists(dir))
106 {
107 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
108 return false;
109 }
110
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100111 if (!boost::filesystem::is_directory(dir))
112 {
113 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
114 return false;
115 }
116
telsoa014fcda012018-03-09 14:13:49 +0000117 return true;
118}
119
120bool InferenceTest(const InferenceTestOptions& params,
121 const std::vector<unsigned int>& defaultTestCaseIds,
122 IInferenceTestCaseProvider& testCaseProvider)
123{
124#if !defined (NDEBUG)
telsoa01c577f2c2018-08-31 09:22:23 +0100125 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
telsoa014fcda012018-03-09 14:13:49 +0000126 {
Derek Lamberti08446972019-11-26 16:38:31 +0000127 ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
telsoa014fcda012018-03-09 14:13:49 +0000128 }
129#endif
130
131 double totalTime = 0;
132 unsigned int nbProcessed = 0;
133 bool success = true;
134
telsoa01c577f2c2018-08-31 09:22:23 +0100135 // Opens the file to write inference times too, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000136 ofstream inferenceTimesFile;
137 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
138 if (recordInferenceTimes)
139 {
140 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
141 if (!inferenceTimesFile.good())
142 {
Derek Lamberti08446972019-11-26 16:38:31 +0000143 ARMNN_LOG(error) << "Failed to open inference times file for writing: "
telsoa014fcda012018-03-09 14:13:49 +0000144 << params.m_InferenceTimesFile;
145 return false;
146 }
147 }
148
telsoa01c577f2c2018-08-31 09:22:23 +0100149 // Create a profiler and register it for the current thread.
150 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
151 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
152
153 // Enable profiling if requested.
154 profiler->EnableProfiling(params.m_EnableProfiling);
155
telsoa014fcda012018-03-09 14:13:49 +0000156 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
157 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
158 if (warmupTestCase == nullptr)
159 {
Derek Lamberti08446972019-11-26 16:38:31 +0000160 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000161 return false;
162 }
163
164 try
165 {
166 warmupTestCase->Run();
167 }
168 catch (const TestFrameworkException& testError)
169 {
Derek Lamberti08446972019-11-26 16:38:31 +0000170 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000171 return false;
172 }
173
174 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100175 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000176
177 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
178 {
179 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
180 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
181
182 if (testCase == nullptr)
183 {
Derek Lamberti08446972019-11-26 16:38:31 +0000184 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000185 return false;
186 }
187
188 time_point<high_resolution_clock> predictStart;
189 time_point<high_resolution_clock> predictEnd;
190
191 TestCaseResult result = TestCaseResult::Ok;
192
193 try
194 {
195 predictStart = high_resolution_clock::now();
196
197 testCase->Run();
198
199 predictEnd = high_resolution_clock::now();
200
201 // duration<double> will convert the time difference into seconds as a double by default.
202 double timeTakenS = duration<double>(predictEnd - predictStart).count();
203 totalTime += timeTakenS;
204
telsoa01c577f2c2018-08-31 09:22:23 +0100205 // Outputss inference times, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000206 if (recordInferenceTimes)
207 {
208 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
209 }
210
211 result = testCase->ProcessResult(params);
212
213 }
214 catch (const TestFrameworkException& testError)
215 {
Derek Lamberti08446972019-11-26 16:38:31 +0000216 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000217 result = TestCaseResult::Abort;
218 }
219
220 switch (result)
221 {
222 case TestCaseResult::Ok:
223 break;
224 case TestCaseResult::Abort:
225 return false;
226 case TestCaseResult::Failed:
227 // This test failed so we will fail the entire program eventually, but keep going for now.
228 success = false;
229 break;
230 default:
231 BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
232 return false;
233 }
234 }
235
236 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
237
Derek Lamberti08446972019-11-26 16:38:31 +0000238 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000239 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
Derek Lamberti08446972019-11-26 16:38:31 +0000240 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000241 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
242
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100243 // if profiling is enabled print out the results
244 if (profiler && profiler->IsProfilingEnabled())
245 {
246 profiler->Print(std::cout);
247 }
248
telsoa014fcda012018-03-09 14:13:49 +0000249 if (!success)
250 {
Derek Lamberti08446972019-11-26 16:38:31 +0000251 ARMNN_LOG(error) << "One or more test cases failed";
telsoa014fcda012018-03-09 14:13:49 +0000252 return false;
253 }
254
255 return testCaseProvider.OnInferenceTestFinished();
256}
257
258} // namespace test
259
260} // namespace armnn