blob: 12edf5ef481f4f294dad8cec8a88ae6884aab300 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Neural network graph classes and enums.
18# Pass - A packed pass containing one or more Operations.
19# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
20# configurations.
21# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
22# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010023import enum
Tim Hall79d07d22020-04-27 18:20:16 +010024
Louis Verhaardaee5d752020-09-30 09:01:52 +020025from .operation import Op
26
Tim Hall79d07d22020-04-27 18:20:16 +010027
28class PassPlacement(enum.Enum):
29 Unknown = 0
30 Cpu = 1
31 Npu = 2
32 MemoryOnly = 3
33 StartupInit = 4
34
35
36class TensorAllocator(enum.Enum):
37 LinearAlloc = 1
38 Greedy = 2
39
40 def __str__(self):
41 return self.name
42
43
44class Pass:
45 def __init__(self, name, placement, is_element_wise, npu_block_type):
46 self.inputs = []
47 self.intermediates = []
48 self.outputs = []
49 self.ops = []
50 self.primary_op = None
51 self.ifm_tensor = None
52 self.ifm2_tensor = None
53 self.ofm_tensor = None
54 self.weight_tensor = None
55 self.scale_tensor = None
Fredrik Svedberga0c36242020-06-03 15:43:31 +020056 self.lut_tensor = None
Tim Hall79d07d22020-04-27 18:20:16 +010057 self.name = name
58 self.cascade = None
59 self.placement = placement
60
61 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
62 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
63 self.is_element_wise = is_element_wise
64 self.npu_block_type = npu_block_type
65 self.block_config = None # will be filled in by scheduler
66 self.shared_buffer = None # will be filled in by scheduler
67
68 self.predecessors = []
69 self.successors = []
70
71 def __str__(self):
72 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
73
74 __repr__ = __str__
75
76 def get_primary_op_ifm_weights(self):
77 if not self.primary_op:
78 return None, None
79 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
80
81 def get_primary_op_ifm_ifm2_weights_ofm(self):
82 if not self.primary_op:
83 return None, None, None, None
84 return self.primary_op.get_ifm_ifm2_weights_ofm()
85
86 def get_primary_op_ifm_weights_biases_ofm(self):
87 if not self.primary_op:
88 return None, None, None, None
89 return self.primary_op.get_ifm_weights_biases_ofm()
90
Fredrik Svedberga0c36242020-06-03 15:43:31 +020091 def get_primary_op_lut(self):
92 if not self.primary_op:
93 return None
94 return self.primary_op.activation_lut
95
Tim Hall79d07d22020-04-27 18:20:16 +010096
97class SchedulingStrategy(enum.Enum):
98 Unknown = -1
99 IfmStream = 0
100 WeightStream = 1
101
102
103class SchedulerRewrite(enum.Enum):
104 Nop = 0
105 ChangeTensorSubPurpose = 1
106
107
108class CascadedPass:
109 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
110 self.name = name
111 self.strategy = strat
112 self.inputs = inputs
113 self.intermediates = intermediates
114 self.outputs = outputs
115 self.passes = passes
116 self.placement = placement
117 self.is_element_wise = is_element_wise
118
119 self.predecessors = []
120 self.successors = []
121
122 def __str__(self):
123 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
124 self.strategy,
125 len(self.passes),
126 self.name,
127 [ps.name for ps in self.passes],
128 [ps.block_config for ps in self.passes],
129 )
130
131 __repr__ = __str__
132
133
134class Subgraph:
135 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
136 self.output_tensors = []
137 self.input_tensors = []
138 self.original_inputs = [] # Preserve the original input order
139 self.passes = []
140 self.cascaded_passes = []
141 self.name = name
142 self.high_level_command_stream = []
143 self.placement = placement
144 self.command_stream_tensor = None
145 self.flash_tensor = None
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200146 # Scratch information locally used in the scheduler
147 self.scheduling_info = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100148
149 self.memory_used = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200150 self.memory_used_per_type = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100151
152 def __str__(self):
153 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
154 self.name,
155 len(self.passes),
156 len(self.cascaded_passes),
157 )
158
159 __repr__ = __str__
160
161 def update_consumers(self):
162 visit_op_set = set()
163 visit_tensor_set = set()
164 self.input_tensors = []
165
166 print_visit = False
167
168 def visit_op(op):
169 if op in visit_op_set:
170 return
171
172 visit_op_set.add(op)
173 for inp in op.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200174 if not inp:
175 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100176 if print_visit:
177 print(inp, "adding consumer", op)
178 visit_tensor(inp)
179 inp.consumer_list.append(op)
180
Louis Verhaardaee5d752020-09-30 09:01:52 +0200181 if op.type in set((Op.Placeholder, Op.SubgraphInput)):
Tim Hall79d07d22020-04-27 18:20:16 +0100182 assert len(op.outputs) == 1
183 self.input_tensors.append(op.outputs[0])
184
185 for out in op.outputs:
186 if out not in visit_tensor_set:
187 out.consumer_list = [] # reset unvisited output, just in case
188
189 def visit_tensor(tens):
190 if tens in visit_tensor_set:
191 return
192 visit_tensor_set.add(tens)
193 tens.consumer_list = []
194 for op in tens.ops:
195 visit_op(op)
196
197 for ps in self.passes:
198 for tens in ps.outputs + ps.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200199 if not tens:
200 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100201 tens.consumer_list = [] # reset unvisited tensors to start with
202
203 for tens in self.output_tensors:
204 visit_tensor(tens)
205 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
206
207 print_visit = True
208 for ps in self.passes:
209 for op in ps.ops:
210 visit_op(op)
211 for tens in ps.inputs:
212 visit_tensor(tens)
213
214 def build_pass_links(self):
215 for idx, ps in enumerate(self.passes):
216 ps.time = 2 * idx
217 ps.predecessors = []
218 ps.successors = []
219
220 for ps in self.passes:
221 for tens in ps.inputs:
222 for op in tens.ops:
223 pred_pass = op.scheduled_pass
224 assert pred_pass.time < ps.time
225 if ps not in pred_pass.successors:
226 pred_pass.successors.append(ps)
227
228 if pred_pass not in ps.predecessors:
229 ps.predecessors.append(pred_pass)
230
231 assert tens in pred_pass.outputs
232
233 def build_pass_dag_predecessors(self):
234 for ps in self.passes:
235 ps.dag_predecessors = []
236
237 class State(enum.Enum):
238 NotVisited = 0
239 BeingVisited = 1
240 Visited = 2
241
242 pass_visit_dict = {}
243
244 def visit_pass(ps):
245 state = pass_visit_dict.get(ps, State.NotVisited)
246 if state == State.Visited:
247 return True
248 elif state == State.BeingVisited:
249 return False # this is a loop, need to remove this link
250 elif state == State.NotVisited:
251 pass_visit_dict[ps] = State.BeingVisited
252
253 ps.dag_predecessors = []
254 for pred in ps.predecessors:
255 if visit_pass(pred):
256 ps.dag_predecessors.append(pred)
257
258 pass_visit_dict[ps] = State.Visited
259 return True
260
261 for ps in self.passes:
262 if not ps.successors:
263 visit_pass(ps)
264
265 def build_cascaded_pass_links(self):
266 for cps in self.cascaded_passes:
267 cps.predecessors = []
268 cps.successors = []
269
270 for cps in self.cascaded_passes:
271 for tens in cps.inputs:
272 for op in tens.ops:
273 pred_cpass = op.scheduled_pass.cascade
274 if cps not in pred_cpass.successors:
275 pred_cpass.successors.append(cps)
276
277 if pred_cpass not in cps.predecessors:
278 cps.predecessors.append(pred_cpass)
279
280 assert tens in pred_cpass.outputs
281
282 def refresh_after_modification(self):
283 self.update_consumers()
284
285 def prune_startup_init_pass(self):
286 assert len(self.passes) >= 1
287 ps = self.passes[0]
288 assert ps.placement == PassPlacement.StartupInit
289
290 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
291 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
292
293 def get_all_ops(self):
294 all_ops = []
295 visit_op_set = set()
296 visit_tensor_set = set()
297
298 def visit_op(op):
299 if op in visit_op_set:
300 return
301 visit_op_set.add(op)
302 for inp in op.inputs:
303 visit_tensor(inp)
304
305 all_ops.append(op)
306
307 def visit_tensor(tens):
Andreas Nevalainene1cc3de2020-09-08 15:31:02 +0200308 if tens is None or tens in visit_tensor_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100309 return
310 visit_tensor_set.add(tens)
311 for op in tens.ops:
312 visit_op(op)
313
314 for tens in self.output_tensors:
315 visit_tensor(tens)
316
317 return all_ops
318
319 def print_operators(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100320 print("print_operators()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100321 all_ops = self.get_all_ops()
322 unique_ops = []
Tim Hall79d07d22020-04-27 18:20:16 +0100323 for op in all_ops:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200324 if op.type in set((Op.Const, Op.Identity, Op.Placeholder)):
Tim Hall79d07d22020-04-27 18:20:16 +0100325 continue
326
Louis Verhaardaee5d752020-09-30 09:01:52 +0200327 attrs = op.attrs.copy()
328 if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias):
Tim Hall79d07d22020-04-27 18:20:16 +0100329 kshape = op.inputs[1].shape
330 attrs["kshape"] = [kshape[0], kshape[1]]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200331 attrs["type"] = op.type.name
Tim Hall79d07d22020-04-27 18:20:16 +0100332 attrs.pop("use_cudnn_on_gpu", None)
333 if attrs not in unique_ops:
334 unique_ops.append(attrs)
335 # print attributes in human readable format
336 a = attrs.copy()
337 s = a.pop("type")
338 data_format = a.pop("data_format", None)
339 if data_format and data_format != b"NHWC":
340 s += " " + str(data_format)
341 t = a.pop("T", None)
342 if t:
343 s += " " + str(t)[9:-2]
344 srct = a.pop("SrcT", None)
345 if srct:
346 s += " " + str(srct)[9:-2]
347 dstt = a.pop("DstT", None)
348 if dstt:
349 s += "->" + str(dstt)[9:-2]
350 print(s + " " + str(a))
351
352 def print_graph(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100353 print("print_graph()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100354 all_ops = self.get_all_ops()
355 for idx, op in enumerate(all_ops):
356 print(idx, op.type, op.name)
357
358 def print_graph_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100359 print("print_graph_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100360 all_ops = self.get_all_ops()
361 for idx, op in enumerate(all_ops):
362 print(idx, op.type, op.name)
363 for idx, tens in enumerate(op.inputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200364 print(
365 " Input %02d %20s %20s %20s %s"
366 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
367 )
Tim Hall79d07d22020-04-27 18:20:16 +0100368 for idx, tens in enumerate(op.outputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200369 print(
370 " Output %02d %20s %20s %20s %s"
371 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
372 )
Tim Hall79d07d22020-04-27 18:20:16 +0100373 print()
374
375 def print_graph_with_tensor_quantization(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100376 print("print_graph_with_tensor_quantization()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100377 all_ops = self.get_all_ops()
378 for idx, op in enumerate(all_ops):
379 print(idx, op.type, op.name)
380 for idx, tens in enumerate(op.inputs):
381 q = tens.quantization
382 if q is None:
383 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
384 else:
385 print(
386 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
387 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
388 )
389 for idx, tens in enumerate(op.outputs):
390 q = tens.quantization
391 if q is None:
392 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
393 else:
394 print(
395 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
396 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
397 )
398 print()
399
400 def print_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100401 print("print_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100402 for idx, ps in enumerate(self.passes):
403 print("%03d %s" % (idx * 2, ps))
404
405 def print_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100406 print("print_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100407 for idx, ps in enumerate(self.passes):
408 print("%3d %s" % (idx * 2, ps))
409 for idx, tens in enumerate(ps.inputs):
410 print(
411 " Input %2d %-15s %-15s %-15s %s"
412 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
413 )
414 for idx, tens in enumerate(ps.intermediates):
415 print(
416 " Intermediate %2d %-15s %-15s %-15s %s"
417 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
418 )
419 for idx, tens in enumerate(ps.outputs):
420 print(
421 " Output %2d %-15s %-15s %-15s %s"
422 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
423 )
424 print()
425
426 def print_cascaded_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100427 print("print_cascaded_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100428 for idx, ps in enumerate(self.cascaded_passes):
429 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
430
431 def print_cascaded_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100432 print("print_cascaded_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100433 for idx, ps in enumerate(self.cascaded_passes):
434 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
435 for idx, tens in enumerate(ps.inputs):
436 print(
437 " Input %2d %-15s %-15s %-15s %s"
438 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
439 )
440 for idx, tens in enumerate(ps.intermediates):
441 print(
442 " Intermediate %2d %-15s %-15s %-15s %s"
443 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
444 )
445 for idx, tens in enumerate(ps.outputs):
446 print(
447 " Output %2d %-15s %-15s %-15s %s"
448 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
449 )
450 print()
451
452 def print_cascaded_passes_with_tensor_sizes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100453 print("print_cascaded_passes_with_tensor_sizes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100454 for idx, ps in enumerate(self.cascaded_passes):
455 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
456 for idx, tens in enumerate(ps.inputs):
457 print(
458 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
459 % (
460 idx,
461 tens.storage_size() / 1024,
462 tens.storage_shape,
463 tens.mem_area.name,
464 tens.purpose.name,
465 tens.format.name,
466 tens.name,
467 )
468 )
469 for idx, tens in enumerate(ps.intermediates):
470 print(
471 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
472 % (
473 idx,
474 tens.storage_size() / 1024,
475 tens.storage_shape,
476 tens.mem_area.name,
477 tens.purpose.name,
478 tens.format.name,
479 tens.name,
480 )
481 )
482 for idx, tens in enumerate(ps.outputs):
483 print(
484 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
485 % (
486 idx,
487 tens.storage_size() / 1024,
488 tens.storage_shape,
489 tens.mem_area.name,
490 tens.purpose.name,
491 tens.format.name,
492 tens.name,
493 )
494 )
495 print()
496
497 def print_high_level_command_stream(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100498 print("print_high_level_command_stream()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100499 for idx, cmd in enumerate(self.high_level_command_stream):
500 print("%3d %s" % (idx, cmd))
501
502
503class Graph:
504 def __init__(self, name="<unnamed>", batch_size=1):
505 self.name = name
506 self.batch_size = batch_size
507 self.subgraphs = []
Michael McGeagh22f74e12020-08-07 16:21:03 +0100508 self.metadata = []
Tim Hall79d07d22020-04-27 18:20:16 +0100509 self.memory_used = {}
510 self.bits_per_element = {}
511 self.total_size = {}
512 self.total_elements = {}
Louis Verhaard3c07c972020-05-07 08:12:58 +0200513 self.weight_cache = None # See CompressedWeightCache
Tim Hall79d07d22020-04-27 18:20:16 +0100514
515 def get_root_subgraph(self):
516 return self.subgraphs[0]
517
518 def prune_startup_init_pass(self):
519 for sg in self.subgraphs:
520 sg.prune_startup_init_pass()
521
522 def update_consumers(self):
523 for sg in self.subgraphs:
524 sg.update_consumers()
525
526 def refresh_after_modification(self):
527 for sg in self.subgraphs:
528 sg.refresh_after_modification()
529
530 def print_operators(self):
531 for sg in self.subgraphs:
532 sg.print_operators()
533
534 def print_graph(self):
535 for sg in self.subgraphs:
536 sg.print_graph()
537
538 def print_graph_with_tensors(self):
539 for sg in self.subgraphs:
540 sg.print_graph_with_tensors()
541
542 def print_graph_with_tensor_quantization(self):
543 for sg in self.subgraphs:
544 sg.print_graph_with_tensor_quantization()
545
546 def print_passes(self):
547 for sg in self.subgraphs:
548 sg.print_passes()
549
550 def print_passes_with_tensors(self):
551 for sg in self.subgraphs:
552 sg.print_passes_with_tensors()
553
554 def print_cascaded_passes(self):
555 for sg in self.subgraphs:
556 sg.print_cascaded_passes()
557
558 def print_cascaded_passes_with_tensors(self):
559 for sg in self.subgraphs:
560 sg.print_cascaded_passes_with_tensors()
561
562 def print_cascaded_passes_with_tensor_sizes(self):
563 for sg in self.subgraphs:
564 sg.print_cascaded_passes_with_tensor_sizes()
565
566 def print_high_level_command_stream(self):
567 for sg in self.subgraphs:
568 sg.print_high_level_command_stream()