blob: 51867b6ace02a7569e6f7854b2c641f523fa0183 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6#include "armnnCaffeParser/ICaffeParser.hpp"
7
8#include "armnn/Types.hpp"
9#include "armnn/NetworkFwd.hpp"
10#include "armnn/Tensor.hpp"
11
12#include <memory>
13#include <vector>
14#include <unordered_map>
15
16namespace caffe
17{
18class BlobShape;
19class LayerParameter;
20class NetParameter;
21}
22
23namespace armnnCaffeParser
24{
25
26using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
27
telsoa01c577f2c2018-08-31 09:22:23 +010028class CaffeParserBase: public ICaffeParser
telsoa014fcda012018-03-09 14:13:49 +000029{
30public:
telsoa01c577f2c2018-08-31 09:22:23 +010031
32 // Because we haven't looked at reducing the memory usage when loading from Text/String
33 // have to retain these functions here for the moment.
telsoa014fcda012018-03-09 14:13:49 +000034 /// Create the network from a protobuf text file on disk
35 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
36 const char* graphFile,
37 const std::map<std::string, armnn::TensorShape>& inputShapes,
38 const std::vector<std::string>& requestedOutputs) override;
39
telsoa014fcda012018-03-09 14:13:49 +000040
telsoa01c577f2c2018-08-31 09:22:23 +010041 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
telsoa014fcda012018-03-09 14:13:49 +000042 virtual armnn::INetworkPtr CreateNetworkFromString(
43 const char* protoText,
44 const std::map<std::string, armnn::TensorShape>& inputShapes,
45 const std::vector<std::string>& requestedOutputs) override;
46
telsoa01c577f2c2018-08-31 09:22:23 +010047 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
telsoa014fcda012018-03-09 14:13:49 +000048 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
49
telsoa01c577f2c2018-08-31 09:22:23 +010050 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
telsoa014fcda012018-03-09 14:13:49 +000051 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
52
telsoa01c577f2c2018-08-31 09:22:23 +010053 CaffeParserBase();
telsoa014fcda012018-03-09 14:13:49 +000054
telsoa01c577f2c2018-08-31 09:22:23 +010055protected:
telsoa014fcda012018-03-09 14:13:49 +000056 /// Adds an armnn layer to m_Network given a Caffe LayerParameter of the correct type
57 /// and is responsible for recording any newly created IOutputSlots using SetArmnnOutputSlotForCaffeTop().
58 /// @{
59 void ParseInputLayer(const caffe::LayerParameter& layerParam);
60 void ParseConvLayer(const caffe::LayerParameter& layerParam);
61 void ParsePoolingLayer(const caffe::LayerParameter& layerParam);
62 void ParseReluLayer(const caffe::LayerParameter& layerParam);
63 void ParseLRNLayer(const caffe::LayerParameter& layerParam);
64 void ParseInnerProductLayer(const caffe::LayerParameter& layerParam);
65 void ParseSoftmaxLayer(const caffe::LayerParameter& layerParam);
66 void ParseEltwiseLayer(const caffe::LayerParameter& layerParam);
67 void ParseConcatLayer(const caffe::LayerParameter& layerParam);
68 void ParseBatchNormLayer(const caffe::LayerParameter& layerParam);
69 void ParseScaleLayer(const caffe::LayerParameter& layerParam);
70 void ParseSplitLayer(const caffe::LayerParameter& layerParam);
71 void ParseDropoutLayer(const caffe::LayerParameter& layerParam);
72 /// @}
73
telsoa01c577f2c2018-08-31 09:22:23 +010074 /// ParseConv may use these helpers depending on the group parameter
75 /// @{
76 void AddConvLayerWithSplits(const caffe::LayerParameter& layerParam,
77 const armnn::Convolution2dDescriptor & desc,
78 unsigned int kernelW,
79 unsigned int kernelH);
80 void AddConvLayerWithDepthwiseConv(const caffe::LayerParameter& layerParam,
81 const armnn::Convolution2dDescriptor & desc,
82 unsigned int kernelW,
83 unsigned int kernelH);
84 /// @}
telsoa014fcda012018-03-09 14:13:49 +000085
telsoa01c577f2c2018-08-31 09:22:23 +010086 /// Converts Caffe's protobuf tensor shape format to ArmNN's
87 armnn::TensorInfo BlobShapeToTensorInfo(const caffe::BlobShape& blobShape) const;
88
89 void TrackInputBinding(armnn::IConnectableLayer* layer,
90 armnn::LayerBindingId id,
91 const armnn::TensorInfo& tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +000092
93 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
telsoa01c577f2c2018-08-31 09:22:23 +010094 const armnn::TensorInfo& tensorInfo,
95 const char* bindingPointDesc,
96 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
97
98 void TrackOutputBinding(armnn::IConnectableLayer* layer,
99 armnn::LayerBindingId id,
100 const armnn::TensorInfo& tensorInfo);
101
102
103 void SetArmnnOutputSlotForCaffeTop(const std::string& caffeTopName, armnn::IOutputSlot& armnnOutputSlot);
telsoa014fcda012018-03-09 14:13:49 +0000104
105 /// Retrieves the Armnn IOutputSlot representing the given Caffe top.
106 /// Throws if it cannot be found (e.g. not parsed yet).
107 armnn::IOutputSlot& GetArmnnOutputSlotForCaffeTop(const std::string& caffeTopName) const;
108
telsoa01c577f2c2018-08-31 09:22:23 +0100109 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(
110 const std::string& layerName,
111 const char* bindingPointDesc,
112 const std::unordered_map<std::string, BindingPointInfo>& bindingInfos);
113
telsoa014fcda012018-03-09 14:13:49 +0000114
115 void Cleanup();
116
telsoa01c577f2c2018-08-31 09:22:23 +0100117 using OperationParsingFunction = void(CaffeParserBase::*)(const caffe::LayerParameter& layerParam);
telsoa014fcda012018-03-09 14:13:49 +0000118
telsoa01c577f2c2018-08-31 09:22:23 +0100119 /// Maps Caffe layer names to parsing member functions.
telsoa014fcda012018-03-09 14:13:49 +0000120 static const std::map<std::string, OperationParsingFunction> ms_CaffeLayerNameToParsingFunctions;
121
telsoa014fcda012018-03-09 14:13:49 +0000122 /// maps input layer names to their corresponding ids and tensor infos
123 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
124
125 /// maps output layer names to their corresponding ids and tensor infos
126 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100127
128 armnn::INetworkPtr m_Network;
129
130 std::map<std::string, armnn::TensorShape> m_InputShapes;
131
132 /// As we add armnn layers we store the armnn IOutputSlot which corresponds to the Caffe tops.
133 std::unordered_map<std::string, armnn::IOutputSlot*> m_ArmnnOutputSlotForCaffeTop;
134
135 std::vector<std::string> m_RequestedOutputs;
136
137
138 // Stuff which has gone to base class simply because we haven't done any
139 // memory optimisation on the text/string format. If we move this to a layer
140 // by layer parse as well these can move to the CaffeParser class.
141 std::map<std::string, const caffe::LayerParameter*> m_CaffeLayersByTopName;
142
143 /// Parses a NetParameter loaded into memory from one of the other CreateNetwork*
144 armnn::INetworkPtr CreateNetworkFromNetParameter(
145 caffe::NetParameter& netParam,
146 const std::map<std::string, armnn::TensorShape>& inputShapes,
147 const std::vector<std::string>& requestedOutputs);
148
149 /// does the actual conversion from caffe::NetParameter to armnn::INetwork
150 void LoadNetParam(caffe::NetParameter& netParameter);
151
152 /// Find the Caffe layers listed as inputs (bottoms) for a given layer.
153 std::vector<const caffe::LayerParameter*> GetInputs(const caffe::LayerParameter& layerParam);
154
155 /// Modifies the Caffe network to replace "in-place" layers (whose top() and bottom() are both the same)
156 /// with regular layers. This simplifies further parsing.
157 void ResolveInPlaceLayers(caffe::NetParameter& netParameter);
158
159};
160
161class CaffeParser : public CaffeParserBase
162{
163public:
164
165 /// Create the network from a protobuf binary file on disk
166 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
167 const char* graphFile,
168 const std::map<std::string, armnn::TensorShape>& inputShapes,
169 const std::vector<std::string>& requestedOutputs) override;
170
171public:
172 CaffeParser();
173
telsoa014fcda012018-03-09 14:13:49 +0000174};
175}