blob: 3c75ed7f24d346f1331090cf4e8b7679aa19d8bb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include <iostream>
6#include <chrono>
7#include <vector>
8#include <array>
9#include <boost/log/trivial.hpp>
10
11#include "armnn/ArmNN.hpp"
12#include "armnn/Utils.hpp"
13#include "armnn/INetwork.hpp"
14#include "armnnCaffeParser/ICaffeParser.hpp"
15#include "../Cifar10Database.hpp"
16#include "../InferenceTest.hpp"
17#include "../InferenceModel.hpp"
18
19using namespace std;
20using namespace std::chrono;
21using namespace armnn::test;
22
23int main(int argc, char* argv[])
24{
25#ifdef NDEBUG
26 armnn::LogSeverity level = armnn::LogSeverity::Info;
27#else
28 armnn::LogSeverity level = armnn::LogSeverity::Debug;
29#endif
30
31 try
32 {
33 // Configure logging for both the ARMNN library and this test program
34 armnn::ConfigureLogging(true, true, level);
35 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
36
37 namespace po = boost::program_options;
38
39 armnn::Compute computeDevice;
40 std::string modelDir;
41 std::string dataDir;
42
43 po::options_description desc("Options");
44 try
45 {
46 // Add generic options needed for all inference tests
47 desc.add_options()
48 ("help", "Display help messages")
49 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
50 "Path to directory containing the Cifar10 model file")
51 ("compute,c", po::value<armnn::Compute>(&computeDevice)->default_value(armnn::Compute::CpuAcc),
52 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
53 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
54 "Path to directory containing the Cifar10 test data");
55 }
56 catch (const std::exception& e)
57 {
58 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
59 // and that desc.add_options() can throw boost::io::too_few_args.
60 // They really won't in any of these cases.
61 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
62 std::cerr << "Fatal internal error: " << e.what() << std::endl;
63 return 1;
64 }
65
66 po::variables_map vm;
67
68 try
69 {
70 po::store(po::parse_command_line(argc, argv, desc), vm);
71
72 if (vm.count("help"))
73 {
74 std::cout << desc << std::endl;
75 return 1;
76 }
77
78 po::notify(vm);
79 }
80 catch (po::error& e)
81 {
82 std::cerr << e.what() << std::endl << std::endl;
83 std::cerr << desc << std::endl;
84 return 1;
85 }
86
87 if (!ValidateDirectory(modelDir))
88 {
89 return 1;
90 }
91 string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
92
93 // Create runtime
94 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(computeDevice));
95
96 // Load networks
97 armnn::Status status;
98 struct Net
99 {
100 Net(armnn::NetworkId netId,
101 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
102 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
103 : m_Network(netId)
104 , m_InputBindingInfo(in)
105 , m_OutputBindingInfo(out)
106 {}
107
108 armnn::NetworkId m_Network;
109 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
110 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
111 };
112 std::vector<Net> networks;
113
114 armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
115
116 const int networksCount = 4;
117 for (int i = 0; i < networksCount; ++i)
118 {
119 // Create a network from a file on disk
120 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
121
122 // optimize the network
123 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
124 try
125 {
126 optimizedNet = armnn::Optimize(*network, runtime->GetDeviceSpec());
127 }
128 catch (armnn::Exception& e)
129 {
130 std::stringstream message;
131 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
132 BOOST_LOG_TRIVIAL(fatal) << message.str();
133 return 1;
134 }
135
136 // Load the network into the runtime
137 armnn::NetworkId networkId;
138 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
139 if (status == armnn::Status::Failure)
140 {
141 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
142 return 1;
143 }
144
145 networks.emplace_back(networkId,
146 parser->GetNetworkInputBindingInfo("data"),
147 parser->GetNetworkOutputBindingInfo("prob"));
148 }
149
150 // Load a test case and test inference
151 if (!ValidateDirectory(dataDir))
152 {
153 return 1;
154 }
155 Cifar10Database cifar10(dataDir);
156
157 for (unsigned int i = 0; i < 3; ++i)
158 {
159 // Load test case data (including image data)
160 std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
161
162 // Test inference
163 std::vector<std::array<float, 10>> outputs(networksCount);
164
165 for (unsigned int k = 0; k < networksCount; ++k)
166 {
167 status = runtime->EnqueueWorkload(networks[k].m_Network,
168 MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage),
169 MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k]));
170 if (status == armnn::Status::Failure)
171 {
172 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";
173 return 1;
174 }
175 }
176
177 // Compare outputs
178 for (unsigned int k = 1; k < networksCount; ++k)
179 {
180 if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end()))
181 {
182 BOOST_LOG_TRIVIAL(error) << "Multiple networks inference failed!";
183 return 1;
184 }
185 }
186 }
187
188 BOOST_LOG_TRIVIAL(info) << "Multiple networks inference ran successfully!";
189 return 0;
190 }
191 catch (armnn::Exception const& e)
192 {
193 BOOST_LOG_TRIVIAL(fatal) <<"Armnn Error: "<< e.what();
194 return 1;
195 }
196}