blob: 660ca87d136b8e74876da0ce6ed7db27da34f598 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include <armnn/DescriptorsFwd.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01008#include <armnn/LstmParams.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/TensorFwd.hpp>
10#include <armnn/Types.hpp>
11
12#include <armnn/INetwork.hpp>
13
14#include <string>
15#include <vector>
16#include <memory>
17
18#include "Layer.hpp"
19
20namespace armnn
21{
22class Graph;
23
telsoa01c577f2c2018-08-31 09:22:23 +010024/// Private implementation of INetwork.
telsoa014fcda012018-03-09 14:13:49 +000025class Network final : public INetwork
26{
27public:
28 Network();
29 ~Network();
30
31 const Graph& GetGraph() const { return *m_Graph; }
32
33 Status PrintGraph() override;
34
35 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
36
Éanna Ó Catháin4e1e1362018-11-12 11:36:34 +000037 IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
38 const char* name = nullptr) override;
39
telsoa014fcda012018-03-09 14:13:49 +000040 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +000041 const ConstTensor& weights,
42 const Optional<ConstTensor>& biases,
43 const char* name = nullptr) override;
44
45 /// @deprecated
46 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
47 const ConstTensor& weights,
48 const char* name = nullptr) override;
49
50 /// @deprecated
51 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
52 const ConstTensor& weights,
53 const ConstTensor& biases,
54 const char* name = nullptr) override;
55
56 IConnectableLayer* AddDepthwiseConvolution2dLayer(
57 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
58 const ConstTensor& weights,
59 const Optional<ConstTensor>& biases,
60 const char* name = nullptr) override;
61
62 /// @deprecated
63 IConnectableLayer* AddDepthwiseConvolution2dLayer(
64 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
telsoa014fcda012018-03-09 14:13:49 +000065 const ConstTensor& weights,
66 const char* name = nullptr) override;
67
Aron Virginas-Tarad402702019-02-22 17:03:44 +000068 /// @deprecated
69 IConnectableLayer* AddDepthwiseConvolution2dLayer(
70 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
telsoa014fcda012018-03-09 14:13:49 +000071 const ConstTensor& weights,
72 const ConstTensor& biases,
73 const char* name = nullptr) override;
74
Nattapat Chaimanowonge4294fd2019-03-28 09:56:53 +000075 IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
76
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000077 IConnectableLayer* AddDetectionPostProcessLayer(
78 const DetectionPostProcessDescriptor& descriptor,
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000079 const ConstTensor& anchors,
Narumol Prangnawarat94dd5d82019-01-23 18:06:26 +000080 const char* name = nullptr) override;
81
telsoa014fcda012018-03-09 14:13:49 +000082 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +000083 const ConstTensor& weights,
84 const Optional<ConstTensor>& biases,
85 const char* name = nullptr) override;
telsoa014fcda012018-03-09 14:13:49 +000086
Aron Virginas-Tarad402702019-02-22 17:03:44 +000087 /// @deprecated
telsoa014fcda012018-03-09 14:13:49 +000088 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +000089 const ConstTensor& weights,
90 const char* name = nullptr) override;
91
92 /// @deprecated
93 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
94 const ConstTensor& weights,
95 const ConstTensor& biases,
96 const char* name = nullptr) override;
telsoa014fcda012018-03-09 14:13:49 +000097
narpra01b89b05f2019-01-16 09:53:09 +000098 IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
99
telsoa014fcda012018-03-09 14:13:49 +0000100 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
101 const char* name = nullptr) override;
102
103 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
104 const char* name = nullptr) override;
105
106 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
107 const char* name = nullptr) override;
108
109 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
110 const char* name = nullptr) override;
111
112 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
113 const char* name = nullptr) override;
114
115 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
116 const char* name = nullptr) override;
117
118 IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor,
119 const char* name = nullptr) override;
120
121 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
122
123 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
124
125 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
126 const ConstTensor& mean,
127 const ConstTensor& variance,
128 const ConstTensor& beta,
129 const ConstTensor& gamma,
130 const char* name = nullptr) override;
131
132 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
133 const char* name = nullptr) override;
134
Matteo Martincighbcd3c852018-09-28 14:14:12 +0100135 IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
136 const char* name = nullptr) override;
telsoa014fcda012018-03-09 14:13:49 +0000137
138 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
139
140 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
141 const char* name = nullptr) override;
142
Nattapat Chaimanowong207ef9a2018-11-02 10:57:25 +0000143 IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
144 const char* name = nullptr) override;
145
telsoa014fcda012018-03-09 14:13:49 +0000146 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
147
148 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
149
telsoa01c577f2c2018-08-31 09:22:23 +0100150 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
151 const LstmInputParams& params,
152 const char* name = nullptr) override;
153
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100154 IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
155
David Beck19526222018-09-12 16:00:08 +0100156 IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
157
Nattapat Chaimanowong5a4304a2018-11-28 10:44:37 +0000158 IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
159
narpra0132b90462018-09-13 11:07:48 +0100160 IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
161
Mohamed Nour Abouelseoud5662c202018-09-24 13:30:09 +0100162 IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
163
Derek Lambertia9cca6a2019-03-25 15:41:58 +0000164 IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
165
Conor Kennedy430b5d82018-11-14 15:28:28 +0000166 IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
167 const char* name = nullptr) override;
168
kevmay0190539692018-11-29 08:40:19 +0000169 IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
170
Matteo Martincigh59a950c2018-12-13 12:48:25 +0000171 IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
172
FrancisMurtagh20995952018-12-17 12:11:36 +0000173 IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
174
Mohamed Nour Abouelseouda1d3c6a2018-12-27 12:39:16 +0000175 IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
176
Nattapat Chaimanowong1f886302019-04-05 13:37:19 +0100177 IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
178
Sadik Armaganeff363d2019-04-05 15:25:46 +0100179 IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
180
Mike Kelly8c1701a2019-02-11 17:01:27 +0000181 void Accept(ILayerVisitor& visitor) const override;
182
telsoa014fcda012018-03-09 14:13:49 +0000183private:
184 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000185 const ConstTensor& weights,
186 const Optional<ConstTensor>& biases,
187 const char* name);
telsoa014fcda012018-03-09 14:13:49 +0000188
189 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000190 const ConstTensor& weights,
191 const Optional<ConstTensor>& biases,
192 const char* name);
telsoa014fcda012018-03-09 14:13:49 +0000193
194 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
195 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
196 const ConstTensor& weights,
Aron Virginas-Tarad402702019-02-22 17:03:44 +0000197 const Optional<ConstTensor>& biases,
telsoa014fcda012018-03-09 14:13:49 +0000198 const char* name);
199
200 std::unique_ptr<Graph> m_Graph;
201};
202
203class OptimizedNetwork final : public IOptimizedNetwork
204{
205public:
206 OptimizedNetwork(std::unique_ptr<Graph> graph);
207 ~OptimizedNetwork();
208
209 Status PrintGraph() override;
surmeh01bceff2f2018-03-29 16:29:27 +0100210 Status SerializeToDot(std::ostream& stream) const override;
telsoa014fcda012018-03-09 14:13:49 +0000211
212 Graph& GetGraph() { return *m_Graph; }
213
214private:
215 std::unique_ptr<Graph> m_Graph;
216};
217
218} // namespace armnn