blob: 40274a4769b5e94bfabd004d8f6c125a2bf80ac7 [file] [log] [blame]
Georgios Pinitasd8734b52017-12-22 15:27:52 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_GRAPH2_LAYERS_H__
25#define __ARM_COMPUTE_GRAPH2_LAYERS_H__
26
27#include "arm_compute/graph2/GraphBuilder.h"
28#include "arm_compute/graph2/Types.h"
29#include "arm_compute/graph2/frontend/ILayer.h"
30#include "arm_compute/graph2/frontend/IStream.h"
31#include "arm_compute/graph2/frontend/SubStream.h"
32
33#include "arm_compute/core/utils/misc/Utility.h"
34
35#include <memory>
36#include <string>
37
38namespace arm_compute
39{
40namespace graph2
41{
42namespace frontend
43{
44/** Input Layer */
45class InputLayer final : public ILayer
46{
47public:
48 InputLayer(TensorDescriptor desc, ITensorAccessorUPtr accessor)
49 : _desc(desc), _accessor(std::move(accessor))
50 {
51 }
52
53 NodeID create_layer(IStream &s) override
54 {
55 NodeParams common_params = { "", s.hints().target_hint };
56 return GraphBuilder::add_input_node(s.graph(), common_params, _desc, std::move(_accessor));
57 }
58
59private:
60 TensorDescriptor _desc;
61 ITensorAccessorUPtr _accessor;
62};
63
64/** Output Layer */
65class OutputLayer final : public ILayer
66{
67public:
68 OutputLayer(ITensorAccessorUPtr accessor)
69 : _accessor(std::move(accessor))
70 {
71 }
72
73 NodeID create_layer(IStream &s) override
74 {
75 NodeParams common_params = { "", s.hints().target_hint };
76 NodeIdxPair input = { s.tail_node(), 0 };
77 return GraphBuilder::add_output_node(s.graph(), common_params, input, std::move(_accessor));
78 }
79
80private:
81 ITensorAccessorUPtr _accessor;
82};
83
84/** Activation Layer */
85class ActivationLayer final : public ILayer
86{
87public:
88 ActivationLayer(ActivationLayerInfo act_info)
89 : _act_info(act_info)
90 {
91 }
92
93 NodeID create_layer(IStream &s) override
94 {
95 NodeParams common_params = { "", s.hints().target_hint };
96 NodeIdxPair input = { s.tail_node(), 0 };
97 return GraphBuilder::add_activation_node(s.graph(), common_params, input, _act_info);
98 }
99
100private:
101 ActivationLayerInfo _act_info;
102};
103
104/** Batchnormalization Layer */
105class BatchNormalizationLayer final : public ILayer
106{
107public:
108 BatchNormalizationLayer(ITensorAccessorUPtr mean,
109 ITensorAccessorUPtr var,
110 ITensorAccessorUPtr gamma = nullptr,
111 ITensorAccessorUPtr beta = nullptr,
112 float epsilon = 0.001f)
113 : _mean(std::move(mean)), _var(std::move(var)), _gamma(std::move(gamma)), _beta(std::move(beta)), _epsilon(epsilon)
114 {
115 }
116
117 NodeID create_layer(IStream &s) override
118 {
119 ARM_COMPUTE_ERROR_ON(_mean == nullptr);
120 ARM_COMPUTE_ERROR_ON(_var == nullptr);
121
122 NodeParams common_params = { "", s.hints().target_hint };
123 NodeIdxPair input = { s.tail_node(), 0 };
124 return GraphBuilder::add_batch_normalization_node(s.graph(), common_params, input, _epsilon,
125 std::move(_mean), std::move(_var), std::move(_beta), std::move(_gamma));
126 }
127
128private:
129 ITensorAccessorUPtr _mean;
130 ITensorAccessorUPtr _var;
131 ITensorAccessorUPtr _gamma;
132 ITensorAccessorUPtr _beta;
133 float _epsilon;
134};
135
136/** Convolution Layer */
137class ConvolutionLayer final : public ILayer
138{
139public:
140 ConvolutionLayer(unsigned int conv_width,
141 unsigned int conv_height,
142 unsigned int ofm,
143 ITensorAccessorUPtr weights,
144 ITensorAccessorUPtr bias,
145 PadStrideInfo conv_info,
146 unsigned int num_groups = 1)
147 : _conv_width(conv_width),
148 _conv_height(conv_height),
149 _ofm(ofm),
150 _conv_info(std::move(conv_info)),
151 _num_groups(num_groups),
152 _weights(std::move(weights)),
153 _bias(std::move(bias))
154 {
155 }
156
157 NodeID create_layer(IStream &s) override
158 {
159 ARM_COMPUTE_UNUSED(_num_groups);
160 NodeIdxPair input = { s.tail_node(), 0 };
161 NodeParams common_params = { "", s.hints().target_hint };
162 return GraphBuilder::add_convolution_node(s.graph(), common_params, input,
163 Size2D(_conv_width, _conv_height), _ofm, _conv_info,
164 s.hints().convolution_method_hint,
165 std::move(_weights), std::move(_bias));
166 }
167
168private:
169 unsigned int _conv_width;
170 unsigned int _conv_height;
171 unsigned int _ofm;
172 const PadStrideInfo _conv_info;
173 unsigned int _num_groups;
174 ITensorAccessorUPtr _weights;
175 ITensorAccessorUPtr _bias;
176};
177
178/** Depthwise Convolution Layer */
179class DepthwiseConvolutionLayer final : public ILayer
180{
181public:
182 DepthwiseConvolutionLayer(unsigned int conv_width,
183 unsigned int conv_height,
184 ITensorAccessorUPtr weights,
185 ITensorAccessorUPtr bias,
186 PadStrideInfo conv_info)
187 : _conv_width(conv_width),
188 _conv_height(conv_height),
189 _conv_info(std::move(conv_info)),
190 _weights(std::move(weights)),
191 _bias(std::move(bias))
192 {
193 }
194
195 NodeID create_layer(IStream &s) override
196 {
197 NodeIdxPair input = { s.tail_node(), 0 };
198 NodeParams common_params = { "", s.hints().target_hint };
199 return GraphBuilder::add_depthwise_convolution_node(s.graph(), common_params,
200 input, Size2D(_conv_width, _conv_height), _conv_info,
201 s.hints().depthwise_convolution_method_hint,
202 std::move(_weights), std::move(_bias));
203 }
204
205private:
206 unsigned int _conv_width;
207 unsigned int _conv_height;
208 const PadStrideInfo _conv_info;
209 ITensorAccessorUPtr _weights;
210 ITensorAccessorUPtr _bias;
211};
212
213/** Flatten Layer */
214class FlattenLayer final : public ILayer
215{
216public:
217 FlattenLayer()
218 {
219 }
220
221 NodeID create_layer(IStream &s) override
222 {
223 NodeParams common_params = { "", s.hints().target_hint };
224 NodeIdxPair input = { s.tail_node(), 0 };
225 return GraphBuilder::add_flatten_node(s.graph(), common_params, input);
226 }
227};
228
229/** Fully Connected Layer */
230class FullyConnectedLayer final : public ILayer
231{
232public:
233 FullyConnectedLayer(unsigned int num_outputs,
234 ITensorAccessorUPtr weights,
235 ITensorAccessorUPtr bias)
236 : _num_outputs(num_outputs), _weights(std::move(weights)), _bias(std::move(bias))
237 {
238 }
239
240 NodeID create_layer(IStream &s) override
241 {
242 NodeParams common_params = { "", s.hints().target_hint };
243 NodeIdxPair input = { s.tail_node(), 0 };
244 return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
245 std::move(_weights), std::move(_bias));
246 }
247
248private:
249 unsigned int _num_outputs;
250 ITensorAccessorUPtr _weights;
251 ITensorAccessorUPtr _bias;
252};
253
254/** Normalization Layer */
255class NormalizationLayer final : public ILayer
256{
257public:
258 NormalizationLayer(NormalizationLayerInfo norm_info)
259 : _norm_info(norm_info)
260 {
261 }
262
263 NodeID create_layer(IStream &s) override
264 {
265 NodeParams common_params = { "", s.hints().target_hint };
266 NodeIdxPair input = { s.tail_node(), 0 };
267 return GraphBuilder::add_normalization_node(s.graph(), common_params, input, _norm_info);
268 }
269
270private:
271 NormalizationLayerInfo _norm_info;
272};
273
274/** Pooling Layer */
275class PoolingLayer final : public ILayer
276{
277public:
278 PoolingLayer(PoolingLayerInfo pool_info)
279 : _pool_info(pool_info)
280 {
281 }
282
283 NodeID create_layer(IStream &s) override
284 {
285 NodeParams common_params = { "", s.hints().target_hint };
286 NodeIdxPair input = { s.tail_node(), 0 };
287 return GraphBuilder::add_pooling_node(s.graph(), common_params, input, _pool_info);
288 }
289
290private:
291 PoolingLayerInfo _pool_info;
292};
293
294/** Reshape Layer */
295class ReshapeLayer final : public ILayer
296{
297public:
298 ReshapeLayer(TensorShape shape)
299 : _shape(shape)
300 {
301 }
302
303 NodeID create_layer(IStream &s) override
304 {
305 NodeParams common_params = { "", s.hints().target_hint };
306 NodeIdxPair input = { s.tail_node(), 0 };
307 return GraphBuilder::add_reshape_node(s.graph(), common_params, input, _shape);
308 }
309
310private:
311 TensorShape _shape;
312};
313
314/** Softmax Layer */
315class SoftmaxLayer final : public ILayer
316{
317public:
318 SoftmaxLayer(float beta = 1.0f)
319 : _beta(beta)
320 {
321 }
322
323 NodeID create_layer(IStream &s) override
324 {
325 NodeParams common_params = { "", s.hints().target_hint };
326 NodeIdxPair input = { s.tail_node(), 0 };
327 return GraphBuilder::add_softmax_node(s.graph(), common_params, input, _beta);
328 }
329
330private:
331 float _beta;
332};
333
334/** Branch Layer */
335class BranchLayer final : public ILayer
336{
337public:
338 /** Default Constructor
339 *
340 * @param[in] merge_method Branch merging method
341 * @param[in] sub_stream1 First graph branch
342 * @param[in] sub_stream2 Second graph branch
343 * @param[in] rest_sub_streams Rest sub-graph branches
344 */
345 template <typename... Ts>
346 BranchLayer(BranchMergeMethod merge_method, SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
347 : _branch_merge_method(merge_method), _sub_streams()
348 {
349 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream1)));
350 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream2)));
351
352 utility::for_each([&](SubStream && sub_stream)
353 {
354 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream)));
355 },
356 std::move(rest_sub_streams)...);
357 }
358 /** Default Constructor
359 *
360 * @param[in] sub_stream Sub-stream
361 */
362 template <typename... Ts>
363 BranchLayer(SubStream &&sub_stream)
364 : _branch_merge_method(BranchMergeMethod::DEPTH_CONCATENATE), _sub_streams()
365 {
366 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream)));
367 }
368 NodeID create_layer(IStream &s) override
369 {
370 NodeID nid = EmptyNodeID;
371 NodeParams common_params = { "", s.hints().target_hint };
372 if(_sub_streams.size() == 1 && _sub_streams.at(0) != nullptr)
373 {
374 nid = _sub_streams[0]->tail_node();
375 }
376 else if(_branch_merge_method == BranchMergeMethod::DEPTH_CONCATENATE)
377 {
378 // Collect tail nodes and perform DepthConcatenate
379 std::vector<NodeIdxPair> nodes;
380 for(auto &ss : _sub_streams)
381 {
382 if(ss && (ss->tail_node() != EmptyNodeID))
383 {
384 const auto tail_node = s.graph().node(ss->tail_node());
385 if(tail_node != nullptr && tail_node->type() != NodeType::Output)
386 {
387 nodes.push_back({ ss->tail_node(), 0 });
388 }
389 }
390 }
391 nid = GraphBuilder::add_depth_concatenate_node(s.graph(), common_params, nodes);
392 }
393 else
394 {
395 ARM_COMPUTE_ERROR_ON(_sub_streams.size() != 2);
396 NodeIdxPair input0 = { _sub_streams[0]->tail_node(), 0 };
397 NodeIdxPair input1 = { _sub_streams[1]->tail_node(), 0 };
398 nid = GraphBuilder::add_elementwise_node(s.graph(), common_params, input0, input1, EltwiseOperation::ADD);
399 }
400 return nid;
401 }
402
403private:
404 BranchMergeMethod _branch_merge_method;
405 std::vector<std::unique_ptr<SubStream>> _sub_streams;
406};
407} // namespace frontend
408} // namespace graph2
409} // namespace arm_compute
410#endif /* __ARM_COMPUTE_GRAPH2_LAYERS_H__ */