blob: b9eee28b6c1bad452f8f7e360a980dce4c5a7717 [file] [log] [blame]
Johan Alfven014bc282024-01-25 12:32:13 +01001# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Neural network graph classes and enums.
19# Pass - A packed pass containing one or more Operations.
20# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
21# configurations.
22# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
23# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010024import enum
patrik.gustavssoneeb85152020-12-21 17:10:40 +000025from typing import List
Tim Hall79d07d22020-04-27 18:20:16 +010026
Louis Verhaardaee5d752020-09-30 09:01:52 +020027from .operation import Op
patrik.gustavssoneeb85152020-12-21 17:10:40 +000028from .shape4d import Shape4D
Louis Verhaardaee5d752020-09-30 09:01:52 +020029
Tim Hall79d07d22020-04-27 18:20:16 +010030
31class PassPlacement(enum.Enum):
32 Unknown = 0
33 Cpu = 1
34 Npu = 2
35 MemoryOnly = 3
36 StartupInit = 4
37
38
39class TensorAllocator(enum.Enum):
40 LinearAlloc = 1
41 Greedy = 2
Louis Verhaardd7002522021-01-20 17:23:54 +010042 HillClimb = 3
Tim Hall79d07d22020-04-27 18:20:16 +010043
44 def __str__(self):
45 return self.name
46
47
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048class NetworkType(enum.Enum):
49 TFLite = 1
50 TOSA = 2
51
52
Tim Hall79d07d22020-04-27 18:20:16 +010053class Pass:
54 def __init__(self, name, placement, is_element_wise, npu_block_type):
55 self.inputs = []
56 self.intermediates = []
57 self.outputs = []
58 self.ops = []
59 self.primary_op = None
60 self.ifm_tensor = None
61 self.ifm2_tensor = None
62 self.ofm_tensor = None
63 self.weight_tensor = None
64 self.scale_tensor = None
Fredrik Svedberga0c36242020-06-03 15:43:31 +020065 self.lut_tensor = None
Tim Hall79d07d22020-04-27 18:20:16 +010066 self.name = name
67 self.cascade = None
68 self.placement = placement
patrik.gustavssoneeb85152020-12-21 17:10:40 +000069 self.ifm_shapes: List[Shape4D] = []
70 self.ofm_shapes: List[Shape4D] = []
Tim Hall79d07d22020-04-27 18:20:16 +010071
72 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
73 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
74 self.is_element_wise = is_element_wise
75 self.npu_block_type = npu_block_type
76 self.block_config = None # will be filled in by scheduler
77 self.shared_buffer = None # will be filled in by scheduler
Tim Halld8339a72021-05-27 18:49:40 +010078 self.scheduling_info = None # will be filled in by scheduler
Tim Hall79d07d22020-04-27 18:20:16 +010079
80 self.predecessors = []
81 self.successors = []
82
83 def __str__(self):
84 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
85
86 __repr__ = __str__
87
88 def get_primary_op_ifm_weights(self):
89 if not self.primary_op:
90 return None, None
91 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
92
93 def get_primary_op_ifm_ifm2_weights_ofm(self):
94 if not self.primary_op:
95 return None, None, None, None
96 return self.primary_op.get_ifm_ifm2_weights_ofm()
97
98 def get_primary_op_ifm_weights_biases_ofm(self):
99 if not self.primary_op:
100 return None, None, None, None
101 return self.primary_op.get_ifm_weights_biases_ofm()
102
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200103 def get_primary_op_lut(self):
104 if not self.primary_op:
105 return None
106 return self.primary_op.activation_lut
107
Tim Hall79d07d22020-04-27 18:20:16 +0100108
109class SchedulingStrategy(enum.Enum):
110 Unknown = -1
111 IfmStream = 0
112 WeightStream = 1
113
114
115class SchedulerRewrite(enum.Enum):
116 Nop = 0
117 ChangeTensorSubPurpose = 1
118
119
120class CascadedPass:
121 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
122 self.name = name
123 self.strategy = strat
124 self.inputs = inputs
125 self.intermediates = intermediates
126 self.outputs = outputs
127 self.passes = passes
128 self.placement = placement
129 self.is_element_wise = is_element_wise
130
131 self.predecessors = []
132 self.successors = []
Tim Halld8339a72021-05-27 18:49:40 +0100133 self.sram_used = 0
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100134 self.time = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100135
136 def __str__(self):
137 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
138 self.strategy,
139 len(self.passes),
140 self.name,
141 [ps.name for ps in self.passes],
142 [ps.block_config for ps in self.passes],
143 )
144
145 __repr__ = __str__
146
147
148class Subgraph:
149 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
150 self.output_tensors = []
151 self.input_tensors = []
Johan Alfven9070f0f2023-02-07 13:01:03 +0100152 # Preserve the original input order
153 self.original_inputs = []
154 # Attach virtual outputs to resource variables op
155 # in order to be able to traverse the graph correctly
156 self.virtual_outputs = []
Tim Hall79d07d22020-04-27 18:20:16 +0100157 self.passes = []
158 self.cascaded_passes = []
159 self.name = name
160 self.high_level_command_stream = []
161 self.placement = placement
162 self.command_stream_tensor = None
163 self.flash_tensor = None
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200164 # Scratch information locally used in the scheduler
Tim Halld8339a72021-05-27 18:49:40 +0100165 self.schedule = None
166 self.sched_ops = []
167
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100168 self.generated_stream_id = None
Tim Hall79d07d22020-04-27 18:20:16 +0100169
170 self.memory_used = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200171 self.memory_used_per_type = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100172
173 def __str__(self):
174 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
175 self.name,
176 len(self.passes),
177 len(self.cascaded_passes),
178 )
179
180 __repr__ = __str__
181
182 def update_consumers(self):
183 visit_op_set = set()
184 visit_tensor_set = set()
Johan Alfvenabed3c22024-04-04 10:08:05 +0200185 sg_passes_op_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100186 self.input_tensors = []
187
Johan Alfvenabed3c22024-04-04 10:08:05 +0200188 for ps in self.passes:
189 for op in ps.ops:
190 sg_passes_op_set.add(op)
191
Tim Hall79d07d22020-04-27 18:20:16 +0100192 print_visit = False
193
194 def visit_op(op):
Johan Alfvenabed3c22024-04-04 10:08:05 +0200195 if op in visit_op_set or (sg_passes_op_set and op not in sg_passes_op_set):
196 # Op already visited or op is not part of a pass in this subgraph
197 # Typcial case when op is not part of this subgraph but is visited anyway are concat ops
198 # that are split up into different subgraphs (several avgpool). Since they share the same
199 # output the avgpool that do not belong to this subgraph will be traversed which
200 # should be avoided.
Tim Hall79d07d22020-04-27 18:20:16 +0100201 return
202
203 visit_op_set.add(op)
204 for inp in op.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200205 if not inp:
206 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100207 if print_visit:
208 print(inp, "adding consumer", op)
209 visit_tensor(inp)
210 inp.consumer_list.append(op)
211
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000212 if op.type in (Op.Placeholder, Op.SubgraphInput):
Tim Hall79d07d22020-04-27 18:20:16 +0100213 assert len(op.outputs) == 1
Fredrik Svedberg33c01e62023-02-13 11:32:12 +0100214 if not op.outputs[0].is_variable:
215 self.input_tensors.append(op.outputs[0])
Tim Hall79d07d22020-04-27 18:20:16 +0100216
217 for out in op.outputs:
218 if out not in visit_tensor_set:
219 out.consumer_list = [] # reset unvisited output, just in case
220
221 def visit_tensor(tens):
222 if tens in visit_tensor_set:
223 return
224 visit_tensor_set.add(tens)
225 tens.consumer_list = []
226 for op in tens.ops:
227 visit_op(op)
228
229 for ps in self.passes:
230 for tens in ps.outputs + ps.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200231 if not tens:
232 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100233 tens.consumer_list = [] # reset unvisited tensors to start with
234
235 for tens in self.output_tensors:
236 visit_tensor(tens)
237 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
238
239 print_visit = True
240 for ps in self.passes:
241 for op in ps.ops:
242 visit_op(op)
243 for tens in ps.inputs:
244 visit_tensor(tens)
245
246 def build_pass_links(self):
247 for idx, ps in enumerate(self.passes):
248 ps.time = 2 * idx
249 ps.predecessors = []
250 ps.successors = []
251
252 for ps in self.passes:
253 for tens in ps.inputs:
254 for op in tens.ops:
255 pred_pass = op.scheduled_pass
Johan Alfvenf9194e12024-04-22 15:17:33 +0200256 # Pass with split concat ops may end up with a dependency to
257 # itself since output from concat is produced by several avg pool ops.
258 # Hence pred_pass can be equal to ps.
259 assert pred_pass == ps or pred_pass.time < ps.time
Tim Hall79d07d22020-04-27 18:20:16 +0100260 if ps not in pred_pass.successors:
261 pred_pass.successors.append(ps)
262
263 if pred_pass not in ps.predecessors:
264 ps.predecessors.append(pred_pass)
265
266 assert tens in pred_pass.outputs
267
268 def build_pass_dag_predecessors(self):
269 for ps in self.passes:
270 ps.dag_predecessors = []
271
272 class State(enum.Enum):
273 NotVisited = 0
274 BeingVisited = 1
275 Visited = 2
276
277 pass_visit_dict = {}
278
279 def visit_pass(ps):
280 state = pass_visit_dict.get(ps, State.NotVisited)
281 if state == State.Visited:
282 return True
283 elif state == State.BeingVisited:
284 return False # this is a loop, need to remove this link
285 elif state == State.NotVisited:
286 pass_visit_dict[ps] = State.BeingVisited
287
288 ps.dag_predecessors = []
289 for pred in ps.predecessors:
290 if visit_pass(pred):
291 ps.dag_predecessors.append(pred)
292
293 pass_visit_dict[ps] = State.Visited
294 return True
295
296 for ps in self.passes:
297 if not ps.successors:
298 visit_pass(ps)
299
300 def build_cascaded_pass_links(self):
301 for cps in self.cascaded_passes:
302 cps.predecessors = []
303 cps.successors = []
304
305 for cps in self.cascaded_passes:
306 for tens in cps.inputs:
307 for op in tens.ops:
308 pred_cpass = op.scheduled_pass.cascade
309 if cps not in pred_cpass.successors:
310 pred_cpass.successors.append(cps)
311
312 if pred_cpass not in cps.predecessors:
313 cps.predecessors.append(pred_cpass)
314
315 assert tens in pred_cpass.outputs
316
317 def refresh_after_modification(self):
Rickard Bolin26c8e842023-05-11 10:53:42 +0000318 try:
319 self.update_consumers()
320 except RecursionError as e:
321 raise RecursionError(
322 "Compilation failed due to exceeding the default maximum recursion depth.\n"
323 'Try increasing the maximum recursion depth with the "--recursion-limit" option.'
324 ) from e
Tim Hall79d07d22020-04-27 18:20:16 +0100325
326 def prune_startup_init_pass(self):
327 assert len(self.passes) >= 1
328 ps = self.passes[0]
329 assert ps.placement == PassPlacement.StartupInit
330
331 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
332 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
333
Johan Alfven014bc282024-01-25 12:32:13 +0100334 # get_all_ops is used when traversing the original graph
Tim Hall79d07d22020-04-27 18:20:16 +0100335 def get_all_ops(self):
336 all_ops = []
337 visit_op_set = set()
338 visit_tensor_set = set()
339
340 def visit_op(op):
341 if op in visit_op_set:
342 return
343 visit_op_set.add(op)
344 for inp in op.inputs:
345 visit_tensor(inp)
346
347 all_ops.append(op)
348
349 def visit_tensor(tens):
Andreas Nevalainene1cc3de2020-09-08 15:31:02 +0200350 if tens is None or tens in visit_tensor_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100351 return
352 visit_tensor_set.add(tens)
353 for op in tens.ops:
354 visit_op(op)
355
356 for tens in self.output_tensors:
357 visit_tensor(tens)
358
359 return all_ops
360
Johan Alfven014bc282024-01-25 12:32:13 +0100361 # get_all_ops_from_passes is used by stats writer to calculate the number of
362 # CPU and NPU ops
363 # Due to a side effect get_all_ops might not be traversing the full graph
364 # after extract_npu_subgraph have been called and should not be used by stats writer.
365 # The reason is that the main graph might have NPU nodes with no visible outputs
366 # and therefore the nodes will be missed.
367 def get_all_ops_from_passes(self):
368 all_ops = []
369 for idx, ps in enumerate(self.passes):
370 for op in ps.ops:
371 all_ops.append(op)
372
373 return all_ops
374
Tim Hallcd035042023-08-08 14:10:17 +0100375 def print_operators(self, ignore_placeholder_const=True, show_attributes=True):
376 print(f"Operators of Subgraph {self.name}")
Tim Hall79d07d22020-04-27 18:20:16 +0100377
Tim Hallcd035042023-08-08 14:10:17 +0100378 ignore_ops = (Op.Const, Op.Identity, Op.Placeholder) if ignore_placeholder_const else ()
379 all_ops = [op for op in self.get_all_ops() if op.type not in ignore_ops]
380
381 if len(all_ops) > 0:
382 max_op_type_len = max([len(op.type.name) for op in all_ops])
383
384 for idx, op in enumerate(all_ops):
385 attrs_str = f" - {op.attrs}" if show_attributes else ""
386 print(f"{idx:3}: {op.type:{max_op_type_len}}{attrs_str} - {op.name}")
387
388 else:
389 print("No Operators")
Tim Hall79d07d22020-04-27 18:20:16 +0100390
Fredrik Svedbergc875aa62021-05-06 09:53:31 +0200391 def print_graph(self, label=None):
392 if label:
393 print(f"\n[ {label} ]")
Michael McGeagh775e3962020-07-28 11:44:22 +0100394 print("print_graph()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100395 all_ops = self.get_all_ops()
396 for idx, op in enumerate(all_ops):
397 print(idx, op.type, op.name)
398
399 def print_graph_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100400 print("print_graph_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100401 all_ops = self.get_all_ops()
402 for idx, op in enumerate(all_ops):
403 print(idx, op.type, op.name)
404 for idx, tens in enumerate(op.inputs):
Fredrik Svedbergb3d941e2021-10-13 14:06:03 +0200405 if tens:
406 print(
407 f" Input {idx:02d}"
408 f" {tens.purpose.name:>20} {tens.mem_area.name:>20} {tens.mem_type.name:>20} {tens}"
409 )
410 else:
411 print(f" Input {idx:02d} {'-':>20} {'-':>20} {'-':>20} {tens}")
Tim Hall79d07d22020-04-27 18:20:16 +0100412 for idx, tens in enumerate(op.outputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200413 print(
Fredrik Svedbergb3d941e2021-10-13 14:06:03 +0200414 f" Output {idx:02d}"
415 f" {tens.purpose.name:>20} {tens.mem_area.name:>20} {tens.mem_type.name:>20} {tens}"
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200416 )
Tim Hall79d07d22020-04-27 18:20:16 +0100417 print()
418
419 def print_graph_with_tensor_quantization(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100420 print("print_graph_with_tensor_quantization()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100421 all_ops = self.get_all_ops()
422 for idx, op in enumerate(all_ops):
423 print(idx, op.type, op.name)
424 for idx, tens in enumerate(op.inputs):
Fredrik Svedbergb3d941e2021-10-13 14:06:03 +0200425 if tens:
426 q = tens.quantization
427 if q is None:
428 print(f" Input {idx:02d} {tens.dtype!s:>10} NO QUANTIZATION INFO {tens.name}")
429 else:
430 print(
431 f" Input {idx:02d} {tens.dtype!s:>10}"
432 f" min={q.min} max={q.max} scale={q.scale_f32!s} zero_point={q.zero_point} {tens.name}"
433 )
Tim Hall79d07d22020-04-27 18:20:16 +0100434 else:
Fredrik Svedbergb3d941e2021-10-13 14:06:03 +0200435 print(f" Input {idx:02d} {'-':>10} {tens}")
Tim Hall79d07d22020-04-27 18:20:16 +0100436 for idx, tens in enumerate(op.outputs):
437 q = tens.quantization
438 if q is None:
Fredrik Svedbergb3d941e2021-10-13 14:06:03 +0200439 print(f" Output {idx:02d} {tens.dtype!s:>10} NO QUANTIZATION INFO {tens.name}")
Tim Hall79d07d22020-04-27 18:20:16 +0100440 else:
441 print(
Fredrik Svedbergb3d941e2021-10-13 14:06:03 +0200442 f" Output {idx:02d} {tens.dtype!s:>10}"
443 f" min={q.min} max={q.max} scale={q.scale_f32!s} zero_point={q.zero_point} {tens.name}"
Tim Hall79d07d22020-04-27 18:20:16 +0100444 )
445 print()
446
447 def print_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100448 print("print_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100449 for idx, ps in enumerate(self.passes):
450 print("%03d %s" % (idx * 2, ps))
451
452 def print_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100453 print("print_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100454 for idx, ps in enumerate(self.passes):
455 print("%3d %s" % (idx * 2, ps))
456 for idx, tens in enumerate(ps.inputs):
457 print(
458 " Input %2d %-15s %-15s %-15s %s"
459 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
460 )
461 for idx, tens in enumerate(ps.intermediates):
462 print(
463 " Intermediate %2d %-15s %-15s %-15s %s"
464 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
465 )
466 for idx, tens in enumerate(ps.outputs):
467 print(
468 " Output %2d %-15s %-15s %-15s %s"
469 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
470 )
471 print()
472
473 def print_cascaded_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100474 print("print_cascaded_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100475 for idx, ps in enumerate(self.cascaded_passes):
476 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
477
478 def print_cascaded_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100479 print("print_cascaded_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100480 for idx, ps in enumerate(self.cascaded_passes):
481 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
482 for idx, tens in enumerate(ps.inputs):
483 print(
484 " Input %2d %-15s %-15s %-15s %s"
485 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
486 )
487 for idx, tens in enumerate(ps.intermediates):
488 print(
489 " Intermediate %2d %-15s %-15s %-15s %s"
490 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
491 )
492 for idx, tens in enumerate(ps.outputs):
493 print(
494 " Output %2d %-15s %-15s %-15s %s"
495 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
496 )
497 print()
498
499 def print_cascaded_passes_with_tensor_sizes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100500 print("print_cascaded_passes_with_tensor_sizes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100501 for idx, ps in enumerate(self.cascaded_passes):
502 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
503 for idx, tens in enumerate(ps.inputs):
504 print(
505 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
506 % (
507 idx,
508 tens.storage_size() / 1024,
509 tens.storage_shape,
510 tens.mem_area.name,
511 tens.purpose.name,
512 tens.format.name,
513 tens.name,
514 )
515 )
516 for idx, tens in enumerate(ps.intermediates):
517 print(
518 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
519 % (
520 idx,
521 tens.storage_size() / 1024,
522 tens.storage_shape,
523 tens.mem_area.name,
524 tens.purpose.name,
525 tens.format.name,
526 tens.name,
527 )
528 )
529 for idx, tens in enumerate(ps.outputs):
530 print(
531 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
532 % (
533 idx,
534 tens.storage_size() / 1024,
535 tens.storage_shape,
536 tens.mem_area.name,
537 tens.purpose.name,
538 tens.format.name,
539 tens.name,
540 )
541 )
542 print()
543
544 def print_high_level_command_stream(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100545 print("print_high_level_command_stream()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100546 for idx, cmd in enumerate(self.high_level_command_stream):
547 print("%3d %s" % (idx, cmd))
548
549
550class Graph:
551 def __init__(self, name="<unnamed>", batch_size=1):
552 self.name = name
553 self.batch_size = batch_size
554 self.subgraphs = []
Michael McGeagh22f74e12020-08-07 16:21:03 +0100555 self.metadata = []
Tim Hall79d07d22020-04-27 18:20:16 +0100556 self.memory_used = {}
Diqing Zhongdb5124c2021-01-11 12:52:48 +0100557 self.total_original_weights = 0
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200558 self.total_npu_encoded_weights = 0
Louis Verhaard3c07c972020-05-07 08:12:58 +0200559 self.weight_cache = None # See CompressedWeightCache
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100560 self.bandwidths = 0
561 self.macs = 0
562 self.cycles = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100563
564 def get_root_subgraph(self):
565 return self.subgraphs[0]
566
567 def prune_startup_init_pass(self):
568 for sg in self.subgraphs:
569 sg.prune_startup_init_pass()
570
571 def update_consumers(self):
572 for sg in self.subgraphs:
573 sg.update_consumers()
574
575 def refresh_after_modification(self):
576 for sg in self.subgraphs:
577 sg.refresh_after_modification()
578
Tim Hallcd035042023-08-08 14:10:17 +0100579 def print_operators(self, ignore_placeholder_const=True, show_attributes=True):
Tim Hall79d07d22020-04-27 18:20:16 +0100580 for sg in self.subgraphs:
Tim Hallcd035042023-08-08 14:10:17 +0100581 sg.print_operators(ignore_placeholder_const, show_attributes)
Tim Hall79d07d22020-04-27 18:20:16 +0100582
Fredrik Svedbergc875aa62021-05-06 09:53:31 +0200583 def print_graph(self, label=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100584 for sg in self.subgraphs:
Fredrik Svedbergc875aa62021-05-06 09:53:31 +0200585 sg.print_graph(label)
Tim Hall79d07d22020-04-27 18:20:16 +0100586
587 def print_graph_with_tensors(self):
588 for sg in self.subgraphs:
589 sg.print_graph_with_tensors()
590
591 def print_graph_with_tensor_quantization(self):
592 for sg in self.subgraphs:
593 sg.print_graph_with_tensor_quantization()
594
595 def print_passes(self):
596 for sg in self.subgraphs:
597 sg.print_passes()
598
599 def print_passes_with_tensors(self):
600 for sg in self.subgraphs:
601 sg.print_passes_with_tensors()
602
603 def print_cascaded_passes(self):
604 for sg in self.subgraphs:
605 sg.print_cascaded_passes()
606
607 def print_cascaded_passes_with_tensors(self):
608 for sg in self.subgraphs:
609 sg.print_cascaded_passes_with_tensors()
610
611 def print_cascaded_passes_with_tensor_sizes(self):
612 for sg in self.subgraphs:
613 sg.print_cascaded_passes_with_tensor_sizes()
614
615 def print_high_level_command_stream(self):
616 for sg in self.subgraphs:
617 sg.print_high_level_command_stream()