blob: b6b8548f085f955f21ff609adf45a9990b3b0473 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include <armnn/DescriptorsFwd.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01008#include <armnn/LstmParams.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <armnn/TensorFwd.hpp>
10#include <armnn/Types.hpp>
11
12#include <armnn/INetwork.hpp>
13
14#include <string>
15#include <vector>
16#include <memory>
17
18#include "Layer.hpp"
19
20namespace armnn
21{
22class Graph;
23
telsoa01c577f2c2018-08-31 09:22:23 +010024/// Private implementation of INetwork.
telsoa014fcda012018-03-09 14:13:49 +000025class Network final : public INetwork
26{
27public:
28 Network();
29 ~Network();
30
31 const Graph& GetGraph() const { return *m_Graph; }
32
33 Status PrintGraph() override;
34
35 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
36
37 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
38 const ConstTensor& weights,
39 const char* name = nullptr) override;
40
41 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
42 const ConstTensor& weights,
43 const ConstTensor& biases,
44 const char* name = nullptr) override;
45
46 IConnectableLayer* AddDepthwiseConvolution2dLayer(
47 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
48 const ConstTensor& weights,
49 const char* name = nullptr) override;
50
51 IConnectableLayer* AddDepthwiseConvolution2dLayer(
52 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
53 const ConstTensor& weights,
54 const ConstTensor& biases,
55 const char* name = nullptr) override;
56
57 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
58 const ConstTensor& weights,
59 const char* name = nullptr) override;
60
61 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
62 const ConstTensor& weights,
63 const ConstTensor& biases,
64 const char* name = nullptr) override;
65
66 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
67 const char* name = nullptr) override;
68
69 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
70 const char* name = nullptr) override;
71
72 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
73 const char* name = nullptr) override;
74
75 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
76 const char* name = nullptr) override;
77
78 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
79 const char* name = nullptr) override;
80
81 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
82 const char* name = nullptr) override;
83
84 IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor,
85 const char* name = nullptr) override;
86
87 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
88
89 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
90
91 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
92 const ConstTensor& mean,
93 const ConstTensor& variance,
94 const ConstTensor& beta,
95 const ConstTensor& gamma,
96 const char* name = nullptr) override;
97
98 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
99 const char* name = nullptr) override;
100
101 IConnectableLayer* AddL2NormalizationLayer(const char* name = nullptr) override;
102
103 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
104
105 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
106 const char* name = nullptr) override;
107
108 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
109
110 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
111
telsoa01c577f2c2018-08-31 09:22:23 +0100112 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
113 const LstmInputParams& params,
114 const char* name = nullptr) override;
115
Francis Murtaghe7a86a42018-08-29 12:42:10 +0100116 IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
117
David Beck19526222018-09-12 16:00:08 +0100118 IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
119
telsoa014fcda012018-03-09 14:13:49 +0000120private:
121 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
122 const ConstTensor& weights,
123 const ConstTensor* biases,
124 const char* name);
125
126 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
127 const ConstTensor& weights,
128 const ConstTensor* biases,
129 const char* name);
130
131 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
132 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
133 const ConstTensor& weights,
134 const ConstTensor* biases,
135 const char* name);
136
137 std::unique_ptr<Graph> m_Graph;
138};
139
140class OptimizedNetwork final : public IOptimizedNetwork
141{
142public:
143 OptimizedNetwork(std::unique_ptr<Graph> graph);
144 ~OptimizedNetwork();
145
146 Status PrintGraph() override;
surmeh01bceff2f2018-03-29 16:29:27 +0100147 Status SerializeToDot(std::ostream& stream) const override;
telsoa014fcda012018-03-09 14:13:49 +0000148
149 Graph& GetGraph() { return *m_Graph; }
150
151private:
152 std::unique_ptr<Graph> m_Graph;
153};
154
155} // namespace armnn