blob: 161481f2cd5ba7949a78b90c7f89e8332492997c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "InferenceTest.hpp"
6
7#include <boost/algorithm/string.hpp>
8#include <boost/numeric/conversion/cast.hpp>
9#include <boost/log/trivial.hpp>
10#include <boost/filesystem/path.hpp>
11#include <boost/assert.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
14#include <boost/filesystem/operations.hpp>
15
16#include <fstream>
17#include <iostream>
18#include <iomanip>
19#include <array>
20
21using namespace std;
22using namespace std::chrono;
23using namespace armnn::test;
24
25namespace armnn
26{
27namespace test
28{
29
30/// Parse the command line of an ArmNN (or referencetests) inference test program.
31/// \return false if any error occurred during options processing, otherwise true
32bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
33 InferenceTestOptions& outParams)
34{
35 namespace po = boost::program_options;
36
37 std::string computeDeviceStr;
38
39 po::options_description desc("Options");
40
41 try
42 {
43 // Add generic options needed for all inference tests
44 desc.add_options()
45 ("help", "Display help messages")
46 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
47 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
48 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
49 "If non-empty, each individual inference time will be recorded and output to this file");
50
51 // Add options specific to the ITestCaseProvider
52 testCaseProvider.AddCommandLineOptions(desc);
53 }
54 catch (const std::exception& e)
55 {
56 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
57 // and that desc.add_options() can throw boost::io::too_few_args.
58 // They really won't in any of these cases.
59 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
60 std::cerr << "Fatal internal error: " << e.what() << std::endl;
61 return false;
62 }
63
64 po::variables_map vm;
65
66 try
67 {
68 po::store(po::parse_command_line(argc, argv, desc), vm);
69
70 if (vm.count("help"))
71 {
72 std::cout << desc << std::endl;
73 return false;
74 }
75
76 po::notify(vm);
77 }
78 catch (po::error& e)
79 {
80 std::cerr << e.what() << std::endl << std::endl;
81 std::cerr << desc << std::endl;
82 return false;
83 }
84
85 if (!testCaseProvider.ProcessCommandLineOptions())
86 {
87 return false;
88 }
89
90 return true;
91}
92
93bool ValidateDirectory(std::string& dir)
94{
95 if (dir[dir.length() - 1] != '/')
96 {
97 dir += "/";
98 }
99
100 if (!boost::filesystem::exists(dir))
101 {
102 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
103 return false;
104 }
105
106 return true;
107}
108
109bool InferenceTest(const InferenceTestOptions& params,
110 const std::vector<unsigned int>& defaultTestCaseIds,
111 IInferenceTestCaseProvider& testCaseProvider)
112{
113#if !defined (NDEBUG)
114 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn
115 {
116 BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
117 }
118#endif
119
120 double totalTime = 0;
121 unsigned int nbProcessed = 0;
122 bool success = true;
123
124 // Open the file to write inference times to, if needed
125 ofstream inferenceTimesFile;
126 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
127 if (recordInferenceTimes)
128 {
129 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
130 if (!inferenceTimesFile.good())
131 {
132 BOOST_LOG_TRIVIAL(error) << "Failed to open inference times file for writing: "
133 << params.m_InferenceTimesFile;
134 return false;
135 }
136 }
137
138 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
139 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
140 if (warmupTestCase == nullptr)
141 {
142 BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
143 return false;
144 }
145
146 try
147 {
148 warmupTestCase->Run();
149 }
150 catch (const TestFrameworkException& testError)
151 {
152 BOOST_LOG_TRIVIAL(error) << testError.what();
153 return false;
154 }
155
156 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
surmeh013537c2c2018-05-18 16:31:43 +0100157 : static_cast<unsigned int>(defaultTestCaseIds.size());
telsoa014fcda012018-03-09 14:13:49 +0000158
159 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
160 {
161 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
162 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
163
164 if (testCase == nullptr)
165 {
166 BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
167 return false;
168 }
169
170 time_point<high_resolution_clock> predictStart;
171 time_point<high_resolution_clock> predictEnd;
172
173 TestCaseResult result = TestCaseResult::Ok;
174
175 try
176 {
177 predictStart = high_resolution_clock::now();
178
179 testCase->Run();
180
181 predictEnd = high_resolution_clock::now();
182
183 // duration<double> will convert the time difference into seconds as a double by default.
184 double timeTakenS = duration<double>(predictEnd - predictStart).count();
185 totalTime += timeTakenS;
186
187 // Output inference times if needed
188 if (recordInferenceTimes)
189 {
190 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
191 }
192
193 result = testCase->ProcessResult(params);
194
195 }
196 catch (const TestFrameworkException& testError)
197 {
198 BOOST_LOG_TRIVIAL(error) << testError.what();
199 result = TestCaseResult::Abort;
200 }
201
202 switch (result)
203 {
204 case TestCaseResult::Ok:
205 break;
206 case TestCaseResult::Abort:
207 return false;
208 case TestCaseResult::Failed:
209 // This test failed so we will fail the entire program eventually, but keep going for now.
210 success = false;
211 break;
212 default:
213 BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
214 return false;
215 }
216 }
217
218 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
219
220 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
221 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
222 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
223 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
224
225 if (!success)
226 {
227 BOOST_LOG_TRIVIAL(error) << "One or more test cases failed";
228 return false;
229 }
230
231 return testCaseProvider.OnInferenceTestFinished();
232}
233
234} // namespace test
235
236} // namespace armnn