blob: 5284fce806f1b00d8d617d6c377d34f540fc2b5e [file] [log] [blame]
Georgios Pinitasd8734b52017-12-22 15:27:52 +00001/*
SiCongLibc912972021-05-25 14:29:21 +01002 * Copyright (c) 2018-2021 Arm Limited.
Georgios Pinitasd8734b52017-12-22 15:27:52 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010024#include "arm_compute/graph/mutators/NodeFusionMutator.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000025
giuros01acce5042019-02-21 17:32:34 +000026#include "arm_compute/graph/GraphBuilder.h"
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Logger.h"
Georgios Pinitas2a2db592018-08-15 12:14:46 +010028#include "arm_compute/graph/Utils.h"
giuros01acce5042019-02-21 17:32:34 +000029#include "arm_compute/graph/backends/BackendRegistry.h"
30#include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationNode.h"
ramelg01b75d6242021-11-26 19:12:40 +000031#include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationWithPostOpsNode.h"
Sheri Zhangfb228032021-11-02 10:45:07 +000032#include "arm_compute/graph/nodes/FusedConvolutionWithPostOpNode.h"
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010033#include "arm_compute/graph/nodes/Nodes.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000034
Gunes Bayircc171f92021-09-13 13:38:29 +010035#include "src/graph/mutators/MutatorUtils.h"
36
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010037#include "support/Cast.h"
Georgios Pinitasd8734b52017-12-22 15:27:52 +000038
Sheri Zhangfb228032021-11-02 10:45:07 +000039#include <list>
Georgios Pinitas6f109bd2018-07-16 12:57:42 +010040#include <set>
41
Georgios Pinitasd8734b52017-12-22 15:27:52 +000042namespace arm_compute
43{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010044namespace graph
Georgios Pinitasd8734b52017-12-22 15:27:52 +000045{
46namespace detail
47{
Sheri Zhangc65023e2021-11-03 21:24:00 +000048void transfer_driving_nodes_and_remove_old_node(Graph &g, INode *new_node, INode *old_node, bool add_output_tensor)
49{
50 if(new_node == nullptr || old_node == nullptr)
51 {
52 return;
53 }
54
55 // Get driving nodes of last fusable node
56 std::vector<NodeIdxPair> last_driving_nodes = get_driving_nodes(*old_node);
57
58 // Extract last fusable node accessor if any
59 if(old_node->output(0) == nullptr)
60 {
61 return;
62 }
63 auto old_node_accessor = old_node->output(0)->extract_accessor();
64
65 // Remove node
66 g.remove_node(old_node->id());
67
68 // Update fused node outputs
69 for(auto &driving_node : last_driving_nodes)
70 {
71 g.add_connection(new_node->id(), 0, driving_node.node_id, driving_node.index);
72 if(add_output_tensor)
73 {
74 configure_tensor(new_node->output(0));
75 }
76 }
77
78 // Update accessor to fused node
79 new_node->output(0)->set_accessor(std::move(old_node_accessor));
80}
81
giuros01acce5042019-02-21 17:32:34 +000082void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
83{
84 ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
85
86 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(output_edge->producer());
87 auto *bn_node = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->consumer());
88
89 // Not fusing if number of groups is greater than 1
90 if(conv_node->num_groups() > 1)
91 {
92 return;
93 }
94
95 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing convolution node with ID : " << output_edge->producer_id()
96 << " with BatchNormalization Layer node with ID : " << output_edge->consumer_id() << std::endl);
97
98 // Prevent fusion if fused node has an output accessor
99 if(conv_node->output(0)->accessor() == nullptr)
100 {
101 const Target assigned_target = conv_node->assigned_target();
102
103 // Extract conv inputs
104 const auto conv_input_id = conv_node->input_edge(0)->producer_id();
105 const auto conv_weights_id = conv_node->input_edge(1)->producer_id();
giuros01acce5042019-02-21 17:32:34 +0000106 const auto conv_info = conv_node->convolution_info();
107 const auto conv_method = conv_node->convolution_method();
108 const auto num_groups = conv_node->num_groups();
109 const auto act_info = bn_node->fused_activation();
110 FastMathHint fast_math_hint = conv_node->fast_math_hint();
111
112 // Extract bn inputs
giuros01351bd132019-08-23 14:27:30 +0100113 const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
114 const auto bn_var_id = bn_node->input_edge(2)->producer_id();
115
116 const auto epsilon = bn_node->epsilon();
giuros01acce5042019-02-21 17:32:34 +0000117
118 // Create the fused node
Manuel Bottinibffb41e2019-06-20 16:00:27 +0100119 const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint, act_info);
giuros01acce5042019-02-21 17:32:34 +0000120
121 if(conv_node->input_edge(2) != nullptr)
122 {
123 auto conv_bias_id = conv_node->input_edge(2)->producer_id();
124 g.add_connection(conv_bias_id, 0, fused_id, 2);
125 }
126
127 // Add connections from the conv/batch_norm inputs to the fused node
128 g.add_connection(conv_input_id, 0, fused_id, 0);
129 g.add_connection(conv_weights_id, 0, fused_id, 1);
130 g.add_connection(bn_mean_id, 0, fused_id, 3);
131 g.add_connection(bn_var_id, 0, fused_id, 4);
giuros01351bd132019-08-23 14:27:30 +0100132
133 if(bn_node->input_edge(3) != nullptr)
134 {
135 const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
136 g.add_connection(bn_beta_id, 0, fused_id, 5);
137 }
138
139 if(bn_node->input_edge(4) != nullptr)
140 {
141 const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
142 g.add_connection(bn_gamma_id, 0, fused_id, 6);
143 }
giuros01acce5042019-02-21 17:32:34 +0000144
Sheri Zhangc65023e2021-11-03 21:24:00 +0000145 auto fused_node = g.node(fused_id);
146 auto bn_node_name = bn_node->name();
giuros01acce5042019-02-21 17:32:34 +0000147
Sheri Zhangc65023e2021-11-03 21:24:00 +0000148 transfer_driving_nodes_and_remove_old_node(g, fused_node, bn_node, true);
giuros01acce5042019-02-21 17:32:34 +0000149
giuros01acce5042019-02-21 17:32:34 +0000150 fused_node->set_assigned_target(assigned_target);
151 fused_node->set_common_node_parameters(NodeParams{ conv_node->name() + "+" + bn_node_name, assigned_target });
152
153 // Remove convolution node
154 g.remove_node(conv_node->id());
155 }
156 else
157 {
158 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution with batch normalization due to the presence of an output accessor\n");
159 }
160}
161
Manuel Bottinibffb41e2019-06-20 16:00:27 +0100162void fuse_depthwise_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
163{
164 ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
165
166 auto *depth_conv_node = arm_compute::utils::cast::polymorphic_downcast<DepthwiseConvolutionLayerNode *>(output_edge->producer());
167 auto *bn_node = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->consumer());
168
169 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing depthwise convolution node with ID : " << output_edge->producer_id()
170 << " with BatchNormalization Layer node with ID : " << output_edge->consumer_id() << std::endl);
171
172 // Prevent fusion if fused node has an output accessor
173 if(depth_conv_node->output(0)->accessor() == nullptr)
174 {
175 const Target assigned_target = depth_conv_node->assigned_target();
176
177 // Extract conv inputs
178 const auto depth_conv_input_id = depth_conv_node->input_edge(0)->producer_id();
179 const auto conv_weights_id = depth_conv_node->input_edge(1)->producer_id();
180 const auto conv_info = depth_conv_node->convolution_info();
181 const auto depth_conv_method = depth_conv_node->depthwise_convolution_method();
182 const auto depth_multiplier = depth_conv_node->depth_multiplier();
183 const auto act_info = bn_node->fused_activation();
184
185 // Extract bn inputs
186 const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
187 const auto bn_var_id = bn_node->input_edge(2)->producer_id();
188 const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
189 const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
190 const auto epsilon = bn_node->epsilon();
191
192 // Create the fused node
193 const NodeID fused_id = g.add_node<FusedDepthwiseConvolutionBatchNormalizationNode>(epsilon, conv_info, depth_multiplier, depth_conv_method, act_info);
194
195 if(depth_conv_node->input_edge(2) != nullptr)
196 {
197 const auto conv_bias_id = depth_conv_node->input_edge(2)->producer_id();
198 g.add_connection(conv_bias_id, 0, fused_id, 2);
199 }
200
201 // Add connections from the conv/batch_norm inputs to the fused node
202 g.add_connection(depth_conv_input_id, 0, fused_id, 0);
203 g.add_connection(conv_weights_id, 0, fused_id, 1);
204 g.add_connection(bn_mean_id, 0, fused_id, 3);
205 g.add_connection(bn_var_id, 0, fused_id, 4);
206 g.add_connection(bn_beta_id, 0, fused_id, 5);
207 g.add_connection(bn_gamma_id, 0, fused_id, 6);
208
Sheri Zhangc65023e2021-11-03 21:24:00 +0000209 auto fused_node = g.node(fused_id);
210 auto bn_node_name = bn_node->name();
Manuel Bottinibffb41e2019-06-20 16:00:27 +0100211
Sheri Zhangc65023e2021-11-03 21:24:00 +0000212 transfer_driving_nodes_and_remove_old_node(g, fused_node, bn_node, true);
Manuel Bottinibffb41e2019-06-20 16:00:27 +0100213
Manuel Bottinibffb41e2019-06-20 16:00:27 +0100214 fused_node->set_assigned_target(assigned_target);
215 fused_node->set_common_node_parameters(NodeParams{ depth_conv_node->name() + "+" + bn_node_name, assigned_target });
216
217 // Remove convolution node
218 g.remove_node(depth_conv_node->id());
219 }
220 else
221 {
222 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of depthwise convolution with batch normalization due to the presence of an output accessor\n");
223 }
224}
225
Georgios Pinitas08346e92018-10-16 19:10:46 +0100226template <typename N>
giuros01acce5042019-02-21 17:32:34 +0000227void fuse_node_with_activation(Graph &g, const Edge *output_edge, const std::set<Activation> &supported_fused_activations)
228{
229 ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
230
231 auto *n_node = arm_compute::utils::cast::polymorphic_downcast<N *>(output_edge->producer());
232 auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(output_edge->consumer());
233
234 ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr || n_node->output(0) == nullptr);
235
236 // Check if activation is supported for fusion
237 if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
238 {
239 return;
240 }
241
Sheri Zhang16dddd22020-05-27 15:03:48 +0100242 // EltwiseLayerNode can only be fused when dataype is float
243 if(n_node->type() == NodeType::EltwiseLayer && !is_data_type_float(n_node->output(0)->desc().data_type))
244 {
245 return;
246 }
247
giuros01acce5042019-02-21 17:32:34 +0000248 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing node with ID : " << output_edge->producer_id()
249 << " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl);
250
251 // Prevent fusion if fused node has an output accessor
252 if(n_node->output(0)->accessor() == nullptr)
253 {
giuros01acce5042019-02-21 17:32:34 +0000254 // Set activation info to fused node
255 n_node->set_fused_activation(act_node->activation_info());
256
Sheri Zhangc65023e2021-11-03 21:24:00 +0000257 transfer_driving_nodes_and_remove_old_node(g, n_node, act_node, false);
giuros01acce5042019-02-21 17:32:34 +0000258 }
259 else
260 {
261 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of node with activation due to the presence of an output accessor\n");
262 }
263}
264
Gunes Bayir814bddf2021-09-01 16:20:54 +0100265template <typename N>
266void fuse_pad_with_convolution(Graph &g, const Edge *output_edge)
267{
268 auto *pad_node = arm_compute::utils::cast::polymorphic_downcast<PadLayerNode *>(output_edge->producer());
269 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<N *>(output_edge->consumer());
270
271 const Edge *input_edge = pad_node->input_edge(0);
272 if(input_edge != nullptr && input_edge->tensor() != nullptr && pad_node->output(0)->accessor() == nullptr
273 && pad_node->pad_value().get<float>() == 0.0)
274 {
275 const DataLayout layout = input_edge->tensor()->desc().layout;
276 const PaddingList padding_list = pad_node->padding();
Gunes Bayir814bddf2021-09-01 16:20:54 +0100277
Gunes Bayircc171f92021-09-13 13:38:29 +0100278 const unsigned int height_index = get_dimension_idx(layout, DataLayoutDimension::HEIGHT);
279 const unsigned int width_index = get_dimension_idx(layout, DataLayoutDimension::WIDTH);
280
281 const PaddingInfo pad_w = width_index < padding_list.size() ? padding_list[width_index] : PaddingInfo(0, 0);
282 const PaddingInfo pad_h = height_index < padding_list.size() ? padding_list[height_index] : PaddingInfo(0, 0);
283
284 if(is_padding_in_height_or_width(layout, padding_list))
Gunes Bayir814bddf2021-09-01 16:20:54 +0100285 {
286 // Add paddings to the convolution node
287 const PadStrideInfo conv_info = conv_node->convolution_info();
288 const PadStrideInfo new_conv_info(
289 conv_info.stride().first,
290 conv_info.stride().second,
291 conv_info.pad_left() + pad_w.first,
292 conv_info.pad_right() + pad_w.second,
293 conv_info.pad_top() + pad_h.first,
294 conv_info.pad_bottom() + pad_h.second,
295 conv_info.round());
296 conv_node->set_convolution_info(new_conv_info);
297
298 // Update drivers of the convolution node
299 std::vector<NodeIdxPair> pad_driver_nodes = get_driver_nodes(*pad_node);
300 g.remove_node(pad_node->id());
301
302 // Update fused node inputs
303 for(auto &driver_node : pad_driver_nodes)
304 {
305 g.add_connection(driver_node.node_id, driver_node.index, conv_node->id(), 0);
306 }
307 }
308 }
309}
310
giuros01acce5042019-02-21 17:32:34 +0000311template <typename N1, typename N2, typename F, typename... Args>
312void fuse_layer(Graph &g, std::function<bool(INode &)> const &prec, const F fuse_fcn, Args &&... optional_arguments)
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000313{
SiCongLibc912972021-05-25 14:29:21 +0100314 // Note that fused nodes may be added to the end of the node list.
315 // Instead of only looping over the original list of nodes, we loop over the current node list which could be growing.
316 // This is intentional as it probes the newly added fused nodes for further fusing opportunities.
317 for(unsigned int i = 0; i < g.nodes().size(); ++i)
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000318 {
SiCongLibc912972021-05-25 14:29:21 +0100319 auto node = g.node(i);
Sheri Zhangfb228032021-11-02 10:45:07 +0000320 // Check if the node is of type N1 and not a branching node
giuros01acce5042019-02-21 17:32:34 +0000321 if(node && node->type() == N1::node_type && node->output_edges().size() == 1)
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000322 {
giuros01acce5042019-02-21 17:32:34 +0000323 const auto output_edge_id = *node->output_edges().begin();
324 const auto output_edge = g.edge(output_edge_id);
325
Sheri Zhangfb228032021-11-02 10:45:07 +0000326 // Check if following node is a type N2 node
giuros01acce5042019-02-21 17:32:34 +0000327 if((output_edge != nullptr) && (output_edge->consumer() != nullptr) && (output_edge->consumer()->type() == N2::node_type) && prec(*output_edge->producer()))
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000328 {
giuros01acce5042019-02-21 17:32:34 +0000329 fuse_fcn(g, output_edge, optional_arguments...);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000330 }
331 }
332 }
333}
Sheri Zhangfb228032021-11-02 10:45:07 +0000334
Sheri Zhangc65023e2021-11-03 21:24:00 +0000335/** Check valid combinations:
336 *
337 * | Main operator | Post operators |
338 * |:--------------|:---------------------------|
339 * |conv | add |
340 * |conv | act + add |
341 * |conv | add + act |
342 * |conv | act + add + act |
343 *
344*/
345#define MAX_VALIDE_COMBINATION 4
346#define MAX_POST_OP_NUM 3
347NodeType valide_post_op_type[MAX_VALIDE_COMBINATION][MAX_POST_OP_NUM] = { { EltwiseLayerNode::node_type },
348 { EltwiseLayerNode::node_type, ActivationLayerNode::node_type },
349 { ActivationLayerNode::node_type, EltwiseLayerNode::node_type },
350 { ActivationLayerNode::node_type, EltwiseLayerNode::node_type, ActivationLayerNode::node_type }
351};
352
353bool check_post_op_type(NodeType *post_op_type, int len)
354{
355 if(len > MAX_POST_OP_NUM || len <= 0)
356 {
357 return false;
358 }
359
360 bool found = false;
361 for(int i = 0; i < MAX_VALIDE_COMBINATION; ++i)
362 {
363 for(int j = 0; j < len; ++j)
364 {
365 if(post_op_type[j] != valide_post_op_type[i][j])
366 {
367 found = false;
368 break;
369 }
370 found = true;
371 }
372 if(found)
373 break;
374 }
375
376 return found;
377}
378
379void fuse_convolution_with_post_op(Graph &g, INode *fused_node, std::list<INode *> post_op_node_list, int prev_op_dst_pos)
380{
381 unsigned int op_idx = 0;
382 // Fuse post operators with conv
383 for(const auto &post_op : post_op_node_list)
384 {
385 switch(post_op->type())
386 {
387 case EltwiseLayerNode::node_type:
388 {
389 auto *eltwise_node = arm_compute::utils::cast::polymorphic_downcast<EltwiseLayerNode *>(post_op);
390 ARM_COMPUTE_ERROR_ON(eltwise_node->output(0) == nullptr);
391
392 fused_node->post_op_info_list().push_back(std::make_unique<ConvPostOpInfoEltwiseAdd>(prev_op_dst_pos, eltwise_node->convert_policy()));
393 ARM_COMPUTE_LOG_GRAPH_VERBOSE(" with Elementwise Layer node with ID : " << post_op->id());
394 break;
395 }
396 case ActivationLayerNode::node_type:
397 {
398 auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(post_op);
399 ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr);
400
401 fused_node->post_op_info_list().push_back(std::make_unique<ConvPostOpInfoActivation>(act_node->activation_info()));
402 ARM_COMPUTE_LOG_GRAPH_VERBOSE(" with Activation Layer node with ID : " << post_op->id());
403 break;
404 }
405 default:
406 {
407 break;
408 }
409 }
410
411 if(op_idx == post_op_node_list.size() - 1) // last fusable node
412 {
413 transfer_driving_nodes_and_remove_old_node(g, fused_node, post_op, true);
414 }
415 else
416 {
417 // Remove node
418 g.remove_node(post_op->id());
419 }
420 op_idx++;
421 }
422}
423
ramelg01b75d6242021-11-26 19:12:40 +0000424std::list<INode *> get_post_op_list(Graph &g, int &eltwise_operand_id, int &prev_op_dst_pos, unsigned int conv_node_id, const std::set<Activation> &supported_fused_activations)
Sheri Zhangc65023e2021-11-03 21:24:00 +0000425{
426 std::list<INode *> post_op_node_list = {};
427 NodeID prev_op_dst_id = conv_node_id;
428 NodeType post_op_type_list[3] = { NodeType::Dummy, NodeType::Dummy, NodeType::Dummy };
429 int post_op_idx = 0;
ramelg01b75d6242021-11-26 19:12:40 +0000430
431 // Get list of the connected nodes
432 auto current_node = g.node(conv_node_id);
433
434 while(post_op_node_list.size() < 3)
Sheri Zhangc65023e2021-11-03 21:24:00 +0000435 {
ramelg01b75d6242021-11-26 19:12:40 +0000436 // This convolution node must have only one output edge, otherwise this function would not have been called
437
438 auto current_output_edge_id = current_node->output_edges().begin();
439 auto current_output_edge = g.edge(*current_output_edge_id);
440 auto post_op_node = current_output_edge->consumer();
441
Sheri Zhangc65023e2021-11-03 21:24:00 +0000442 bool fusable_post_op = false;
443 if(post_op_node != nullptr && post_op_node->output_edges().size() > 0)
444 {
ramelg01b75d6242021-11-26 19:12:40 +0000445 switch(post_op_node->type())
Sheri Zhangc65023e2021-11-03 21:24:00 +0000446 {
ramelg01b75d6242021-11-26 19:12:40 +0000447 case EltwiseLayerNode::node_type:
Sheri Zhangc65023e2021-11-03 21:24:00 +0000448 {
ramelg01b75d6242021-11-26 19:12:40 +0000449 auto *eltwise_node = arm_compute::utils::cast::polymorphic_downcast<EltwiseLayerNode *>(post_op_node);
450 ARM_COMPUTE_ERROR_ON(eltwise_node->output(0) == nullptr);
451 if(eltwise_node->output(0)->accessor() == nullptr)
Sheri Zhangc65023e2021-11-03 21:24:00 +0000452 {
ramelg01b75d6242021-11-26 19:12:40 +0000453 post_op_node_list.push_back(post_op_node);
454 fusable_post_op = true;
455 post_op_type_list[post_op_idx++] = eltwise_node->type();
Sheri Zhangc65023e2021-11-03 21:24:00 +0000456
ramelg01b75d6242021-11-26 19:12:40 +0000457 // Extract elementwise inputs
458 const auto eltwise_input_id_0 = eltwise_node->input_edge(0)->producer_id();
459 const auto eltwise_input_id_1 = eltwise_node->input_edge(1)->producer_id();
460 if(eltwise_input_id_0 == prev_op_dst_id)
Sheri Zhangc65023e2021-11-03 21:24:00 +0000461 {
ramelg01b75d6242021-11-26 19:12:40 +0000462 eltwise_operand_id = eltwise_input_id_1;
463 prev_op_dst_pos = 0;
Sheri Zhangc65023e2021-11-03 21:24:00 +0000464 }
ramelg01b75d6242021-11-26 19:12:40 +0000465 else if(eltwise_input_id_1 == prev_op_dst_id)
466 {
467 eltwise_operand_id = eltwise_input_id_0;
468 prev_op_dst_pos = 1;
469 }
Sheri Zhangc65023e2021-11-03 21:24:00 +0000470 }
ramelg01b75d6242021-11-26 19:12:40 +0000471 else
Sheri Zhangc65023e2021-11-03 21:24:00 +0000472 {
ramelg01b75d6242021-11-26 19:12:40 +0000473 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with elementwise due to the presence of an output accessor\n");
Sheri Zhangc65023e2021-11-03 21:24:00 +0000474 }
ramelg01b75d6242021-11-26 19:12:40 +0000475 break;
476 }
477 case ActivationLayerNode::node_type:
478 {
479 auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(post_op_node);
480 ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr);
481 // Check if activation is supported for fusion
482 if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
Sheri Zhangc65023e2021-11-03 21:24:00 +0000483 {
484 break;
485 }
ramelg01b75d6242021-11-26 19:12:40 +0000486 if(act_node->output(0)->accessor() == nullptr)
487 {
488 post_op_node_list.push_back(post_op_node);
489 fusable_post_op = true;
490 post_op_type_list[post_op_idx++] = act_node->type();
491 prev_op_dst_id = act_node->id();
492 }
493 else
494 {
495 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to the presence of an output accessor\n");
496 }
497 break;
498 }
499 default:
500 {
501 break;
Sheri Zhangc65023e2021-11-03 21:24:00 +0000502 }
503 }
504
505 // Check if the node is not a branching node and current node is fusable
ramelg01b75d6242021-11-26 19:12:40 +0000506 if(post_op_node->output_edges().size() == 1 && fusable_post_op == true)
Sheri Zhangc65023e2021-11-03 21:24:00 +0000507 {
ramelg01b75d6242021-11-26 19:12:40 +0000508 current_node = post_op_node;
Sheri Zhangc65023e2021-11-03 21:24:00 +0000509 }
510 else
511 {
512 break;
513 }
514 }
515 }
516
517 // Check whether it's valid post op list
518 if(post_op_node_list.size() > 0)
519 {
520 bool fuse_with_post_op = check_post_op_type(post_op_type_list, post_op_node_list.size());
521 if(!fuse_with_post_op)
522 {
523 post_op_node_list.clear();
524 }
525 }
526
527 return post_op_node_list;
528}
529
Sheri Zhangfb228032021-11-02 10:45:07 +0000530/** Fuse below operators:
531 *
532 * | Main operator | Post operators |
533 * |:--------------|:---------------------------|
534 * |conv | add |
535 * |conv | act + add |
536 * |conv | add + act |
537 * |conv | act + add + act |
538 *
539 * Notes: currently, only GEMM supports fusion with post operator
540*/
ramelg01b75d6242021-11-26 19:12:40 +0000541void fuse_convolution_with_post_ops(Graph &g, const Edge *output_edge, unsigned int conv_node_id, const std::set<Activation> &supported_fused_activations)
Sheri Zhangfb228032021-11-02 10:45:07 +0000542{
543 ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
544
ramelg01b75d6242021-11-26 19:12:40 +0000545 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(output_edge->producer());
Sheri Zhangfb228032021-11-02 10:45:07 +0000546 ARM_COMPUTE_ERROR_ON(conv_node->output(0) == nullptr);
Sheri Zhangc65023e2021-11-03 21:24:00 +0000547
548 const ConvolutionMethod conv_algorithm = conv_node->convolution_method();
549 if(conv_algorithm != ConvolutionMethod::GEMM)
550 {
551 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
552 return;
553 }
554
Sheri Zhangfb228032021-11-02 10:45:07 +0000555 // Prevent fusion if fused node has an output accessor
556 if(conv_node->output(0)->accessor() == nullptr)
557 {
ramelg01b75d6242021-11-26 19:12:40 +0000558 // If data type is FP32/FP16, data layout is NHWC, and filter size is 1x1, fuse convolution with post op, as Conv1x1 always leads to GEMM.
Sheri Zhangfb228032021-11-02 10:45:07 +0000559 const Edge *input_edge = conv_node->input_edge(1);
560 if(input_edge != nullptr && input_edge->tensor() != nullptr)
561 {
562 const DataLayout data_layout = input_edge->tensor()->desc().layout;
563 const DataType data_type = input_edge->tensor()->desc().data_type;
564 const TensorShape tensor_shape = input_edge->tensor()->desc().shape;
ramelg01b75d6242021-11-26 19:12:40 +0000565 if((data_layout != DataLayout::NHWC) || (is_data_type_float(data_type) == false) || (tensor_shape.y() != 1) || (tensor_shape.z() != 1))
Sheri Zhangfb228032021-11-02 10:45:07 +0000566 {
567 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
568 return;
569 }
570 }
571 else
572 {
573 return;
574 }
575
Sheri Zhangc65023e2021-11-03 21:24:00 +0000576 // Get post op list
577 int eltwise_operand_id = 0;
578 int prev_op_dst_pos = 0; // Previous operator dst's postion in current operator
579 std::list<INode *> post_op_node_list = get_post_op_list(g, eltwise_operand_id, prev_op_dst_pos, conv_node_id, supported_fused_activations);
Sheri Zhangfb228032021-11-02 10:45:07 +0000580
581 if(post_op_node_list.size() == 0)
582 {
583 return;
584 }
Sheri Zhangc65023e2021-11-03 21:24:00 +0000585 else // Do convolution fusion with post op if there're one(elementwise), two or more operators
Sheri Zhangfb228032021-11-02 10:45:07 +0000586 {
587 const Target assigned_target = conv_node->assigned_target();
588
589 // Extract conv inputs
590 const auto conv_input_id = conv_node->input_edge(0)->producer_id();
591 const auto conv_weights_id = conv_node->input_edge(1)->producer_id();
592 const auto conv_info = conv_node->convolution_info();
593 const auto conv_method = conv_node->convolution_method();
594 const auto num_groups = conv_node->num_groups();
595 FastMathHint fast_math_hint = conv_node->fast_math_hint();
596
597 // Create the fused node
598 const NodeID fused_id = g.add_node<FusedConvolutionWithPostOpNode>(conv_info, num_groups, conv_method, fast_math_hint);
599 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing convolution node with ID : " << conv_node->id());
600
601 // Add connections from the conv inputs to the fused node
602 g.add_connection(conv_input_id, 0, fused_id, 0);
603 g.add_connection(conv_weights_id, 0, fused_id, 1);
604 if(conv_node->input_edge(2) != nullptr)
605 {
606 auto conv_bias_id = conv_node->input_edge(2)->producer_id();
607 g.add_connection(conv_bias_id, 0, fused_id, 2);
608 }
ramelg01b75d6242021-11-26 19:12:40 +0000609 // Adding the Element wise operand in case the post op is element wise operation
610 auto it = std::find_if(post_op_node_list.begin(),
611 post_op_node_list.end(),
612 [&](const INode * nd)
613 {
614 return (nd->type() == graph::NodeType::EltwiseLayer);
615 });
616
617 if(it != post_op_node_list.end())
618 {
619 g.add_connection(eltwise_operand_id, 0, fused_id, 3);
620 }
Sheri Zhangfb228032021-11-02 10:45:07 +0000621 g.remove_node(conv_node->id());
622
623 // Update fused node outputs
Sheri Zhangc65023e2021-11-03 21:24:00 +0000624 auto fused_node = g.node(fused_id);
Sheri Zhangfb228032021-11-02 10:45:07 +0000625 fused_node->set_assigned_target(assigned_target);
626
Sheri Zhangc65023e2021-11-03 21:24:00 +0000627 // Fuse convolution with post op
628 fuse_convolution_with_post_op(g, fused_node, post_op_node_list, prev_op_dst_pos);
Sheri Zhangfb228032021-11-02 10:45:07 +0000629
Sheri Zhangfb228032021-11-02 10:45:07 +0000630 post_op_node_list.clear();
631 ARM_COMPUTE_LOG_GRAPH_VERBOSE(std::endl);
632 }
633 }
634 else
635 {
636 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to the presence of an output accessor\n");
637 }
638}
639
ramelg01b75d6242021-11-26 19:12:40 +0000640void fuse_convolution_batch_normalization_with_post_ops(Graph &g, const Edge *output_edge, unsigned int conv_node_id, const std::set<Activation> &supported_fused_activations)
641{
642 ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
643
644 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<FusedConvolutionBatchNormalizationNode *>(output_edge->producer());
645 ARM_COMPUTE_ERROR_ON(conv_node->output(0) == nullptr);
646 const ConvolutionMethod conv_algorithm = conv_node->convolution_method();
647 if(conv_algorithm != ConvolutionMethod::GEMM)
648 {
649 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
650 return;
651 }
652
653 // Prevent fusion if fused node has an output accessor
654 if(conv_node->output(0)->accessor() == nullptr)
655 {
656 // If data type is FP32/FP16, data layout is NHWC, and filter size is 1x1, fuse convolution with post op, as Conv1x1 always leads to GEMM.
657 const Edge *input_edge = conv_node->input_edge(1);
658 if(input_edge != nullptr && input_edge->tensor() != nullptr)
659 {
660 const DataLayout data_layout = input_edge->tensor()->desc().layout;
661 const DataType data_type = input_edge->tensor()->desc().data_type;
662 const TensorShape tensor_shape = input_edge->tensor()->desc().shape;
663 if((data_layout != DataLayout::NHWC) || (is_data_type_float(data_type) == false) || (tensor_shape.y() != 1) || (tensor_shape.z() != 1))
664 {
665 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
666 return;
667 }
668 }
669 else
670 {
671 return;
672 }
673
674 // Get post op list
675 int eltwise_operand_id = 0;
676 int prev_op_dst_pos = 0; // Previous operator dst's postion in current operator
677 std::list<INode *> post_op_node_list = get_post_op_list(g, eltwise_operand_id, prev_op_dst_pos, conv_node_id, supported_fused_activations);
678
679 if(post_op_node_list.size() == 0)
680 {
681 return;
682 }
683 else // Do convolution fusion with post op if there're one(elementwise), two or more operators
684 {
685 const Target assigned_target = conv_node->assigned_target();
686
687 // Extract conv inputs
688 const auto conv_input_id = conv_node->input_edge(0)->producer_id();
689 const auto conv_weights_id = conv_node->input_edge(1)->producer_id();
690 const auto bn_mean_id = conv_node->input_edge(3)->producer_id();
691 const auto bn_var_id = conv_node->input_edge(4)->producer_id();
692 const auto conv_info = conv_node->convolution_info();
693 const auto conv_method = conv_node->convolution_method();
694 const auto num_groups = conv_node->num_groups();
695 FastMathHint fast_math_hint = conv_node->fast_math_hint();
696
697 // Create the fused node
698
699 const float epsilon = conv_node->epsilon();
700 const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationWithPostOpsNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint);
701
702 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing FusedConvolutionBatchNormalization node with ID : " << conv_node->id());
703
704 // Add connections from the conv inputs to the fused node
705 g.add_connection(conv_input_id, 0, fused_id, 0);
706 g.add_connection(conv_weights_id, 0, fused_id, 1);
707
708 if(conv_node->input_edge(2) != nullptr)
709 {
710 auto conv_bias_id = conv_node->input_edge(2)->producer_id();
711 g.add_connection(conv_bias_id, 0, fused_id, 2);
712 }
713 g.add_connection(bn_mean_id, 0, fused_id, 3);
714 g.add_connection(bn_var_id, 0, fused_id, 4);
715
716 // Move connections of old FusedConvolutionBatchNormalization to the fused node
717 if(conv_node->input_edge(5) != nullptr)
718 {
719 const auto bn_beta_id = conv_node->input_edge(5)->producer_id();
720 g.add_connection(bn_beta_id, 0, fused_id, 5);
721 }
722
723 if(conv_node->input_edge(6) != nullptr)
724 {
725 const auto bn_gamma_id = conv_node->input_edge(6)->producer_id();
726 g.add_connection(bn_gamma_id, 0, fused_id, 6);
727 }
728
729 // Adding the Element wise operand in case the post op is element wise operation
730 auto it = std::find_if(post_op_node_list.begin(),
731 post_op_node_list.end(),
732 [&](const INode * nd)
733 {
734 return (nd->type() == graph::NodeType::EltwiseLayer);
735 });
736
737 if(it != post_op_node_list.end())
738 {
739 g.add_connection(eltwise_operand_id, 0, fused_id, 7);
740 }
741
742 // Update fused node outputs
743 auto fused_node = g.node(fused_id);
744 fused_node->set_assigned_target(assigned_target);
745
746 auto conv_node_name = conv_node->name();
747
748 // collect the post ops names
749 std::string post_ops_name = "";
750 for(auto &post_op : post_op_node_list)
751 {
752 post_ops_name += post_op->name();
753 }
754 fused_node->set_common_node_parameters(NodeParams{ conv_node->name() + "+" + post_ops_name, assigned_target });
755
756 // Fuse convolution with post op
757 fuse_convolution_with_post_op(g, fused_node, post_op_node_list, prev_op_dst_pos);
758
759 post_op_node_list.clear();
760 g.remove_node(conv_node->id());
761 ARM_COMPUTE_LOG_GRAPH_VERBOSE(std::endl);
762 }
763 }
764 else
765 {
766 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to the presence of an output accessor\n");
767 }
768}
769
Sheri Zhangfb228032021-11-02 10:45:07 +0000770template <typename N1, typename F, typename... Args>
771void fuse_layer(Graph &g, std::function<bool(INode &)> const &prec, const F fuse_fcn, Args &&... optional_arguments)
772{
773 // Note that fused nodes may be added to the end of the node list.
774 // Instead of only looping over the original list of nodes, we loop over the current node list which could be growing.
775 // This is intentional as it probes the newly added fused nodes for further fusing opportunities.
776 for(unsigned int i = 0; i < g.nodes().size(); ++i)
777 {
778 auto node = g.node(i);
779 // Check if the node is of type N1 and not a branching node
780 if(node && node->type() == N1::node_type && node->output_edges().size() == 1)
781 {
782 const auto output_edge_id = *node->output_edges().begin();
783 const auto output_edge = g.edge(output_edge_id);
784
785 // Check if it's the correct target
786 if((output_edge != nullptr) && (output_edge->consumer() != nullptr) && prec(*output_edge->producer()))
787 {
788 fuse_fcn(g, output_edge, i, optional_arguments...);
789 }
790 }
791 }
792}
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000793} // namespace detail
794
795const char *NodeFusionMutator::name()
796{
797 return "NodeFusionMutator";
798}
799
Georgios Pinitasf4261ad2019-12-02 11:58:19 +0000800IGraphMutator::MutationType NodeFusionMutator::type() const
801{
802 return IGraphMutator::MutationType::Backend;
803}
804
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000805void NodeFusionMutator::mutate(Graph &g)
806{
Georgios Pinitas08346e92018-10-16 19:10:46 +0100807 // Supported activations when fusing
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +0100808 const std::set<Activation> supported_fused_activations = { Activation::ABS, Activation::BOUNDED_RELU, Activation::ELU,
809 Activation::HARD_SWISH, Activation::IDENTITY, Activation::LEAKY_RELU,
810 Activation::LINEAR, Activation::LOGISTIC, Activation::LU_BOUNDED_RELU,
811 Activation::RELU, Activation::SOFT_RELU, Activation::SQRT,
812 Activation::SQUARE, Activation::TANH
813 };
Georgios Pinitas08346e92018-10-16 19:10:46 +0100814
Georgios Pinitas60e98252018-10-22 16:17:20 +0100815 // Preconditions
Michalis Spyrou299fdd32019-05-01 13:03:59 +0100816 auto empty_prec = [](INode &)
Georgios Pinitas60e98252018-10-22 16:17:20 +0100817 {
818 return true;
819 };
Giorgio Arena8b2a7d32020-02-11 17:21:31 +0000820 auto cl_target_prec = [](INode & n)
821 {
822 return n.assigned_target() == Target::CL;
823 };
Isabella Gottardi0ae5de92019-03-14 10:32:11 +0000824 auto qs8_prec = [&g](INode & n)
Georgios Pinitas60e98252018-10-22 16:17:20 +0100825 {
826 ARM_COMPUTE_ERROR_ON(n.output(0) == nullptr);
Isabella Gottardi0ae5de92019-03-14 10:32:11 +0000827
828 const auto output_edge_id = *n.output_edges().begin();
829 const auto output_edge = g.edge(output_edge_id);
830 // To perform fusion the two nodes must have same output quantization information
831 const bool same_qinfo = n.output(0)->desc().quant_info == output_edge->producer()->output(0)->desc().quant_info;
832 const bool output_qasymm8 = n.output(0)->desc().data_type == DataType::QASYMM8;
833
Georgios Pinitascadb3682019-03-29 10:54:36 +0000834 return (output_qasymm8 && same_qinfo) || !output_qasymm8;
Georgios Pinitas60e98252018-10-22 16:17:20 +0100835 };
836
837 // Fusion mutations
Sheri Zhangfb228032021-11-02 10:45:07 +0000838
Gunes Bayir814bddf2021-09-01 16:20:54 +0100839 detail::fuse_layer<PadLayerNode, ConvolutionLayerNode>(g, empty_prec, detail::fuse_pad_with_convolution<ConvolutionLayerNode>);
840 detail::fuse_layer<PadLayerNode, DepthwiseConvolutionLayerNode>(g, empty_prec, detail::fuse_pad_with_convolution<DepthwiseConvolutionLayerNode>);
SiCongLi1d4a3202021-11-12 15:38:00 +0000841 // The fusion of PostOps to ConvolutionLayer:
842 // It must occur after the fusion of PadLayer into ConvolutionLayer
843 // It must occur before the fusion of normal ActivationLayer into ConvolutionLayer as it takes precedence
ramelg01b75d6242021-11-26 19:12:40 +0000844 detail::fuse_layer<ConvolutionLayerNode>(g, cl_target_prec, detail::fuse_convolution_with_post_ops, supported_fused_activations);
Gian Marco Iodice047c6fc2020-09-21 14:22:25 +0100845 detail::fuse_layer<BatchNormalizationLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<BatchNormalizationLayerNode>, supported_fused_activations);
SiCongLi1d4a3202021-11-12 15:38:00 +0000846 detail::fuse_layer<ConvolutionLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<ConvolutionLayerNode>, supported_fused_activations);
Gian Marco Iodice047c6fc2020-09-21 14:22:25 +0100847 detail::fuse_layer<DepthwiseConvolutionLayerNode, ActivationLayerNode>(g, qs8_prec, detail::fuse_node_with_activation<DepthwiseConvolutionLayerNode>, supported_fused_activations);
848 detail::fuse_layer<FullyConnectedLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<FullyConnectedLayerNode>, supported_fused_activations);
849 detail::fuse_layer<EltwiseLayerNode, ActivationLayerNode>(g, cl_target_prec, detail::fuse_node_with_activation<EltwiseLayerNode>, supported_fused_activations);
SiCongLi1d4a3202021-11-12 15:38:00 +0000850 // The fusion of BatchNormalizationLayer must occur after the fusion of ActivationLayer. Because FusedConvolutionBatchNormalizationNode assumes the BatchNormalization is already fused with activation, if any
851 detail::fuse_layer<ConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_convolution_with_batch_normalization);
852 detail::fuse_layer<DepthwiseConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_depthwise_convolution_with_batch_normalization);
ramelg01b75d6242021-11-26 19:12:40 +0000853 detail::fuse_layer<FusedConvolutionBatchNormalizationNode>(g, cl_target_prec, detail::fuse_convolution_batch_normalization_with_post_ops, supported_fused_activations);
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000854}
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100855} // namespace graph
Georgios Pinitasd8734b52017-12-22 15:27:52 +0000856} // namespace arm_compute