blob: adc7904f02e1ef4ddee8ee68916afbff17f582e5 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Neural network graph classes and enums.
18# Pass - A packed pass containing one or more Operations.
19# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
20# configurations.
21# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
22# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010023import enum
patrik.gustavssoneeb85152020-12-21 17:10:40 +000024from typing import List
Tim Hall79d07d22020-04-27 18:20:16 +010025
Louis Verhaardaee5d752020-09-30 09:01:52 +020026from .operation import Op
patrik.gustavssoneeb85152020-12-21 17:10:40 +000027from .shape4d import Shape4D
Louis Verhaardaee5d752020-09-30 09:01:52 +020028
Tim Hall79d07d22020-04-27 18:20:16 +010029
30class PassPlacement(enum.Enum):
31 Unknown = 0
32 Cpu = 1
33 Npu = 2
34 MemoryOnly = 3
35 StartupInit = 4
36
37
38class TensorAllocator(enum.Enum):
39 LinearAlloc = 1
40 Greedy = 2
Louis Verhaardd7002522021-01-20 17:23:54 +010041 HillClimb = 3
Tim Hall79d07d22020-04-27 18:20:16 +010042
43 def __str__(self):
44 return self.name
45
46
47class Pass:
48 def __init__(self, name, placement, is_element_wise, npu_block_type):
49 self.inputs = []
50 self.intermediates = []
51 self.outputs = []
52 self.ops = []
53 self.primary_op = None
54 self.ifm_tensor = None
55 self.ifm2_tensor = None
56 self.ofm_tensor = None
57 self.weight_tensor = None
58 self.scale_tensor = None
Fredrik Svedberga0c36242020-06-03 15:43:31 +020059 self.lut_tensor = None
Tim Hall79d07d22020-04-27 18:20:16 +010060 self.name = name
61 self.cascade = None
62 self.placement = placement
patrik.gustavssoneeb85152020-12-21 17:10:40 +000063 self.ifm_shapes: List[Shape4D] = []
64 self.ofm_shapes: List[Shape4D] = []
Tim Hall79d07d22020-04-27 18:20:16 +010065
66 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
67 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
68 self.is_element_wise = is_element_wise
69 self.npu_block_type = npu_block_type
70 self.block_config = None # will be filled in by scheduler
71 self.shared_buffer = None # will be filled in by scheduler
72
73 self.predecessors = []
74 self.successors = []
75
76 def __str__(self):
77 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
78
79 __repr__ = __str__
80
81 def get_primary_op_ifm_weights(self):
82 if not self.primary_op:
83 return None, None
84 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
85
86 def get_primary_op_ifm_ifm2_weights_ofm(self):
87 if not self.primary_op:
88 return None, None, None, None
89 return self.primary_op.get_ifm_ifm2_weights_ofm()
90
91 def get_primary_op_ifm_weights_biases_ofm(self):
92 if not self.primary_op:
93 return None, None, None, None
94 return self.primary_op.get_ifm_weights_biases_ofm()
95
Fredrik Svedberga0c36242020-06-03 15:43:31 +020096 def get_primary_op_lut(self):
97 if not self.primary_op:
98 return None
99 return self.primary_op.activation_lut
100
Tim Hall79d07d22020-04-27 18:20:16 +0100101
102class SchedulingStrategy(enum.Enum):
103 Unknown = -1
104 IfmStream = 0
105 WeightStream = 1
106
107
108class SchedulerRewrite(enum.Enum):
109 Nop = 0
110 ChangeTensorSubPurpose = 1
111
112
113class CascadedPass:
114 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
115 self.name = name
116 self.strategy = strat
117 self.inputs = inputs
118 self.intermediates = intermediates
119 self.outputs = outputs
120 self.passes = passes
121 self.placement = placement
122 self.is_element_wise = is_element_wise
123
124 self.predecessors = []
125 self.successors = []
126
127 def __str__(self):
128 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
129 self.strategy,
130 len(self.passes),
131 self.name,
132 [ps.name for ps in self.passes],
133 [ps.block_config for ps in self.passes],
134 )
135
136 __repr__ = __str__
137
138
139class Subgraph:
140 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
141 self.output_tensors = []
142 self.input_tensors = []
143 self.original_inputs = [] # Preserve the original input order
144 self.passes = []
145 self.cascaded_passes = []
146 self.name = name
147 self.high_level_command_stream = []
148 self.placement = placement
149 self.command_stream_tensor = None
150 self.flash_tensor = None
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200151 # Scratch information locally used in the scheduler
152 self.scheduling_info = {}
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100153 self.generated_stream_id = None
Tim Hall79d07d22020-04-27 18:20:16 +0100154
155 self.memory_used = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200156 self.memory_used_per_type = {}
erik.andersson@arm.com3438c922021-03-24 10:32:09 +0100157 self.min_mem_usage = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100158
159 def __str__(self):
160 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
161 self.name,
162 len(self.passes),
163 len(self.cascaded_passes),
164 )
165
166 __repr__ = __str__
167
168 def update_consumers(self):
169 visit_op_set = set()
170 visit_tensor_set = set()
171 self.input_tensors = []
172
173 print_visit = False
174
175 def visit_op(op):
176 if op in visit_op_set:
177 return
178
179 visit_op_set.add(op)
180 for inp in op.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200181 if not inp:
182 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100183 if print_visit:
184 print(inp, "adding consumer", op)
185 visit_tensor(inp)
186 inp.consumer_list.append(op)
187
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000188 if op.type in (Op.Placeholder, Op.SubgraphInput):
Tim Hall79d07d22020-04-27 18:20:16 +0100189 assert len(op.outputs) == 1
190 self.input_tensors.append(op.outputs[0])
191
192 for out in op.outputs:
193 if out not in visit_tensor_set:
194 out.consumer_list = [] # reset unvisited output, just in case
195
196 def visit_tensor(tens):
197 if tens in visit_tensor_set:
198 return
199 visit_tensor_set.add(tens)
200 tens.consumer_list = []
201 for op in tens.ops:
202 visit_op(op)
203
204 for ps in self.passes:
205 for tens in ps.outputs + ps.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200206 if not tens:
207 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100208 tens.consumer_list = [] # reset unvisited tensors to start with
209
210 for tens in self.output_tensors:
211 visit_tensor(tens)
212 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
213
214 print_visit = True
215 for ps in self.passes:
216 for op in ps.ops:
217 visit_op(op)
218 for tens in ps.inputs:
219 visit_tensor(tens)
220
221 def build_pass_links(self):
222 for idx, ps in enumerate(self.passes):
223 ps.time = 2 * idx
224 ps.predecessors = []
225 ps.successors = []
226
227 for ps in self.passes:
228 for tens in ps.inputs:
229 for op in tens.ops:
230 pred_pass = op.scheduled_pass
231 assert pred_pass.time < ps.time
232 if ps not in pred_pass.successors:
233 pred_pass.successors.append(ps)
234
235 if pred_pass not in ps.predecessors:
236 ps.predecessors.append(pred_pass)
237
238 assert tens in pred_pass.outputs
239
240 def build_pass_dag_predecessors(self):
241 for ps in self.passes:
242 ps.dag_predecessors = []
243
244 class State(enum.Enum):
245 NotVisited = 0
246 BeingVisited = 1
247 Visited = 2
248
249 pass_visit_dict = {}
250
251 def visit_pass(ps):
252 state = pass_visit_dict.get(ps, State.NotVisited)
253 if state == State.Visited:
254 return True
255 elif state == State.BeingVisited:
256 return False # this is a loop, need to remove this link
257 elif state == State.NotVisited:
258 pass_visit_dict[ps] = State.BeingVisited
259
260 ps.dag_predecessors = []
261 for pred in ps.predecessors:
262 if visit_pass(pred):
263 ps.dag_predecessors.append(pred)
264
265 pass_visit_dict[ps] = State.Visited
266 return True
267
268 for ps in self.passes:
269 if not ps.successors:
270 visit_pass(ps)
271
272 def build_cascaded_pass_links(self):
273 for cps in self.cascaded_passes:
274 cps.predecessors = []
275 cps.successors = []
276
277 for cps in self.cascaded_passes:
278 for tens in cps.inputs:
279 for op in tens.ops:
280 pred_cpass = op.scheduled_pass.cascade
281 if cps not in pred_cpass.successors:
282 pred_cpass.successors.append(cps)
283
284 if pred_cpass not in cps.predecessors:
285 cps.predecessors.append(pred_cpass)
286
287 assert tens in pred_cpass.outputs
288
289 def refresh_after_modification(self):
290 self.update_consumers()
291
292 def prune_startup_init_pass(self):
293 assert len(self.passes) >= 1
294 ps = self.passes[0]
295 assert ps.placement == PassPlacement.StartupInit
296
297 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
298 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
299
300 def get_all_ops(self):
301 all_ops = []
302 visit_op_set = set()
303 visit_tensor_set = set()
304
305 def visit_op(op):
306 if op in visit_op_set:
307 return
308 visit_op_set.add(op)
309 for inp in op.inputs:
310 visit_tensor(inp)
311
312 all_ops.append(op)
313
314 def visit_tensor(tens):
Andreas Nevalainene1cc3de2020-09-08 15:31:02 +0200315 if tens is None or tens in visit_tensor_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100316 return
317 visit_tensor_set.add(tens)
318 for op in tens.ops:
319 visit_op(op)
320
321 for tens in self.output_tensors:
322 visit_tensor(tens)
323
324 return all_ops
325
326 def print_operators(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100327 print("print_operators()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100328 all_ops = self.get_all_ops()
329 unique_ops = []
Tim Hall79d07d22020-04-27 18:20:16 +0100330 for op in all_ops:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000331 if op.type in (Op.Const, Op.Identity, Op.Placeholder):
Tim Hall79d07d22020-04-27 18:20:16 +0100332 continue
333
Louis Verhaardaee5d752020-09-30 09:01:52 +0200334 attrs = op.attrs.copy()
335 if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias):
Tim Hall79d07d22020-04-27 18:20:16 +0100336 kshape = op.inputs[1].shape
337 attrs["kshape"] = [kshape[0], kshape[1]]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200338 attrs["type"] = op.type.name
Tim Hall79d07d22020-04-27 18:20:16 +0100339 attrs.pop("use_cudnn_on_gpu", None)
Fredrik Svedberg16343052021-04-16 14:36:22 +0200340 custom_options = attrs.pop("custom_options", None)
Tim Hall79d07d22020-04-27 18:20:16 +0100341 if attrs not in unique_ops:
342 unique_ops.append(attrs)
343 # print attributes in human readable format
344 a = attrs.copy()
Fredrik Svedberg16343052021-04-16 14:36:22 +0200345 if custom_options is not None:
346 a["custom_options"] = custom_options
Tim Hall79d07d22020-04-27 18:20:16 +0100347 s = a.pop("type")
348 data_format = a.pop("data_format", None)
349 if data_format and data_format != b"NHWC":
350 s += " " + str(data_format)
351 t = a.pop("T", None)
352 if t:
353 s += " " + str(t)[9:-2]
354 srct = a.pop("SrcT", None)
355 if srct:
356 s += " " + str(srct)[9:-2]
357 dstt = a.pop("DstT", None)
358 if dstt:
359 s += "->" + str(dstt)[9:-2]
360 print(s + " " + str(a))
361
Fredrik Svedbergc875aa62021-05-06 09:53:31 +0200362 def print_graph(self, label=None):
363 if label:
364 print(f"\n[ {label} ]")
Michael McGeagh775e3962020-07-28 11:44:22 +0100365 print("print_graph()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100366 all_ops = self.get_all_ops()
367 for idx, op in enumerate(all_ops):
368 print(idx, op.type, op.name)
369
370 def print_graph_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100371 print("print_graph_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100372 all_ops = self.get_all_ops()
373 for idx, op in enumerate(all_ops):
374 print(idx, op.type, op.name)
375 for idx, tens in enumerate(op.inputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200376 print(
377 " Input %02d %20s %20s %20s %s"
378 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
379 )
Tim Hall79d07d22020-04-27 18:20:16 +0100380 for idx, tens in enumerate(op.outputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200381 print(
382 " Output %02d %20s %20s %20s %s"
383 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
384 )
Tim Hall79d07d22020-04-27 18:20:16 +0100385 print()
386
387 def print_graph_with_tensor_quantization(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100388 print("print_graph_with_tensor_quantization()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100389 all_ops = self.get_all_ops()
390 for idx, op in enumerate(all_ops):
391 print(idx, op.type, op.name)
392 for idx, tens in enumerate(op.inputs):
393 q = tens.quantization
394 if q is None:
395 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
396 else:
397 print(
398 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
399 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
400 )
401 for idx, tens in enumerate(op.outputs):
402 q = tens.quantization
403 if q is None:
404 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
405 else:
406 print(
407 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
408 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
409 )
410 print()
411
412 def print_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100413 print("print_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100414 for idx, ps in enumerate(self.passes):
415 print("%03d %s" % (idx * 2, ps))
416
417 def print_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100418 print("print_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100419 for idx, ps in enumerate(self.passes):
420 print("%3d %s" % (idx * 2, ps))
421 for idx, tens in enumerate(ps.inputs):
422 print(
423 " Input %2d %-15s %-15s %-15s %s"
424 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
425 )
426 for idx, tens in enumerate(ps.intermediates):
427 print(
428 " Intermediate %2d %-15s %-15s %-15s %s"
429 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
430 )
431 for idx, tens in enumerate(ps.outputs):
432 print(
433 " Output %2d %-15s %-15s %-15s %s"
434 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
435 )
436 print()
437
438 def print_cascaded_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100439 print("print_cascaded_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100440 for idx, ps in enumerate(self.cascaded_passes):
441 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
442
443 def print_cascaded_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100444 print("print_cascaded_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100445 for idx, ps in enumerate(self.cascaded_passes):
446 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
447 for idx, tens in enumerate(ps.inputs):
448 print(
449 " Input %2d %-15s %-15s %-15s %s"
450 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
451 )
452 for idx, tens in enumerate(ps.intermediates):
453 print(
454 " Intermediate %2d %-15s %-15s %-15s %s"
455 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
456 )
457 for idx, tens in enumerate(ps.outputs):
458 print(
459 " Output %2d %-15s %-15s %-15s %s"
460 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
461 )
462 print()
463
464 def print_cascaded_passes_with_tensor_sizes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100465 print("print_cascaded_passes_with_tensor_sizes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100466 for idx, ps in enumerate(self.cascaded_passes):
467 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
468 for idx, tens in enumerate(ps.inputs):
469 print(
470 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
471 % (
472 idx,
473 tens.storage_size() / 1024,
474 tens.storage_shape,
475 tens.mem_area.name,
476 tens.purpose.name,
477 tens.format.name,
478 tens.name,
479 )
480 )
481 for idx, tens in enumerate(ps.intermediates):
482 print(
483 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
484 % (
485 idx,
486 tens.storage_size() / 1024,
487 tens.storage_shape,
488 tens.mem_area.name,
489 tens.purpose.name,
490 tens.format.name,
491 tens.name,
492 )
493 )
494 for idx, tens in enumerate(ps.outputs):
495 print(
496 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
497 % (
498 idx,
499 tens.storage_size() / 1024,
500 tens.storage_shape,
501 tens.mem_area.name,
502 tens.purpose.name,
503 tens.format.name,
504 tens.name,
505 )
506 )
507 print()
508
509 def print_high_level_command_stream(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100510 print("print_high_level_command_stream()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100511 for idx, cmd in enumerate(self.high_level_command_stream):
512 print("%3d %s" % (idx, cmd))
513
514
515class Graph:
516 def __init__(self, name="<unnamed>", batch_size=1):
517 self.name = name
518 self.batch_size = batch_size
519 self.subgraphs = []
Michael McGeagh22f74e12020-08-07 16:21:03 +0100520 self.metadata = []
Tim Hall79d07d22020-04-27 18:20:16 +0100521 self.memory_used = {}
Diqing Zhongdb5124c2021-01-11 12:52:48 +0100522 self.total_original_weights = 0
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200523 self.total_npu_weights = 0
524 self.total_npu_encoded_weights = 0
Louis Verhaard3c07c972020-05-07 08:12:58 +0200525 self.weight_cache = None # See CompressedWeightCache
Tim Hall79d07d22020-04-27 18:20:16 +0100526
527 def get_root_subgraph(self):
528 return self.subgraphs[0]
529
530 def prune_startup_init_pass(self):
531 for sg in self.subgraphs:
532 sg.prune_startup_init_pass()
533
534 def update_consumers(self):
535 for sg in self.subgraphs:
536 sg.update_consumers()
537
538 def refresh_after_modification(self):
539 for sg in self.subgraphs:
540 sg.refresh_after_modification()
541
542 def print_operators(self):
543 for sg in self.subgraphs:
544 sg.print_operators()
545
Fredrik Svedbergc875aa62021-05-06 09:53:31 +0200546 def print_graph(self, label=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100547 for sg in self.subgraphs:
Fredrik Svedbergc875aa62021-05-06 09:53:31 +0200548 sg.print_graph(label)
Tim Hall79d07d22020-04-27 18:20:16 +0100549
550 def print_graph_with_tensors(self):
551 for sg in self.subgraphs:
552 sg.print_graph_with_tensors()
553
554 def print_graph_with_tensor_quantization(self):
555 for sg in self.subgraphs:
556 sg.print_graph_with_tensor_quantization()
557
558 def print_passes(self):
559 for sg in self.subgraphs:
560 sg.print_passes()
561
562 def print_passes_with_tensors(self):
563 for sg in self.subgraphs:
564 sg.print_passes_with_tensors()
565
566 def print_cascaded_passes(self):
567 for sg in self.subgraphs:
568 sg.print_cascaded_passes()
569
570 def print_cascaded_passes_with_tensors(self):
571 for sg in self.subgraphs:
572 sg.print_cascaded_passes_with_tensors()
573
574 def print_cascaded_passes_with_tensor_sizes(self):
575 for sg in self.subgraphs:
576 sg.print_cascaded_passes_with_tensor_sizes()
577
578 def print_high_level_command_stream(self):
579 for sg in self.subgraphs:
580 sg.print_high_level_command_stream()