blob: 18061909716d07c9ae8847478f0d300b759370b0 [file] [log] [blame]
Anthony Barbier2a07e182017-08-04 18:20:27 +01001/*
Giorgio Arenaa66eaa22017-12-21 19:50:06 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier2a07e182017-08-04 18:20:27 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_GRAPH_CONVOLUTION_LAYER_H__
25#define __ARM_COMPUTE_GRAPH_CONVOLUTION_LAYER_H__
26
Georgios Pinitasff421f22017-10-04 16:53:58 +010027#include "arm_compute/graph/GraphContext.h"
Anthony Barbier2a07e182017-08-04 18:20:27 +010028#include "arm_compute/graph/INode.h"
Georgios Pinitase2c82fe2017-10-02 18:51:47 +010029#include "arm_compute/graph/ITensorObject.h"
Georgios Pinitas6f669f02017-09-26 12:32:57 +010030#include "arm_compute/graph/SubTensor.h"
Anthony Barbier2a07e182017-08-04 18:20:27 +010031#include "arm_compute/graph/Tensor.h"
32#include "arm_compute/graph/Types.h"
Georgios Pinitas6f669f02017-09-26 12:32:57 +010033#include "arm_compute/runtime/IFunction.h"
34
35#include <memory>
Anthony Barbier2a07e182017-08-04 18:20:27 +010036
37namespace arm_compute
38{
39namespace graph
40{
41/** Convolution layer node */
Georgios Pinitas6f669f02017-09-26 12:32:57 +010042class ConvolutionLayer final : public INode
Anthony Barbier2a07e182017-08-04 18:20:27 +010043{
44public:
45 /** Default Constructor
46 *
Giorgio Arenaa66eaa22017-12-21 19:50:06 +000047 * @param[in] conv_width Convolution width
48 * @param[in] conv_height Convolution height
49 * @param[in] ofm Output feature map
50 * @param[in] weights Weights of the convolution layer
51 * @param[in] biases Bias of the convolution layer
52 * @param[in] conv_info Convolution information
53 * @param[in] num_groups (Optional) Number of groups, default = 1
54 * @param[in] weights_info (Optional) Weights information
55 * @param[in] weights_quant_info (Optional) Weights quantization information
56 * @param[in] out_quant_info (Optional) Output quantization info
Anthony Barbier2a07e182017-08-04 18:20:27 +010057 */
58 template <typename AccessorTypeWeights, typename AccessorTypeBiases>
Giorgio Arenaa66eaa22017-12-21 19:50:06 +000059 ConvolutionLayer(unsigned int conv_width,
60 unsigned int conv_height,
61 unsigned int ofm,
Georgios Pinitas6f669f02017-09-26 12:32:57 +010062 AccessorTypeWeights &&weights,
Giorgio Arenaa66eaa22017-12-21 19:50:06 +000063 AccessorTypeBiases &&biases,
64 const PadStrideInfo conv_info,
65 unsigned int num_groups = 1,
66 const WeightsInfo weights_info = WeightsInfo(),
67 const QuantizationInfo weights_quant_info = QuantizationInfo(),
68 const QuantizationInfo out_quant_info = QuantizationInfo())
Georgios Pinitas6f669f02017-09-26 12:32:57 +010069 : _conv_width(conv_width),
70 _conv_height(conv_height),
71 _ofm(ofm),
72 _weights(std::move(weights)),
73 _biases(std::move(biases)),
74 _conv_info(std::move(conv_info)),
75 _num_groups(num_groups),
76 _weights_info(std::move(weights_info)),
Giorgio Arenaa66eaa22017-12-21 19:50:06 +000077 _weights_quant_info(std::move(weights_quant_info)),
78 _out_quant_info(std::move(out_quant_info)),
Georgios Pinitas6f669f02017-09-26 12:32:57 +010079 _is(nullptr),
80 _os(nullptr),
81 _ws(nullptr),
82 _bs(nullptr)
Anthony Barbier2a07e182017-08-04 18:20:27 +010083 {
84 }
85
86 // Inherited methods overriden:
Georgios Pinitase2c82fe2017-10-02 18:51:47 +010087 std::unique_ptr<arm_compute::IFunction> instantiate_node(GraphContext &ctx, ITensorObject *input, ITensorObject *output) override;
Anthony Barbier2a07e182017-08-04 18:20:27 +010088
89private:
Georgios Pinitas6f669f02017-09-26 12:32:57 +010090 /** Instantiates a non-grouped convolution
91 *
Michalis Spyroue4720822017-10-02 17:44:52 +010092 * @param[in] input Input tensor
93 * @param[in] output Output tensor
Georgios Pinitas6f669f02017-09-26 12:32:57 +010094 * @param[in] conv_method_hint Hint that specifies which convolution layer method to use
95 *
96 * @return Convolution function
97 */
Michalis Spyroue4720822017-10-02 17:44:52 +010098 std::unique_ptr<arm_compute::IFunction> instantiate_convolution(ITensor *input, ITensor *output, ConvolutionMethodHint conv_method_hint);
Georgios Pinitas6f669f02017-09-26 12:32:57 +010099 /** Instantiates a grouped convolution
100 *
Michalis Spyroue4720822017-10-02 17:44:52 +0100101 * @param[in] input Input tensor
102 * @param[in] output Output tensor
Georgios Pinitas6f669f02017-09-26 12:32:57 +0100103 * @param[in] conv_method_hint Hint that specifies which convolution layer method to use
104 *
105 * @return Grouped Convolution function
106 */
Michalis Spyroue4720822017-10-02 17:44:52 +0100107 std::unique_ptr<arm_compute::IFunction> instantiate_grouped_convolution(ITensor *input, ITensor *output, ConvolutionMethodHint conv_method_hint);
Georgios Pinitas6f669f02017-09-26 12:32:57 +0100108
109private:
Giorgio Arenaa66eaa22017-12-21 19:50:06 +0000110 unsigned int _conv_width; /**< Convolution width */
111 unsigned int _conv_height; /**< Convolution height */
112 unsigned int _ofm; /**< Output feature maps */
113 Tensor _weights; /**< Weights tensor */
114 Tensor _biases; /**< Biases tensor */
115 const PadStrideInfo _conv_info; /**< Convolution layer information */
116 unsigned int _num_groups; /**< Number of groups */
117 const WeightsInfo _weights_info; /**< Convolution layer weights information */
118 const QuantizationInfo _weights_quant_info; /**< Output quantization information */
119 const QuantizationInfo _out_quant_info; /**< Output quantization information */
Georgios Pinitas6f669f02017-09-26 12:32:57 +0100120
121 std::unique_ptr<SubTensor[]> _is; /**< Input tensor sub-tensors used for grouped convolution */
122 std::unique_ptr<SubTensor[]> _os; /**< Output tensor sub-tensors used for grouped convolution */
123 std::unique_ptr<SubTensor[]> _ws; /**< Weights tensor sub-tensors used for grouped convolution */
124 std::unique_ptr<SubTensor[]> _bs; /**< Biases tensor sub-tensors used for grouped convolution */
Anthony Barbier2a07e182017-08-04 18:20:27 +0100125};
126} // namespace graph
127} // namespace arm_compute
128#endif /* __ARM_COMPUTE_GRAPH_CONVOLUTION_LAYER_H__ */