blob: 105204e923c256493bf81b6e2a1ab7ad2a3ac17e [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6#include "armnnCaffeParser/ICaffeParser.hpp"
7
8#include "armnn/Types.hpp"
9#include "armnn/NetworkFwd.hpp"
10#include "armnn/Tensor.hpp"
11
12#include <memory>
13#include <vector>
14#include <unordered_map>
15
16namespace caffe
17{
18class BlobShape;
19class LayerParameter;
20class NetParameter;
21}
22
23namespace armnnCaffeParser
24{
25
telsoa01c577f2c2018-08-31 09:22:23 +010026class CaffeParserBase: public ICaffeParser
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
telsoa01c577f2c2018-08-31 09:22:23 +010029
30 // Because we haven't looked at reducing the memory usage when loading from Text/String
31 // have to retain these functions here for the moment.
telsoa014fcda012018-03-09 14:13:49 +000032 /// Create the network from a protobuf text file on disk
33 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
34 const char* graphFile,
35 const std::map<std::string, armnn::TensorShape>& inputShapes,
36 const std::vector<std::string>& requestedOutputs) override;
37
telsoa014fcda012018-03-09 14:13:49 +000038
telsoa01c577f2c2018-08-31 09:22:23 +010039 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
telsoa014fcda012018-03-09 14:13:49 +000040 virtual armnn::INetworkPtr CreateNetworkFromString(
41 const char* protoText,
42 const std::map<std::string, armnn::TensorShape>& inputShapes,
43 const std::vector<std::string>& requestedOutputs) override;
44
telsoa01c577f2c2018-08-31 09:22:23 +010045 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
telsoa014fcda012018-03-09 14:13:49 +000046 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
47
telsoa01c577f2c2018-08-31 09:22:23 +010048 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
telsoa014fcda012018-03-09 14:13:49 +000049 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
50
telsoa01c577f2c2018-08-31 09:22:23 +010051 CaffeParserBase();
telsoa014fcda012018-03-09 14:13:49 +000052
telsoa01c577f2c2018-08-31 09:22:23 +010053protected:
telsoa014fcda012018-03-09 14:13:49 +000054 /// Adds an armnn layer to m_Network given a Caffe LayerParameter of the correct type
55 /// and is responsible for recording any newly created IOutputSlots using SetArmnnOutputSlotForCaffeTop().
56 /// @{
57 void ParseInputLayer(const caffe::LayerParameter& layerParam);
58 void ParseConvLayer(const caffe::LayerParameter& layerParam);
59 void ParsePoolingLayer(const caffe::LayerParameter& layerParam);
60 void ParseReluLayer(const caffe::LayerParameter& layerParam);
61 void ParseLRNLayer(const caffe::LayerParameter& layerParam);
62 void ParseInnerProductLayer(const caffe::LayerParameter& layerParam);
63 void ParseSoftmaxLayer(const caffe::LayerParameter& layerParam);
64 void ParseEltwiseLayer(const caffe::LayerParameter& layerParam);
65 void ParseConcatLayer(const caffe::LayerParameter& layerParam);
66 void ParseBatchNormLayer(const caffe::LayerParameter& layerParam);
67 void ParseScaleLayer(const caffe::LayerParameter& layerParam);
68 void ParseSplitLayer(const caffe::LayerParameter& layerParam);
69 void ParseDropoutLayer(const caffe::LayerParameter& layerParam);
70 /// @}
71
telsoa01c577f2c2018-08-31 09:22:23 +010072 /// ParseConv may use these helpers depending on the group parameter
73 /// @{
74 void AddConvLayerWithSplits(const caffe::LayerParameter& layerParam,
75 const armnn::Convolution2dDescriptor & desc,
76 unsigned int kernelW,
77 unsigned int kernelH);
78 void AddConvLayerWithDepthwiseConv(const caffe::LayerParameter& layerParam,
79 const armnn::Convolution2dDescriptor & desc,
80 unsigned int kernelW,
81 unsigned int kernelH);
82 /// @}
telsoa014fcda012018-03-09 14:13:49 +000083
telsoa01c577f2c2018-08-31 09:22:23 +010084 /// Converts Caffe's protobuf tensor shape format to ArmNN's
85 armnn::TensorInfo BlobShapeToTensorInfo(const caffe::BlobShape& blobShape) const;
86
87 void TrackInputBinding(armnn::IConnectableLayer* layer,
88 armnn::LayerBindingId id,
89 const armnn::TensorInfo& tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +000090
91 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
telsoa01c577f2c2018-08-31 09:22:23 +010092 const armnn::TensorInfo& tensorInfo,
93 const char* bindingPointDesc,
94 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
95
96 void TrackOutputBinding(armnn::IConnectableLayer* layer,
97 armnn::LayerBindingId id,
98 const armnn::TensorInfo& tensorInfo);
99
100
101 void SetArmnnOutputSlotForCaffeTop(const std::string& caffeTopName, armnn::IOutputSlot& armnnOutputSlot);
telsoa014fcda012018-03-09 14:13:49 +0000102
103 /// Retrieves the Armnn IOutputSlot representing the given Caffe top.
104 /// Throws if it cannot be found (e.g. not parsed yet).
105 armnn::IOutputSlot& GetArmnnOutputSlotForCaffeTop(const std::string& caffeTopName) const;
106
telsoa01c577f2c2018-08-31 09:22:23 +0100107 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(
108 const std::string& layerName,
109 const char* bindingPointDesc,
110 const std::unordered_map<std::string, BindingPointInfo>& bindingInfos);
111
telsoa014fcda012018-03-09 14:13:49 +0000112
113 void Cleanup();
114
telsoa01c577f2c2018-08-31 09:22:23 +0100115 using OperationParsingFunction = void(CaffeParserBase::*)(const caffe::LayerParameter& layerParam);
telsoa014fcda012018-03-09 14:13:49 +0000116
telsoa01c577f2c2018-08-31 09:22:23 +0100117 /// Maps Caffe layer names to parsing member functions.
telsoa014fcda012018-03-09 14:13:49 +0000118 static const std::map<std::string, OperationParsingFunction> ms_CaffeLayerNameToParsingFunctions;
119
telsoa014fcda012018-03-09 14:13:49 +0000120 /// maps input layer names to their corresponding ids and tensor infos
121 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
122
123 /// maps output layer names to their corresponding ids and tensor infos
124 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100125
126 armnn::INetworkPtr m_Network;
127
128 std::map<std::string, armnn::TensorShape> m_InputShapes;
129
130 /// As we add armnn layers we store the armnn IOutputSlot which corresponds to the Caffe tops.
131 std::unordered_map<std::string, armnn::IOutputSlot*> m_ArmnnOutputSlotForCaffeTop;
132
133 std::vector<std::string> m_RequestedOutputs;
134
135
136 // Stuff which has gone to base class simply because we haven't done any
137 // memory optimisation on the text/string format. If we move this to a layer
138 // by layer parse as well these can move to the CaffeParser class.
139 std::map<std::string, const caffe::LayerParameter*> m_CaffeLayersByTopName;
140
141 /// Parses a NetParameter loaded into memory from one of the other CreateNetwork*
142 armnn::INetworkPtr CreateNetworkFromNetParameter(
143 caffe::NetParameter& netParam,
144 const std::map<std::string, armnn::TensorShape>& inputShapes,
145 const std::vector<std::string>& requestedOutputs);
146
147 /// does the actual conversion from caffe::NetParameter to armnn::INetwork
148 void LoadNetParam(caffe::NetParameter& netParameter);
149
150 /// Find the Caffe layers listed as inputs (bottoms) for a given layer.
151 std::vector<const caffe::LayerParameter*> GetInputs(const caffe::LayerParameter& layerParam);
152
153 /// Modifies the Caffe network to replace "in-place" layers (whose top() and bottom() are both the same)
154 /// with regular layers. This simplifies further parsing.
155 void ResolveInPlaceLayers(caffe::NetParameter& netParameter);
156
157};
158
159class CaffeParser : public CaffeParserBase
160{
161public:
162
163 /// Create the network from a protobuf binary file on disk
164 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
165 const char* graphFile,
166 const std::map<std::string, armnn::TensorShape>& inputShapes,
167 const std::vector<std::string>& requestedOutputs) override;
168
169public:
170 CaffeParser();
171
telsoa014fcda012018-03-09 14:13:49 +0000172};
173}