blob: 58aab61197a180397aeed1d9674e86872cde8ed9 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Neural network graph classes and enums.
18# Pass - A packed pass containing one or more Operations.
19# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
20# configurations.
21# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
22# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010023import enum
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26class PassPlacement(enum.Enum):
27 Unknown = 0
28 Cpu = 1
29 Npu = 2
30 MemoryOnly = 3
31 StartupInit = 4
32
33
34class TensorAllocator(enum.Enum):
35 LinearAlloc = 1
36 Greedy = 2
37
38 def __str__(self):
39 return self.name
40
41
42class Pass:
43 def __init__(self, name, placement, is_element_wise, npu_block_type):
44 self.inputs = []
45 self.intermediates = []
46 self.outputs = []
47 self.ops = []
48 self.primary_op = None
49 self.ifm_tensor = None
50 self.ifm2_tensor = None
51 self.ofm_tensor = None
52 self.weight_tensor = None
53 self.scale_tensor = None
Fredrik Svedberga0c36242020-06-03 15:43:31 +020054 self.lut_tensor = None
Tim Hall79d07d22020-04-27 18:20:16 +010055 self.name = name
56 self.cascade = None
57 self.placement = placement
58
59 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
60 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
61 self.is_element_wise = is_element_wise
62 self.npu_block_type = npu_block_type
63 self.block_config = None # will be filled in by scheduler
64 self.shared_buffer = None # will be filled in by scheduler
65
66 self.predecessors = []
67 self.successors = []
68
69 def __str__(self):
70 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
71
72 __repr__ = __str__
73
74 def get_primary_op_ifm_weights(self):
75 if not self.primary_op:
76 return None, None
77 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
78
79 def get_primary_op_ifm_ifm2_weights_ofm(self):
80 if not self.primary_op:
81 return None, None, None, None
82 return self.primary_op.get_ifm_ifm2_weights_ofm()
83
84 def get_primary_op_ifm_weights_biases_ofm(self):
85 if not self.primary_op:
86 return None, None, None, None
87 return self.primary_op.get_ifm_weights_biases_ofm()
88
Fredrik Svedberga0c36242020-06-03 15:43:31 +020089 def get_primary_op_lut(self):
90 if not self.primary_op:
91 return None
92 return self.primary_op.activation_lut
93
Tim Hall79d07d22020-04-27 18:20:16 +010094
95class SchedulingStrategy(enum.Enum):
96 Unknown = -1
97 IfmStream = 0
98 WeightStream = 1
99
100
101class SchedulerRewrite(enum.Enum):
102 Nop = 0
103 ChangeTensorSubPurpose = 1
104
105
106class CascadedPass:
107 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
108 self.name = name
109 self.strategy = strat
110 self.inputs = inputs
111 self.intermediates = intermediates
112 self.outputs = outputs
113 self.passes = passes
114 self.placement = placement
115 self.is_element_wise = is_element_wise
116
117 self.predecessors = []
118 self.successors = []
119
120 def __str__(self):
121 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
122 self.strategy,
123 len(self.passes),
124 self.name,
125 [ps.name for ps in self.passes],
126 [ps.block_config for ps in self.passes],
127 )
128
129 __repr__ = __str__
130
131
132class Subgraph:
133 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
134 self.output_tensors = []
135 self.input_tensors = []
136 self.original_inputs = [] # Preserve the original input order
137 self.passes = []
138 self.cascaded_passes = []
139 self.name = name
140 self.high_level_command_stream = []
141 self.placement = placement
142 self.command_stream_tensor = None
143 self.flash_tensor = None
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200144 # Scratch information locally used in the scheduler
145 self.scheduling_info = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100146
147 self.memory_used = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200148 self.memory_used_per_type = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100149
150 def __str__(self):
151 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
152 self.name,
153 len(self.passes),
154 len(self.cascaded_passes),
155 )
156
157 __repr__ = __str__
158
159 def update_consumers(self):
160 visit_op_set = set()
161 visit_tensor_set = set()
162 self.input_tensors = []
163
164 print_visit = False
165
166 def visit_op(op):
167 if op in visit_op_set:
168 return
169
170 visit_op_set.add(op)
171 for inp in op.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200172 if not inp:
173 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100174 if print_visit:
175 print(inp, "adding consumer", op)
176 visit_tensor(inp)
177 inp.consumer_list.append(op)
178
179 if op.type in set(("Placeholder", "SubgraphInput")):
180 assert len(op.outputs) == 1
181 self.input_tensors.append(op.outputs[0])
182
183 for out in op.outputs:
184 if out not in visit_tensor_set:
185 out.consumer_list = [] # reset unvisited output, just in case
186
187 def visit_tensor(tens):
188 if tens in visit_tensor_set:
189 return
190 visit_tensor_set.add(tens)
191 tens.consumer_list = []
192 for op in tens.ops:
193 visit_op(op)
194
195 for ps in self.passes:
196 for tens in ps.outputs + ps.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200197 if not tens:
198 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100199 tens.consumer_list = [] # reset unvisited tensors to start with
200
201 for tens in self.output_tensors:
202 visit_tensor(tens)
203 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
204
205 print_visit = True
206 for ps in self.passes:
207 for op in ps.ops:
208 visit_op(op)
209 for tens in ps.inputs:
210 visit_tensor(tens)
211
212 def build_pass_links(self):
213 for idx, ps in enumerate(self.passes):
214 ps.time = 2 * idx
215 ps.predecessors = []
216 ps.successors = []
217
218 for ps in self.passes:
219 for tens in ps.inputs:
220 for op in tens.ops:
221 pred_pass = op.scheduled_pass
222 assert pred_pass.time < ps.time
223 if ps not in pred_pass.successors:
224 pred_pass.successors.append(ps)
225
226 if pred_pass not in ps.predecessors:
227 ps.predecessors.append(pred_pass)
228
229 assert tens in pred_pass.outputs
230
231 def build_pass_dag_predecessors(self):
232 for ps in self.passes:
233 ps.dag_predecessors = []
234
235 class State(enum.Enum):
236 NotVisited = 0
237 BeingVisited = 1
238 Visited = 2
239
240 pass_visit_dict = {}
241
242 def visit_pass(ps):
243 state = pass_visit_dict.get(ps, State.NotVisited)
244 if state == State.Visited:
245 return True
246 elif state == State.BeingVisited:
247 return False # this is a loop, need to remove this link
248 elif state == State.NotVisited:
249 pass_visit_dict[ps] = State.BeingVisited
250
251 ps.dag_predecessors = []
252 for pred in ps.predecessors:
253 if visit_pass(pred):
254 ps.dag_predecessors.append(pred)
255
256 pass_visit_dict[ps] = State.Visited
257 return True
258
259 for ps in self.passes:
260 if not ps.successors:
261 visit_pass(ps)
262
263 def build_cascaded_pass_links(self):
264 for cps in self.cascaded_passes:
265 cps.predecessors = []
266 cps.successors = []
267
268 for cps in self.cascaded_passes:
269 for tens in cps.inputs:
270 for op in tens.ops:
271 pred_cpass = op.scheduled_pass.cascade
272 if cps not in pred_cpass.successors:
273 pred_cpass.successors.append(cps)
274
275 if pred_cpass not in cps.predecessors:
276 cps.predecessors.append(pred_cpass)
277
278 assert tens in pred_cpass.outputs
279
280 def refresh_after_modification(self):
281 self.update_consumers()
282
283 def prune_startup_init_pass(self):
284 assert len(self.passes) >= 1
285 ps = self.passes[0]
286 assert ps.placement == PassPlacement.StartupInit
287
288 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
289 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
290
291 def get_all_ops(self):
292 all_ops = []
293 visit_op_set = set()
294 visit_tensor_set = set()
295
296 def visit_op(op):
297 if op in visit_op_set:
298 return
299 visit_op_set.add(op)
300 for inp in op.inputs:
301 visit_tensor(inp)
302
303 all_ops.append(op)
304
305 def visit_tensor(tens):
Andreas Nevalainene1cc3de2020-09-08 15:31:02 +0200306 if tens is None or tens in visit_tensor_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100307 return
308 visit_tensor_set.add(tens)
309 for op in tens.ops:
310 visit_op(op)
311
312 for tens in self.output_tensors:
313 visit_tensor(tens)
314
315 return all_ops
316
317 def print_operators(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100318 print("print_operators()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100319 all_ops = self.get_all_ops()
320 unique_ops = []
Tim Hall79d07d22020-04-27 18:20:16 +0100321 for op in all_ops:
322 if op.type in set(("Const", "Identity", "Placeholder")):
323 continue
324
325 attrs = op.attrs
326 if (
327 op.type == "Conv2D"
328 or op.type == "DepthwiseConv2dNative"
329 or op.type == "Conv2DBiasAct"
330 or op.type == "DepthwiseConv2dBiasAct"
331 ):
332 kshape = op.inputs[1].shape
333 attrs["kshape"] = [kshape[0], kshape[1]]
334 attrs["type"] = op.type
335 attrs.pop("use_cudnn_on_gpu", None)
336 if attrs not in unique_ops:
337 unique_ops.append(attrs)
338 # print attributes in human readable format
339 a = attrs.copy()
340 s = a.pop("type")
341 data_format = a.pop("data_format", None)
342 if data_format and data_format != b"NHWC":
343 s += " " + str(data_format)
344 t = a.pop("T", None)
345 if t:
346 s += " " + str(t)[9:-2]
347 srct = a.pop("SrcT", None)
348 if srct:
349 s += " " + str(srct)[9:-2]
350 dstt = a.pop("DstT", None)
351 if dstt:
352 s += "->" + str(dstt)[9:-2]
353 print(s + " " + str(a))
354
355 def print_graph(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100356 print("print_graph()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100357 all_ops = self.get_all_ops()
358 for idx, op in enumerate(all_ops):
359 print(idx, op.type, op.name)
360
361 def print_graph_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100362 print("print_graph_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100363 all_ops = self.get_all_ops()
364 for idx, op in enumerate(all_ops):
365 print(idx, op.type, op.name)
366 for idx, tens in enumerate(op.inputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200367 print(
368 " Input %02d %20s %20s %20s %s"
369 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
370 )
Tim Hall79d07d22020-04-27 18:20:16 +0100371 for idx, tens in enumerate(op.outputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200372 print(
373 " Output %02d %20s %20s %20s %s"
374 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
375 )
Tim Hall79d07d22020-04-27 18:20:16 +0100376 print()
377
378 def print_graph_with_tensor_quantization(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100379 print("print_graph_with_tensor_quantization()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100380 all_ops = self.get_all_ops()
381 for idx, op in enumerate(all_ops):
382 print(idx, op.type, op.name)
383 for idx, tens in enumerate(op.inputs):
384 q = tens.quantization
385 if q is None:
386 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
387 else:
388 print(
389 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
390 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
391 )
392 for idx, tens in enumerate(op.outputs):
393 q = tens.quantization
394 if q is None:
395 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
396 else:
397 print(
398 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
399 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
400 )
401 print()
402
403 def print_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100404 print("print_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100405 for idx, ps in enumerate(self.passes):
406 print("%03d %s" % (idx * 2, ps))
407
408 def print_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100409 print("print_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100410 for idx, ps in enumerate(self.passes):
411 print("%3d %s" % (idx * 2, ps))
412 for idx, tens in enumerate(ps.inputs):
413 print(
414 " Input %2d %-15s %-15s %-15s %s"
415 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
416 )
417 for idx, tens in enumerate(ps.intermediates):
418 print(
419 " Intermediate %2d %-15s %-15s %-15s %s"
420 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
421 )
422 for idx, tens in enumerate(ps.outputs):
423 print(
424 " Output %2d %-15s %-15s %-15s %s"
425 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
426 )
427 print()
428
429 def print_cascaded_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100430 print("print_cascaded_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100431 for idx, ps in enumerate(self.cascaded_passes):
432 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
433
434 def print_cascaded_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100435 print("print_cascaded_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100436 for idx, ps in enumerate(self.cascaded_passes):
437 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
438 for idx, tens in enumerate(ps.inputs):
439 print(
440 " Input %2d %-15s %-15s %-15s %s"
441 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
442 )
443 for idx, tens in enumerate(ps.intermediates):
444 print(
445 " Intermediate %2d %-15s %-15s %-15s %s"
446 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
447 )
448 for idx, tens in enumerate(ps.outputs):
449 print(
450 " Output %2d %-15s %-15s %-15s %s"
451 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
452 )
453 print()
454
455 def print_cascaded_passes_with_tensor_sizes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100456 print("print_cascaded_passes_with_tensor_sizes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100457 for idx, ps in enumerate(self.cascaded_passes):
458 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
459 for idx, tens in enumerate(ps.inputs):
460 print(
461 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
462 % (
463 idx,
464 tens.storage_size() / 1024,
465 tens.storage_shape,
466 tens.mem_area.name,
467 tens.purpose.name,
468 tens.format.name,
469 tens.name,
470 )
471 )
472 for idx, tens in enumerate(ps.intermediates):
473 print(
474 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
475 % (
476 idx,
477 tens.storage_size() / 1024,
478 tens.storage_shape,
479 tens.mem_area.name,
480 tens.purpose.name,
481 tens.format.name,
482 tens.name,
483 )
484 )
485 for idx, tens in enumerate(ps.outputs):
486 print(
487 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
488 % (
489 idx,
490 tens.storage_size() / 1024,
491 tens.storage_shape,
492 tens.mem_area.name,
493 tens.purpose.name,
494 tens.format.name,
495 tens.name,
496 )
497 )
498 print()
499
500 def print_high_level_command_stream(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100501 print("print_high_level_command_stream()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100502 for idx, cmd in enumerate(self.high_level_command_stream):
503 print("%3d %s" % (idx, cmd))
504
505
506class Graph:
507 def __init__(self, name="<unnamed>", batch_size=1):
508 self.name = name
509 self.batch_size = batch_size
510 self.subgraphs = []
Michael McGeagh22f74e12020-08-07 16:21:03 +0100511 self.metadata = []
Tim Hall79d07d22020-04-27 18:20:16 +0100512 self.memory_used = {}
513 self.bits_per_element = {}
514 self.total_size = {}
515 self.total_elements = {}
Louis Verhaard3c07c972020-05-07 08:12:58 +0200516 self.weight_cache = None # See CompressedWeightCache
Tim Hall79d07d22020-04-27 18:20:16 +0100517
518 def get_root_subgraph(self):
519 return self.subgraphs[0]
520
521 def prune_startup_init_pass(self):
522 for sg in self.subgraphs:
523 sg.prune_startup_init_pass()
524
525 def update_consumers(self):
526 for sg in self.subgraphs:
527 sg.update_consumers()
528
529 def refresh_after_modification(self):
530 for sg in self.subgraphs:
531 sg.refresh_after_modification()
532
533 def print_operators(self):
534 for sg in self.subgraphs:
535 sg.print_operators()
536
537 def print_graph(self):
538 for sg in self.subgraphs:
539 sg.print_graph()
540
541 def print_graph_with_tensors(self):
542 for sg in self.subgraphs:
543 sg.print_graph_with_tensors()
544
545 def print_graph_with_tensor_quantization(self):
546 for sg in self.subgraphs:
547 sg.print_graph_with_tensor_quantization()
548
549 def print_passes(self):
550 for sg in self.subgraphs:
551 sg.print_passes()
552
553 def print_passes_with_tensors(self):
554 for sg in self.subgraphs:
555 sg.print_passes_with_tensors()
556
557 def print_cascaded_passes(self):
558 for sg in self.subgraphs:
559 sg.print_cascaded_passes()
560
561 def print_cascaded_passes_with_tensors(self):
562 for sg in self.subgraphs:
563 sg.print_cascaded_passes_with_tensors()
564
565 def print_cascaded_passes_with_tensor_sizes(self):
566 for sg in self.subgraphs:
567 sg.print_cascaded_passes_with_tensor_sizes()
568
569 def print_high_level_command_stream(self):
570 for sg in self.subgraphs:
571 sg.print_high_level_command_stream()