blob: b0d0b47443c39757213b2e686c012064d8cfb485 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +01008#include <Filesystem.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009
telsoa01c577f2c2018-08-31 09:22:23 +010010#include "../src/armnn/Profiling.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012#include <boost/format.hpp>
13#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
15#include <fstream>
16#include <iostream>
17#include <iomanip>
18#include <array>
19
20using namespace std;
21using namespace std::chrono;
22using namespace armnn::test;
23
24namespace armnn
25{
26namespace test
27{
telsoa014fcda012018-03-09 14:13:49 +000028/// Parse the command line of an ArmNN (or referencetests) inference test program.
29/// \return false if any error occurred during options processing, otherwise true
30bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
31 InferenceTestOptions& outParams)
32{
33 namespace po = boost::program_options;
34
telsoa014fcda012018-03-09 14:13:49 +000035 po::options_description desc("Options");
36
37 try
38 {
telsoa01c577f2c2018-08-31 09:22:23 +010039 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000040 desc.add_options()
41 ("help", "Display help messages")
42 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
43 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
44 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
telsoa01c577f2c2018-08-31 09:22:23 +010045 "If non-empty, each individual inference time will be recorded and output to this file")
46 ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
47 "Enables built in profiler. If unset, defaults to off.");
telsoa014fcda012018-03-09 14:13:49 +000048
telsoa01c577f2c2018-08-31 09:22:23 +010049 // Adds options specific to the ITestCaseProvider.
telsoa014fcda012018-03-09 14:13:49 +000050 testCaseProvider.AddCommandLineOptions(desc);
51 }
52 catch (const std::exception& e)
53 {
54 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
55 // and that desc.add_options() can throw boost::io::too_few_args.
56 // They really won't in any of these cases.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010057 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
telsoa014fcda012018-03-09 14:13:49 +000058 std::cerr << "Fatal internal error: " << e.what() << std::endl;
59 return false;
60 }
61
62 po::variables_map vm;
63
64 try
65 {
66 po::store(po::parse_command_line(argc, argv, desc), vm);
67
68 if (vm.count("help"))
69 {
70 std::cout << desc << std::endl;
71 return false;
72 }
73
74 po::notify(vm);
75 }
76 catch (po::error& e)
77 {
78 std::cerr << e.what() << std::endl << std::endl;
79 std::cerr << desc << std::endl;
80 return false;
81 }
82
Matthew Bentham3e68b972019-04-09 13:10:46 +010083 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
telsoa014fcda012018-03-09 14:13:49 +000084 {
85 return false;
86 }
87
88 return true;
89}
90
91bool ValidateDirectory(std::string& dir)
92{
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010093 if (dir.empty())
94 {
95 std::cerr << "No directory specified" << std::endl;
96 return false;
97 }
98
telsoa014fcda012018-03-09 14:13:49 +000099 if (dir[dir.length() - 1] != '/')
100 {
101 dir += "/";
102 }
103
Francis Murtagh532a29d2020-06-29 11:50:01 +0100104 if (!fs::exists(dir))
telsoa014fcda012018-03-09 14:13:49 +0000105 {
106 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
107 return false;
108 }
109
Francis Murtagh532a29d2020-06-29 11:50:01 +0100110 if (!fs::is_directory(dir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100111 {
112 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
113 return false;
114 }
115
telsoa014fcda012018-03-09 14:13:49 +0000116 return true;
117}
118
119bool InferenceTest(const InferenceTestOptions& params,
120 const std::vector<unsigned int>& defaultTestCaseIds,
121 IInferenceTestCaseProvider& testCaseProvider)
122{
123#if !defined (NDEBUG)
telsoa01c577f2c2018-08-31 09:22:23 +0100124 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
telsoa014fcda012018-03-09 14:13:49 +0000125 {
Derek Lamberti08446972019-11-26 16:38:31 +0000126 ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
telsoa014fcda012018-03-09 14:13:49 +0000127 }
128#endif
129
130 double totalTime = 0;
131 unsigned int nbProcessed = 0;
132 bool success = true;
133
telsoa01c577f2c2018-08-31 09:22:23 +0100134 // Opens the file to write inference times too, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000135 ofstream inferenceTimesFile;
136 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
137 if (recordInferenceTimes)
138 {
139 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
140 if (!inferenceTimesFile.good())
141 {
Derek Lamberti08446972019-11-26 16:38:31 +0000142 ARMNN_LOG(error) << "Failed to open inference times file for writing: "
telsoa014fcda012018-03-09 14:13:49 +0000143 << params.m_InferenceTimesFile;
144 return false;
145 }
146 }
147
telsoa01c577f2c2018-08-31 09:22:23 +0100148 // Create a profiler and register it for the current thread.
149 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
150 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
151
152 // Enable profiling if requested.
153 profiler->EnableProfiling(params.m_EnableProfiling);
154
telsoa014fcda012018-03-09 14:13:49 +0000155 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
156 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
157 if (warmupTestCase == nullptr)
158 {
Derek Lamberti08446972019-11-26 16:38:31 +0000159 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000160 return false;
161 }
162
163 try
164 {
165 warmupTestCase->Run();
166 }
167 catch (const TestFrameworkException& testError)
168 {
Derek Lamberti08446972019-11-26 16:38:31 +0000169 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000170 return false;
171 }
172
173 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100174 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000175
176 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
177 {
178 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
179 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
180
181 if (testCase == nullptr)
182 {
Derek Lamberti08446972019-11-26 16:38:31 +0000183 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000184 return false;
185 }
186
187 time_point<high_resolution_clock> predictStart;
188 time_point<high_resolution_clock> predictEnd;
189
190 TestCaseResult result = TestCaseResult::Ok;
191
192 try
193 {
194 predictStart = high_resolution_clock::now();
195
196 testCase->Run();
197
198 predictEnd = high_resolution_clock::now();
199
200 // duration<double> will convert the time difference into seconds as a double by default.
201 double timeTakenS = duration<double>(predictEnd - predictStart).count();
202 totalTime += timeTakenS;
203
telsoa01c577f2c2018-08-31 09:22:23 +0100204 // Outputss inference times, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000205 if (recordInferenceTimes)
206 {
207 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
208 }
209
210 result = testCase->ProcessResult(params);
211
212 }
213 catch (const TestFrameworkException& testError)
214 {
Derek Lamberti08446972019-11-26 16:38:31 +0000215 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000216 result = TestCaseResult::Abort;
217 }
218
219 switch (result)
220 {
221 case TestCaseResult::Ok:
222 break;
223 case TestCaseResult::Abort:
224 return false;
225 case TestCaseResult::Failed:
226 // This test failed so we will fail the entire program eventually, but keep going for now.
227 success = false;
228 break;
229 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100230 ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
telsoa014fcda012018-03-09 14:13:49 +0000231 return false;
232 }
233 }
234
235 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
236
Derek Lamberti08446972019-11-26 16:38:31 +0000237 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000238 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
Derek Lamberti08446972019-11-26 16:38:31 +0000239 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000240 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
241
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100242 // if profiling is enabled print out the results
243 if (profiler && profiler->IsProfilingEnabled())
244 {
245 profiler->Print(std::cout);
246 }
247
telsoa014fcda012018-03-09 14:13:49 +0000248 if (!success)
249 {
Derek Lamberti08446972019-11-26 16:38:31 +0000250 ARMNN_LOG(error) << "One or more test cases failed";
telsoa014fcda012018-03-09 14:13:49 +0000251 return false;
252 }
253
254 return testCaseProvider.OnInferenceTestFinished();
255}
256
257} // namespace test
258
259} // namespace armnn