blob: f9a3c68e77c6ee102ac97f12848f1aff5a462ceb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa01c577f2c2018-08-31 09:22:23 +01007#include "../src/armnn/Profiling.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008#include <boost/algorithm/string.hpp>
9#include <boost/numeric/conversion/cast.hpp>
10#include <boost/log/trivial.hpp>
11#include <boost/filesystem/path.hpp>
12#include <boost/assert.hpp>
13#include <boost/format.hpp>
14#include <boost/program_options.hpp>
15#include <boost/filesystem/operations.hpp>
16
17#include <fstream>
18#include <iostream>
19#include <iomanip>
20#include <array>
21
22using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26namespace armnn
27{
28namespace test
29{
telsoa014fcda012018-03-09 14:13:49 +000030/// Parse the command line of an ArmNN (or referencetests) inference test program.
31/// \return false if any error occurred during options processing, otherwise true
32bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
33 InferenceTestOptions& outParams)
34{
35 namespace po = boost::program_options;
36
37 std::string computeDeviceStr;
38
39 po::options_description desc("Options");
40
41 try
42 {
telsoa01c577f2c2018-08-31 09:22:23 +010043 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000044 desc.add_options()
45 ("help", "Display help messages")
46 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
47 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
48 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
telsoa01c577f2c2018-08-31 09:22:23 +010049 "If non-empty, each individual inference time will be recorded and output to this file")
50 ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
51 "Enables built in profiler. If unset, defaults to off.");
telsoa014fcda012018-03-09 14:13:49 +000052
telsoa01c577f2c2018-08-31 09:22:23 +010053 // Adds options specific to the ITestCaseProvider.
telsoa014fcda012018-03-09 14:13:49 +000054 testCaseProvider.AddCommandLineOptions(desc);
55 }
56 catch (const std::exception& e)
57 {
58 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
59 // and that desc.add_options() can throw boost::io::too_few_args.
60 // They really won't in any of these cases.
61 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
62 std::cerr << "Fatal internal error: " << e.what() << std::endl;
63 return false;
64 }
65
66 po::variables_map vm;
67
68 try
69 {
70 po::store(po::parse_command_line(argc, argv, desc), vm);
71
72 if (vm.count("help"))
73 {
74 std::cout << desc << std::endl;
75 return false;
76 }
77
78 po::notify(vm);
79 }
80 catch (po::error& e)
81 {
82 std::cerr << e.what() << std::endl << std::endl;
83 std::cerr << desc << std::endl;
84 return false;
85 }
86
87 if (!testCaseProvider.ProcessCommandLineOptions())
88 {
89 return false;
90 }
91
92 return true;
93}
94
95bool ValidateDirectory(std::string& dir)
96{
97 if (dir[dir.length() - 1] != '/')
98 {
99 dir += "/";
100 }
101
102 if (!boost::filesystem::exists(dir))
103 {
104 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
105 return false;
106 }
107
108 return true;
109}
110
111bool InferenceTest(const InferenceTestOptions& params,
112 const std::vector<unsigned int>& defaultTestCaseIds,
113 IInferenceTestCaseProvider& testCaseProvider)
114{
115#if !defined (NDEBUG)
telsoa01c577f2c2018-08-31 09:22:23 +0100116 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
telsoa014fcda012018-03-09 14:13:49 +0000117 {
118 BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
119 }
120#endif
121
122 double totalTime = 0;
123 unsigned int nbProcessed = 0;
124 bool success = true;
125
telsoa01c577f2c2018-08-31 09:22:23 +0100126 // Opens the file to write inference times too, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000127 ofstream inferenceTimesFile;
128 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
129 if (recordInferenceTimes)
130 {
131 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
132 if (!inferenceTimesFile.good())
133 {
134 BOOST_LOG_TRIVIAL(error) << "Failed to open inference times file for writing: "
135 << params.m_InferenceTimesFile;
136 return false;
137 }
138 }
139
telsoa01c577f2c2018-08-31 09:22:23 +0100140 // Create a profiler and register it for the current thread.
141 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
142 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
143
144 // Enable profiling if requested.
145 profiler->EnableProfiling(params.m_EnableProfiling);
146
telsoa014fcda012018-03-09 14:13:49 +0000147 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
148 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
149 if (warmupTestCase == nullptr)
150 {
151 BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
152 return false;
153 }
154
155 try
156 {
157 warmupTestCase->Run();
158 }
159 catch (const TestFrameworkException& testError)
160 {
161 BOOST_LOG_TRIVIAL(error) << testError.what();
162 return false;
163 }
164
165 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100166 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000167
168 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
169 {
170 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
171 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
172
173 if (testCase == nullptr)
174 {
175 BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
176 return false;
177 }
178
179 time_point<high_resolution_clock> predictStart;
180 time_point<high_resolution_clock> predictEnd;
181
182 TestCaseResult result = TestCaseResult::Ok;
183
184 try
185 {
186 predictStart = high_resolution_clock::now();
187
188 testCase->Run();
189
190 predictEnd = high_resolution_clock::now();
191
192 // duration<double> will convert the time difference into seconds as a double by default.
193 double timeTakenS = duration<double>(predictEnd - predictStart).count();
194 totalTime += timeTakenS;
195
telsoa01c577f2c2018-08-31 09:22:23 +0100196 // Outputss inference times, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000197 if (recordInferenceTimes)
198 {
199 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
200 }
201
202 result = testCase->ProcessResult(params);
203
204 }
205 catch (const TestFrameworkException& testError)
206 {
207 BOOST_LOG_TRIVIAL(error) << testError.what();
208 result = TestCaseResult::Abort;
209 }
210
211 switch (result)
212 {
213 case TestCaseResult::Ok:
214 break;
215 case TestCaseResult::Abort:
216 return false;
217 case TestCaseResult::Failed:
218 // This test failed so we will fail the entire program eventually, but keep going for now.
219 success = false;
220 break;
221 default:
222 BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
223 return false;
224 }
225 }
226
227 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
228
229 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
230 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
231 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
232 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
233
234 if (!success)
235 {
236 BOOST_LOG_TRIVIAL(error) << "One or more test cases failed";
237 return false;
238 }
239
240 return testCaseProvider.OnInferenceTestFinished();
241}
242
243} // namespace test
244
245} // namespace armnn