blob: b3b38d138640c4fa0a2c02bfdc53e89f731f3556 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +01008#include <Filesystem.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009
telsoa01c577f2c2018-08-31 09:22:23 +010010#include "../src/armnn/Profiling.hpp"
telsoa014fcda012018-03-09 14:13:49 +000011#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
13#include <fstream>
14#include <iostream>
15#include <iomanip>
16#include <array>
17
18using namespace std;
19using namespace std::chrono;
20using namespace armnn::test;
21
22namespace armnn
23{
24namespace test
25{
telsoa014fcda012018-03-09 14:13:49 +000026/// Parse the command line of an ArmNN (or referencetests) inference test program.
27/// \return false if any error occurred during options processing, otherwise true
28bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29 InferenceTestOptions& outParams)
30{
31 namespace po = boost::program_options;
32
telsoa014fcda012018-03-09 14:13:49 +000033 po::options_description desc("Options");
34
35 try
36 {
telsoa01c577f2c2018-08-31 09:22:23 +010037 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000038 desc.add_options()
39 ("help", "Display help messages")
40 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
41 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
42 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
telsoa01c577f2c2018-08-31 09:22:23 +010043 "If non-empty, each individual inference time will be recorded and output to this file")
44 ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
45 "Enables built in profiler. If unset, defaults to off.");
telsoa014fcda012018-03-09 14:13:49 +000046
telsoa01c577f2c2018-08-31 09:22:23 +010047 // Adds options specific to the ITestCaseProvider.
telsoa014fcda012018-03-09 14:13:49 +000048 testCaseProvider.AddCommandLineOptions(desc);
49 }
50 catch (const std::exception& e)
51 {
52 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
53 // and that desc.add_options() can throw boost::io::too_few_args.
54 // They really won't in any of these cases.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010055 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
telsoa014fcda012018-03-09 14:13:49 +000056 std::cerr << "Fatal internal error: " << e.what() << std::endl;
57 return false;
58 }
59
60 po::variables_map vm;
61
62 try
63 {
64 po::store(po::parse_command_line(argc, argv, desc), vm);
65
66 if (vm.count("help"))
67 {
68 std::cout << desc << std::endl;
69 return false;
70 }
71
72 po::notify(vm);
73 }
74 catch (po::error& e)
75 {
76 std::cerr << e.what() << std::endl << std::endl;
77 std::cerr << desc << std::endl;
78 return false;
79 }
80
Matthew Bentham3e68b972019-04-09 13:10:46 +010081 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
telsoa014fcda012018-03-09 14:13:49 +000082 {
83 return false;
84 }
85
86 return true;
87}
88
89bool ValidateDirectory(std::string& dir)
90{
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010091 if (dir.empty())
92 {
93 std::cerr << "No directory specified" << std::endl;
94 return false;
95 }
96
telsoa014fcda012018-03-09 14:13:49 +000097 if (dir[dir.length() - 1] != '/')
98 {
99 dir += "/";
100 }
101
Francis Murtagh532a29d2020-06-29 11:50:01 +0100102 if (!fs::exists(dir))
telsoa014fcda012018-03-09 14:13:49 +0000103 {
104 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
105 return false;
106 }
107
Francis Murtagh532a29d2020-06-29 11:50:01 +0100108 if (!fs::is_directory(dir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100109 {
110 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
111 return false;
112 }
113
telsoa014fcda012018-03-09 14:13:49 +0000114 return true;
115}
116
117bool InferenceTest(const InferenceTestOptions& params,
118 const std::vector<unsigned int>& defaultTestCaseIds,
119 IInferenceTestCaseProvider& testCaseProvider)
120{
121#if !defined (NDEBUG)
telsoa01c577f2c2018-08-31 09:22:23 +0100122 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
telsoa014fcda012018-03-09 14:13:49 +0000123 {
Derek Lamberti08446972019-11-26 16:38:31 +0000124 ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
telsoa014fcda012018-03-09 14:13:49 +0000125 }
126#endif
127
128 double totalTime = 0;
129 unsigned int nbProcessed = 0;
130 bool success = true;
131
telsoa01c577f2c2018-08-31 09:22:23 +0100132 // Opens the file to write inference times too, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000133 ofstream inferenceTimesFile;
134 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
135 if (recordInferenceTimes)
136 {
137 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
138 if (!inferenceTimesFile.good())
139 {
Derek Lamberti08446972019-11-26 16:38:31 +0000140 ARMNN_LOG(error) << "Failed to open inference times file for writing: "
telsoa014fcda012018-03-09 14:13:49 +0000141 << params.m_InferenceTimesFile;
142 return false;
143 }
144 }
145
telsoa01c577f2c2018-08-31 09:22:23 +0100146 // Create a profiler and register it for the current thread.
147 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
148 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
149
150 // Enable profiling if requested.
151 profiler->EnableProfiling(params.m_EnableProfiling);
152
telsoa014fcda012018-03-09 14:13:49 +0000153 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
154 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
155 if (warmupTestCase == nullptr)
156 {
Derek Lamberti08446972019-11-26 16:38:31 +0000157 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000158 return false;
159 }
160
161 try
162 {
163 warmupTestCase->Run();
164 }
165 catch (const TestFrameworkException& testError)
166 {
Derek Lamberti08446972019-11-26 16:38:31 +0000167 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000168 return false;
169 }
170
171 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100172 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000173
174 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
175 {
176 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
177 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
178
179 if (testCase == nullptr)
180 {
Derek Lamberti08446972019-11-26 16:38:31 +0000181 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000182 return false;
183 }
184
185 time_point<high_resolution_clock> predictStart;
186 time_point<high_resolution_clock> predictEnd;
187
188 TestCaseResult result = TestCaseResult::Ok;
189
190 try
191 {
192 predictStart = high_resolution_clock::now();
193
194 testCase->Run();
195
196 predictEnd = high_resolution_clock::now();
197
198 // duration<double> will convert the time difference into seconds as a double by default.
199 double timeTakenS = duration<double>(predictEnd - predictStart).count();
200 totalTime += timeTakenS;
201
telsoa01c577f2c2018-08-31 09:22:23 +0100202 // Outputss inference times, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000203 if (recordInferenceTimes)
204 {
205 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
206 }
207
208 result = testCase->ProcessResult(params);
209
210 }
211 catch (const TestFrameworkException& testError)
212 {
Derek Lamberti08446972019-11-26 16:38:31 +0000213 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000214 result = TestCaseResult::Abort;
215 }
216
217 switch (result)
218 {
219 case TestCaseResult::Ok:
220 break;
221 case TestCaseResult::Abort:
222 return false;
223 case TestCaseResult::Failed:
224 // This test failed so we will fail the entire program eventually, but keep going for now.
225 success = false;
226 break;
227 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100228 ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
telsoa014fcda012018-03-09 14:13:49 +0000229 return false;
230 }
231 }
232
233 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
234
Derek Lamberti08446972019-11-26 16:38:31 +0000235 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000236 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
Derek Lamberti08446972019-11-26 16:38:31 +0000237 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000238 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
239
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100240 // if profiling is enabled print out the results
241 if (profiler && profiler->IsProfilingEnabled())
242 {
243 profiler->Print(std::cout);
244 }
245
telsoa014fcda012018-03-09 14:13:49 +0000246 if (!success)
247 {
Derek Lamberti08446972019-11-26 16:38:31 +0000248 ARMNN_LOG(error) << "One or more test cases failed";
telsoa014fcda012018-03-09 14:13:49 +0000249 return false;
250 }
251
252 return testCaseProvider.OnInferenceTestFinished();
253}
254
255} // namespace test
256
257} // namespace armnn