blob: f8494a872f865d1d3fa6245cc786e2329dc0d117 [file] [log] [blame]
Georgios Pinitas2a2db592018-08-15 12:14:46 +01001/*
Matthew Bentham758b5ba2020-03-05 23:37:48 +00002 * Copyright (c) 2018-2020 ARM Limited.
Georgios Pinitas2a2db592018-08-15 12:14:46 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph/mutators/GroupedConvolutionMutator.h"
25
26#include "arm_compute/graph/Graph.h"
27#include "arm_compute/graph/GraphBuilder.h"
28#include "arm_compute/graph/Logger.h"
29#include "arm_compute/graph/Utils.h"
30#include "arm_compute/graph/backends/BackendRegistry.h"
31#include "arm_compute/graph/nodes/Nodes.h"
32
33#include "arm_compute/core/utils/misc/Cast.h"
34
Matthew Bentham758b5ba2020-03-05 23:37:48 +000035#include "support/StringSupport.h"
36
Georgios Pinitas2a2db592018-08-15 12:14:46 +010037#include <set>
38
39namespace arm_compute
40{
41namespace graph
42{
43namespace
44{
45NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPair input, NodeID weights, NodeID bias,
Georgios Pinitas1c32bf3962018-11-12 18:36:19 +000046 PadStrideInfo conv_info, ConvolutionMethod method, ActivationLayerInfo fused_act, FastMathHint fast_math_hint, unsigned int num_groups)
Georgios Pinitas2a2db592018-08-15 12:14:46 +010047{
48 bool has_bias = (bias != EmptyNodeID);
49
50 // Split input
51 const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
Georgios Pinitas9e4824c2019-04-12 13:15:58 +010052 const unsigned int input_idx = get_dimension_idx(input_tensor_desc.layout, DataLayoutDimension::CHANNEL);
Georgios Pinitas2a2db592018-08-15 12:14:46 +010053 NodeID input_split = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx);
54
55 // Split weights
56 const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]);
Georgios Pinitas9e4824c2019-04-12 13:15:58 +010057 const unsigned int batch_idx = get_dimension_idx(weights_tensor_desc.layout, DataLayoutDimension::BATCHES);
Georgios Pinitas2a2db592018-08-15 12:14:46 +010058 NodeID weights_split = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx);
59
60 // Split bias
61 NodeID bias_split = EmptyNodeID;
62 if(has_bias)
63 {
64 // Split bias
65 bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
66 }
67
68 std::vector<NodeIdxPair> convolution_outputs;
69 for(unsigned int i = 0; i < num_groups; ++i)
70 {
71 NodeParams group_params = params;
72 NodeID conv_nid = g.add_node<ConvolutionLayerNode>(conv_info, 1, method, fast_math_hint);
73 g.add_connection(input_split, i, conv_nid, 0);
74 g.add_connection(weights_split, i, conv_nid, 1);
75 if(has_bias)
76 {
77 g.add_connection(bias_split, i, conv_nid, 2);
78 }
79
80 // Add group name
81 if(!group_params.name.empty())
82 {
83 group_params.name.append("_g" + arm_compute::support::cpp11::to_string(i));
84 }
85
86 // Set node parameters
87 INode *node = g.node(conv_nid);
88 ARM_COMPUTE_ERROR_ON(node == nullptr);
89 node->set_common_node_parameters(group_params);
90
Georgios Pinitas1c32bf3962018-11-12 18:36:19 +000091 // Down-cast node
92 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
93 conv_node->set_fused_activation(fused_act);
94
Georgios Pinitas2a2db592018-08-15 12:14:46 +010095 convolution_outputs.push_back({ conv_nid, 0 });
96 }
97
98 // Depth concatenate output
99 return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL);
100}
101} // namespace
102
103const char *GroupedConvolutionMutator::name()
104{
105 return "GroupedConvolutionMutator";
106}
107
Georgios Pinitasf4261ad2019-12-02 11:58:19 +0000108IGraphMutator::MutationType GroupedConvolutionMutator::type() const
109{
110 return IGraphMutator::MutationType::Backend;
111}
112
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100113void GroupedConvolutionMutator::mutate(Graph &g)
114{
115 // Early exit if no Convolution layers exist in graph
116 if(g.nodes(NodeType::ConvolutionLayer).empty())
117 {
118 return;
119 }
120
121 // Total nodes
122 size_t total_nodes = g.nodes().size();
123
124 // Iterate over convolution nodes
125 for(unsigned int i = 0; i < total_nodes; ++i)
126 {
127 INode *node = g.node(i);
128 if(node != nullptr && node->type() == NodeType::ConvolutionLayer && arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node)->num_groups() != 1)
129 {
130 // Validate node
Anthony Barbier890ad1b2018-08-22 13:44:36 +0100131 backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
132 Status status = backend.validate_node(*node);
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100133
134 // If grouped convolution is not supported
135 if(!bool(status))
136 {
137 // Down-cast node
138 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
139
140 // Get internal convolution info
Georgios Pinitas1c32bf3962018-11-12 18:36:19 +0000141 // TODO (geopin01) : Create a descriptor or a clone interface
142 const PadStrideInfo conv_info = conv_node->convolution_info();
143 const ConvolutionMethod conv_method = conv_node->convolution_method();
144 const ActivationLayerInfo fused_act_info = conv_node->fused_activation();
145 const FastMathHint fast_math_hint = conv_node->fast_math_hint();
146 const unsigned int num_groups = conv_node->num_groups();
147 const NodeParams params = conv_node->common_node_params();
148 const Target assigned_target = conv_node->assigned_target();
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100149
150 // Extract node ids
Georgios Pinitas1c32bf3962018-11-12 18:36:19 +0000151 ARM_COMPUTE_ERROR_ON(conv_node->input_edge(0) == nullptr || conv_node->input_edge(1) == nullptr);
152 const NodeID input_id = conv_node->input_edge(0)->producer()->id();
153 const NodeID weights_id = conv_node->input_edge(1)->producer()->id();
154 const NodeID bias_id = (conv_node->input_edge(2) != nullptr) ? conv_node->input_edge(2)->producer()->id() : EmptyNodeID;
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100155
156 // Get driving nodes
157 std::vector<NodeIdxPair> driving_nodes = get_driving_nodes(*node);
158
159 // Extract activation node accessor if any
160 auto node_accessor = conv_node->output(0)->extract_accessor();
161
162 // Current max tensor and node id
163 TensorID latest_tid = g.tensors().size();
164 NodeID latest_nid = g.nodes().size();
165
166 // Create grouped convolution node
167 NodeID grouped_conv_id = create_grouped_convolution(g, params, { input_id, 0 }, weights_id, bias_id,
Georgios Pinitas1c32bf3962018-11-12 18:36:19 +0000168 conv_info, conv_method, fused_act_info, fast_math_hint, num_groups);
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100169
170 // Remove convolution node
171 g.remove_node(node->id());
172
173 // Update batch normalization node outputs
174 for(auto &driving_node : driving_nodes)
175 {
176 g.add_connection(grouped_conv_id, 0, driving_node.node_id, driving_node.index);
177 }
178
179 // Update accessor to batch normalization node
180 g.node(grouped_conv_id)->output(0)->set_accessor(std::move(node_accessor));
181
182 // Configure new tensors and nodes
183 std::for_each(g.tensors().begin() + latest_tid, g.tensors().end(), [](std::unique_ptr<Tensor> &t)
184 {
185 configure_tensor(t.get());
186 });
187 std::for_each(g.nodes().begin() + latest_nid, g.nodes().end(), [&assigned_target](std::unique_ptr<INode> &n)
188 {
189 if(n != nullptr)
190 {
191 n->set_assigned_target(assigned_target);
192 }
193 });
194 }
195 }
196 }
197}
198} // namespace graph
199} // namespace arm_compute