blob: 1024307e7bc035075b23a1eb9c7bfa642d7c35af [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Neural network graph classes and enums.
18# Pass - A packed pass containing one or more Operations.
19# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
20# configurations.
21# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
22# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010023import enum
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26class PassPlacement(enum.Enum):
27 Unknown = 0
28 Cpu = 1
29 Npu = 2
30 MemoryOnly = 3
31 StartupInit = 4
32
33
34class TensorAllocator(enum.Enum):
35 LinearAlloc = 1
36 Greedy = 2
37
38 def __str__(self):
39 return self.name
40
41
42class Pass:
43 def __init__(self, name, placement, is_element_wise, npu_block_type):
44 self.inputs = []
45 self.intermediates = []
46 self.outputs = []
47 self.ops = []
48 self.primary_op = None
49 self.ifm_tensor = None
50 self.ifm2_tensor = None
51 self.ofm_tensor = None
52 self.weight_tensor = None
53 self.scale_tensor = None
54 self.name = name
55 self.cascade = None
56 self.placement = placement
57
58 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
59 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
60 self.is_element_wise = is_element_wise
61 self.npu_block_type = npu_block_type
62 self.block_config = None # will be filled in by scheduler
63 self.shared_buffer = None # will be filled in by scheduler
64
65 self.predecessors = []
66 self.successors = []
67
68 def __str__(self):
69 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
70
71 __repr__ = __str__
72
73 def get_primary_op_ifm_weights(self):
74 if not self.primary_op:
75 return None, None
76 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
77
78 def get_primary_op_ifm_ifm2_weights_ofm(self):
79 if not self.primary_op:
80 return None, None, None, None
81 return self.primary_op.get_ifm_ifm2_weights_ofm()
82
83 def get_primary_op_ifm_weights_biases_ofm(self):
84 if not self.primary_op:
85 return None, None, None, None
86 return self.primary_op.get_ifm_weights_biases_ofm()
87
88
89class SchedulingStrategy(enum.Enum):
90 Unknown = -1
91 IfmStream = 0
92 WeightStream = 1
93
94
95class SchedulerRewrite(enum.Enum):
96 Nop = 0
97 ChangeTensorSubPurpose = 1
98
99
100class CascadedPass:
101 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
102 self.name = name
103 self.strategy = strat
104 self.inputs = inputs
105 self.intermediates = intermediates
106 self.outputs = outputs
107 self.passes = passes
108 self.placement = placement
109 self.is_element_wise = is_element_wise
110
111 self.predecessors = []
112 self.successors = []
113
114 def __str__(self):
115 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
116 self.strategy,
117 len(self.passes),
118 self.name,
119 [ps.name for ps in self.passes],
120 [ps.block_config for ps in self.passes],
121 )
122
123 __repr__ = __str__
124
125
126class Subgraph:
127 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
128 self.output_tensors = []
129 self.input_tensors = []
130 self.original_inputs = [] # Preserve the original input order
131 self.passes = []
132 self.cascaded_passes = []
133 self.name = name
134 self.high_level_command_stream = []
135 self.placement = placement
136 self.command_stream_tensor = None
137 self.flash_tensor = None
138
139 self.memory_used = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200140 self.memory_used_per_type = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100141
142 def __str__(self):
143 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
144 self.name,
145 len(self.passes),
146 len(self.cascaded_passes),
147 )
148
149 __repr__ = __str__
150
151 def update_consumers(self):
152 visit_op_set = set()
153 visit_tensor_set = set()
154 self.input_tensors = []
155
156 print_visit = False
157
158 def visit_op(op):
159 if op in visit_op_set:
160 return
161
162 visit_op_set.add(op)
163 for inp in op.inputs:
164 if print_visit:
165 print(inp, "adding consumer", op)
166 visit_tensor(inp)
167 inp.consumer_list.append(op)
168
169 if op.type in set(("Placeholder", "SubgraphInput")):
170 assert len(op.outputs) == 1
171 self.input_tensors.append(op.outputs[0])
172
173 for out in op.outputs:
174 if out not in visit_tensor_set:
175 out.consumer_list = [] # reset unvisited output, just in case
176
177 def visit_tensor(tens):
178 if tens in visit_tensor_set:
179 return
180 visit_tensor_set.add(tens)
181 tens.consumer_list = []
182 for op in tens.ops:
183 visit_op(op)
184
185 for ps in self.passes:
186 for tens in ps.outputs + ps.inputs:
187 tens.consumer_list = [] # reset unvisited tensors to start with
188
189 for tens in self.output_tensors:
190 visit_tensor(tens)
191 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
192
193 print_visit = True
194 for ps in self.passes:
195 for op in ps.ops:
196 visit_op(op)
197 for tens in ps.inputs:
198 visit_tensor(tens)
199
200 def build_pass_links(self):
201 for idx, ps in enumerate(self.passes):
202 ps.time = 2 * idx
203 ps.predecessors = []
204 ps.successors = []
205
206 for ps in self.passes:
207 for tens in ps.inputs:
208 for op in tens.ops:
209 pred_pass = op.scheduled_pass
210 assert pred_pass.time < ps.time
211 if ps not in pred_pass.successors:
212 pred_pass.successors.append(ps)
213
214 if pred_pass not in ps.predecessors:
215 ps.predecessors.append(pred_pass)
216
217 assert tens in pred_pass.outputs
218
219 def build_pass_dag_predecessors(self):
220 for ps in self.passes:
221 ps.dag_predecessors = []
222
223 class State(enum.Enum):
224 NotVisited = 0
225 BeingVisited = 1
226 Visited = 2
227
228 pass_visit_dict = {}
229
230 def visit_pass(ps):
231 state = pass_visit_dict.get(ps, State.NotVisited)
232 if state == State.Visited:
233 return True
234 elif state == State.BeingVisited:
235 return False # this is a loop, need to remove this link
236 elif state == State.NotVisited:
237 pass_visit_dict[ps] = State.BeingVisited
238
239 ps.dag_predecessors = []
240 for pred in ps.predecessors:
241 if visit_pass(pred):
242 ps.dag_predecessors.append(pred)
243
244 pass_visit_dict[ps] = State.Visited
245 return True
246
247 for ps in self.passes:
248 if not ps.successors:
249 visit_pass(ps)
250
251 def build_cascaded_pass_links(self):
252 for cps in self.cascaded_passes:
253 cps.predecessors = []
254 cps.successors = []
255
256 for cps in self.cascaded_passes:
257 for tens in cps.inputs:
258 for op in tens.ops:
259 pred_cpass = op.scheduled_pass.cascade
260 if cps not in pred_cpass.successors:
261 pred_cpass.successors.append(cps)
262
263 if pred_cpass not in cps.predecessors:
264 cps.predecessors.append(pred_cpass)
265
266 assert tens in pred_cpass.outputs
267
268 def refresh_after_modification(self):
269 self.update_consumers()
270
271 def prune_startup_init_pass(self):
272 assert len(self.passes) >= 1
273 ps = self.passes[0]
274 assert ps.placement == PassPlacement.StartupInit
275
276 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
277 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
278
279 def get_all_ops(self):
280 all_ops = []
281 visit_op_set = set()
282 visit_tensor_set = set()
283
284 def visit_op(op):
285 if op in visit_op_set:
286 return
287 visit_op_set.add(op)
288 for inp in op.inputs:
289 visit_tensor(inp)
290
291 all_ops.append(op)
292
293 def visit_tensor(tens):
294 if tens in visit_tensor_set:
295 return
296 visit_tensor_set.add(tens)
297 for op in tens.ops:
298 visit_op(op)
299
300 for tens in self.output_tensors:
301 visit_tensor(tens)
302
303 return all_ops
304
305 def print_operators(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100306 print("print_operators()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100307 all_ops = self.get_all_ops()
308 unique_ops = []
Tim Hall79d07d22020-04-27 18:20:16 +0100309 for op in all_ops:
310 if op.type in set(("Const", "Identity", "Placeholder")):
311 continue
312
313 attrs = op.attrs
314 if (
315 op.type == "Conv2D"
316 or op.type == "DepthwiseConv2dNative"
317 or op.type == "Conv2DBiasAct"
318 or op.type == "DepthwiseConv2dBiasAct"
319 ):
320 kshape = op.inputs[1].shape
321 attrs["kshape"] = [kshape[0], kshape[1]]
322 attrs["type"] = op.type
323 attrs.pop("use_cudnn_on_gpu", None)
324 if attrs not in unique_ops:
325 unique_ops.append(attrs)
326 # print attributes in human readable format
327 a = attrs.copy()
328 s = a.pop("type")
329 data_format = a.pop("data_format", None)
330 if data_format and data_format != b"NHWC":
331 s += " " + str(data_format)
332 t = a.pop("T", None)
333 if t:
334 s += " " + str(t)[9:-2]
335 srct = a.pop("SrcT", None)
336 if srct:
337 s += " " + str(srct)[9:-2]
338 dstt = a.pop("DstT", None)
339 if dstt:
340 s += "->" + str(dstt)[9:-2]
341 print(s + " " + str(a))
342
343 def print_graph(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100344 print("print_graph()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100345 all_ops = self.get_all_ops()
346 for idx, op in enumerate(all_ops):
347 print(idx, op.type, op.name)
348
349 def print_graph_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100350 print("print_graph_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100351 all_ops = self.get_all_ops()
352 for idx, op in enumerate(all_ops):
353 print(idx, op.type, op.name)
354 for idx, tens in enumerate(op.inputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200355 print(
356 " Input %02d %20s %20s %20s %s"
357 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
358 )
Tim Hall79d07d22020-04-27 18:20:16 +0100359 for idx, tens in enumerate(op.outputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200360 print(
361 " Output %02d %20s %20s %20s %s"
362 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
363 )
Tim Hall79d07d22020-04-27 18:20:16 +0100364 print()
365
366 def print_graph_with_tensor_quantization(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100367 print("print_graph_with_tensor_quantization()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100368 all_ops = self.get_all_ops()
369 for idx, op in enumerate(all_ops):
370 print(idx, op.type, op.name)
371 for idx, tens in enumerate(op.inputs):
372 q = tens.quantization
373 if q is None:
374 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
375 else:
376 print(
377 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
378 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
379 )
380 for idx, tens in enumerate(op.outputs):
381 q = tens.quantization
382 if q is None:
383 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
384 else:
385 print(
386 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
387 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
388 )
389 print()
390
391 def print_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100392 print("print_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100393 for idx, ps in enumerate(self.passes):
394 print("%03d %s" % (idx * 2, ps))
395
396 def print_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100397 print("print_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100398 for idx, ps in enumerate(self.passes):
399 print("%3d %s" % (idx * 2, ps))
400 for idx, tens in enumerate(ps.inputs):
401 print(
402 " Input %2d %-15s %-15s %-15s %s"
403 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
404 )
405 for idx, tens in enumerate(ps.intermediates):
406 print(
407 " Intermediate %2d %-15s %-15s %-15s %s"
408 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
409 )
410 for idx, tens in enumerate(ps.outputs):
411 print(
412 " Output %2d %-15s %-15s %-15s %s"
413 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
414 )
415 print()
416
417 def print_cascaded_passes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100418 print("print_cascaded_passes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100419 for idx, ps in enumerate(self.cascaded_passes):
420 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
421
422 def print_cascaded_passes_with_tensors(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100423 print("print_cascaded_passes_with_tensors()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100424 for idx, ps in enumerate(self.cascaded_passes):
425 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
426 for idx, tens in enumerate(ps.inputs):
427 print(
428 " Input %2d %-15s %-15s %-15s %s"
429 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
430 )
431 for idx, tens in enumerate(ps.intermediates):
432 print(
433 " Intermediate %2d %-15s %-15s %-15s %s"
434 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
435 )
436 for idx, tens in enumerate(ps.outputs):
437 print(
438 " Output %2d %-15s %-15s %-15s %s"
439 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
440 )
441 print()
442
443 def print_cascaded_passes_with_tensor_sizes(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100444 print("print_cascaded_passes_with_tensor_sizes()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100445 for idx, ps in enumerate(self.cascaded_passes):
446 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
447 for idx, tens in enumerate(ps.inputs):
448 print(
449 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
450 % (
451 idx,
452 tens.storage_size() / 1024,
453 tens.storage_shape,
454 tens.mem_area.name,
455 tens.purpose.name,
456 tens.format.name,
457 tens.name,
458 )
459 )
460 for idx, tens in enumerate(ps.intermediates):
461 print(
462 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
463 % (
464 idx,
465 tens.storage_size() / 1024,
466 tens.storage_shape,
467 tens.mem_area.name,
468 tens.purpose.name,
469 tens.format.name,
470 tens.name,
471 )
472 )
473 for idx, tens in enumerate(ps.outputs):
474 print(
475 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
476 % (
477 idx,
478 tens.storage_size() / 1024,
479 tens.storage_shape,
480 tens.mem_area.name,
481 tens.purpose.name,
482 tens.format.name,
483 tens.name,
484 )
485 )
486 print()
487
488 def print_high_level_command_stream(self):
Michael McGeagh775e3962020-07-28 11:44:22 +0100489 print("print_high_level_command_stream()", self.name)
Tim Hall79d07d22020-04-27 18:20:16 +0100490 for idx, cmd in enumerate(self.high_level_command_stream):
491 print("%3d %s" % (idx, cmd))
492
493
494class Graph:
495 def __init__(self, name="<unnamed>", batch_size=1):
496 self.name = name
497 self.batch_size = batch_size
498 self.subgraphs = []
499
500 self.memory_used = {}
501 self.bits_per_element = {}
502 self.total_size = {}
503 self.total_elements = {}
Louis Verhaard3c07c972020-05-07 08:12:58 +0200504 self.weight_cache = None # See CompressedWeightCache
Tim Hall79d07d22020-04-27 18:20:16 +0100505
506 def get_root_subgraph(self):
507 return self.subgraphs[0]
508
509 def prune_startup_init_pass(self):
510 for sg in self.subgraphs:
511 sg.prune_startup_init_pass()
512
513 def update_consumers(self):
514 for sg in self.subgraphs:
515 sg.update_consumers()
516
517 def refresh_after_modification(self):
518 for sg in self.subgraphs:
519 sg.refresh_after_modification()
520
521 def print_operators(self):
522 for sg in self.subgraphs:
523 sg.print_operators()
524
525 def print_graph(self):
526 for sg in self.subgraphs:
527 sg.print_graph()
528
529 def print_graph_with_tensors(self):
530 for sg in self.subgraphs:
531 sg.print_graph_with_tensors()
532
533 def print_graph_with_tensor_quantization(self):
534 for sg in self.subgraphs:
535 sg.print_graph_with_tensor_quantization()
536
537 def print_passes(self):
538 for sg in self.subgraphs:
539 sg.print_passes()
540
541 def print_passes_with_tensors(self):
542 for sg in self.subgraphs:
543 sg.print_passes_with_tensors()
544
545 def print_cascaded_passes(self):
546 for sg in self.subgraphs:
547 sg.print_cascaded_passes()
548
549 def print_cascaded_passes_with_tensors(self):
550 for sg in self.subgraphs:
551 sg.print_cascaded_passes_with_tensors()
552
553 def print_cascaded_passes_with_tensor_sizes(self):
554 for sg in self.subgraphs:
555 sg.print_cascaded_passes_with_tensor_sizes()
556
557 def print_high_level_command_stream(self):
558 for sg in self.subgraphs:
559 sg.print_high_level_command_stream()