blob: 663520fc8ab2083e8cb050f26fe195914f6d12f1 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Packs a subgraph with Neural Network Operations into Passes. Each Pass has one or more Operations.
20
21from .nn_graph import Operation, Pass, PassPlacement, TensorPurpose, NpuBlockType, Tensor
22import collections
23import enum
24from .data_type import BaseType, DataType
25
26
27class PassFlags(enum.Flag):
28 Empty = 0
29 Pre = 1
30 Main = 2
31 Post = 4
32 Mac = 8
33 Dma = 32
34 ElementWise = 256
35 Npu = 512
36 Cpu = 1024
37 StartupInit = 2048
38 MemoryOnly = 4096
39 PostFusingLimited = 8192
40
41
42npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
43
44mac_main_ops = set(
45 (
46 # convolutions
47 "Conv2DBiasAct",
48 "Conv2D",
49 "QuantizedConv2D",
50 "Conv2DBackpropInputSwitched",
51 # depth-wise convolutions
52 "DepthwiseConv2dBiasAct",
53 "DepthwiseConv2dNative",
54 "QuantizedDepthwiseConv2D",
55 # FC layers
56 "QuantizedMatMul",
57 "MatMul",
58 "FullyConnectedAct",
59 # RNN/LSTM/GRU
60 "BlockLSTM",
61 # pooling
62 "QuantizedMaxPool",
63 "QuantizedAvgPool",
64 "AvgPool",
65 "MaxPool",
66 "AvgPoolAct",
67 "MaxPoolAct",
68 )
69)
70
71binary_elem_wise_main_ops = set(
72 (
73 # binary element-wise
74 "AddAct",
75 "MulAct",
76 "SubAct",
77 "QuantizedAdd",
78 "QuantizedSub",
79 "QuantizedMul",
80 "Mul",
81 "Add",
82 "Sub",
83 "Minimum",
84 "Maximum",
85 )
86)
87
88unary_elem_wise_main_ops = set(("LeakyRelu", "Abs")) # Unary element-wise operations
89
90elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
91
92activation_ops = set(("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1"))
93npu_post_ops = activation_ops | set(
94 # Bias-add operations: Get rid of these once we have rewrites from Conv2D + BiasAdd + Activation to Conv2DBiasAct.
95 ("Mul", "Add", "QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm")
96)
97
98npu_post_fuse_limited_ops = set(
99 # Set of post operators that should not be fused with main/elementwise ops
100 ("ConcatSliceWrite", "Sigmoid", "Tanh")
101)
102
103elem_wise_ops = elem_wise_main_ops | activation_ops | set(("Sigmoid", "Tanh"))
104
105
106quantization_ops = set(("Dequantize", "QuantizeV2", "Max", "Min"))
107cpu_ops = (
108 set(("Softmax", "QuantizedSoftmax", "LRN", "Shape", "QuantizedPad", "Pad", "AddN"))
109 | quantization_ops
110)
111
112npu_dma_ops = set(("DMA",))
113startup_init_ops = set(("Const", "VariableV2", "Placeholder", "SubgraphInput"))
114memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",))
115
116
117test_sequence = [
118 (
119 # ops_set
120 npu_post_ops,
121 # incompatible_pack_flags
122 PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Pre | PassFlags.Main,
123 # flags_to_set
124 PassFlags.Npu | PassFlags.Post,
125 # flags_to_clear
126 PassFlags.Empty,
127 ),
128 (
129 # ops_set
130 npu_post_fuse_limited_ops,
131 # incompatible_pack_flags
132 PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Pre | PassFlags.Main,
133 # flags_to_set
134 PassFlags.Npu | PassFlags.PostFusingLimited,
135 # flags_to_clear
136 PassFlags.Empty,
137 ),
138 (
139 # ops_set
140 mac_main_ops,
141 # incompatible_pack_flags
142 PassFlags.Cpu
143 | PassFlags.MemoryOnly
144 | PassFlags.ElementWise
145 | PassFlags.Pre
146 | PassFlags.Main
147 | PassFlags.PostFusingLimited,
148 # flags_to_set
149 PassFlags.Npu | PassFlags.Mac | PassFlags.Main,
150 # flags_to_clear
151 PassFlags.Empty,
152 ),
153 (
154 # ops_set
155 elem_wise_main_ops,
156 # incompatible_pack_flags
157 PassFlags.Cpu
158 | PassFlags.MemoryOnly
159 | PassFlags.Mac
160 | PassFlags.Pre
161 | PassFlags.Main
162 | PassFlags.PostFusingLimited,
163 # flags_to_set
164 PassFlags.Npu | PassFlags.ElementWise | PassFlags.Main,
165 # flags_to_clear
166 PassFlags.Empty,
167 ),
168 (
169 # ops_set
170 npu_pre_ops,
171 # incompatible_pack_flags
172 PassFlags.Cpu | PassFlags.MemoryOnly,
173 # flags_to_set
174 PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise,
175 # flags_to_clear
176 PassFlags.Empty,
177 ),
178 (
179 # ops_set
180 npu_dma_ops,
181 # incompatible_pack_flags
182 PassFlags.Cpu | PassFlags.MemoryOnly,
183 # flags_to_set
184 PassFlags.Npu | PassFlags.Dma,
185 # flags_to_clear
186 PassFlags.Empty
187 ),
188 (
189 # ops_set
190 startup_init_ops,
191 # incompatible_pack_flags
192 PassFlags.Npu | PassFlags.Cpu | PassFlags.MemoryOnly,
193 # flags_to_set
194 PassFlags.StartupInit | PassFlags.Main,
195 # flags_to_clear
196 PassFlags.Empty,
197 ),
198 (
199 # ops_set
200 memory_only_ops,
201 # incompatible_pack_flags
202 PassFlags.Npu | PassFlags.Cpu,
203 # flags_to_set
204 PassFlags.MemoryOnly | PassFlags.Main,
205 # flags_to_clear
206 PassFlags.Empty
207 ),
208 (
209 # ops_set
210 cpu_ops,
211 # incompatible_pack_flags
212 PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
213 # flags_to_set
214 PassFlags.Cpu | PassFlags.Main,
215 # flags_to_clear
216 PassFlags.Empty
217 ),
218 ( # This last one is a fallback for unrecognised operations
219 # ops_set
220 None,
221 # incompatible_pack_flags
222 PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
223 # flags_to_set
224 PassFlags.Cpu | PassFlags.Main,
225 # flags_to_clear
226 PassFlags.Empty
227 ),
228]
229
230# Some sanity checking
231for (operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear) in test_sequence:
232 assert not flags_to_clear & flags_to_set
233
234 if operation_set is not None:
235 for op in operation_set:
236 assert len(op) > 1 # This is to avoid string literals being decomposed
237
238
239def pack_into_passes(nng, arch, verbose_packing=False):
240 def visit_op(op, ignored):
241 visit_op_refcount[op] += 1
242
243 if visit_op_refcount[op] == 1: # First-time visit, go and fix up unused output tensors
244 for tens in op.outputs:
245 if len(tens.consumers()) == 0:
246 visit_op_refcount[op] += 1
247
248 assert visit_op_refcount[op] <= len(op.outputs)
249 if visit_op_refcount[op] == len(op.outputs):
250
251 if op.type in startup_init_ops:
252 startup_list.append(op)
253 else:
254 _, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
255 if ofm_tensor is None:
256 ofm_tensor = op.outputs[0]
257 build_pass((op,), ofm_tensor)
258
259 def build_pass(start_ops_to_process, ofm_tensor=None):
260 reverse_ops_list = []
261 curr_flags = PassFlags.Empty
262 npu_block_type = NpuBlockType.Default
263
264 reverse_intermediates = []
265 input_set = set()
266 ifm_tensor = None
267 primary_op = None
268
269 to_process = collections.deque()
270 for start_op in start_ops_to_process:
271 to_process.append((start_op, None))
272
273 while to_process:
274 curr_op, tens = to_process.popleft()
275
276 if curr_op in reverse_ops_list:
277 continue
278
279 for operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear in test_sequence:
280 if operation_set is None or curr_op.type in operation_set:
281 if not (curr_flags & incompatible_pack_flags):
282 if flags_to_set & PassFlags.Npu:
283 if not curr_op.run_on_npu:
284 continue
285
286 reverse_ops_list.append(curr_op)
287 new_block_type = curr_op.attrs.get("npu_block_type", NpuBlockType.Default)
288 if new_block_type != NpuBlockType.Default:
289 assert npu_block_type == NpuBlockType.Default
290 npu_block_type = new_block_type # Only one major block type per pass
291 assert primary_op is None
292 primary_op = curr_op
293
294 curr_flags &= ~flags_to_clear
295 curr_flags |= flags_to_set
296
297 if flags_to_set & PassFlags.Npu:
298 if flags_to_set & (
299 PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
300 ):
301 assert len(curr_op.inputs) >= 1
302 if curr_op.type == "BlockLSTM":
303 ifm_tensor = curr_op.inputs[3]
304 else:
305 ifm_tensor = curr_op.inputs[0]
306 assert ifm_tensor.purpose == TensorPurpose.FeatureMap
307
308 if flags_to_set & PassFlags.Dma:
309 # DMAs are special - Output buffers need to be preserved as intermediates,
310 # if the pass consumes the results
311 if tens is not None:
312 reverse_intermediates.append(tens)
313
314 if operation_set is None:
315 print("Warning:", curr_op.type, "operation is unknown or unsupported, placing on CPU")
316
317 for inp in curr_op.inputs:
318 can_pack = True
319 if len(inp.ops) == 1:
320 next_op = inp.ops[0]
321 for outp in next_op.outputs:
322 consumers = outp.consumers()
323 if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
324 can_pack = False
325 break
326 else:
327 can_pack = False
328
329 if can_pack:
330 to_process.append((next_op, inp))
331 else:
332 assert inp is not None
333 input_set.add(inp)
334
335 break
336
337 else:
338 # This operation is not compatible with already packed operations, just register the tensor as an input
339 assert tens is not None
340 input_set.add(tens)
341
342 if curr_flags & PassFlags.Npu and not curr_flags & (PassFlags.ElementWise | PassFlags.Mac):
343 # Make the choice that if we don't have a mac operation, the ambidextrous operations go on the
344 # element wise unit
345 curr_flags |= PassFlags.ElementWise
346
347 is_element_wise = True
348 for op in reverse_ops_list:
349 if not op.type in elem_wise_ops and not op.type in npu_dma_ops:
350 is_element_wise = False
351 break
352
353 placement = PassPlacement.Unknown
354 if curr_flags & PassFlags.Npu:
355 assert placement == PassPlacement.Unknown
356 placement = PassPlacement.Npu
357 if curr_flags & PassFlags.Cpu:
358 assert placement == PassPlacement.Unknown
359 placement = PassPlacement.Cpu
360 if curr_flags & PassFlags.MemoryOnly:
361 assert placement == PassPlacement.Unknown
362 placement = PassPlacement.MemoryOnly
363 if curr_flags & PassFlags.StartupInit:
364 assert placement == PassPlacement.Unknown
365 placement = PassPlacement.StartupInit
366 assert placement != PassPlacement.Unknown
367
368 ops_list = list(reversed(reverse_ops_list))
369 intermediates = list(reversed(reverse_intermediates))
370
371 if primary_op == None:
372 primary_op = create_primary_op(ops_list)
373 if primary_op != None:
374 visit_tensor_refcount[primary_op.inputs[0]] += 1
375 npu_block_type = primary_op.attrs["npu_block_type"]
376 for input_tens in primary_op.inputs:
377 if input_tens not in input_set:
378 input_set.add(input_tens)
379
380 ordered_input_list = []
381 input_refcounts = collections.defaultdict(int)
382 for op in ops_list:
383 for inp in op.inputs:
384 if inp in input_set:
385 if input_refcounts[inp] == 0:
386 ordered_input_list.append(inp)
387 input_refcounts[inp] += 1
388
389 name = ops_list[0].name
390 non_dma_ops = [op for op in ops_list if op.type != "DMA"]
391 if non_dma_ops:
392 name = non_dma_ops[0].name
393 ps = Pass(name, placement, is_element_wise, npu_block_type)
394 ps.ops = ops_list
395 ps.primary_op = primary_op
396 ps.inputs = ordered_input_list
397 ps.intermediates = intermediates
398 ps.outputs = list(ops_list[-1].outputs)
399 ps.ifm_tensor = ifm_tensor
400
401 # ElementWise operation, 2 IFMs
402 if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
403 ps.ifm_tensor = ps.inputs[0]
404
405 if len(ps.inputs) == 1:
406 # Only 1 input, IFM and IFM2 are the same tensor
407 ps.ifm2_tensor = ps.inputs[0]
408 else:
409 ps.ifm2_tensor = ps.inputs[1]
410 else:
411 ps.ifm_tensor = ifm_tensor
412 ps.ifm2_tensor = None
413
414 ps.ofm_tensor = ofm_tensor
415 assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
416 ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
417 ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2]
418
419 for op in ps.ops:
420 op.scheduled_pass = ps
421
422 reverse_pass_list.append(ps)
423
424 for inp, refcount in input_refcounts.items():
425 for _ in range(refcount):
426 visit_tensor(inp)
427
428 return ps
429
430 def visit_tensor(tens):
431 visit_tensor_refcount[tens] += 1
432 assert visit_tensor_refcount[tens] <= len(tens.consumers())
433 if visit_tensor_refcount[tens] == len(tens.consumers()):
434 for op in reversed(tens.ops):
435 visit_op(op, tens)
436
437 def create_primary_op(ops_list):
438 if any(op.type in (npu_pre_ops | npu_post_ops | npu_post_fuse_limited_ops) for op in ops_list):
439 # Configure a 1x1 AvgPool and attach the op onto it
440 op = ops_list[0]
441 inp = op.inputs[0]
442 avgpool_name = op.name + "_avgpool"
443 avgpool_op = Operation("AvgPool", avgpool_name)
444 avgpool_op.inputs = [inp]
445 avgpool_op.inputs[0].consumer_list.append(avgpool_op)
446 avgpool_op.attrs["padding"] = b"VALID"
447 avgpool_op.attrs["npu_block_type"] = NpuBlockType.Pooling
448 avgpool_op.attrs["stride_w"] = 1
449 avgpool_op.attrs["stride_h"] = 1
450 avgpool_op.attrs["filter_width"] = 1
451 avgpool_op.attrs["filter_height"] = 1
452 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
453 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
454 avgpool_op.attrs["skirt"] = [0, 0, 0, 0]
455 avgpool_op.attrs["explicit_padding"] = [0, 0, 0, 0]
456 avgpool_out = inp.clone("_avgpooled")
457 avgpool_out.consumer_list.append(op)
458 avgpool_out.ops = [avgpool_op]
459 avgpool_op.outputs = [avgpool_out]
460
461 op.inputs[0] = avgpool_out
462 ops_list.insert(0, avgpool_op)
463
464 return avgpool_op
465
466 return None
467
468 for sg in nng.subgraphs:
469 reverse_pass_list = []
470 visit_op_refcount = collections.defaultdict(int)
471 visit_tensor_refcount = collections.defaultdict(int)
472
473 startup_list = []
474
475 for tens in sg.output_tensors:
476 visit_tensor(tens)
477
478 if startup_list:
479 startup_ps = build_pass(startup_list)
480 startup_ps.outputs = [op.outputs[0] for op in startup_list] # Need to fixup the outputs
481 startup_ps.name = "startup_weight_initialisation"
482
483 sg.passes = list(reversed(reverse_pass_list))
484 sg.build_pass_links()
485
486 if verbose_packing:
487 nng.print_passes()
488
489 return nng