blob: c390ccdc2f930aeee03e34338807406cfe7c00b4 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "armnn/ArmNN.hpp"
8
9#include <boost/log/trivial.hpp>
10#include <boost/format.hpp>
11#include <boost/program_options.hpp>
12
13#include <map>
14#include <string>
15
16template<typename TContainer>
17inline armnn::InputTensors MakeInputTensors(const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& input,
18 const TContainer& inputTensorData)
19{
20 if (inputTensorData.size() != input.second.GetNumElements())
21 {
22 throw armnn::Exception(boost::str(boost::format("Input tensor has incorrect size. Expected %1% elements "
23 "but got %2%.") % input.second.GetNumElements() % inputTensorData.size()));
24 }
25 return { { input.first, armnn::ConstTensor(input.second, inputTensorData.data()) } };
26}
27
28template<typename TContainer>
29inline armnn::OutputTensors MakeOutputTensors(const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& output,
30 TContainer& outputTensorData)
31{
32 if (outputTensorData.size() != output.second.GetNumElements())
33 {
34 throw armnn::Exception("Output tensor has incorrect size");
35 }
36 return { { output.first, armnn::Tensor(output.second, outputTensorData.data()) } };
37}
38
39template <typename IParser, typename TDataType>
40class InferenceModel
41{
42public:
43 using DataType = TDataType;
44
45 struct CommandLineOptions
46 {
47 std::string m_ModelDir;
48 armnn::Compute m_ComputeDevice;
49 };
50
51 static void AddCommandLineOptions(boost::program_options::options_description& desc, CommandLineOptions& options)
52 {
53 namespace po = boost::program_options;
54
55 desc.add_options()
56 ("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(),
57 "Path to directory containing model files (.caffemodel/.prototxt)")
58 ("compute,c", po::value<armnn::Compute>(&options.m_ComputeDevice)->default_value(armnn::Compute::CpuAcc),
59 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc");
60 }
61
62 struct Params
63 {
64 std::string m_ModelPath;
65 std::string m_InputBinding;
66 std::string m_OutputBinding;
67 const armnn::TensorShape* m_InputTensorShape;
68 armnn::Compute m_ComputeDevice;
69 bool m_IsModelBinary;
70
71 Params()
72 : m_InputTensorShape(nullptr)
73 , m_ComputeDevice(armnn::Compute::CpuRef)
74 , m_IsModelBinary(true)
75 {
76 }
77 };
78
79
80 InferenceModel(const Params& params)
81 : m_Runtime(armnn::IRuntime::Create(params.m_ComputeDevice))
82 {
83 const std::string& modelPath = params.m_ModelPath;
84
85 // Create a network from a file on disk
86 auto parser(IParser::Create());
87
88 std::map<std::string, armnn::TensorShape> inputShapes;
89 if (params.m_InputTensorShape)
90 {
91 inputShapes[params.m_InputBinding] = *params.m_InputTensorShape;
92 }
93 std::vector<std::string> requestedOutputs{ params.m_OutputBinding };
94
95 // Handle text and binary input differently by calling the corresponding parser function
96 armnn::INetworkPtr network = (params.m_IsModelBinary ?
97 parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) :
98 parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs));
99
100 m_InputBindingInfo = parser->GetNetworkInputBindingInfo(params.m_InputBinding);
101 m_OutputBindingInfo = parser->GetNetworkOutputBindingInfo(params.m_OutputBinding);
102
103 armnn::IOptimizedNetworkPtr optNet =
104 armnn::Optimize(*network, m_Runtime->GetDeviceSpec());
105
106 // Load the network into the runtime.
107 armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet));
108 if (ret == armnn::Status::Failure)
109 {
110 throw armnn::Exception("IRuntime::LoadNetwork failed");
111 }
112 }
113
114 unsigned int GetOutputSize() const
115 {
116 return m_OutputBindingInfo.second.GetNumElements();
117 }
118
119 void Run(const std::vector<TDataType>& input, std::vector<TDataType>& output)
120 {
121 BOOST_ASSERT(output.size() == GetOutputSize());
122 armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
123 MakeInputTensors(input),
124 MakeOutputTensors(output));
125 if (ret == armnn::Status::Failure)
126 {
127 throw armnn::Exception("IRuntime::EnqueueWorkload failed");
128 }
129 }
130
131private:
132 template<typename TContainer>
133 armnn::InputTensors MakeInputTensors(const TContainer& inputTensorData)
134 {
135 return ::MakeInputTensors(m_InputBindingInfo, inputTensorData);
136 }
137
138 template<typename TContainer>
139 armnn::OutputTensors MakeOutputTensors(TContainer& outputTensorData)
140 {
141 return ::MakeOutputTensors(m_OutputBindingInfo, outputTensorData);
142 }
143
144 armnn::NetworkId m_NetworkIdentifier;
145 armnn::IRuntimePtr m_Runtime;
146
147 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
148 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
149};