blob: 247e6ccec21e14ad04e28ddeb2d9e969e67bee03 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Neural network graph classes and enums.
18# Pass - A packed pass containing one or more Operations.
19# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
20# configurations.
21# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
22# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010023import enum
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26class PassPlacement(enum.Enum):
27 Unknown = 0
28 Cpu = 1
29 Npu = 2
30 MemoryOnly = 3
31 StartupInit = 4
32
33
34class TensorAllocator(enum.Enum):
35 LinearAlloc = 1
36 Greedy = 2
37
38 def __str__(self):
39 return self.name
40
41
42class Pass:
43 def __init__(self, name, placement, is_element_wise, npu_block_type):
44 self.inputs = []
45 self.intermediates = []
46 self.outputs = []
47 self.ops = []
48 self.primary_op = None
49 self.ifm_tensor = None
50 self.ifm2_tensor = None
51 self.ofm_tensor = None
52 self.weight_tensor = None
53 self.scale_tensor = None
54 self.name = name
55 self.cascade = None
56 self.placement = placement
57
58 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
59 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
60 self.is_element_wise = is_element_wise
61 self.npu_block_type = npu_block_type
62 self.block_config = None # will be filled in by scheduler
63 self.shared_buffer = None # will be filled in by scheduler
64
65 self.predecessors = []
66 self.successors = []
67
68 def __str__(self):
69 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
70
71 __repr__ = __str__
72
73 def get_primary_op_ifm_weights(self):
74 if not self.primary_op:
75 return None, None
76 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
77
78 def get_primary_op_ifm_ifm2_weights_ofm(self):
79 if not self.primary_op:
80 return None, None, None, None
81 return self.primary_op.get_ifm_ifm2_weights_ofm()
82
83 def get_primary_op_ifm_weights_biases_ofm(self):
84 if not self.primary_op:
85 return None, None, None, None
86 return self.primary_op.get_ifm_weights_biases_ofm()
87
88
89class SchedulingStrategy(enum.Enum):
90 Unknown = -1
91 IfmStream = 0
92 WeightStream = 1
93
94
95class SchedulerRewrite(enum.Enum):
96 Nop = 0
97 ChangeTensorSubPurpose = 1
98
99
100class CascadedPass:
101 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
102 self.name = name
103 self.strategy = strat
104 self.inputs = inputs
105 self.intermediates = intermediates
106 self.outputs = outputs
107 self.passes = passes
108 self.placement = placement
109 self.is_element_wise = is_element_wise
110
111 self.predecessors = []
112 self.successors = []
113
114 def __str__(self):
115 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
116 self.strategy,
117 len(self.passes),
118 self.name,
119 [ps.name for ps in self.passes],
120 [ps.block_config for ps in self.passes],
121 )
122
123 __repr__ = __str__
124
125
126class Subgraph:
127 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
128 self.output_tensors = []
129 self.input_tensors = []
130 self.original_inputs = [] # Preserve the original input order
131 self.passes = []
132 self.cascaded_passes = []
133 self.name = name
134 self.high_level_command_stream = []
135 self.placement = placement
136 self.command_stream_tensor = None
137 self.flash_tensor = None
138
139 self.memory_used = {}
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200140 self.memory_used_per_type = {}
Tim Hall79d07d22020-04-27 18:20:16 +0100141
142 def __str__(self):
143 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
144 self.name,
145 len(self.passes),
146 len(self.cascaded_passes),
147 )
148
149 __repr__ = __str__
150
151 def update_consumers(self):
152 visit_op_set = set()
153 visit_tensor_set = set()
154 self.input_tensors = []
155
156 print_visit = False
157
158 def visit_op(op):
159 if op in visit_op_set:
160 return
161
162 visit_op_set.add(op)
163 for inp in op.inputs:
164 if print_visit:
165 print(inp, "adding consumer", op)
166 visit_tensor(inp)
167 inp.consumer_list.append(op)
168
169 if op.type in set(("Placeholder", "SubgraphInput")):
170 assert len(op.outputs) == 1
171 self.input_tensors.append(op.outputs[0])
172
173 for out in op.outputs:
174 if out not in visit_tensor_set:
175 out.consumer_list = [] # reset unvisited output, just in case
176
177 def visit_tensor(tens):
178 if tens in visit_tensor_set:
179 return
180 visit_tensor_set.add(tens)
181 tens.consumer_list = []
182 for op in tens.ops:
183 visit_op(op)
184
185 for ps in self.passes:
186 for tens in ps.outputs + ps.inputs:
187 tens.consumer_list = [] # reset unvisited tensors to start with
188
189 for tens in self.output_tensors:
190 visit_tensor(tens)
191 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
192
193 print_visit = True
194 for ps in self.passes:
195 for op in ps.ops:
196 visit_op(op)
197 for tens in ps.inputs:
198 visit_tensor(tens)
199
200 def build_pass_links(self):
201 for idx, ps in enumerate(self.passes):
202 ps.time = 2 * idx
203 ps.predecessors = []
204 ps.successors = []
205
206 for ps in self.passes:
207 for tens in ps.inputs:
208 for op in tens.ops:
209 pred_pass = op.scheduled_pass
210 assert pred_pass.time < ps.time
211 if ps not in pred_pass.successors:
212 pred_pass.successors.append(ps)
213
214 if pred_pass not in ps.predecessors:
215 ps.predecessors.append(pred_pass)
216
217 assert tens in pred_pass.outputs
218
219 def build_pass_dag_predecessors(self):
220 for ps in self.passes:
221 ps.dag_predecessors = []
222
223 class State(enum.Enum):
224 NotVisited = 0
225 BeingVisited = 1
226 Visited = 2
227
228 pass_visit_dict = {}
229
230 def visit_pass(ps):
231 state = pass_visit_dict.get(ps, State.NotVisited)
232 if state == State.Visited:
233 return True
234 elif state == State.BeingVisited:
235 return False # this is a loop, need to remove this link
236 elif state == State.NotVisited:
237 pass_visit_dict[ps] = State.BeingVisited
238
239 ps.dag_predecessors = []
240 for pred in ps.predecessors:
241 if visit_pass(pred):
242 ps.dag_predecessors.append(pred)
243
244 pass_visit_dict[ps] = State.Visited
245 return True
246
247 for ps in self.passes:
248 if not ps.successors:
249 visit_pass(ps)
250
251 def build_cascaded_pass_links(self):
252 for cps in self.cascaded_passes:
253 cps.predecessors = []
254 cps.successors = []
255
256 for cps in self.cascaded_passes:
257 for tens in cps.inputs:
258 for op in tens.ops:
259 pred_cpass = op.scheduled_pass.cascade
260 if cps not in pred_cpass.successors:
261 pred_cpass.successors.append(cps)
262
263 if pred_cpass not in cps.predecessors:
264 cps.predecessors.append(pred_cpass)
265
266 assert tens in pred_cpass.outputs
267
268 def refresh_after_modification(self):
269 self.update_consumers()
270
271 def prune_startup_init_pass(self):
272 assert len(self.passes) >= 1
273 ps = self.passes[0]
274 assert ps.placement == PassPlacement.StartupInit
275
276 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
277 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
278
279 def get_all_ops(self):
280 all_ops = []
281 visit_op_set = set()
282 visit_tensor_set = set()
283
284 def visit_op(op):
285 if op in visit_op_set:
286 return
287 visit_op_set.add(op)
288 for inp in op.inputs:
289 visit_tensor(inp)
290
291 all_ops.append(op)
292
293 def visit_tensor(tens):
294 if tens in visit_tensor_set:
295 return
296 visit_tensor_set.add(tens)
297 for op in tens.ops:
298 visit_op(op)
299
300 for tens in self.output_tensors:
301 visit_tensor(tens)
302
303 return all_ops
304
305 def print_operators(self):
306 all_ops = self.get_all_ops()
307 unique_ops = []
308 print("print_operators")
309 for op in all_ops:
310 if op.type in set(("Const", "Identity", "Placeholder")):
311 continue
312
313 attrs = op.attrs
314 if (
315 op.type == "Conv2D"
316 or op.type == "DepthwiseConv2dNative"
317 or op.type == "Conv2DBiasAct"
318 or op.type == "DepthwiseConv2dBiasAct"
319 ):
320 kshape = op.inputs[1].shape
321 attrs["kshape"] = [kshape[0], kshape[1]]
322 attrs["type"] = op.type
323 attrs.pop("use_cudnn_on_gpu", None)
324 if attrs not in unique_ops:
325 unique_ops.append(attrs)
326 # print attributes in human readable format
327 a = attrs.copy()
328 s = a.pop("type")
329 data_format = a.pop("data_format", None)
330 if data_format and data_format != b"NHWC":
331 s += " " + str(data_format)
332 t = a.pop("T", None)
333 if t:
334 s += " " + str(t)[9:-2]
335 srct = a.pop("SrcT", None)
336 if srct:
337 s += " " + str(srct)[9:-2]
338 dstt = a.pop("DstT", None)
339 if dstt:
340 s += "->" + str(dstt)[9:-2]
341 print(s + " " + str(a))
342
343 def print_graph(self):
344 all_ops = self.get_all_ops()
345 for idx, op in enumerate(all_ops):
346 print(idx, op.type, op.name)
347
348 def print_graph_with_tensors(self):
349 all_ops = self.get_all_ops()
350 for idx, op in enumerate(all_ops):
351 print(idx, op.type, op.name)
352 for idx, tens in enumerate(op.inputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200353 print(
354 " Input %02d %20s %20s %20s %s"
355 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
356 )
Tim Hall79d07d22020-04-27 18:20:16 +0100357 for idx, tens in enumerate(op.outputs):
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200358 print(
359 " Output %02d %20s %20s %20s %s"
360 % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
361 )
Tim Hall79d07d22020-04-27 18:20:16 +0100362 print()
363
364 def print_graph_with_tensor_quantization(self):
365 all_ops = self.get_all_ops()
366 for idx, op in enumerate(all_ops):
367 print(idx, op.type, op.name)
368 for idx, tens in enumerate(op.inputs):
369 q = tens.quantization
370 if q is None:
371 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
372 else:
373 print(
374 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
375 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
376 )
377 for idx, tens in enumerate(op.outputs):
378 q = tens.quantization
379 if q is None:
380 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
381 else:
382 print(
383 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
384 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
385 )
386 print()
387
388 def print_passes(self):
389 for idx, ps in enumerate(self.passes):
390 print("%03d %s" % (idx * 2, ps))
391
392 def print_passes_with_tensors(self):
393 for idx, ps in enumerate(self.passes):
394 print("%3d %s" % (idx * 2, ps))
395 for idx, tens in enumerate(ps.inputs):
396 print(
397 " Input %2d %-15s %-15s %-15s %s"
398 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
399 )
400 for idx, tens in enumerate(ps.intermediates):
401 print(
402 " Intermediate %2d %-15s %-15s %-15s %s"
403 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
404 )
405 for idx, tens in enumerate(ps.outputs):
406 print(
407 " Output %2d %-15s %-15s %-15s %s"
408 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
409 )
410 print()
411
412 def print_cascaded_passes(self):
413 for idx, ps in enumerate(self.cascaded_passes):
414 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
415
416 def print_cascaded_passes_with_tensors(self):
417 for idx, ps in enumerate(self.cascaded_passes):
418 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
419 for idx, tens in enumerate(ps.inputs):
420 print(
421 " Input %2d %-15s %-15s %-15s %s"
422 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
423 )
424 for idx, tens in enumerate(ps.intermediates):
425 print(
426 " Intermediate %2d %-15s %-15s %-15s %s"
427 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
428 )
429 for idx, tens in enumerate(ps.outputs):
430 print(
431 " Output %2d %-15s %-15s %-15s %s"
432 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
433 )
434 print()
435
436 def print_cascaded_passes_with_tensor_sizes(self):
437 for idx, ps in enumerate(self.cascaded_passes):
438 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
439 for idx, tens in enumerate(ps.inputs):
440 print(
441 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
442 % (
443 idx,
444 tens.storage_size() / 1024,
445 tens.storage_shape,
446 tens.mem_area.name,
447 tens.purpose.name,
448 tens.format.name,
449 tens.name,
450 )
451 )
452 for idx, tens in enumerate(ps.intermediates):
453 print(
454 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
455 % (
456 idx,
457 tens.storage_size() / 1024,
458 tens.storage_shape,
459 tens.mem_area.name,
460 tens.purpose.name,
461 tens.format.name,
462 tens.name,
463 )
464 )
465 for idx, tens in enumerate(ps.outputs):
466 print(
467 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
468 % (
469 idx,
470 tens.storage_size() / 1024,
471 tens.storage_shape,
472 tens.mem_area.name,
473 tens.purpose.name,
474 tens.format.name,
475 tens.name,
476 )
477 )
478 print()
479
480 def print_high_level_command_stream(self):
481 for idx, cmd in enumerate(self.high_level_command_stream):
482 print("%3d %s" % (idx, cmd))
483
484
485class Graph:
486 def __init__(self, name="<unnamed>", batch_size=1):
487 self.name = name
488 self.batch_size = batch_size
489 self.subgraphs = []
490
491 self.memory_used = {}
492 self.bits_per_element = {}
493 self.total_size = {}
494 self.total_elements = {}
Louis Verhaard3c07c972020-05-07 08:12:58 +0200495 self.weight_cache = None # See CompressedWeightCache
Tim Hall79d07d22020-04-27 18:20:16 +0100496
497 def get_root_subgraph(self):
498 return self.subgraphs[0]
499
500 def prune_startup_init_pass(self):
501 for sg in self.subgraphs:
502 sg.prune_startup_init_pass()
503
504 def update_consumers(self):
505 for sg in self.subgraphs:
506 sg.update_consumers()
507
508 def refresh_after_modification(self):
509 for sg in self.subgraphs:
510 sg.refresh_after_modification()
511
512 def print_operators(self):
513 for sg in self.subgraphs:
514 sg.print_operators()
515
516 def print_graph(self):
517 for sg in self.subgraphs:
518 sg.print_graph()
519
520 def print_graph_with_tensors(self):
521 for sg in self.subgraphs:
522 sg.print_graph_with_tensors()
523
524 def print_graph_with_tensor_quantization(self):
525 for sg in self.subgraphs:
526 sg.print_graph_with_tensor_quantization()
527
528 def print_passes(self):
529 for sg in self.subgraphs:
530 sg.print_passes()
531
532 def print_passes_with_tensors(self):
533 for sg in self.subgraphs:
534 sg.print_passes_with_tensors()
535
536 def print_cascaded_passes(self):
537 for sg in self.subgraphs:
538 sg.print_cascaded_passes()
539
540 def print_cascaded_passes_with_tensors(self):
541 for sg in self.subgraphs:
542 sg.print_cascaded_passes_with_tensors()
543
544 def print_cascaded_passes_with_tensor_sizes(self):
545 for sg in self.subgraphs:
546 sg.print_cascaded_passes_with_tensor_sizes()
547
548 def print_high_level_command_stream(self):
549 for sg in self.subgraphs:
550 sg.print_high_level_command_stream()