blob: fec78ac805fefc8eb08ca89a856f51070711c788 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include <iostream>
6#include <chrono>
7#include <vector>
8#include <array>
9#include <boost/log/trivial.hpp>
10
11#include "armnn/ArmNN.hpp"
12#include "armnn/Utils.hpp"
13#include "armnn/INetwork.hpp"
14#include "armnnCaffeParser/ICaffeParser.hpp"
15#include "../Cifar10Database.hpp"
16#include "../InferenceTest.hpp"
17#include "../InferenceModel.hpp"
18
19using namespace std;
20using namespace std::chrono;
21using namespace armnn::test;
22
23int main(int argc, char* argv[])
24{
25#ifdef NDEBUG
26 armnn::LogSeverity level = armnn::LogSeverity::Info;
27#else
28 armnn::LogSeverity level = armnn::LogSeverity::Debug;
29#endif
30
31 try
32 {
telsoa01c577f2c2018-08-31 09:22:23 +010033 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +000034 armnn::ConfigureLogging(true, true, level);
35 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
36
37 namespace po = boost::program_options;
38
David Beckf0b48452018-10-19 15:20:56 +010039 std::vector<armnn::BackendId> computeDevice;
40 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
telsoa014fcda012018-03-09 14:13:49 +000041 std::string modelDir;
42 std::string dataDir;
43
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010044 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
45 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
46
telsoa014fcda012018-03-09 14:13:49 +000047 po::options_description desc("Options");
48 try
49 {
telsoa01c577f2c2018-08-31 09:22:23 +010050 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000051 desc.add_options()
52 ("help", "Display help messages")
53 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
54 "Path to directory containing the Cifar10 model file")
David Beckf0b48452018-10-19 15:20:56 +010055 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010056 backendsMessage.c_str())
telsoa014fcda012018-03-09 14:13:49 +000057 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
58 "Path to directory containing the Cifar10 test data");
59 }
60 catch (const std::exception& e)
61 {
62 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
63 // and that desc.add_options() can throw boost::io::too_few_args.
64 // They really won't in any of these cases.
65 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
66 std::cerr << "Fatal internal error: " << e.what() << std::endl;
67 return 1;
68 }
69
70 po::variables_map vm;
71
72 try
73 {
74 po::store(po::parse_command_line(argc, argv, desc), vm);
75
76 if (vm.count("help"))
77 {
78 std::cout << desc << std::endl;
79 return 1;
80 }
81
82 po::notify(vm);
83 }
84 catch (po::error& e)
85 {
86 std::cerr << e.what() << std::endl << std::endl;
87 std::cerr << desc << std::endl;
88 return 1;
89 }
90
91 if (!ValidateDirectory(modelDir))
92 {
93 return 1;
94 }
95 string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
96
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010097 // Check if the requested backend are all valid
98 std::string invalidBackends;
99 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
100 {
101 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
102 << invalidBackends;
103 return EXIT_FAILURE;
104 }
105
telsoa014fcda012018-03-09 14:13:49 +0000106 // Create runtime
telsoa01c577f2c2018-08-31 09:22:23 +0100107 armnn::IRuntime::CreationOptions options;
108 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
telsoa014fcda012018-03-09 14:13:49 +0000109
telsoa01c577f2c2018-08-31 09:22:23 +0100110 // Loads networks.
telsoa014fcda012018-03-09 14:13:49 +0000111 armnn::Status status;
112 struct Net
113 {
114 Net(armnn::NetworkId netId,
115 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
116 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
117 : m_Network(netId)
118 , m_InputBindingInfo(in)
119 , m_OutputBindingInfo(out)
120 {}
121
122 armnn::NetworkId m_Network;
123 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
124 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
125 };
126 std::vector<Net> networks;
127
128 armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
129
130 const int networksCount = 4;
131 for (int i = 0; i < networksCount; ++i)
132 {
telsoa01c577f2c2018-08-31 09:22:23 +0100133 // Creates a network from a file on the disk.
telsoa014fcda012018-03-09 14:13:49 +0000134 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
135
telsoa01c577f2c2018-08-31 09:22:23 +0100136 // Optimizes the network.
telsoa014fcda012018-03-09 14:13:49 +0000137 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
138 try
139 {
telsoa01c577f2c2018-08-31 09:22:23 +0100140 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
telsoa014fcda012018-03-09 14:13:49 +0000141 }
142 catch (armnn::Exception& e)
143 {
144 std::stringstream message;
145 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
146 BOOST_LOG_TRIVIAL(fatal) << message.str();
147 return 1;
148 }
149
telsoa01c577f2c2018-08-31 09:22:23 +0100150 // Loads the network into the runtime.
telsoa014fcda012018-03-09 14:13:49 +0000151 armnn::NetworkId networkId;
152 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
153 if (status == armnn::Status::Failure)
154 {
155 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
156 return 1;
157 }
158
159 networks.emplace_back(networkId,
160 parser->GetNetworkInputBindingInfo("data"),
161 parser->GetNetworkOutputBindingInfo("prob"));
162 }
163
telsoa01c577f2c2018-08-31 09:22:23 +0100164 // Loads a test case and tests inference.
telsoa014fcda012018-03-09 14:13:49 +0000165 if (!ValidateDirectory(dataDir))
166 {
167 return 1;
168 }
169 Cifar10Database cifar10(dataDir);
170
171 for (unsigned int i = 0; i < 3; ++i)
172 {
telsoa01c577f2c2018-08-31 09:22:23 +0100173 // Loads test case data (including image data).
telsoa014fcda012018-03-09 14:13:49 +0000174 std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
175
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000176 // Tests inference.
Ferran Balaguerc602f292019-02-08 17:09:55 +0000177 std::vector<TContainer> outputs;
178 outputs.reserve(networksCount);
179
180 for (unsigned int j = 0; j < networksCount; ++j)
181 {
182 outputs.push_back(std::vector<float>(10));
183 }
184
telsoa014fcda012018-03-09 14:13:49 +0000185 for (unsigned int k = 0; k < networksCount; ++k)
186 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100187 std::vector<armnn::BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo };
188 std::vector<armnn::BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000189
Ferran Balaguerc602f292019-02-08 17:09:55 +0000190 std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage };
191 std::vector<TContainer> outputDataContainers = { outputs[k] };
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000192
telsoa014fcda012018-03-09 14:13:49 +0000193 status = runtime->EnqueueWorkload(networks[k].m_Network,
Jim Flynn2fd61002019-05-03 12:54:26 +0100194 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
195 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
telsoa014fcda012018-03-09 14:13:49 +0000196 if (status == armnn::Status::Failure)
197 {
198 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";
199 return 1;
200 }
201 }
202
telsoa01c577f2c2018-08-31 09:22:23 +0100203 // Compares outputs.
Ferran Balaguerc602f292019-02-08 17:09:55 +0000204 std::vector<float> output0 = boost::get<std::vector<float>>(outputs[0]);
205
telsoa014fcda012018-03-09 14:13:49 +0000206 for (unsigned int k = 1; k < networksCount; ++k)
207 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000208 std::vector<float> outputK = boost::get<std::vector<float>>(outputs[k]);
209
210 if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end()))
telsoa014fcda012018-03-09 14:13:49 +0000211 {
212 BOOST_LOG_TRIVIAL(error) << "Multiple networks inference failed!";
213 return 1;
214 }
215 }
216 }
217
218 BOOST_LOG_TRIVIAL(info) << "Multiple networks inference ran successfully!";
219 return 0;
220 }
221 catch (armnn::Exception const& e)
222 {
surmeh013537c2c2018-05-18 16:31:43 +0100223 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
224 // exception of type std::length_error.
225 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
226 std::cerr << "Armnn Error: " << e.what() << std::endl;
telsoa014fcda012018-03-09 14:13:49 +0000227 return 1;
228 }
surmeh013537c2c2018-05-18 16:31:43 +0100229 catch (const std::exception& e)
230 {
David Beckf0b48452018-10-19 15:20:56 +0100231 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
surmeh013537c2c2018-05-18 16:31:43 +0100232 std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
233 "multiple networks inference tests: " << e.what() << std::endl;
234 return 1;
235 }
236}