blob: 02ef56952ddfc72314de3a506f70110786396f14 [file] [log] [blame]
Georgios Pinitasd8734b52017-12-22 15:27:52 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010024#ifndef __ARM_COMPUTE_GRAPH_LAYERS_H__
25#define __ARM_COMPUTE_GRAPH_LAYERS_H__
Georgios Pinitasd8734b52017-12-22 15:27:52 +000026
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/GraphBuilder.h"
28#include "arm_compute/graph/Types.h"
29#include "arm_compute/graph/frontend/ILayer.h"
30#include "arm_compute/graph/frontend/IStream.h"
31#include "arm_compute/graph/frontend/SubStream.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000032
33#include "arm_compute/core/utils/misc/Utility.h"
34
35#include <memory>
36#include <string>
37
38namespace arm_compute
39{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010040namespace graph
Georgios Pinitasd8734b52017-12-22 15:27:52 +000041{
42namespace frontend
43{
44/** Input Layer */
45class InputLayer final : public ILayer
46{
47public:
Alex Gildayc357c472018-03-21 13:54:09 +000048 /** Construct an input layer.
49 *
50 * @param[in] desc Description of input tensor.
51 * @param[in] accessor Accessor to get input tensor data from.
52 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +000053 InputLayer(TensorDescriptor desc, ITensorAccessorUPtr accessor)
54 : _desc(desc), _accessor(std::move(accessor))
55 {
56 }
57
58 NodeID create_layer(IStream &s) override
59 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +010060 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +000061 return GraphBuilder::add_input_node(s.graph(), common_params, _desc, std::move(_accessor));
62 }
63
64private:
65 TensorDescriptor _desc;
66 ITensorAccessorUPtr _accessor;
67};
68
69/** Output Layer */
70class OutputLayer final : public ILayer
71{
72public:
Alex Gildayc357c472018-03-21 13:54:09 +000073 /** Construct an output layer.
74 *
75 * @param[in] accessor Accessor to give output tensor data to.
76 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +000077 OutputLayer(ITensorAccessorUPtr accessor)
78 : _accessor(std::move(accessor))
79 {
80 }
81
82 NodeID create_layer(IStream &s) override
83 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +010084 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +000085 NodeIdxPair input = { s.tail_node(), 0 };
86 return GraphBuilder::add_output_node(s.graph(), common_params, input, std::move(_accessor));
87 }
88
89private:
90 ITensorAccessorUPtr _accessor;
91};
92
93/** Activation Layer */
94class ActivationLayer final : public ILayer
95{
96public:
Alex Gildayc357c472018-03-21 13:54:09 +000097 /** Construct an activation layer.
98 *
99 * @param[in] act_info Activation information
100 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000101 ActivationLayer(ActivationLayerInfo act_info)
102 : _act_info(act_info)
103 {
104 }
105
106 NodeID create_layer(IStream &s) override
107 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100108 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000109 NodeIdxPair input = { s.tail_node(), 0 };
110 return GraphBuilder::add_activation_node(s.graph(), common_params, input, _act_info);
111 }
112
113private:
114 ActivationLayerInfo _act_info;
115};
116
117/** Batchnormalization Layer */
118class BatchNormalizationLayer final : public ILayer
119{
120public:
Alex Gildayc357c472018-03-21 13:54:09 +0000121 /** Construct a batch normalization layer.
122 *
123 * @param[in] mean Accessor to get mean tensor data from.
124 * @param[in] var Accessor to get var tensor data from.
125 * @param[in] gamma (Optional) Accessor to get gamma tensor data from. Default: nullptr.
126 * @param[in] beta (Optional) Accessor to get beta tensor data from. Default: nullptr.
127 * @param[in] epsilon (Optional) Epsilon value. Default: 0.001.
128 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000129 BatchNormalizationLayer(ITensorAccessorUPtr mean,
130 ITensorAccessorUPtr var,
131 ITensorAccessorUPtr gamma = nullptr,
132 ITensorAccessorUPtr beta = nullptr,
133 float epsilon = 0.001f)
134 : _mean(std::move(mean)), _var(std::move(var)), _gamma(std::move(gamma)), _beta(std::move(beta)), _epsilon(epsilon)
135 {
136 }
137
138 NodeID create_layer(IStream &s) override
139 {
140 ARM_COMPUTE_ERROR_ON(_mean == nullptr);
141 ARM_COMPUTE_ERROR_ON(_var == nullptr);
142
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100143 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000144 NodeIdxPair input = { s.tail_node(), 0 };
145 return GraphBuilder::add_batch_normalization_node(s.graph(), common_params, input, _epsilon,
146 std::move(_mean), std::move(_var), std::move(_beta), std::move(_gamma));
147 }
148
149private:
150 ITensorAccessorUPtr _mean;
151 ITensorAccessorUPtr _var;
152 ITensorAccessorUPtr _gamma;
153 ITensorAccessorUPtr _beta;
154 float _epsilon;
155};
156
Georgios Pinitas087eaf62018-05-16 15:52:35 +0100157/** Channel Shuffle Layer */
158class ChannelShuffleLayer final : public ILayer
159{
160public:
161 /** Construct a Channel Shuffle layer.
162 *
163 * @param[in] num_groups Number of groups
164 */
165 ChannelShuffleLayer(unsigned int num_groups)
166 : _num_groups(num_groups)
167 {
168 }
169
170 NodeID create_layer(IStream &s) override
171 {
172 NodeParams common_params = { name(), s.hints().target_hint };
173 NodeIdxPair input = { s.tail_node(), 0 };
174 return GraphBuilder::add_channel_shuffle_node(s.graph(), common_params, input, _num_groups);
175 }
176
177private:
178 unsigned int _num_groups;
179};
180
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000181/** Convolution Layer */
182class ConvolutionLayer final : public ILayer
183{
184public:
Alex Gildayc357c472018-03-21 13:54:09 +0000185 /** Construct a convolution layer.
186 *
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100187 * @param[in] conv_width Convolution width.
188 * @param[in] conv_height Convolution height.
189 * @param[in] ofm Output feature map.
190 * @param[in] weights Accessor to get kernel weights from.
191 * @param[in] bias Accessor to get kernel bias from.
192 * @param[in] conv_info Padding and stride information.
193 * @param[in] num_groups (Optional) Number of groups. Default: 1.
194 * @param[in] weights_quant_info (Optional) Weights quantization information
195 * @param[in] out_quant_info (Optional) Output quantization info
Alex Gildayc357c472018-03-21 13:54:09 +0000196 */
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100197 ConvolutionLayer(unsigned int conv_width,
198 unsigned int conv_height,
199 unsigned int ofm,
200 ITensorAccessorUPtr weights,
201 ITensorAccessorUPtr bias,
202 PadStrideInfo conv_info,
203 unsigned int num_groups = 1,
204 const QuantizationInfo weights_quant_info = QuantizationInfo(),
205 const QuantizationInfo out_quant_info = QuantizationInfo())
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000206 : _conv_width(conv_width),
207 _conv_height(conv_height),
208 _ofm(ofm),
209 _conv_info(std::move(conv_info)),
210 _num_groups(num_groups),
211 _weights(std::move(weights)),
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100212 _bias(std::move(bias)),
213 _weights_quant_info(std::move(weights_quant_info)),
214 _out_quant_info(std::move(out_quant_info))
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000215 {
216 }
217
218 NodeID create_layer(IStream &s) override
219 {
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000220 NodeIdxPair input = { s.tail_node(), 0 };
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100221 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000222 return GraphBuilder::add_convolution_node(s.graph(), common_params, input,
Georgios Pinitasee33ea52018-03-08 16:01:29 +0000223 Size2D(_conv_width, _conv_height), _ofm, _conv_info, _num_groups,
Giorgio Arena59631a12018-05-02 13:59:04 +0100224 s.hints().convolution_method_hint, s.hints().fast_math_hint,
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100225 std::move(_weights), std::move(_bias), std::move(_weights_quant_info), std::move(_out_quant_info));
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000226 }
227
228private:
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100229 unsigned int _conv_width;
230 unsigned int _conv_height;
231 unsigned int _ofm;
232 const PadStrideInfo _conv_info;
233 unsigned int _num_groups;
234 ITensorAccessorUPtr _weights;
235 ITensorAccessorUPtr _bias;
236 const QuantizationInfo _weights_quant_info;
237 const QuantizationInfo _out_quant_info;
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000238};
239
Georgios Pinitas087eaf62018-05-16 15:52:35 +0100240/** Deconvolution Layer */
241class DeconvolutionLayer final : public ILayer
242{
243public:
244 /** Construct a convolution layer.
245 *
246 * @param[in] conv_width Convolution width.
247 * @param[in] conv_height Convolution height.
248 * @param[in] ofm Output feature map.
249 * @param[in] weights Accessor to get kernel weights from.
250 * @param[in] bias Accessor to get kernel bias from.
251 * @param[in] deconv_info Padding and stride information.
252 * @param[in] inner_border Inner border padding (right, top)
253 */
254 DeconvolutionLayer(unsigned int conv_width,
255 unsigned int conv_height,
256 unsigned int ofm,
257 ITensorAccessorUPtr weights,
258 ITensorAccessorUPtr bias,
259 PadStrideInfo deconv_info,
260 Size2D inner_border)
261 : _conv_width(conv_width),
262 _conv_height(conv_height),
263 _ofm(ofm),
264 _deconv_info(std::move(deconv_info)),
265 _inner_border(inner_border),
266 _weights(std::move(weights)),
267 _bias(std::move(bias))
268 {
269 }
270
271 NodeID create_layer(IStream &s) override
272 {
273 NodeIdxPair input = { s.tail_node(), 0 };
274 NodeParams common_params = { name(), s.hints().target_hint };
275 return GraphBuilder::add_deconvolution_node(s.graph(), common_params, input,
276 Size2D(_conv_width, _conv_height), _ofm, _deconv_info, _inner_border,
277 std::move(_weights), std::move(_bias));
278 }
279
280private:
281 unsigned int _conv_width;
282 unsigned int _conv_height;
283 unsigned int _ofm;
284 const PadStrideInfo _deconv_info;
285 Size2D _inner_border;
286 ITensorAccessorUPtr _weights;
287 ITensorAccessorUPtr _bias;
288};
289
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000290/** Depthwise Convolution Layer */
291class DepthwiseConvolutionLayer final : public ILayer
292{
293public:
Alex Gildayc357c472018-03-21 13:54:09 +0000294 /** Construct a depthwise convolution layer.
295 *
296 * @param[in] conv_width Convolution width.
297 * @param[in] conv_height Convolution height.
298 * @param[in] weights Accessor to get kernel weights from.
299 * @param[in] bias Accessor to get kernel bias from.
300 * @param[in] conv_info Padding and stride information.
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100301 * @param[in] quant_info (Optional) Quantization info used for weights
Alex Gildayc357c472018-03-21 13:54:09 +0000302 */
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100303 DepthwiseConvolutionLayer(unsigned int conv_width,
304 unsigned int conv_height,
305 ITensorAccessorUPtr weights,
306 ITensorAccessorUPtr bias,
307 PadStrideInfo conv_info,
308 const QuantizationInfo quant_info = QuantizationInfo())
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000309 : _conv_width(conv_width),
310 _conv_height(conv_height),
311 _conv_info(std::move(conv_info)),
312 _weights(std::move(weights)),
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100313 _bias(std::move(bias)),
314 _quant_info(std::move(quant_info))
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000315 {
316 }
317
318 NodeID create_layer(IStream &s) override
319 {
320 NodeIdxPair input = { s.tail_node(), 0 };
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100321 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000322 return GraphBuilder::add_depthwise_convolution_node(s.graph(), common_params,
323 input, Size2D(_conv_width, _conv_height), _conv_info,
324 s.hints().depthwise_convolution_method_hint,
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100325 std::move(_weights), std::move(_bias), std::move(_quant_info));
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000326 }
327
328private:
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100329 unsigned int _conv_width;
330 unsigned int _conv_height;
331 const PadStrideInfo _conv_info;
332 ITensorAccessorUPtr _weights;
333 ITensorAccessorUPtr _bias;
334 const QuantizationInfo _quant_info;
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000335};
336
Georgios Pinitas087eaf62018-05-16 15:52:35 +0100337/** Dummy Layer */
338class DummyLayer final : public ILayer
339{
340public:
341 /** Construct an input layer.
342 *
343 * @param[in] shape Output shape
344 */
345 DummyLayer(TensorShape shape)
346 : _shape(shape)
347 {
348 }
349
350 NodeID create_layer(IStream &s) override
351 {
352 NodeParams common_params = { name(), s.hints().target_hint };
353 NodeIdxPair input = { s.tail_node(), 0 };
354 return GraphBuilder::add_dummy_node(s.graph(), common_params, input, _shape);
355 }
356
357private:
358 TensorShape _shape;
359};
360
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000361/** Flatten Layer */
362class FlattenLayer final : public ILayer
363{
364public:
Alex Gildayc357c472018-03-21 13:54:09 +0000365 /** Construct a flatten layer. */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000366 FlattenLayer()
367 {
368 }
369
370 NodeID create_layer(IStream &s) override
371 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100372 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000373 NodeIdxPair input = { s.tail_node(), 0 };
374 return GraphBuilder::add_flatten_node(s.graph(), common_params, input);
375 }
376};
377
378/** Fully Connected Layer */
379class FullyConnectedLayer final : public ILayer
380{
381public:
Alex Gildayc357c472018-03-21 13:54:09 +0000382 /** Construct a fully connected layer.
383 *
384 * @param[in] num_outputs Number of outputs.
385 * @param[in] weights Accessor to get weights from.
386 * @param[in] bias Accessor to get bias from.
387 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000388 FullyConnectedLayer(unsigned int num_outputs,
389 ITensorAccessorUPtr weights,
390 ITensorAccessorUPtr bias)
391 : _num_outputs(num_outputs), _weights(std::move(weights)), _bias(std::move(bias))
392 {
393 }
394
395 NodeID create_layer(IStream &s) override
396 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100397 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000398 NodeIdxPair input = { s.tail_node(), 0 };
399 return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
400 std::move(_weights), std::move(_bias));
401 }
402
403private:
404 unsigned int _num_outputs;
405 ITensorAccessorUPtr _weights;
406 ITensorAccessorUPtr _bias;
407};
408
409/** Normalization Layer */
410class NormalizationLayer final : public ILayer
411{
412public:
Alex Gildayc357c472018-03-21 13:54:09 +0000413 /** Construct a normalization layer.
414 *
415 * @param[in] norm_info Normalization information.
416 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000417 NormalizationLayer(NormalizationLayerInfo norm_info)
418 : _norm_info(norm_info)
419 {
420 }
421
422 NodeID create_layer(IStream &s) override
423 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100424 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000425 NodeIdxPair input = { s.tail_node(), 0 };
426 return GraphBuilder::add_normalization_node(s.graph(), common_params, input, _norm_info);
427 }
428
429private:
430 NormalizationLayerInfo _norm_info;
431};
432
433/** Pooling Layer */
434class PoolingLayer final : public ILayer
435{
436public:
Alex Gildayc357c472018-03-21 13:54:09 +0000437 /** Construct a pooling layer.
438 *
439 * @param[in] pool_info Pooling information.
440 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000441 PoolingLayer(PoolingLayerInfo pool_info)
442 : _pool_info(pool_info)
443 {
444 }
445
446 NodeID create_layer(IStream &s) override
447 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100448 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000449 NodeIdxPair input = { s.tail_node(), 0 };
450 return GraphBuilder::add_pooling_node(s.graph(), common_params, input, _pool_info);
451 }
452
453private:
454 PoolingLayerInfo _pool_info;
455};
456
457/** Reshape Layer */
458class ReshapeLayer final : public ILayer
459{
460public:
Alex Gildayc357c472018-03-21 13:54:09 +0000461 /** Construct a reshape layer.
462 *
463 * @param[in] shape Target shape.
464 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000465 ReshapeLayer(TensorShape shape)
466 : _shape(shape)
467 {
468 }
469
470 NodeID create_layer(IStream &s) override
471 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100472 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000473 NodeIdxPair input = { s.tail_node(), 0 };
474 return GraphBuilder::add_reshape_node(s.graph(), common_params, input, _shape);
475 }
476
477private:
478 TensorShape _shape;
479};
480
Georgios Pinitas087eaf62018-05-16 15:52:35 +0100481/** Resize Layer */
482class ResizeLayer final : public ILayer
483{
484public:
485 ResizeLayer(InterpolationPolicy policy, float width_scale, float height_scale)
486 : _policy(policy), _width_scale(width_scale), _height_scale(height_scale)
487 {
488 }
489
490 NodeID create_layer(IStream &s) override
491 {
492 NodeParams common_params = { name(), s.hints().target_hint };
493 NodeIdxPair input = { s.tail_node(), 0 };
494 return GraphBuilder::add_resize_node(s.graph(), common_params, input, _policy, _width_scale, _height_scale);
495 }
496
497private:
498 InterpolationPolicy _policy;
499 float _width_scale;
500 float _height_scale;
501};
502
Isabella Gottardi88d5b222018-04-06 12:24:55 +0100503/** Scale Layer */
504class ScaleLayer final : public ILayer
505{
506public:
507 /** Construct a scale layer.
508 *
509 * @param[in] mul_w Accessor to get mul weight from.
510 * @param[in] add_w Accessor to get add weight from.
511 */
512 ScaleLayer(ITensorAccessorUPtr mul_w,
513 ITensorAccessorUPtr add_w)
514 : _mul_w(std::move(mul_w)), _add_w(std::move(add_w))
515 {
516 }
517
518 NodeID create_layer(IStream &s) override
519 {
520 NodeParams common_params = { name(), s.hints().target_hint };
521 NodeIdxPair input = { s.tail_node(), 0 };
522 return GraphBuilder::add_scale_layer(s.graph(), common_params, input, std::move(_mul_w), std::move(_add_w));
523 }
524
525private:
526 ITensorAccessorUPtr _mul_w;
527 ITensorAccessorUPtr _add_w;
528};
529
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000530/** Softmax Layer */
531class SoftmaxLayer final : public ILayer
532{
533public:
Alex Gildayc357c472018-03-21 13:54:09 +0000534 /** Construct a softmax layer.
535 *
536 * @param[in] beta (Optional) Beta value. Default 1.0.
537 */
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000538 SoftmaxLayer(float beta = 1.0f)
539 : _beta(beta)
540 {
541 }
542
543 NodeID create_layer(IStream &s) override
544 {
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100545 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000546 NodeIdxPair input = { s.tail_node(), 0 };
547 return GraphBuilder::add_softmax_node(s.graph(), common_params, input, _beta);
548 }
549
550private:
551 float _beta;
552};
553
554/** Branch Layer */
555class BranchLayer final : public ILayer
556{
557public:
Alex Gildayc357c472018-03-21 13:54:09 +0000558 /** Construct a branch layer
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000559 *
560 * @param[in] merge_method Branch merging method
561 * @param[in] sub_stream1 First graph branch
562 * @param[in] sub_stream2 Second graph branch
563 * @param[in] rest_sub_streams Rest sub-graph branches
564 */
565 template <typename... Ts>
566 BranchLayer(BranchMergeMethod merge_method, SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
567 : _branch_merge_method(merge_method), _sub_streams()
568 {
569 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream1)));
570 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream2)));
571
572 utility::for_each([&](SubStream && sub_stream)
573 {
574 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream)));
575 },
576 std::move(rest_sub_streams)...);
577 }
Alex Gildayc357c472018-03-21 13:54:09 +0000578 /** Construct a branch layer
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000579 *
580 * @param[in] sub_stream Sub-stream
581 */
582 template <typename... Ts>
583 BranchLayer(SubStream &&sub_stream)
584 : _branch_merge_method(BranchMergeMethod::DEPTH_CONCATENATE), _sub_streams()
585 {
586 _sub_streams.push_back(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream)));
587 }
588 NodeID create_layer(IStream &s) override
589 {
590 NodeID nid = EmptyNodeID;
Georgios Pinitas5c2fb3f2018-05-01 15:26:20 +0100591 NodeParams common_params = { name(), s.hints().target_hint };
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000592 if(_sub_streams.size() == 1 && _sub_streams.at(0) != nullptr)
593 {
594 nid = _sub_streams[0]->tail_node();
595 }
596 else if(_branch_merge_method == BranchMergeMethod::DEPTH_CONCATENATE)
597 {
Georgios Pinitase2220552018-07-20 13:23:44 +0100598 // Collect tail nodes and concatenate
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000599 std::vector<NodeIdxPair> nodes;
600 for(auto &ss : _sub_streams)
601 {
602 if(ss && (ss->tail_node() != EmptyNodeID))
603 {
604 const auto tail_node = s.graph().node(ss->tail_node());
605 if(tail_node != nullptr && tail_node->type() != NodeType::Output)
606 {
607 nodes.push_back({ ss->tail_node(), 0 });
608 }
609 }
610 }
Georgios Pinitase2220552018-07-20 13:23:44 +0100611 nid = GraphBuilder::add_concatenate_node(s.graph(), common_params, nodes, DataLayoutDimension::CHANNEL);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000612 }
613 else
614 {
615 ARM_COMPUTE_ERROR_ON(_sub_streams.size() != 2);
616 NodeIdxPair input0 = { _sub_streams[0]->tail_node(), 0 };
617 NodeIdxPair input1 = { _sub_streams[1]->tail_node(), 0 };
Georgios Pinitase2220552018-07-20 13:23:44 +0100618 nid = GraphBuilder::add_elementwise_node(s.graph(), common_params, input0, input1, EltwiseOperation::Add);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000619 }
620 return nid;
621 }
622
623private:
624 BranchMergeMethod _branch_merge_method;
625 std::vector<std::unique_ptr<SubStream>> _sub_streams;
626};
627} // namespace frontend
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100628} // namespace graph
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000629} // namespace arm_compute
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100630#endif /* __ARM_COMPUTE_GRAPH_LAYERS_H__ */