blob: c14a70bed4ce5f307c8ac14da552dc790b16b6ee [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Packs a subgraph with Neural Network Operations into Passes. Each Pass has one or more Operations.
Diego Russoea6111a2020-04-14 18:41:58 +010018import collections
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
Diego Russoea6111a2020-04-14 18:41:58 +010020
Diego Russoe8a10452020-04-21 17:39:10 +010021from .nn_graph import Pass
22from .nn_graph import PassPlacement
23from .operation import NpuBlockType
24from .operation import Operation
Diego Russoea6111a2020-04-14 18:41:58 +010025from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010026
27
28class PassFlags(enum.Flag):
29 Empty = 0
30 Pre = 1
31 Main = 2
32 Post = 4
33 Mac = 8
34 Dma = 32
35 ElementWise = 256
36 Npu = 512
37 Cpu = 1024
38 StartupInit = 2048
39 MemoryOnly = 4096
40 PostFusingLimited = 8192
41
42
43npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
44
45mac_main_ops = set(
46 (
47 # convolutions
48 "Conv2DBiasAct",
49 "Conv2D",
50 "QuantizedConv2D",
Jacob Bohlincf7da102020-05-20 09:03:40 +020051 "Conv2DBackpropInputSwitchedBias",
Tim Hall79d07d22020-04-27 18:20:16 +010052 # depth-wise convolutions
53 "DepthwiseConv2dBiasAct",
54 "DepthwiseConv2dNative",
55 "QuantizedDepthwiseConv2D",
56 # FC layers
57 "QuantizedMatMul",
58 "MatMul",
59 "FullyConnectedAct",
60 # RNN/LSTM/GRU
61 "BlockLSTM",
62 # pooling
63 "QuantizedMaxPool",
64 "QuantizedAvgPool",
65 "AvgPool",
66 "MaxPool",
67 "AvgPoolAct",
68 "MaxPoolAct",
Dwight Lidman3ec04ac2020-04-30 11:54:48 +020069 # deconvolution
70 "ResizeBilinear",
Tim Hall79d07d22020-04-27 18:20:16 +010071 )
72)
73
74binary_elem_wise_main_ops = set(
75 (
76 # binary element-wise
77 "AddAct",
78 "MulAct",
79 "SubAct",
80 "QuantizedAdd",
81 "QuantizedSub",
82 "QuantizedMul",
83 "Mul",
84 "Add",
85 "Sub",
86 "Minimum",
87 "Maximum",
88 )
89)
90
91unary_elem_wise_main_ops = set(("LeakyRelu", "Abs")) # Unary element-wise operations
92
93elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
94
95activation_ops = set(("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1"))
96npu_post_ops = activation_ops | set(
97 # Bias-add operations: Get rid of these once we have rewrites from Conv2D + BiasAdd + Activation to Conv2DBiasAct.
98 ("Mul", "Add", "QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm")
99)
100
101npu_post_fuse_limited_ops = set(
102 # Set of post operators that should not be fused with main/elementwise ops
Jacob Bohlin9fbc4912020-06-29 11:58:50 +0200103 ("ConcatSliceWrite", "Sigmoid", "Tanh", "Quantize")
Tim Hall79d07d22020-04-27 18:20:16 +0100104)
105
106elem_wise_ops = elem_wise_main_ops | activation_ops | set(("Sigmoid", "Tanh"))
107
108
109quantization_ops = set(("Dequantize", "QuantizeV2", "Max", "Min"))
Diego Russoea6111a2020-04-14 18:41:58 +0100110cpu_ops = set(("Softmax", "QuantizedSoftmax", "LRN", "Shape", "QuantizedPad", "Pad", "AddN")) | quantization_ops
Tim Hall79d07d22020-04-27 18:20:16 +0100111
112npu_dma_ops = set(("DMA",))
113startup_init_ops = set(("Const", "VariableV2", "Placeholder", "SubgraphInput"))
114memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",))
115
116
117test_sequence = [
118 (
119 # ops_set
120 npu_post_ops,
121 # incompatible_pack_flags
122 PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Pre | PassFlags.Main,
123 # flags_to_set
124 PassFlags.Npu | PassFlags.Post,
125 # flags_to_clear
126 PassFlags.Empty,
127 ),
128 (
129 # ops_set
130 npu_post_fuse_limited_ops,
131 # incompatible_pack_flags
132 PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Pre | PassFlags.Main,
133 # flags_to_set
134 PassFlags.Npu | PassFlags.PostFusingLimited,
135 # flags_to_clear
136 PassFlags.Empty,
137 ),
138 (
139 # ops_set
140 mac_main_ops,
141 # incompatible_pack_flags
142 PassFlags.Cpu
143 | PassFlags.MemoryOnly
144 | PassFlags.ElementWise
145 | PassFlags.Pre
146 | PassFlags.Main
147 | PassFlags.PostFusingLimited,
148 # flags_to_set
149 PassFlags.Npu | PassFlags.Mac | PassFlags.Main,
150 # flags_to_clear
151 PassFlags.Empty,
152 ),
153 (
154 # ops_set
155 elem_wise_main_ops,
156 # incompatible_pack_flags
157 PassFlags.Cpu
158 | PassFlags.MemoryOnly
159 | PassFlags.Mac
160 | PassFlags.Pre
161 | PassFlags.Main
162 | PassFlags.PostFusingLimited,
163 # flags_to_set
164 PassFlags.Npu | PassFlags.ElementWise | PassFlags.Main,
165 # flags_to_clear
166 PassFlags.Empty,
167 ),
168 (
169 # ops_set
170 npu_pre_ops,
171 # incompatible_pack_flags
172 PassFlags.Cpu | PassFlags.MemoryOnly,
173 # flags_to_set
174 PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise,
175 # flags_to_clear
176 PassFlags.Empty,
177 ),
178 (
179 # ops_set
180 npu_dma_ops,
181 # incompatible_pack_flags
182 PassFlags.Cpu | PassFlags.MemoryOnly,
183 # flags_to_set
184 PassFlags.Npu | PassFlags.Dma,
185 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100186 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100187 ),
188 (
189 # ops_set
190 startup_init_ops,
191 # incompatible_pack_flags
192 PassFlags.Npu | PassFlags.Cpu | PassFlags.MemoryOnly,
193 # flags_to_set
194 PassFlags.StartupInit | PassFlags.Main,
195 # flags_to_clear
196 PassFlags.Empty,
197 ),
198 (
199 # ops_set
200 memory_only_ops,
201 # incompatible_pack_flags
202 PassFlags.Npu | PassFlags.Cpu,
203 # flags_to_set
204 PassFlags.MemoryOnly | PassFlags.Main,
205 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100206 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100207 ),
208 (
209 # ops_set
210 cpu_ops,
211 # incompatible_pack_flags
212 PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
213 # flags_to_set
214 PassFlags.Cpu | PassFlags.Main,
215 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100216 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100217 ),
Diego Russoea6111a2020-04-14 18:41:58 +0100218 ( # This last one is a fallback for unrecognised operations
Tim Hall79d07d22020-04-27 18:20:16 +0100219 # ops_set
220 None,
221 # incompatible_pack_flags
222 PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
223 # flags_to_set
224 PassFlags.Cpu | PassFlags.Main,
225 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100226 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100227 ),
228]
229
230# Some sanity checking
231for (operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear) in test_sequence:
232 assert not flags_to_clear & flags_to_set
233
234 if operation_set is not None:
235 for op in operation_set:
236 assert len(op) > 1 # This is to avoid string literals being decomposed
237
238
239def pack_into_passes(nng, arch, verbose_packing=False):
240 def visit_op(op, ignored):
241 visit_op_refcount[op] += 1
242
243 if visit_op_refcount[op] == 1: # First-time visit, go and fix up unused output tensors
244 for tens in op.outputs:
245 if len(tens.consumers()) == 0:
246 visit_op_refcount[op] += 1
247
248 assert visit_op_refcount[op] <= len(op.outputs)
249 if visit_op_refcount[op] == len(op.outputs):
250
251 if op.type in startup_init_ops:
252 startup_list.append(op)
253 else:
254 _, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
255 if ofm_tensor is None:
256 ofm_tensor = op.outputs[0]
257 build_pass((op,), ofm_tensor)
258
Charles Xu70cc1212020-06-17 12:42:41 +0200259 def broadcast_input_check(ps):
260 if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape:
261 return
262
263 if ps.inputs[0].shape == [] or ps.inputs[1].shape == []:
264 return
265
266 for idx in range(len(ps.inputs[1].shape)):
267 if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1:
268 return
269
270 ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0]
271 ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0]
272
Tim Hall79d07d22020-04-27 18:20:16 +0100273 def build_pass(start_ops_to_process, ofm_tensor=None):
274 reverse_ops_list = []
275 curr_flags = PassFlags.Empty
276 npu_block_type = NpuBlockType.Default
277
278 reverse_intermediates = []
279 input_set = set()
280 ifm_tensor = None
281 primary_op = None
282
283 to_process = collections.deque()
284 for start_op in start_ops_to_process:
285 to_process.append((start_op, None))
286
287 while to_process:
288 curr_op, tens = to_process.popleft()
289
290 if curr_op in reverse_ops_list:
291 continue
292
293 for operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear in test_sequence:
294 if operation_set is None or curr_op.type in operation_set:
295 if not (curr_flags & incompatible_pack_flags):
296 if flags_to_set & PassFlags.Npu:
297 if not curr_op.run_on_npu:
298 continue
299
300 reverse_ops_list.append(curr_op)
301 new_block_type = curr_op.attrs.get("npu_block_type", NpuBlockType.Default)
302 if new_block_type != NpuBlockType.Default:
303 assert npu_block_type == NpuBlockType.Default
304 npu_block_type = new_block_type # Only one major block type per pass
305 assert primary_op is None
306 primary_op = curr_op
307
308 curr_flags &= ~flags_to_clear
309 curr_flags |= flags_to_set
310
311 if flags_to_set & PassFlags.Npu:
312 if flags_to_set & (
313 PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
314 ):
315 assert len(curr_op.inputs) >= 1
316 if curr_op.type == "BlockLSTM":
317 ifm_tensor = curr_op.inputs[3]
318 else:
319 ifm_tensor = curr_op.inputs[0]
320 assert ifm_tensor.purpose == TensorPurpose.FeatureMap
321
322 if flags_to_set & PassFlags.Dma:
323 # DMAs are special - Output buffers need to be preserved as intermediates,
324 # if the pass consumes the results
325 if tens is not None:
326 reverse_intermediates.append(tens)
327
328 if operation_set is None:
329 print("Warning:", curr_op.type, "operation is unknown or unsupported, placing on CPU")
330
Charles Xu600351a2020-05-18 08:54:47 +0200331 for inp in reversed(curr_op.inputs):
Tim Hall79d07d22020-04-27 18:20:16 +0100332 can_pack = True
333 if len(inp.ops) == 1:
334 next_op = inp.ops[0]
335 for outp in next_op.outputs:
336 consumers = outp.consumers()
337 if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
338 can_pack = False
339 break
340 else:
341 can_pack = False
342
343 if can_pack:
344 to_process.append((next_op, inp))
345 else:
346 assert inp is not None
347 input_set.add(inp)
348
349 break
350
351 else:
352 # This operation is not compatible with already packed operations, just register the tensor as an input
353 assert tens is not None
354 input_set.add(tens)
355
356 if curr_flags & PassFlags.Npu and not curr_flags & (PassFlags.ElementWise | PassFlags.Mac):
357 # Make the choice that if we don't have a mac operation, the ambidextrous operations go on the
358 # element wise unit
359 curr_flags |= PassFlags.ElementWise
360
361 is_element_wise = True
362 for op in reverse_ops_list:
Diego Russoea6111a2020-04-14 18:41:58 +0100363 if op.type not in elem_wise_ops and op.type not in npu_dma_ops:
Tim Hall79d07d22020-04-27 18:20:16 +0100364 is_element_wise = False
365 break
366
367 placement = PassPlacement.Unknown
368 if curr_flags & PassFlags.Npu:
369 assert placement == PassPlacement.Unknown
370 placement = PassPlacement.Npu
371 if curr_flags & PassFlags.Cpu:
372 assert placement == PassPlacement.Unknown
373 placement = PassPlacement.Cpu
374 if curr_flags & PassFlags.MemoryOnly:
375 assert placement == PassPlacement.Unknown
376 placement = PassPlacement.MemoryOnly
377 if curr_flags & PassFlags.StartupInit:
378 assert placement == PassPlacement.Unknown
379 placement = PassPlacement.StartupInit
380 assert placement != PassPlacement.Unknown
381
382 ops_list = list(reversed(reverse_ops_list))
383 intermediates = list(reversed(reverse_intermediates))
384
Diego Russoea6111a2020-04-14 18:41:58 +0100385 if primary_op is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100386 primary_op = create_primary_op(ops_list)
Diego Russoea6111a2020-04-14 18:41:58 +0100387 if primary_op is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100388 visit_tensor_refcount[primary_op.inputs[0]] += 1
389 npu_block_type = primary_op.attrs["npu_block_type"]
390 for input_tens in primary_op.inputs:
391 if input_tens not in input_set:
392 input_set.add(input_tens)
393
394 ordered_input_list = []
395 input_refcounts = collections.defaultdict(int)
396 for op in ops_list:
397 for inp in op.inputs:
398 if inp in input_set:
399 if input_refcounts[inp] == 0:
400 ordered_input_list.append(inp)
401 input_refcounts[inp] += 1
402
403 name = ops_list[0].name
404 non_dma_ops = [op for op in ops_list if op.type != "DMA"]
405 if non_dma_ops:
406 name = non_dma_ops[0].name
407 ps = Pass(name, placement, is_element_wise, npu_block_type)
408 ps.ops = ops_list
409 ps.primary_op = primary_op
410 ps.inputs = ordered_input_list
411 ps.intermediates = intermediates
412 ps.outputs = list(ops_list[-1].outputs)
413 ps.ifm_tensor = ifm_tensor
414
415 # ElementWise operation, 2 IFMs
416 if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
Charles Xu70cc1212020-06-17 12:42:41 +0200417 # Swap broadcast input if applicable
418 broadcast_input_check(ps)
419
Tim Hall79d07d22020-04-27 18:20:16 +0100420 ps.ifm_tensor = ps.inputs[0]
421
422 if len(ps.inputs) == 1:
423 # Only 1 input, IFM and IFM2 are the same tensor
424 ps.ifm2_tensor = ps.inputs[0]
425 else:
426 ps.ifm2_tensor = ps.inputs[1]
427 else:
428 ps.ifm_tensor = ifm_tensor
429 ps.ifm2_tensor = None
430
431 ps.ofm_tensor = ofm_tensor
432 assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
433 ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
434 ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2]
435
436 for op in ps.ops:
437 op.scheduled_pass = ps
438
439 reverse_pass_list.append(ps)
440
441 for inp, refcount in input_refcounts.items():
442 for _ in range(refcount):
443 visit_tensor(inp)
444
445 return ps
446
447 def visit_tensor(tens):
448 visit_tensor_refcount[tens] += 1
449 assert visit_tensor_refcount[tens] <= len(tens.consumers())
450 if visit_tensor_refcount[tens] == len(tens.consumers()):
451 for op in reversed(tens.ops):
452 visit_op(op, tens)
453
454 def create_primary_op(ops_list):
455 if any(op.type in (npu_pre_ops | npu_post_ops | npu_post_fuse_limited_ops) for op in ops_list):
456 # Configure a 1x1 AvgPool and attach the op onto it
457 op = ops_list[0]
458 inp = op.inputs[0]
459 avgpool_name = op.name + "_avgpool"
460 avgpool_op = Operation("AvgPool", avgpool_name)
461 avgpool_op.inputs = [inp]
462 avgpool_op.inputs[0].consumer_list.append(avgpool_op)
463 avgpool_op.attrs["padding"] = b"VALID"
464 avgpool_op.attrs["npu_block_type"] = NpuBlockType.Pooling
465 avgpool_op.attrs["stride_w"] = 1
466 avgpool_op.attrs["stride_h"] = 1
467 avgpool_op.attrs["filter_width"] = 1
468 avgpool_op.attrs["filter_height"] = 1
469 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
470 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
471 avgpool_op.attrs["skirt"] = [0, 0, 0, 0]
472 avgpool_op.attrs["explicit_padding"] = [0, 0, 0, 0]
473 avgpool_out = inp.clone("_avgpooled")
474 avgpool_out.consumer_list.append(op)
475 avgpool_out.ops = [avgpool_op]
476 avgpool_op.outputs = [avgpool_out]
477
478 op.inputs[0] = avgpool_out
479 ops_list.insert(0, avgpool_op)
480
481 return avgpool_op
482
483 return None
484
485 for sg in nng.subgraphs:
486 reverse_pass_list = []
487 visit_op_refcount = collections.defaultdict(int)
488 visit_tensor_refcount = collections.defaultdict(int)
489
490 startup_list = []
491
492 for tens in sg.output_tensors:
493 visit_tensor(tens)
494
495 if startup_list:
496 startup_ps = build_pass(startup_list)
497 startup_ps.outputs = [op.outputs[0] for op in startup_list] # Need to fixup the outputs
498 startup_ps.name = "startup_weight_initialisation"
499
500 sg.passes = list(reversed(reverse_pass_list))
501 sg.build_pass_links()
502
503 if verbose_packing:
504 nng.print_passes()
505
506 return nng