blob: 5c969c68dd312434c5374b58c83003f8bfd56feb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include <iostream>
6#include <chrono>
7#include <vector>
8#include <array>
telsoa014fcda012018-03-09 14:13:49 +00009
10#include "armnn/ArmNN.hpp"
11#include "armnn/Utils.hpp"
12#include "armnn/INetwork.hpp"
13#include "armnnCaffeParser/ICaffeParser.hpp"
14#include "../Cifar10Database.hpp"
15#include "../InferenceTest.hpp"
16#include "../InferenceModel.hpp"
17
18using namespace std;
19using namespace std::chrono;
20using namespace armnn::test;
21
22int main(int argc, char* argv[])
23{
24#ifdef NDEBUG
25 armnn::LogSeverity level = armnn::LogSeverity::Info;
26#else
27 armnn::LogSeverity level = armnn::LogSeverity::Debug;
28#endif
29
30 try
31 {
telsoa01c577f2c2018-08-31 09:22:23 +010032 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +000033 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +000034 namespace po = boost::program_options;
35
David Beckf0b48452018-10-19 15:20:56 +010036 std::vector<armnn::BackendId> computeDevice;
37 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
telsoa014fcda012018-03-09 14:13:49 +000038 std::string modelDir;
39 std::string dataDir;
40
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010041 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
42 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
43
telsoa014fcda012018-03-09 14:13:49 +000044 po::options_description desc("Options");
45 try
46 {
telsoa01c577f2c2018-08-31 09:22:23 +010047 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000048 desc.add_options()
49 ("help", "Display help messages")
50 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
51 "Path to directory containing the Cifar10 model file")
David Beckf0b48452018-10-19 15:20:56 +010052 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010053 backendsMessage.c_str())
telsoa014fcda012018-03-09 14:13:49 +000054 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
55 "Path to directory containing the Cifar10 test data");
56 }
57 catch (const std::exception& e)
58 {
59 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
60 // and that desc.add_options() can throw boost::io::too_few_args.
61 // They really won't in any of these cases.
62 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
63 std::cerr << "Fatal internal error: " << e.what() << std::endl;
64 return 1;
65 }
66
67 po::variables_map vm;
68
69 try
70 {
71 po::store(po::parse_command_line(argc, argv, desc), vm);
72
73 if (vm.count("help"))
74 {
75 std::cout << desc << std::endl;
76 return 1;
77 }
78
79 po::notify(vm);
80 }
81 catch (po::error& e)
82 {
83 std::cerr << e.what() << std::endl << std::endl;
84 std::cerr << desc << std::endl;
85 return 1;
86 }
87
88 if (!ValidateDirectory(modelDir))
89 {
90 return 1;
91 }
92 string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
93
Narumol Prangnawaratd9d61f52020-01-02 17:46:53 +000094 // Create runtime
95 // This will also load dynamic backend in case that the dynamic backend path is specified
96 armnn::IRuntime::CreationOptions options;
97 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
98
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010099 // Check if the requested backend are all valid
100 std::string invalidBackends;
101 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
102 {
Derek Lamberti08446972019-11-26 16:38:31 +0000103 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
104 << invalidBackends;
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100105 return EXIT_FAILURE;
106 }
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Loads networks.
telsoa014fcda012018-03-09 14:13:49 +0000109 armnn::Status status;
110 struct Net
111 {
112 Net(armnn::NetworkId netId,
113 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
114 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
115 : m_Network(netId)
116 , m_InputBindingInfo(in)
117 , m_OutputBindingInfo(out)
118 {}
119
120 armnn::NetworkId m_Network;
121 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
122 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
123 };
124 std::vector<Net> networks;
125
126 armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
127
128 const int networksCount = 4;
129 for (int i = 0; i < networksCount; ++i)
130 {
telsoa01c577f2c2018-08-31 09:22:23 +0100131 // Creates a network from a file on the disk.
telsoa014fcda012018-03-09 14:13:49 +0000132 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
133
telsoa01c577f2c2018-08-31 09:22:23 +0100134 // Optimizes the network.
telsoa014fcda012018-03-09 14:13:49 +0000135 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
136 try
137 {
telsoa01c577f2c2018-08-31 09:22:23 +0100138 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
telsoa014fcda012018-03-09 14:13:49 +0000139 }
140 catch (armnn::Exception& e)
141 {
142 std::stringstream message;
143 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
Derek Lamberti08446972019-11-26 16:38:31 +0000144 ARMNN_LOG(fatal) << message.str();
telsoa014fcda012018-03-09 14:13:49 +0000145 return 1;
146 }
147
telsoa01c577f2c2018-08-31 09:22:23 +0100148 // Loads the network into the runtime.
telsoa014fcda012018-03-09 14:13:49 +0000149 armnn::NetworkId networkId;
150 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
151 if (status == armnn::Status::Failure)
152 {
Derek Lamberti08446972019-11-26 16:38:31 +0000153 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
telsoa014fcda012018-03-09 14:13:49 +0000154 return 1;
155 }
156
157 networks.emplace_back(networkId,
158 parser->GetNetworkInputBindingInfo("data"),
159 parser->GetNetworkOutputBindingInfo("prob"));
160 }
161
telsoa01c577f2c2018-08-31 09:22:23 +0100162 // Loads a test case and tests inference.
telsoa014fcda012018-03-09 14:13:49 +0000163 if (!ValidateDirectory(dataDir))
164 {
165 return 1;
166 }
167 Cifar10Database cifar10(dataDir);
168
169 for (unsigned int i = 0; i < 3; ++i)
170 {
telsoa01c577f2c2018-08-31 09:22:23 +0100171 // Loads test case data (including image data).
telsoa014fcda012018-03-09 14:13:49 +0000172 std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
173
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000174 // Tests inference.
Ferran Balaguerc602f292019-02-08 17:09:55 +0000175 std::vector<TContainer> outputs;
176 outputs.reserve(networksCount);
177
178 for (unsigned int j = 0; j < networksCount; ++j)
179 {
180 outputs.push_back(std::vector<float>(10));
181 }
182
telsoa014fcda012018-03-09 14:13:49 +0000183 for (unsigned int k = 0; k < networksCount; ++k)
184 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100185 std::vector<armnn::BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo };
186 std::vector<armnn::BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000187
Ferran Balaguerc602f292019-02-08 17:09:55 +0000188 std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage };
189 std::vector<TContainer> outputDataContainers = { outputs[k] };
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000190
telsoa014fcda012018-03-09 14:13:49 +0000191 status = runtime->EnqueueWorkload(networks[k].m_Network,
Jim Flynn2fd61002019-05-03 12:54:26 +0100192 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
193 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
telsoa014fcda012018-03-09 14:13:49 +0000194 if (status == armnn::Status::Failure)
195 {
Derek Lamberti08446972019-11-26 16:38:31 +0000196 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload";
telsoa014fcda012018-03-09 14:13:49 +0000197 return 1;
198 }
199 }
200
telsoa01c577f2c2018-08-31 09:22:23 +0100201 // Compares outputs.
Ferran Balaguerc602f292019-02-08 17:09:55 +0000202 std::vector<float> output0 = boost::get<std::vector<float>>(outputs[0]);
203
telsoa014fcda012018-03-09 14:13:49 +0000204 for (unsigned int k = 1; k < networksCount; ++k)
205 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000206 std::vector<float> outputK = boost::get<std::vector<float>>(outputs[k]);
207
208 if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end()))
telsoa014fcda012018-03-09 14:13:49 +0000209 {
Derek Lamberti08446972019-11-26 16:38:31 +0000210 ARMNN_LOG(error) << "Multiple networks inference failed!";
telsoa014fcda012018-03-09 14:13:49 +0000211 return 1;
212 }
213 }
214 }
215
Derek Lamberti08446972019-11-26 16:38:31 +0000216 ARMNN_LOG(info) << "Multiple networks inference ran successfully!";
telsoa014fcda012018-03-09 14:13:49 +0000217 return 0;
218 }
219 catch (armnn::Exception const& e)
220 {
surmeh013537c2c2018-05-18 16:31:43 +0100221 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
222 // exception of type std::length_error.
223 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
224 std::cerr << "Armnn Error: " << e.what() << std::endl;
telsoa014fcda012018-03-09 14:13:49 +0000225 return 1;
226 }
surmeh013537c2c2018-05-18 16:31:43 +0100227 catch (const std::exception& e)
228 {
David Beckf0b48452018-10-19 15:20:56 +0100229 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
surmeh013537c2c2018-05-18 16:31:43 +0100230 std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
231 "multiple networks inference tests: " << e.what() << std::endl;
232 return 1;
233 }
234}