blob: 3392f6ea517e9252f30e285c20a36ffc4f603489 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +01008#include <Filesystem.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01009
telsoa01c577f2c2018-08-31 09:22:23 +010010#include "../src/armnn/Profiling.hpp"
James Wardc89829f2020-10-12 14:17:36 +010011#include <cxxopts/cxxopts.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
13#include <fstream>
14#include <iostream>
15#include <iomanip>
16#include <array>
17
18using namespace std;
19using namespace std::chrono;
20using namespace armnn::test;
21
22namespace armnn
23{
24namespace test
25{
telsoa014fcda012018-03-09 14:13:49 +000026/// Parse the command line of an ArmNN (or referencetests) inference test program.
27/// \return false if any error occurred during options processing, otherwise true
28bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29 InferenceTestOptions& outParams)
30{
James Wardc89829f2020-10-12 14:17:36 +010031 cxxopts::Options options("InferenceTest", "Inference iteration parameters");
telsoa014fcda012018-03-09 14:13:49 +000032
33 try
34 {
telsoa01c577f2c2018-08-31 09:22:23 +010035 // Adds generic options needed for all inference tests.
James Wardc89829f2020-10-12 14:17:36 +010036 options
37 .allow_unrecognised_options()
38 .add_options()
39 ("h,help", "Display help messages")
40 ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41 cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42 ("inference-times-file",
43 "If non-empty, each individual inference time will be recorded and output to this file",
44 cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45 ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46 cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47
48 std::vector<std::string> required; //to be passed as reference to derived inference tests
telsoa014fcda012018-03-09 14:13:49 +000049
telsoa01c577f2c2018-08-31 09:22:23 +010050 // Adds options specific to the ITestCaseProvider.
James Wardc89829f2020-10-12 14:17:36 +010051 testCaseProvider.AddCommandLineOptions(options, required);
52
53 auto result = options.parse(argc, argv);
54
55 if (result.count("help"))
56 {
57 std::cout << options.help() << std::endl;
58 return false;
59 }
60
61 CheckRequiredOptions(result, required);
62
63 }
64 catch (const cxxopts::OptionException& e)
65 {
66 std::cerr << e.what() << std::endl << options.help() << std::endl;
67 return false;
telsoa014fcda012018-03-09 14:13:49 +000068 }
69 catch (const std::exception& e)
70 {
71 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
72 // and that desc.add_options() can throw boost::io::too_few_args.
73 // They really won't in any of these cases.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010074 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
telsoa014fcda012018-03-09 14:13:49 +000075 std::cerr << "Fatal internal error: " << e.what() << std::endl;
76 return false;
77 }
78
Matthew Bentham3e68b972019-04-09 13:10:46 +010079 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
telsoa014fcda012018-03-09 14:13:49 +000080 {
81 return false;
82 }
83
84 return true;
85}
86
87bool ValidateDirectory(std::string& dir)
88{
Éanna Ó Catháina4247d52019-05-08 14:00:45 +010089 if (dir.empty())
90 {
91 std::cerr << "No directory specified" << std::endl;
92 return false;
93 }
94
telsoa014fcda012018-03-09 14:13:49 +000095 if (dir[dir.length() - 1] != '/')
96 {
97 dir += "/";
98 }
99
Francis Murtagh532a29d2020-06-29 11:50:01 +0100100 if (!fs::exists(dir))
telsoa014fcda012018-03-09 14:13:49 +0000101 {
102 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
103 return false;
104 }
105
Francis Murtagh532a29d2020-06-29 11:50:01 +0100106 if (!fs::is_directory(dir))
Éanna Ó Catháina4247d52019-05-08 14:00:45 +0100107 {
108 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
109 return false;
110 }
111
telsoa014fcda012018-03-09 14:13:49 +0000112 return true;
113}
114
115bool InferenceTest(const InferenceTestOptions& params,
116 const std::vector<unsigned int>& defaultTestCaseIds,
117 IInferenceTestCaseProvider& testCaseProvider)
118{
119#if !defined (NDEBUG)
telsoa01c577f2c2018-08-31 09:22:23 +0100120 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
telsoa014fcda012018-03-09 14:13:49 +0000121 {
Derek Lamberti08446972019-11-26 16:38:31 +0000122 ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
telsoa014fcda012018-03-09 14:13:49 +0000123 }
124#endif
125
126 double totalTime = 0;
127 unsigned int nbProcessed = 0;
128 bool success = true;
129
telsoa01c577f2c2018-08-31 09:22:23 +0100130 // Opens the file to write inference times too, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000131 ofstream inferenceTimesFile;
132 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
133 if (recordInferenceTimes)
134 {
135 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
136 if (!inferenceTimesFile.good())
137 {
Derek Lamberti08446972019-11-26 16:38:31 +0000138 ARMNN_LOG(error) << "Failed to open inference times file for writing: "
telsoa014fcda012018-03-09 14:13:49 +0000139 << params.m_InferenceTimesFile;
140 return false;
141 }
142 }
143
telsoa01c577f2c2018-08-31 09:22:23 +0100144 // Create a profiler and register it for the current thread.
145 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
146 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
147
148 // Enable profiling if requested.
149 profiler->EnableProfiling(params.m_EnableProfiling);
150
telsoa014fcda012018-03-09 14:13:49 +0000151 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
152 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
153 if (warmupTestCase == nullptr)
154 {
Derek Lamberti08446972019-11-26 16:38:31 +0000155 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000156 return false;
157 }
158
159 try
160 {
161 warmupTestCase->Run();
162 }
163 catch (const TestFrameworkException& testError)
164 {
Derek Lamberti08446972019-11-26 16:38:31 +0000165 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000166 return false;
167 }
168
169 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100170 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000171
172 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
173 {
174 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
175 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
176
177 if (testCase == nullptr)
178 {
Derek Lamberti08446972019-11-26 16:38:31 +0000179 ARMNN_LOG(error) << "Failed to load test case";
telsoa014fcda012018-03-09 14:13:49 +0000180 return false;
181 }
182
183 time_point<high_resolution_clock> predictStart;
184 time_point<high_resolution_clock> predictEnd;
185
186 TestCaseResult result = TestCaseResult::Ok;
187
188 try
189 {
190 predictStart = high_resolution_clock::now();
191
192 testCase->Run();
193
194 predictEnd = high_resolution_clock::now();
195
196 // duration<double> will convert the time difference into seconds as a double by default.
197 double timeTakenS = duration<double>(predictEnd - predictStart).count();
198 totalTime += timeTakenS;
199
telsoa01c577f2c2018-08-31 09:22:23 +0100200 // Outputss inference times, if needed.
telsoa014fcda012018-03-09 14:13:49 +0000201 if (recordInferenceTimes)
202 {
203 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
204 }
205
206 result = testCase->ProcessResult(params);
207
208 }
209 catch (const TestFrameworkException& testError)
210 {
Derek Lamberti08446972019-11-26 16:38:31 +0000211 ARMNN_LOG(error) << testError.what();
telsoa014fcda012018-03-09 14:13:49 +0000212 result = TestCaseResult::Abort;
213 }
214
215 switch (result)
216 {
217 case TestCaseResult::Ok:
218 break;
219 case TestCaseResult::Abort:
220 return false;
221 case TestCaseResult::Failed:
222 // This test failed so we will fail the entire program eventually, but keep going for now.
223 success = false;
224 break;
225 default:
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100226 ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
telsoa014fcda012018-03-09 14:13:49 +0000227 return false;
228 }
229 }
230
231 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
232
Derek Lamberti08446972019-11-26 16:38:31 +0000233 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000234 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
Derek Lamberti08446972019-11-26 16:38:31 +0000235 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
telsoa014fcda012018-03-09 14:13:49 +0000236 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
237
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100238 // if profiling is enabled print out the results
239 if (profiler && profiler->IsProfilingEnabled())
240 {
241 profiler->Print(std::cout);
242 }
243
telsoa014fcda012018-03-09 14:13:49 +0000244 if (!success)
245 {
Derek Lamberti08446972019-11-26 16:38:31 +0000246 ARMNN_LOG(error) << "One or more test cases failed";
telsoa014fcda012018-03-09 14:13:49 +0000247 return false;
248 }
249
250 return testCaseProvider.OnInferenceTestFinished();
251}
252
253} // namespace test
254
255} // namespace armnn