blob: 823cbbc50abdcdcf59b644a124ec71c9c3f50466 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Jim Flynn357add22023-04-10 23:26:40 +01002// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
telsoa014fcda012018-03-09 14:13:49 +00005
6#include "armnn/ArmNN.hpp"
7#include "armnn/Utils.hpp"
8#include "armnn/INetwork.hpp"
Nikhil Raj5d955cf2021-04-19 16:59:48 +01009#include "armnnTfLiteParser/TfLiteParser.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include "../Cifar10Database.hpp"
11#include "../InferenceTest.hpp"
12#include "../InferenceModel.hpp"
13
James Ward7a1966c2020-10-05 17:11:23 +010014#include <cxxopts/cxxopts.hpp>
15
16#include <iostream>
17#include <chrono>
18#include <vector>
19#include <array>
20
21
telsoa014fcda012018-03-09 14:13:49 +000022using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26int main(int argc, char* argv[])
27{
28#ifdef NDEBUG
29 armnn::LogSeverity level = armnn::LogSeverity::Info;
30#else
31 armnn::LogSeverity level = armnn::LogSeverity::Debug;
32#endif
33
34 try
35 {
telsoa01c577f2c2018-08-31 09:22:23 +010036 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +000037 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +000038
David Beckf0b48452018-10-19 15:20:56 +010039 std::vector<armnn::BackendId> computeDevice;
telsoa014fcda012018-03-09 14:13:49 +000040 std::string modelDir;
41 std::string dataDir;
42
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010043 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
44 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
45
James Ward7a1966c2020-10-05 17:11:23 +010046 cxxopts::Options in_options("MultipleNetworksCifar10",
47 "Run multiple networks inference tests using Cifar-10 data.");
48
telsoa014fcda012018-03-09 14:13:49 +000049 try
50 {
telsoa01c577f2c2018-08-31 09:22:23 +010051 // Adds generic options needed for all inference tests.
James Ward7a1966c2020-10-05 17:11:23 +010052 in_options.add_options()
53 ("h,help", "Display help messages")
54 ("m,model-dir", "Path to directory containing the Cifar10 model file",
55 cxxopts::value<std::string>(modelDir))
56 ("c,compute", backendsMessage.c_str(),
57 cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
58 ("d,data-dir", "Path to directory containing the Cifar10 test data",
59 cxxopts::value<std::string>(dataDir));
telsoa014fcda012018-03-09 14:13:49 +000060
James Ward7a1966c2020-10-05 17:11:23 +010061 auto result = in_options.parse(argc, argv);
telsoa014fcda012018-03-09 14:13:49 +000062
James Ward7a1966c2020-10-05 17:11:23 +010063 if(result.count("help") > 0)
telsoa014fcda012018-03-09 14:13:49 +000064 {
James Ward7a1966c2020-10-05 17:11:23 +010065 std::cout << in_options.help() << std::endl;
66 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +000067 }
68
James Ward7a1966c2020-10-05 17:11:23 +010069 //ensure mandatory parameters given
70 std::string mandatorySingleParameters[] = {"model-dir", "data-dir"};
71 for (auto param : mandatorySingleParameters)
72 {
73 if(result.count(param) > 0)
74 {
75 std::string dir = result[param].as<std::string>();
76
77 if(!ValidateDirectory(dir)) {
78 return EXIT_FAILURE;
79 }
80 } else {
81 std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
82 return EXIT_FAILURE;
83 }
84 }
telsoa014fcda012018-03-09 14:13:49 +000085 }
Jim Flynn357add22023-04-10 23:26:40 +010086 catch (const cxxopts::exceptions::exception& e)
telsoa014fcda012018-03-09 14:13:49 +000087 {
James Ward7a1966c2020-10-05 17:11:23 +010088 std::cerr << e.what() << std::endl << in_options.help() << std::endl;
89 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +000090 }
91
Nikhil Raj6dd178f2021-04-02 22:04:39 +010092 fs::path modelPath = fs::path(modelDir + "/cifar10_tf.prototxt");
telsoa014fcda012018-03-09 14:13:49 +000093
Narumol Prangnawaratd9d61f52020-01-02 17:46:53 +000094 // Create runtime
95 // This will also load dynamic backend in case that the dynamic backend path is specified
96 armnn::IRuntime::CreationOptions options;
97 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
98
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010099 // Check if the requested backend are all valid
100 std::string invalidBackends;
101 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
102 {
Derek Lamberti08446972019-11-26 16:38:31 +0000103 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
104 << invalidBackends;
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100105 return EXIT_FAILURE;
106 }
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Loads networks.
telsoa014fcda012018-03-09 14:13:49 +0000109 armnn::Status status;
110 struct Net
111 {
112 Net(armnn::NetworkId netId,
113 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
114 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
115 : m_Network(netId)
116 , m_InputBindingInfo(in)
117 , m_OutputBindingInfo(out)
118 {}
119
120 armnn::NetworkId m_Network;
121 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
122 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
123 };
124 std::vector<Net> networks;
125
Nikhil Raj5d955cf2021-04-19 16:59:48 +0100126 armnnTfLiteParser::ITfLiteParserPtr parser(armnnTfLiteParser::ITfLiteParserPtr::Create());
telsoa014fcda012018-03-09 14:13:49 +0000127
128 const int networksCount = 4;
129 for (int i = 0; i < networksCount; ++i)
130 {
telsoa01c577f2c2018-08-31 09:22:23 +0100131 // Creates a network from a file on the disk.
telsoa014fcda012018-03-09 14:13:49 +0000132 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
133
telsoa01c577f2c2018-08-31 09:22:23 +0100134 // Optimizes the network.
telsoa014fcda012018-03-09 14:13:49 +0000135 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
136 try
137 {
telsoa01c577f2c2018-08-31 09:22:23 +0100138 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
telsoa014fcda012018-03-09 14:13:49 +0000139 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000140 catch (const armnn::Exception& e)
telsoa014fcda012018-03-09 14:13:49 +0000141 {
142 std::stringstream message;
143 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
Derek Lamberti08446972019-11-26 16:38:31 +0000144 ARMNN_LOG(fatal) << message.str();
James Ward7a1966c2020-10-05 17:11:23 +0100145 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000146 }
147
telsoa01c577f2c2018-08-31 09:22:23 +0100148 // Loads the network into the runtime.
telsoa014fcda012018-03-09 14:13:49 +0000149 armnn::NetworkId networkId;
150 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
151 if (status == armnn::Status::Failure)
152 {
Derek Lamberti08446972019-11-26 16:38:31 +0000153 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
James Ward7a1966c2020-10-05 17:11:23 +0100154 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000155 }
156
157 networks.emplace_back(networkId,
158 parser->GetNetworkInputBindingInfo("data"),
159 parser->GetNetworkOutputBindingInfo("prob"));
160 }
161
telsoa01c577f2c2018-08-31 09:22:23 +0100162 // Loads a test case and tests inference.
telsoa014fcda012018-03-09 14:13:49 +0000163 if (!ValidateDirectory(dataDir))
164 {
James Ward7a1966c2020-10-05 17:11:23 +0100165 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000166 }
167 Cifar10Database cifar10(dataDir);
168
169 for (unsigned int i = 0; i < 3; ++i)
170 {
telsoa01c577f2c2018-08-31 09:22:23 +0100171 // Loads test case data (including image data).
telsoa014fcda012018-03-09 14:13:49 +0000172 std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
173
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000174 // Tests inference.
Ferran Balaguerc602f292019-02-08 17:09:55 +0000175 std::vector<TContainer> outputs;
176 outputs.reserve(networksCount);
177
178 for (unsigned int j = 0; j < networksCount; ++j)
179 {
180 outputs.push_back(std::vector<float>(10));
181 }
182
telsoa014fcda012018-03-09 14:13:49 +0000183 for (unsigned int k = 0; k < networksCount; ++k)
184 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100185 std::vector<armnn::BindingPointInfo> inputBindings = { networks[k].m_InputBindingInfo };
186 std::vector<armnn::BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000187
Ferran Balaguerc602f292019-02-08 17:09:55 +0000188 std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage };
189 std::vector<TContainer> outputDataContainers = { outputs[k] };
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000190
telsoa014fcda012018-03-09 14:13:49 +0000191 status = runtime->EnqueueWorkload(networks[k].m_Network,
Jim Flynn2fd61002019-05-03 12:54:26 +0100192 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
193 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
telsoa014fcda012018-03-09 14:13:49 +0000194 if (status == armnn::Status::Failure)
195 {
Derek Lamberti08446972019-11-26 16:38:31 +0000196 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload";
James Ward7a1966c2020-10-05 17:11:23 +0100197 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000198 }
199 }
200
telsoa01c577f2c2018-08-31 09:22:23 +0100201 // Compares outputs.
James Ward6d9f5c52020-09-28 11:56:35 +0100202 std::vector<float> output0 = mapbox::util::get<std::vector<float>>(outputs[0]);
Ferran Balaguerc602f292019-02-08 17:09:55 +0000203
telsoa014fcda012018-03-09 14:13:49 +0000204 for (unsigned int k = 1; k < networksCount; ++k)
205 {
James Ward6d9f5c52020-09-28 11:56:35 +0100206 std::vector<float> outputK = mapbox::util::get<std::vector<float>>(outputs[k]);
Ferran Balaguerc602f292019-02-08 17:09:55 +0000207
208 if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end()))
telsoa014fcda012018-03-09 14:13:49 +0000209 {
Derek Lamberti08446972019-11-26 16:38:31 +0000210 ARMNN_LOG(error) << "Multiple networks inference failed!";
James Ward7a1966c2020-10-05 17:11:23 +0100211 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000212 }
213 }
214 }
215
Derek Lamberti08446972019-11-26 16:38:31 +0000216 ARMNN_LOG(info) << "Multiple networks inference ran successfully!";
James Ward7a1966c2020-10-05 17:11:23 +0100217 return EXIT_SUCCESS;
telsoa014fcda012018-03-09 14:13:49 +0000218 }
Pavel Macenauer855a47b2020-05-26 10:54:22 +0000219 catch (const armnn::Exception& e)
telsoa014fcda012018-03-09 14:13:49 +0000220 {
surmeh013537c2c2018-05-18 16:31:43 +0100221 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
222 // exception of type std::length_error.
223 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
224 std::cerr << "Armnn Error: " << e.what() << std::endl;
James Ward7a1966c2020-10-05 17:11:23 +0100225 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000226 }
surmeh013537c2c2018-05-18 16:31:43 +0100227 catch (const std::exception& e)
228 {
David Beckf0b48452018-10-19 15:20:56 +0100229 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
surmeh013537c2c2018-05-18 16:31:43 +0100230 std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
231 "multiple networks inference tests: " << e.what() << std::endl;
James Ward7a1966c2020-10-05 17:11:23 +0100232 return EXIT_FAILURE;
surmeh013537c2c2018-05-18 16:31:43 +0100233 }
234}