blob: f5f00378ca8aeaea913007115ebc5239f80494cb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
telsoa014fcda012018-03-09 14:13:49 +00006#include "armnn/ArmNN.hpp"
surmeh013537c2c2018-05-18 16:31:43 +01007#include "HeapProfiling.hpp"
telsoa014fcda012018-03-09 14:13:49 +00008
surmeh013537c2c2018-05-18 16:31:43 +01009#include <boost/exception/exception.hpp>
10#include <boost/exception/diagnostic_information.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <boost/log/trivial.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010014#include <boost/filesystem.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
16#include <map>
17#include <string>
surmeh013537c2c2018-05-18 16:31:43 +010018#include <fstream>
telsoa014fcda012018-03-09 14:13:49 +000019
20template<typename TContainer>
21inline armnn::InputTensors MakeInputTensors(const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& input,
22 const TContainer& inputTensorData)
23{
24 if (inputTensorData.size() != input.second.GetNumElements())
25 {
surmeh013537c2c2018-05-18 16:31:43 +010026 try
27 {
28 throw armnn::Exception(boost::str(boost::format("Input tensor has incorrect size. Expected %1% elements "
29 "but got %2%.") % input.second.GetNumElements() % inputTensorData.size()));
30 } catch (const boost::exception& e)
31 {
32 // Coverity fix: it should not be possible to get here but boost::str and boost::format can both
33 // throw uncaught exceptions - convert them to armnn exceptions and rethrow
34 throw armnn::Exception(diagnostic_information(e));
35 }
telsoa014fcda012018-03-09 14:13:49 +000036 }
37 return { { input.first, armnn::ConstTensor(input.second, inputTensorData.data()) } };
38}
39
40template<typename TContainer>
41inline armnn::OutputTensors MakeOutputTensors(const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& output,
42 TContainer& outputTensorData)
43{
44 if (outputTensorData.size() != output.second.GetNumElements())
45 {
46 throw armnn::Exception("Output tensor has incorrect size");
47 }
48 return { { output.first, armnn::Tensor(output.second, outputTensorData.data()) } };
49}
50
51template <typename IParser, typename TDataType>
52class InferenceModel
53{
54public:
55 using DataType = TDataType;
56
57 struct CommandLineOptions
58 {
59 std::string m_ModelDir;
60 armnn::Compute m_ComputeDevice;
surmeh013537c2c2018-05-18 16:31:43 +010061 bool m_VisualizePostOptimizationModel;
telsoa014fcda012018-03-09 14:13:49 +000062 };
63
64 static void AddCommandLineOptions(boost::program_options::options_description& desc, CommandLineOptions& options)
65 {
66 namespace po = boost::program_options;
67
68 desc.add_options()
69 ("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(),
70 "Path to directory containing model files (.caffemodel/.prototxt)")
71 ("compute,c", po::value<armnn::Compute>(&options.m_ComputeDevice)->default_value(armnn::Compute::CpuAcc),
surmeh013537c2c2018-05-18 16:31:43 +010072 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
73 ("visualize-optimized-model,v",
74 po::value<bool>(&options.m_VisualizePostOptimizationModel)->default_value(false),
75 "Produce a dot file useful for visualizing the graph post optimization."
76 "The file will have the same name as the model with the .dot extention.");
telsoa014fcda012018-03-09 14:13:49 +000077 }
78
79 struct Params
80 {
81 std::string m_ModelPath;
82 std::string m_InputBinding;
83 std::string m_OutputBinding;
84 const armnn::TensorShape* m_InputTensorShape;
85 armnn::Compute m_ComputeDevice;
86 bool m_IsModelBinary;
surmeh013537c2c2018-05-18 16:31:43 +010087 bool m_VisualizePostOptimizationModel;
telsoa014fcda012018-03-09 14:13:49 +000088
89 Params()
90 : m_InputTensorShape(nullptr)
91 , m_ComputeDevice(armnn::Compute::CpuRef)
92 , m_IsModelBinary(true)
surmeh013537c2c2018-05-18 16:31:43 +010093 , m_VisualizePostOptimizationModel(false)
telsoa014fcda012018-03-09 14:13:49 +000094 {
95 }
96 };
97
98
99 InferenceModel(const Params& params)
100 : m_Runtime(armnn::IRuntime::Create(params.m_ComputeDevice))
101 {
102 const std::string& modelPath = params.m_ModelPath;
103
104 // Create a network from a file on disk
105 auto parser(IParser::Create());
106
107 std::map<std::string, armnn::TensorShape> inputShapes;
108 if (params.m_InputTensorShape)
109 {
110 inputShapes[params.m_InputBinding] = *params.m_InputTensorShape;
111 }
112 std::vector<std::string> requestedOutputs{ params.m_OutputBinding };
113
surmeh013537c2c2018-05-18 16:31:43 +0100114 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
115 {
116 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
117 // Handle text and binary input differently by calling the corresponding parser function
118 network = (params.m_IsModelBinary ?
119 parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) :
120 parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs));
121 }
telsoa014fcda012018-03-09 14:13:49 +0000122
123 m_InputBindingInfo = parser->GetNetworkInputBindingInfo(params.m_InputBinding);
124 m_OutputBindingInfo = parser->GetNetworkOutputBindingInfo(params.m_OutputBinding);
125
surmeh013537c2c2018-05-18 16:31:43 +0100126 armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork *){}};
127 {
128 ARMNN_SCOPED_HEAP_PROFILING("Optimizing");
129 optNet = armnn::Optimize(*network, m_Runtime->GetDeviceSpec());
130 }
telsoa014fcda012018-03-09 14:13:49 +0000131
surmeh013537c2c2018-05-18 16:31:43 +0100132 if (params.m_VisualizePostOptimizationModel)
133 {
134 boost::filesystem::path filename = params.m_ModelPath;
135 filename.replace_extension("dot");
136 std::fstream file(filename.c_str(),file.out);
137 optNet->SerializeToDot(file);
138 }
139
140 armnn::Status ret;
141 {
142 ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork");
143 ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet));
144 }
145
telsoa014fcda012018-03-09 14:13:49 +0000146 if (ret == armnn::Status::Failure)
147 {
148 throw armnn::Exception("IRuntime::LoadNetwork failed");
149 }
150 }
151
152 unsigned int GetOutputSize() const
153 {
154 return m_OutputBindingInfo.second.GetNumElements();
155 }
156
157 void Run(const std::vector<TDataType>& input, std::vector<TDataType>& output)
158 {
159 BOOST_ASSERT(output.size() == GetOutputSize());
160 armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
161 MakeInputTensors(input),
162 MakeOutputTensors(output));
163 if (ret == armnn::Status::Failure)
164 {
165 throw armnn::Exception("IRuntime::EnqueueWorkload failed");
166 }
167 }
168
169private:
170 template<typename TContainer>
171 armnn::InputTensors MakeInputTensors(const TContainer& inputTensorData)
172 {
173 return ::MakeInputTensors(m_InputBindingInfo, inputTensorData);
174 }
175
176 template<typename TContainer>
177 armnn::OutputTensors MakeOutputTensors(TContainer& outputTensorData)
178 {
179 return ::MakeOutputTensors(m_OutputBindingInfo, outputTensorData);
180 }
181
182 armnn::NetworkId m_NetworkIdentifier;
183 armnn::IRuntimePtr m_Runtime;
184
185 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
186 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
187};