blob: f9fdf8b3ea242a8c88984def806ce59b43a34e24 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include <iostream>
6#include <chrono>
7#include <vector>
8#include <array>
9#include <boost/log/trivial.hpp>
10
11#include "armnn/ArmNN.hpp"
12#include "armnn/Utils.hpp"
13#include "armnn/INetwork.hpp"
14#include "armnnCaffeParser/ICaffeParser.hpp"
15#include "../Cifar10Database.hpp"
16#include "../InferenceTest.hpp"
17#include "../InferenceModel.hpp"
18
19using namespace std;
20using namespace std::chrono;
21using namespace armnn::test;
22
23int main(int argc, char* argv[])
24{
25#ifdef NDEBUG
26 armnn::LogSeverity level = armnn::LogSeverity::Info;
27#else
28 armnn::LogSeverity level = armnn::LogSeverity::Debug;
29#endif
30
31 try
32 {
telsoa01c577f2c2018-08-31 09:22:23 +010033 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +000034 armnn::ConfigureLogging(true, true, level);
35 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
36
37 namespace po = boost::program_options;
38
David Beckf0b48452018-10-19 15:20:56 +010039 std::vector<armnn::BackendId> computeDevice;
40 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
telsoa014fcda012018-03-09 14:13:49 +000041 std::string modelDir;
42 std::string dataDir;
43
44 po::options_description desc("Options");
45 try
46 {
telsoa01c577f2c2018-08-31 09:22:23 +010047 // Adds generic options needed for all inference tests.
telsoa014fcda012018-03-09 14:13:49 +000048 desc.add_options()
49 ("help", "Display help messages")
50 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
51 "Path to directory containing the Cifar10 model file")
David Beckf0b48452018-10-19 15:20:56 +010052 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
telsoa014fcda012018-03-09 14:13:49 +000053 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
54 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
55 "Path to directory containing the Cifar10 test data");
56 }
57 catch (const std::exception& e)
58 {
59 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
60 // and that desc.add_options() can throw boost::io::too_few_args.
61 // They really won't in any of these cases.
62 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
63 std::cerr << "Fatal internal error: " << e.what() << std::endl;
64 return 1;
65 }
66
67 po::variables_map vm;
68
69 try
70 {
71 po::store(po::parse_command_line(argc, argv, desc), vm);
72
73 if (vm.count("help"))
74 {
75 std::cout << desc << std::endl;
76 return 1;
77 }
78
79 po::notify(vm);
80 }
81 catch (po::error& e)
82 {
83 std::cerr << e.what() << std::endl << std::endl;
84 std::cerr << desc << std::endl;
85 return 1;
86 }
87
88 if (!ValidateDirectory(modelDir))
89 {
90 return 1;
91 }
92 string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
93
94 // Create runtime
telsoa01c577f2c2018-08-31 09:22:23 +010095 armnn::IRuntime::CreationOptions options;
96 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
telsoa014fcda012018-03-09 14:13:49 +000097
telsoa01c577f2c2018-08-31 09:22:23 +010098 // Loads networks.
telsoa014fcda012018-03-09 14:13:49 +000099 armnn::Status status;
100 struct Net
101 {
102 Net(armnn::NetworkId netId,
103 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
104 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
105 : m_Network(netId)
106 , m_InputBindingInfo(in)
107 , m_OutputBindingInfo(out)
108 {}
109
110 armnn::NetworkId m_Network;
111 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
112 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
113 };
114 std::vector<Net> networks;
115
116 armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
117
118 const int networksCount = 4;
119 for (int i = 0; i < networksCount; ++i)
120 {
telsoa01c577f2c2018-08-31 09:22:23 +0100121 // Creates a network from a file on the disk.
telsoa014fcda012018-03-09 14:13:49 +0000122 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
123
telsoa01c577f2c2018-08-31 09:22:23 +0100124 // Optimizes the network.
telsoa014fcda012018-03-09 14:13:49 +0000125 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
126 try
127 {
telsoa01c577f2c2018-08-31 09:22:23 +0100128 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
telsoa014fcda012018-03-09 14:13:49 +0000129 }
130 catch (armnn::Exception& e)
131 {
132 std::stringstream message;
133 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
134 BOOST_LOG_TRIVIAL(fatal) << message.str();
135 return 1;
136 }
137
telsoa01c577f2c2018-08-31 09:22:23 +0100138 // Loads the network into the runtime.
telsoa014fcda012018-03-09 14:13:49 +0000139 armnn::NetworkId networkId;
140 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
141 if (status == armnn::Status::Failure)
142 {
143 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
144 return 1;
145 }
146
147 networks.emplace_back(networkId,
148 parser->GetNetworkInputBindingInfo("data"),
149 parser->GetNetworkOutputBindingInfo("prob"));
150 }
151
telsoa01c577f2c2018-08-31 09:22:23 +0100152 // Loads a test case and tests inference.
telsoa014fcda012018-03-09 14:13:49 +0000153 if (!ValidateDirectory(dataDir))
154 {
155 return 1;
156 }
157 Cifar10Database cifar10(dataDir);
158
159 for (unsigned int i = 0; i < 3; ++i)
160 {
telsoa01c577f2c2018-08-31 09:22:23 +0100161 // Loads test case data (including image data).
telsoa014fcda012018-03-09 14:13:49 +0000162 std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
163
telsoa01c577f2c2018-08-31 09:22:23 +0100164 // Tests inference.
telsoa014fcda012018-03-09 14:13:49 +0000165 std::vector<std::array<float, 10>> outputs(networksCount);
166
167 for (unsigned int k = 0; k < networksCount; ++k)
168 {
169 status = runtime->EnqueueWorkload(networks[k].m_Network,
170 MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage),
171 MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k]));
172 if (status == armnn::Status::Failure)
173 {
174 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";
175 return 1;
176 }
177 }
178
telsoa01c577f2c2018-08-31 09:22:23 +0100179 // Compares outputs.
telsoa014fcda012018-03-09 14:13:49 +0000180 for (unsigned int k = 1; k < networksCount; ++k)
181 {
182 if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end()))
183 {
184 BOOST_LOG_TRIVIAL(error) << "Multiple networks inference failed!";
185 return 1;
186 }
187 }
188 }
189
190 BOOST_LOG_TRIVIAL(info) << "Multiple networks inference ran successfully!";
191 return 0;
192 }
193 catch (armnn::Exception const& e)
194 {
surmeh013537c2c2018-05-18 16:31:43 +0100195 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
196 // exception of type std::length_error.
197 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
198 std::cerr << "Armnn Error: " << e.what() << std::endl;
telsoa014fcda012018-03-09 14:13:49 +0000199 return 1;
200 }
surmeh013537c2c2018-05-18 16:31:43 +0100201 catch (const std::exception& e)
202 {
David Beckf0b48452018-10-19 15:20:56 +0100203 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
surmeh013537c2c2018-05-18 16:31:43 +0100204 std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
205 "multiple networks inference tests: " << e.what() << std::endl;
206 return 1;
207 }
208}