blob: 5673c2df39a6b323f101999e0b90b6fc0976b9b2 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Packs a subgraph with Neural Network Operations into Passes. Each Pass has one or more Operations.
Diego Russoea6111a2020-04-14 18:41:58 +010018import collections
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
Diego Russoea6111a2020-04-14 18:41:58 +010020
Diego Russoe8a10452020-04-21 17:39:10 +010021from .nn_graph import Pass
22from .nn_graph import PassPlacement
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +010023from .operation import create_avgpool_nop
Diego Russoe8a10452020-04-21 17:39:10 +010024from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020025from .operation import Op
Diego Russoea6111a2020-04-14 18:41:58 +010026from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010027
28
29class PassFlags(enum.Flag):
30 Empty = 0
31 Pre = 1
32 Main = 2
33 Post = 4
34 Mac = 8
35 Dma = 32
36 ElementWise = 256
37 Npu = 512
38 Cpu = 1024
39 StartupInit = 2048
40 MemoryOnly = 4096
41 PostFusingLimited = 8192
42
43
Louis Verhaardaee5d752020-09-30 09:01:52 +020044npu_pre_ops = set((Op.SplitSliceRead,))
Tim Hall79d07d22020-04-27 18:20:16 +010045
46mac_main_ops = set(
47 (
48 # convolutions
Louis Verhaardaee5d752020-09-30 09:01:52 +020049 Op.Conv2DBias,
50 Op.Conv2D,
51 Op.QuantizedConv2D,
52 Op.Conv2DBackpropInputSwitchedBias,
Tim Hall79d07d22020-04-27 18:20:16 +010053 # depth-wise convolutions
Louis Verhaardaee5d752020-09-30 09:01:52 +020054 Op.DepthwiseConv2DBias,
Tim Hall79d07d22020-04-27 18:20:16 +010055 # FC layers
Louis Verhaardaee5d752020-09-30 09:01:52 +020056 Op.QuantizedMatMul,
57 Op.MatMul,
58 Op.FullyConnected,
Tim Hall79d07d22020-04-27 18:20:16 +010059 # RNN/LSTM/GRU
Louis Verhaardaee5d752020-09-30 09:01:52 +020060 Op.BlockLSTM,
Tim Hall79d07d22020-04-27 18:20:16 +010061 # pooling
Louis Verhaardaee5d752020-09-30 09:01:52 +020062 Op.QuantizedMaxPool,
63 Op.QuantizedAvgPool,
64 Op.AvgPool,
65 Op.MaxPool,
66 Op.ReduceSum,
Dwight Lidman3ec04ac2020-04-30 11:54:48 +020067 # deconvolution
Louis Verhaardaee5d752020-09-30 09:01:52 +020068 Op.ResizeBilinear,
Tim Hall79d07d22020-04-27 18:20:16 +010069 )
70)
71
Louis Verhaardaee5d752020-09-30 09:01:52 +020072binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op)
Tim Hall79d07d22020-04-27 18:20:16 +010073
Louis Verhaardaee5d752020-09-30 09:01:52 +020074unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op) # Unary element-wise operations
Tim Hall79d07d22020-04-27 18:20:16 +010075
76elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
77
Louis Verhaardaee5d752020-09-30 09:01:52 +020078activation_ops = Op.op_set(Op.is_relu_op)
79npu_post_ops = activation_ops
Tim Hall79d07d22020-04-27 18:20:16 +010080
81npu_post_fuse_limited_ops = set(
82 # Set of post operators that should not be fused with main/elementwise ops
Louis Verhaardaee5d752020-09-30 09:01:52 +020083 (Op.ConcatSliceWrite, Op.Sigmoid, Op.Tanh, Op.Quantize)
Tim Hall79d07d22020-04-27 18:20:16 +010084)
85
Louis Verhaardaee5d752020-09-30 09:01:52 +020086elem_wise_ops = elem_wise_main_ops | activation_ops | set((Op.Sigmoid, Op.Tanh))
Tim Hall79d07d22020-04-27 18:20:16 +010087
88
Louis Verhaardaee5d752020-09-30 09:01:52 +020089quantization_ops = set((Op.Dequantize, Op.Max, Op.Min))
90cpu_ops = set((Op.Softmax, Op.LRN, Op.Shape, Op.Pad, Op.AddN)) | quantization_ops
Tim Hall79d07d22020-04-27 18:20:16 +010091
Louis Verhaardaee5d752020-09-30 09:01:52 +020092npu_dma_ops = set((Op.DMA,))
patrik.gustavsson10683622020-10-14 10:57:46 +000093startup_init_ops = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
Louis Verhaardaee5d752020-09-30 09:01:52 +020094memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,))
Tim Hall79d07d22020-04-27 18:20:16 +010095
96
97test_sequence = [
98 (
99 # ops_set
100 npu_post_ops,
101 # incompatible_pack_flags
102 PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Pre | PassFlags.Main,
103 # flags_to_set
104 PassFlags.Npu | PassFlags.Post,
105 # flags_to_clear
106 PassFlags.Empty,
107 ),
108 (
109 # ops_set
110 npu_post_fuse_limited_ops,
111 # incompatible_pack_flags
112 PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Pre | PassFlags.Main,
113 # flags_to_set
114 PassFlags.Npu | PassFlags.PostFusingLimited,
115 # flags_to_clear
116 PassFlags.Empty,
117 ),
118 (
119 # ops_set
120 mac_main_ops,
121 # incompatible_pack_flags
122 PassFlags.Cpu
123 | PassFlags.MemoryOnly
124 | PassFlags.ElementWise
125 | PassFlags.Pre
126 | PassFlags.Main
127 | PassFlags.PostFusingLimited,
128 # flags_to_set
129 PassFlags.Npu | PassFlags.Mac | PassFlags.Main,
130 # flags_to_clear
131 PassFlags.Empty,
132 ),
133 (
134 # ops_set
135 elem_wise_main_ops,
136 # incompatible_pack_flags
137 PassFlags.Cpu
138 | PassFlags.MemoryOnly
139 | PassFlags.Mac
140 | PassFlags.Pre
141 | PassFlags.Main
142 | PassFlags.PostFusingLimited,
143 # flags_to_set
144 PassFlags.Npu | PassFlags.ElementWise | PassFlags.Main,
145 # flags_to_clear
146 PassFlags.Empty,
147 ),
148 (
149 # ops_set
150 npu_pre_ops,
151 # incompatible_pack_flags
152 PassFlags.Cpu | PassFlags.MemoryOnly,
153 # flags_to_set
154 PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise,
155 # flags_to_clear
156 PassFlags.Empty,
157 ),
158 (
159 # ops_set
160 npu_dma_ops,
161 # incompatible_pack_flags
162 PassFlags.Cpu | PassFlags.MemoryOnly,
163 # flags_to_set
164 PassFlags.Npu | PassFlags.Dma,
165 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100166 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100167 ),
168 (
169 # ops_set
170 startup_init_ops,
171 # incompatible_pack_flags
172 PassFlags.Npu | PassFlags.Cpu | PassFlags.MemoryOnly,
173 # flags_to_set
174 PassFlags.StartupInit | PassFlags.Main,
175 # flags_to_clear
176 PassFlags.Empty,
177 ),
178 (
179 # ops_set
180 memory_only_ops,
181 # incompatible_pack_flags
182 PassFlags.Npu | PassFlags.Cpu,
183 # flags_to_set
184 PassFlags.MemoryOnly | PassFlags.Main,
185 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100186 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100187 ),
188 (
189 # ops_set
190 cpu_ops,
191 # incompatible_pack_flags
192 PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
193 # flags_to_set
194 PassFlags.Cpu | PassFlags.Main,
195 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100196 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100197 ),
Diego Russoea6111a2020-04-14 18:41:58 +0100198 ( # This last one is a fallback for unrecognised operations
Tim Hall79d07d22020-04-27 18:20:16 +0100199 # ops_set
200 None,
201 # incompatible_pack_flags
202 PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
203 # flags_to_set
204 PassFlags.Cpu | PassFlags.Main,
205 # flags_to_clear
Diego Russoea6111a2020-04-14 18:41:58 +0100206 PassFlags.Empty,
Tim Hall79d07d22020-04-27 18:20:16 +0100207 ),
208]
209
210# Some sanity checking
211for (operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear) in test_sequence:
212 assert not flags_to_clear & flags_to_set
213
Tim Hall79d07d22020-04-27 18:20:16 +0100214
215def pack_into_passes(nng, arch, verbose_packing=False):
216 def visit_op(op, ignored):
217 visit_op_refcount[op] += 1
218
219 if visit_op_refcount[op] == 1: # First-time visit, go and fix up unused output tensors
220 for tens in op.outputs:
221 if len(tens.consumers()) == 0:
222 visit_op_refcount[op] += 1
223
224 assert visit_op_refcount[op] <= len(op.outputs)
225 if visit_op_refcount[op] == len(op.outputs):
226
227 if op.type in startup_init_ops:
228 startup_list.append(op)
229 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200230 ofm_tensor = op.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100231 if ofm_tensor is None:
232 ofm_tensor = op.outputs[0]
233 build_pass((op,), ofm_tensor)
234
235 def build_pass(start_ops_to_process, ofm_tensor=None):
236 reverse_ops_list = []
237 curr_flags = PassFlags.Empty
238 npu_block_type = NpuBlockType.Default
239
240 reverse_intermediates = []
241 input_set = set()
242 ifm_tensor = None
243 primary_op = None
244
245 to_process = collections.deque()
246 for start_op in start_ops_to_process:
247 to_process.append((start_op, None))
248
249 while to_process:
250 curr_op, tens = to_process.popleft()
251
252 if curr_op in reverse_ops_list:
253 continue
254
255 for operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear in test_sequence:
256 if operation_set is None or curr_op.type in operation_set:
257 if not (curr_flags & incompatible_pack_flags):
258 if flags_to_set & PassFlags.Npu:
259 if not curr_op.run_on_npu:
260 continue
261
262 reverse_ops_list.append(curr_op)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200263 new_block_type = curr_op.type.npu_block_type
Tim Hall79d07d22020-04-27 18:20:16 +0100264 if new_block_type != NpuBlockType.Default:
265 assert npu_block_type == NpuBlockType.Default
266 npu_block_type = new_block_type # Only one major block type per pass
267 assert primary_op is None
268 primary_op = curr_op
269
270 curr_flags &= ~flags_to_clear
271 curr_flags |= flags_to_set
272
273 if flags_to_set & PassFlags.Npu:
274 if flags_to_set & (
275 PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
276 ):
277 assert len(curr_op.inputs) >= 1
Louis Verhaardaee5d752020-09-30 09:01:52 +0200278 ifm_tensor = curr_op.ifm
Louis Verhaard04f8c002020-10-09 11:40:21 +0200279 assert ifm_tensor is not None, "IFM missing in {}".format(curr_op)
Tim Hall79d07d22020-04-27 18:20:16 +0100280 assert ifm_tensor.purpose == TensorPurpose.FeatureMap
281
282 if flags_to_set & PassFlags.Dma:
283 # DMAs are special - Output buffers need to be preserved as intermediates,
284 # if the pass consumes the results
285 if tens is not None:
286 reverse_intermediates.append(tens)
287
288 if operation_set is None:
289 print("Warning:", curr_op.type, "operation is unknown or unsupported, placing on CPU")
290
Charles Xu600351a2020-05-18 08:54:47 +0200291 for inp in reversed(curr_op.inputs):
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200292 if inp is None:
293 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100294 can_pack = True
295 if len(inp.ops) == 1:
296 next_op = inp.ops[0]
297 for outp in next_op.outputs:
298 consumers = outp.consumers()
299 if len(consumers) > 1 or (len(consumers) == 1 and consumers[0] != curr_op):
300 can_pack = False
301 break
302 else:
303 can_pack = False
304
305 if can_pack:
306 to_process.append((next_op, inp))
307 else:
308 assert inp is not None
309 input_set.add(inp)
310
311 break
312
313 else:
314 # This operation is not compatible with already packed operations, just register the tensor as an input
315 assert tens is not None
316 input_set.add(tens)
317
318 if curr_flags & PassFlags.Npu and not curr_flags & (PassFlags.ElementWise | PassFlags.Mac):
319 # Make the choice that if we don't have a mac operation, the ambidextrous operations go on the
320 # element wise unit
321 curr_flags |= PassFlags.ElementWise
322
323 is_element_wise = True
324 for op in reverse_ops_list:
Diego Russoea6111a2020-04-14 18:41:58 +0100325 if op.type not in elem_wise_ops and op.type not in npu_dma_ops:
Tim Hall79d07d22020-04-27 18:20:16 +0100326 is_element_wise = False
327 break
328
329 placement = PassPlacement.Unknown
330 if curr_flags & PassFlags.Npu:
331 assert placement == PassPlacement.Unknown
332 placement = PassPlacement.Npu
333 if curr_flags & PassFlags.Cpu:
334 assert placement == PassPlacement.Unknown
335 placement = PassPlacement.Cpu
336 if curr_flags & PassFlags.MemoryOnly:
337 assert placement == PassPlacement.Unknown
338 placement = PassPlacement.MemoryOnly
339 if curr_flags & PassFlags.StartupInit:
340 assert placement == PassPlacement.Unknown
341 placement = PassPlacement.StartupInit
342 assert placement != PassPlacement.Unknown
343
344 ops_list = list(reversed(reverse_ops_list))
345 intermediates = list(reversed(reverse_intermediates))
346
Diego Russoea6111a2020-04-14 18:41:58 +0100347 if primary_op is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100348 primary_op = create_primary_op(ops_list)
Diego Russoea6111a2020-04-14 18:41:58 +0100349 if primary_op is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100350 visit_tensor_refcount[primary_op.inputs[0]] += 1
Louis Verhaardaee5d752020-09-30 09:01:52 +0200351 npu_block_type = primary_op.type.npu_block_type
Tim Hall79d07d22020-04-27 18:20:16 +0100352 for input_tens in primary_op.inputs:
353 if input_tens not in input_set:
354 input_set.add(input_tens)
355
356 ordered_input_list = []
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200357 # Keep LUT-s in a separate list and add as inputs at the end
358 # to avoid that they would accidentally be assigned as ifm or ifm2
359 lut_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100360 input_refcounts = collections.defaultdict(int)
Diqing Zhong2abd3dd2020-08-25 10:40:36 +0200361 input_ops_list = ops_list.copy()
362
363 # Check primary_op first
364 if primary_op is not None:
365 for inp in primary_op.inputs:
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200366 if inp is None:
367 continue
Louis Verhaardaee5d752020-09-30 09:01:52 +0200368 if len(inp.ops) == 1 and inp.ops[0].type == Op.DMA and inp.purpose == TensorPurpose.FeatureMap:
Diqing Zhong2abd3dd2020-08-25 10:40:36 +0200369 src_op = inp.ops[0]
370 if src_op in input_ops_list:
371 inp = src_op.inputs[0]
372 input_ops_list.remove(src_op)
373 add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list)
374 input_ops_list.remove(primary_op)
375
376 # Check rest of the list
377 for op in input_ops_list:
Tim Hall79d07d22020-04-27 18:20:16 +0100378 for inp in op.inputs:
Diqing Zhong2abd3dd2020-08-25 10:40:36 +0200379 add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list)
Tim Hall79d07d22020-04-27 18:20:16 +0100380
381 name = ops_list[0].name
Louis Verhaardaee5d752020-09-30 09:01:52 +0200382 non_dma_ops = [op for op in ops_list if op.type != Op.DMA]
Tim Hall79d07d22020-04-27 18:20:16 +0100383 if non_dma_ops:
384 name = non_dma_ops[0].name
385 ps = Pass(name, placement, is_element_wise, npu_block_type)
386 ps.ops = ops_list
387 ps.primary_op = primary_op
388 ps.inputs = ordered_input_list
389 ps.intermediates = intermediates
390 ps.outputs = list(ops_list[-1].outputs)
Tim Hall79d07d22020-04-27 18:20:16 +0100391
392 # ElementWise operation, 2 IFMs
393 if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
Tim Hall79d07d22020-04-27 18:20:16 +0100394 ps.ifm_tensor = ps.inputs[0]
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200395 ps.ifm2_tensor = ps.inputs[-1]
Tim Hall79d07d22020-04-27 18:20:16 +0100396
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200397 if len(ps.inputs) > 2:
398 ps.ifm_tensor = ps.inputs[-2]
Tim Hall79d07d22020-04-27 18:20:16 +0100399 else:
400 ps.ifm_tensor = ifm_tensor
401 ps.ifm2_tensor = None
402
403 ps.ofm_tensor = ofm_tensor
404 assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
405 ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
406 ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2]
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200407 ps.lut_tensor = ps.get_primary_op_lut()
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200408 ps.inputs.extend(lut_list)
Tim Hall79d07d22020-04-27 18:20:16 +0100409
410 for op in ps.ops:
411 op.scheduled_pass = ps
412
413 reverse_pass_list.append(ps)
414
415 for inp, refcount in input_refcounts.items():
416 for _ in range(refcount):
417 visit_tensor(inp)
418
419 return ps
420
421 def visit_tensor(tens):
422 visit_tensor_refcount[tens] += 1
423 assert visit_tensor_refcount[tens] <= len(tens.consumers())
424 if visit_tensor_refcount[tens] == len(tens.consumers()):
425 for op in reversed(tens.ops):
426 visit_op(op, tens)
427
Jacob Bohlinfb858732020-08-17 09:42:35 +0200428 def create_primary_op(op_list):
429 if any(op.type in (npu_pre_ops | npu_post_ops | npu_post_fuse_limited_ops) and op.run_on_npu for op in op_list):
Tim Hall79d07d22020-04-27 18:20:16 +0100430 # Configure a 1x1 AvgPool and attach the op onto it
Jacob Bohlinfb858732020-08-17 09:42:35 +0200431 op = op_list[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100432 inp = op.inputs[0]
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100433
434 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
435 avgpool_op.add_input_tensor(inp)
Tim Hall79d07d22020-04-27 18:20:16 +0100436 avgpool_out = inp.clone("_avgpooled")
437 avgpool_out.consumer_list.append(op)
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100438 avgpool_op.set_output_tensor(avgpool_out)
Tim Hall79d07d22020-04-27 18:20:16 +0100439
440 op.inputs[0] = avgpool_out
Jacob Bohlinfb858732020-08-17 09:42:35 +0200441 op_list.insert(0, avgpool_op)
Tim Hall79d07d22020-04-27 18:20:16 +0100442
443 return avgpool_op
444
445 return None
446
Diqing Zhong2abd3dd2020-08-25 10:40:36 +0200447 def add_input_list(inp_to_add, inp_set, inp_refcnts, lut_list, ordered_inp_list):
448 if inp_to_add in inp_set:
449 if inp_refcnts[inp_to_add] == 0:
450 if inp_to_add.purpose == TensorPurpose.LUT:
451 lut_list.append(inp_to_add)
452 else:
453 ordered_inp_list.append(inp_to_add)
454 inp_refcnts[inp_to_add] += 1
455
Tim Hall79d07d22020-04-27 18:20:16 +0100456 for sg in nng.subgraphs:
457 reverse_pass_list = []
458 visit_op_refcount = collections.defaultdict(int)
459 visit_tensor_refcount = collections.defaultdict(int)
460
461 startup_list = []
462
463 for tens in sg.output_tensors:
464 visit_tensor(tens)
465
466 if startup_list:
467 startup_ps = build_pass(startup_list)
468 startup_ps.outputs = [op.outputs[0] for op in startup_list] # Need to fixup the outputs
469 startup_ps.name = "startup_weight_initialisation"
470
471 sg.passes = list(reversed(reverse_pass_list))
472 sg.build_pass_links()
473
474 if verbose_packing:
475 nng.print_passes()
476
477 return nng