blob: 0b31e187dd823c52316058c04325e4fa73dececf [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6#include "armnnCaffeParser/ICaffeParser.hpp"
7
8#include "armnn/Types.hpp"
9#include "armnn/NetworkFwd.hpp"
10#include "armnn/Tensor.hpp"
11
12#include <memory>
13#include <vector>
14#include <unordered_map>
15
16namespace caffe
17{
18class BlobShape;
19class LayerParameter;
20class NetParameter;
21}
22
23namespace armnnCaffeParser
24{
25
26using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
27
28class CaffeParser : public ICaffeParser
29{
30public:
31 /// Create the network from a protobuf text file on disk
32 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
33 const char* graphFile,
34 const std::map<std::string, armnn::TensorShape>& inputShapes,
35 const std::vector<std::string>& requestedOutputs) override;
36
37 /// Create the network from a protobuf binary file on disk
38 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
39 const char* graphFile,
40 const std::map<std::string, armnn::TensorShape>& inputShapes,
41 const std::vector<std::string>& requestedOutputs) override;
42
43 /// Create the network directly from protobuf text in a string. Useful for debugging/testing
44 virtual armnn::INetworkPtr CreateNetworkFromString(
45 const char* protoText,
46 const std::map<std::string, armnn::TensorShape>& inputShapes,
47 const std::vector<std::string>& requestedOutputs) override;
48
49 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
50 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
51
52 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
53 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
54
55public:
56 CaffeParser();
57
58private:
59 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
60 const char* bindingPointDesc,
61 const std::unordered_map<std::string, BindingPointInfo>& bindingInfos);
62
63 /// Parses a NetParameter loaded into memory from one of the other CreateNetwork*
64 armnn::INetworkPtr CreateNetworkFromNetParameter(
65 caffe::NetParameter& netParam,
66 const std::map<std::string, armnn::TensorShape>& inputShapes,
67 const std::vector<std::string>& requestedOutputs);
68
69 /// does the actual conversion from caffe::NetParameter to armnn::INetwork
70 void LoadNetParam(caffe::NetParameter& netParameter);
71
72 /// Find the Caffe layers listed as inputs (bottoms) for a given layer.
73 std::vector<const caffe::LayerParameter*> GetInputs(const caffe::LayerParameter& layerParam);
74
75 /// Modifies the Caffe network to replace "in-place" layers (whose top() and bottom() are both the same)
76 /// with regular layers. This simplifies further parsing.
77 void ResolveInPlaceLayers(caffe::NetParameter& netParameter);
78
79 /// Converts Caffe's protobuf tensor shape format to ArmNN's
80 armnn::TensorInfo BlobShapeToTensorInfo(const caffe::BlobShape& blobShape) const;
81
82 /// Adds an armnn layer to m_Network given a Caffe LayerParameter of the correct type
83 /// and is responsible for recording any newly created IOutputSlots using SetArmnnOutputSlotForCaffeTop().
84 /// @{
85 void ParseInputLayer(const caffe::LayerParameter& layerParam);
86 void ParseConvLayer(const caffe::LayerParameter& layerParam);
87 void ParsePoolingLayer(const caffe::LayerParameter& layerParam);
88 void ParseReluLayer(const caffe::LayerParameter& layerParam);
89 void ParseLRNLayer(const caffe::LayerParameter& layerParam);
90 void ParseInnerProductLayer(const caffe::LayerParameter& layerParam);
91 void ParseSoftmaxLayer(const caffe::LayerParameter& layerParam);
92 void ParseEltwiseLayer(const caffe::LayerParameter& layerParam);
93 void ParseConcatLayer(const caffe::LayerParameter& layerParam);
94 void ParseBatchNormLayer(const caffe::LayerParameter& layerParam);
95 void ParseScaleLayer(const caffe::LayerParameter& layerParam);
96 void ParseSplitLayer(const caffe::LayerParameter& layerParam);
97 void ParseDropoutLayer(const caffe::LayerParameter& layerParam);
98 /// @}
99
100 void TrackInputBinding(armnn::IConnectableLayer* layer,
101 armnn::LayerBindingId id,
102 const armnn::TensorInfo& tensorInfo);
103
104 void TrackOutputBinding(armnn::IConnectableLayer* layer,
105 armnn::LayerBindingId id,
106 const armnn::TensorInfo& tensorInfo);
107
108 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
109 const armnn::TensorInfo& tensorInfo,
110 const char* bindingPointDesc,
111 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
112
113 /// Retrieves the Armnn IOutputSlot representing the given Caffe top.
114 /// Throws if it cannot be found (e.g. not parsed yet).
115 armnn::IOutputSlot& GetArmnnOutputSlotForCaffeTop(const std::string& caffeTopName) const;
116
117 void SetArmnnOutputSlotForCaffeTop(const std::string& caffeTopName, armnn::IOutputSlot& armnnOutputSlot);
118
119 void Cleanup();
120
121 armnn::INetworkPtr m_Network;
122
123 std::map<std::string, const caffe::LayerParameter*> m_CaffeLayersByTopName;
124
125 using OperationParsingFunction = void(CaffeParser::*)(const caffe::LayerParameter& layerParam);
126
127 /// map of Caffe layer names to parsing member functions
128 static const std::map<std::string, OperationParsingFunction> ms_CaffeLayerNameToParsingFunctions;
129
130 std::map<std::string, armnn::TensorShape> m_InputShapes;
131 std::vector<std::string> m_RequestedOutputs;
132
133 /// As we add armnn layers we store the armnn IOutputSlot which corresponds to the Caffe tops.
134 std::unordered_map<std::string, armnn::IOutputSlot*> m_ArmnnOutputSlotForCaffeTop;
135
136 /// maps input layer names to their corresponding ids and tensor infos
137 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
138
139 /// maps output layer names to their corresponding ids and tensor infos
140 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
141};
142}