blob: 46c51b6cfdbb2f91e33e6f39d05cf918cb75ccc8 [file] [log] [blame]
Moritz Pflanzeree493ae2017-07-05 10:52:21 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Moritz Pflanzerd03b00a2017-07-17 13:50:12 +010024#include "framework/DatasetModes.h"
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010025#include "framework/Macros.h"
26#include "framework/command_line/CommandLineOptions.h"
27#include "framework/command_line/CommandLineParser.h"
28#include "framework/instruments/Instruments.h"
29#include "framework/printers/Printers.h"
30#include "support/ToolchainSupport.h"
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010031#include "tests/TensorLibrary.h"
32
Anthony Barbier15d5ac82017-07-17 15:22:17 +010033#ifdef ARM_COMPUTE_CL
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010034#include "arm_compute/runtime/CL/CLScheduler.h"
Anthony Barbier15d5ac82017-07-17 15:22:17 +010035#endif /* ARM_COMPUTE_CL */
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010036#include "arm_compute/runtime/Scheduler.h"
37
38#include <fstream>
39#include <initializer_list>
40#include <iostream>
41#include <memory>
42#include <random>
43#include <utility>
44
45using namespace arm_compute;
46using namespace arm_compute::test;
47
48namespace arm_compute
49{
50namespace test
51{
52std::unique_ptr<TensorLibrary> library;
53} // namespace test
54} // namespace arm_compute
55
56int main(int argc, char **argv)
57{
Anthony Barbier15d5ac82017-07-17 15:22:17 +010058#ifdef ARM_COMPUTE_CL
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010059 CLScheduler::get().default_init();
Anthony Barbier15d5ac82017-07-17 15:22:17 +010060#endif /* ARM_COMPUTE_CL */
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010061
62 framework::Framework &framework = framework::Framework::get();
63
64 framework::CommandLineParser parser;
65
66 std::set<framework::InstrumentType> allowed_instruments
67 {
68 framework::InstrumentType::ALL,
69 framework::InstrumentType::NONE,
70 };
71
72 for(const auto &type : framework.available_instruments())
73 {
74 allowed_instruments.insert(type);
75 }
76
Moritz Pflanzerd03b00a2017-07-17 13:50:12 +010077 std::set<framework::DatasetMode> allowed_modes
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010078 {
Moritz Pflanzerd03b00a2017-07-17 13:50:12 +010079 framework::DatasetMode::PRECOMMIT,
80 framework::DatasetMode::NIGHTLY,
81 framework::DatasetMode::ALL
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010082 };
83
84 std::set<framework::LogFormat> supported_log_formats
85 {
86 framework::LogFormat::NONE,
87 framework::LogFormat::PRETTY,
88 framework::LogFormat::JSON,
89 };
90
91 auto help = parser.add_option<framework::ToggleOption>("help");
92 help->set_help("Show this help message");
Moritz Pflanzerd03b00a2017-07-17 13:50:12 +010093 auto dataset_mode = parser.add_option<framework::EnumOption<framework::DatasetMode>>("mode", allowed_modes, framework::DatasetMode::ALL);
Moritz Pflanzeree493ae2017-07-05 10:52:21 +010094 dataset_mode->set_help("For managed datasets select which group to use");
95 auto instruments = parser.add_option<framework::EnumListOption<framework::InstrumentType>>("instruments", allowed_instruments, std::initializer_list<framework::InstrumentType> { framework::InstrumentType::ALL });
96 instruments->set_help("Set the profiling instruments to use");
97 auto iterations = parser.add_option<framework::SimpleOption<int>>("iterations", 1);
98 iterations->set_help("Number of iterations per test case");
99 auto threads = parser.add_option<framework::SimpleOption<int>>("threads", 1);
100 threads->set_help("Number of threads to use");
101 auto log_format = parser.add_option<framework::EnumOption<framework::LogFormat>>("log-format", supported_log_formats, framework::LogFormat::PRETTY);
102 log_format->set_help("Output format for measurements and failures");
103 auto filter = parser.add_option<framework::SimpleOption<std::string>>("filter", ".*");
104 filter->set_help("Regular expression to select test cases");
105 auto filter_id = parser.add_option<framework::SimpleOption<std::string>>("filter-id", ".*");
106 filter_id->set_help("Regular expression to select test cases by id");
107 auto log_file = parser.add_option<framework::SimpleOption<std::string>>("log-file");
108 log_file->set_help("Write output to file instead of to the console");
109 auto throw_errors = parser.add_option<framework::ToggleOption>("throw-errors");
110 throw_errors->set_help("Don't catch errors (useful for debugging)");
111 auto seed = parser.add_option<framework::SimpleOption<std::random_device::result_type>>("seed", std::random_device()());
112 seed->set_help("Global seed for random number generation");
113 auto color_output = parser.add_option<framework::ToggleOption>("color-output", true);
114 color_output->set_help("Produce colored output on the console");
115 auto list_tests = parser.add_option<framework::ToggleOption>("list-tests", false);
116 list_tests->set_help("List all test names");
117 auto assets = parser.add_positional_option<framework::SimpleOption<std::string>>("assets");
118 assets->set_help("Path to the assets directory");
119 assets->set_required(true);
120
121 try
122 {
123 parser.parse(argc, argv);
124
125 if(help->is_set() && help->value())
126 {
127 parser.print_help(argv[0]);
128 return 0;
129 }
130
131 if(!parser.validate())
132 {
133 return 1;
134 }
135
136 std::unique_ptr<framework::Printer> printer;
137 std::ofstream log_stream;
138
139 switch(log_format->value())
140 {
141 case framework::LogFormat::JSON:
142 printer = support::cpp14::make_unique<framework::JSONPrinter>();
143 break;
144 case framework::LogFormat::NONE:
145 break;
146 case framework::LogFormat::PRETTY:
147 default:
148 {
149 auto pretty_printer = support::cpp14::make_unique<framework::PrettyPrinter>();
150 pretty_printer->set_color_output(color_output->value());
151 printer = std::move(pretty_printer);
152 break;
153 }
154 }
155
156 if(printer != nullptr)
157 {
158 if(log_file->is_set())
159 {
160 log_stream.open(log_file->value());
161 printer->set_stream(log_stream);
162 }
163 }
164
Moritz Pflanzeree493ae2017-07-05 10:52:21 +0100165 library = support::cpp14::make_unique<TensorLibrary>(assets->value(), seed->value());
166 Scheduler::get().set_num_threads(threads->value());
167
168 printer->print_global_header();
169 printer->print_entry("Seed", support::cpp11::to_string(seed->value()));
170 printer->print_entry("Iterations", support::cpp11::to_string(iterations->value()));
171 printer->print_entry("Threads", support::cpp11::to_string(threads->value()));
172 {
173 using support::cpp11::to_string;
174 printer->print_entry("Dataset mode", to_string(dataset_mode->value()));
175 }
176
Moritz Pflanzerd03b00a2017-07-17 13:50:12 +0100177 framework.init(instruments->value(), iterations->value(), dataset_mode->value(), filter->value(), filter_id->value());
Moritz Pflanzeree493ae2017-07-05 10:52:21 +0100178 framework.set_printer(printer.get());
179 framework.set_throw_errors(throw_errors->value());
180
181 bool success = true;
182
183 if(list_tests->value())
184 {
185 for(const auto &id : framework.test_ids())
186 {
Moritz Pflanzerd03b00a2017-07-17 13:50:12 +0100187 std::cout << "[" << std::get<0>(id) << ", " << std::get<2>(id) << "] " << std::get<1>(id) << "\n";
Moritz Pflanzeree493ae2017-07-05 10:52:21 +0100188 }
189 }
190 else
191 {
192 success = framework.run();
193 }
194
195 printer->print_global_footer();
196
197 return (success ? 0 : 1);
198 }
199 catch(const std::exception &error)
200 {
201 std::cerr << error.what() << "\n";
202
203 if(throw_errors->value())
204 {
205 throw;
206 }
207
208 return 1;
209 }
210
211 return 0;
212}