blob: ea35c08772eb3b565492ff8dc476716a5c549e12 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Neural network graph classes and enums.
18# Pass - A packed pass containing one or more Operations.
19# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
20# configurations.
21# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
22# Graph - A full neural network graph with one or more Subgraphs.
Tim Hall79d07d22020-04-27 18:20:16 +010023import enum
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26class PassPlacement(enum.Enum):
27 Unknown = 0
28 Cpu = 1
29 Npu = 2
30 MemoryOnly = 3
31 StartupInit = 4
32
33
34class TensorAllocator(enum.Enum):
35 LinearAlloc = 1
36 Greedy = 2
37
38 def __str__(self):
39 return self.name
40
41
42class Pass:
43 def __init__(self, name, placement, is_element_wise, npu_block_type):
44 self.inputs = []
45 self.intermediates = []
46 self.outputs = []
47 self.ops = []
48 self.primary_op = None
49 self.ifm_tensor = None
50 self.ifm2_tensor = None
51 self.ofm_tensor = None
52 self.weight_tensor = None
53 self.scale_tensor = None
54 self.name = name
55 self.cascade = None
56 self.placement = placement
57
58 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
59 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
60 self.is_element_wise = is_element_wise
61 self.npu_block_type = npu_block_type
62 self.block_config = None # will be filled in by scheduler
63 self.shared_buffer = None # will be filled in by scheduler
64
65 self.predecessors = []
66 self.successors = []
67
68 def __str__(self):
69 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
70
71 __repr__ = __str__
72
73 def get_primary_op_ifm_weights(self):
74 if not self.primary_op:
75 return None, None
76 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
77
78 def get_primary_op_ifm_ifm2_weights_ofm(self):
79 if not self.primary_op:
80 return None, None, None, None
81 return self.primary_op.get_ifm_ifm2_weights_ofm()
82
83 def get_primary_op_ifm_weights_biases_ofm(self):
84 if not self.primary_op:
85 return None, None, None, None
86 return self.primary_op.get_ifm_weights_biases_ofm()
87
88
89class SchedulingStrategy(enum.Enum):
90 Unknown = -1
91 IfmStream = 0
92 WeightStream = 1
93
94
95class SchedulerRewrite(enum.Enum):
96 Nop = 0
97 ChangeTensorSubPurpose = 1
98
99
100class CascadedPass:
101 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
102 self.name = name
103 self.strategy = strat
104 self.inputs = inputs
105 self.intermediates = intermediates
106 self.outputs = outputs
107 self.passes = passes
108 self.placement = placement
109 self.is_element_wise = is_element_wise
110
111 self.predecessors = []
112 self.successors = []
113
114 def __str__(self):
115 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
116 self.strategy,
117 len(self.passes),
118 self.name,
119 [ps.name for ps in self.passes],
120 [ps.block_config for ps in self.passes],
121 )
122
123 __repr__ = __str__
124
125
126class Subgraph:
127 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
128 self.output_tensors = []
129 self.input_tensors = []
130 self.original_inputs = [] # Preserve the original input order
131 self.passes = []
132 self.cascaded_passes = []
133 self.name = name
134 self.high_level_command_stream = []
135 self.placement = placement
136 self.command_stream_tensor = None
137 self.flash_tensor = None
138
139 self.memory_used = {}
140
141 def __str__(self):
142 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
143 self.name,
144 len(self.passes),
145 len(self.cascaded_passes),
146 )
147
148 __repr__ = __str__
149
150 def update_consumers(self):
151 visit_op_set = set()
152 visit_tensor_set = set()
153 self.input_tensors = []
154
155 print_visit = False
156
157 def visit_op(op):
158 if op in visit_op_set:
159 return
160
161 visit_op_set.add(op)
162 for inp in op.inputs:
163 if print_visit:
164 print(inp, "adding consumer", op)
165 visit_tensor(inp)
166 inp.consumer_list.append(op)
167
168 if op.type in set(("Placeholder", "SubgraphInput")):
169 assert len(op.outputs) == 1
170 self.input_tensors.append(op.outputs[0])
171
172 for out in op.outputs:
173 if out not in visit_tensor_set:
174 out.consumer_list = [] # reset unvisited output, just in case
175
176 def visit_tensor(tens):
177 if tens in visit_tensor_set:
178 return
179 visit_tensor_set.add(tens)
180 tens.consumer_list = []
181 for op in tens.ops:
182 visit_op(op)
183
184 for ps in self.passes:
185 for tens in ps.outputs + ps.inputs:
186 tens.consumer_list = [] # reset unvisited tensors to start with
187
188 for tens in self.output_tensors:
189 visit_tensor(tens)
190 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
191
192 print_visit = True
193 for ps in self.passes:
194 for op in ps.ops:
195 visit_op(op)
196 for tens in ps.inputs:
197 visit_tensor(tens)
198
199 def build_pass_links(self):
200 for idx, ps in enumerate(self.passes):
201 ps.time = 2 * idx
202 ps.predecessors = []
203 ps.successors = []
204
205 for ps in self.passes:
206 for tens in ps.inputs:
207 for op in tens.ops:
208 pred_pass = op.scheduled_pass
209 assert pred_pass.time < ps.time
210 if ps not in pred_pass.successors:
211 pred_pass.successors.append(ps)
212
213 if pred_pass not in ps.predecessors:
214 ps.predecessors.append(pred_pass)
215
216 assert tens in pred_pass.outputs
217
218 def build_pass_dag_predecessors(self):
219 for ps in self.passes:
220 ps.dag_predecessors = []
221
222 class State(enum.Enum):
223 NotVisited = 0
224 BeingVisited = 1
225 Visited = 2
226
227 pass_visit_dict = {}
228
229 def visit_pass(ps):
230 state = pass_visit_dict.get(ps, State.NotVisited)
231 if state == State.Visited:
232 return True
233 elif state == State.BeingVisited:
234 return False # this is a loop, need to remove this link
235 elif state == State.NotVisited:
236 pass_visit_dict[ps] = State.BeingVisited
237
238 ps.dag_predecessors = []
239 for pred in ps.predecessors:
240 if visit_pass(pred):
241 ps.dag_predecessors.append(pred)
242
243 pass_visit_dict[ps] = State.Visited
244 return True
245
246 for ps in self.passes:
247 if not ps.successors:
248 visit_pass(ps)
249
250 def build_cascaded_pass_links(self):
251 for cps in self.cascaded_passes:
252 cps.predecessors = []
253 cps.successors = []
254
255 for cps in self.cascaded_passes:
256 for tens in cps.inputs:
257 for op in tens.ops:
258 pred_cpass = op.scheduled_pass.cascade
259 if cps not in pred_cpass.successors:
260 pred_cpass.successors.append(cps)
261
262 if pred_cpass not in cps.predecessors:
263 cps.predecessors.append(pred_cpass)
264
265 assert tens in pred_cpass.outputs
266
267 def refresh_after_modification(self):
268 self.update_consumers()
269
270 def prune_startup_init_pass(self):
271 assert len(self.passes) >= 1
272 ps = self.passes[0]
273 assert ps.placement == PassPlacement.StartupInit
274
275 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
276 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
277
278 def get_all_ops(self):
279 all_ops = []
280 visit_op_set = set()
281 visit_tensor_set = set()
282
283 def visit_op(op):
284 if op in visit_op_set:
285 return
286 visit_op_set.add(op)
287 for inp in op.inputs:
288 visit_tensor(inp)
289
290 all_ops.append(op)
291
292 def visit_tensor(tens):
293 if tens in visit_tensor_set:
294 return
295 visit_tensor_set.add(tens)
296 for op in tens.ops:
297 visit_op(op)
298
299 for tens in self.output_tensors:
300 visit_tensor(tens)
301
302 return all_ops
303
304 def print_operators(self):
305 all_ops = self.get_all_ops()
306 unique_ops = []
307 print("print_operators")
308 for op in all_ops:
309 if op.type in set(("Const", "Identity", "Placeholder")):
310 continue
311
312 attrs = op.attrs
313 if (
314 op.type == "Conv2D"
315 or op.type == "DepthwiseConv2dNative"
316 or op.type == "Conv2DBiasAct"
317 or op.type == "DepthwiseConv2dBiasAct"
318 ):
319 kshape = op.inputs[1].shape
320 attrs["kshape"] = [kshape[0], kshape[1]]
321 attrs["type"] = op.type
322 attrs.pop("use_cudnn_on_gpu", None)
323 if attrs not in unique_ops:
324 unique_ops.append(attrs)
325 # print attributes in human readable format
326 a = attrs.copy()
327 s = a.pop("type")
328 data_format = a.pop("data_format", None)
329 if data_format and data_format != b"NHWC":
330 s += " " + str(data_format)
331 t = a.pop("T", None)
332 if t:
333 s += " " + str(t)[9:-2]
334 srct = a.pop("SrcT", None)
335 if srct:
336 s += " " + str(srct)[9:-2]
337 dstt = a.pop("DstT", None)
338 if dstt:
339 s += "->" + str(dstt)[9:-2]
340 print(s + " " + str(a))
341
342 def print_graph(self):
343 all_ops = self.get_all_ops()
344 for idx, op in enumerate(all_ops):
345 print(idx, op.type, op.name)
346
347 def print_graph_with_tensors(self):
348 all_ops = self.get_all_ops()
349 for idx, op in enumerate(all_ops):
350 print(idx, op.type, op.name)
351 for idx, tens in enumerate(op.inputs):
352 print(" Input %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
353 for idx, tens in enumerate(op.outputs):
354 print(" Output %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
355 print()
356
357 def print_graph_with_tensor_quantization(self):
358 all_ops = self.get_all_ops()
359 for idx, op in enumerate(all_ops):
360 print(idx, op.type, op.name)
361 for idx, tens in enumerate(op.inputs):
362 q = tens.quantization
363 if q is None:
364 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
365 else:
366 print(
367 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
368 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
369 )
370 for idx, tens in enumerate(op.outputs):
371 q = tens.quantization
372 if q is None:
373 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
374 else:
375 print(
376 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
377 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
378 )
379 print()
380
381 def print_passes(self):
382 for idx, ps in enumerate(self.passes):
383 print("%03d %s" % (idx * 2, ps))
384
385 def print_passes_with_tensors(self):
386 for idx, ps in enumerate(self.passes):
387 print("%3d %s" % (idx * 2, ps))
388 for idx, tens in enumerate(ps.inputs):
389 print(
390 " Input %2d %-15s %-15s %-15s %s"
391 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
392 )
393 for idx, tens in enumerate(ps.intermediates):
394 print(
395 " Intermediate %2d %-15s %-15s %-15s %s"
396 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
397 )
398 for idx, tens in enumerate(ps.outputs):
399 print(
400 " Output %2d %-15s %-15s %-15s %s"
401 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
402 )
403 print()
404
405 def print_cascaded_passes(self):
406 for idx, ps in enumerate(self.cascaded_passes):
407 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
408
409 def print_cascaded_passes_with_tensors(self):
410 for idx, ps in enumerate(self.cascaded_passes):
411 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
412 for idx, tens in enumerate(ps.inputs):
413 print(
414 " Input %2d %-15s %-15s %-15s %s"
415 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
416 )
417 for idx, tens in enumerate(ps.intermediates):
418 print(
419 " Intermediate %2d %-15s %-15s %-15s %s"
420 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
421 )
422 for idx, tens in enumerate(ps.outputs):
423 print(
424 " Output %2d %-15s %-15s %-15s %s"
425 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
426 )
427 print()
428
429 def print_cascaded_passes_with_tensor_sizes(self):
430 for idx, ps in enumerate(self.cascaded_passes):
431 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
432 for idx, tens in enumerate(ps.inputs):
433 print(
434 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
435 % (
436 idx,
437 tens.storage_size() / 1024,
438 tens.storage_shape,
439 tens.mem_area.name,
440 tens.purpose.name,
441 tens.format.name,
442 tens.name,
443 )
444 )
445 for idx, tens in enumerate(ps.intermediates):
446 print(
447 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
448 % (
449 idx,
450 tens.storage_size() / 1024,
451 tens.storage_shape,
452 tens.mem_area.name,
453 tens.purpose.name,
454 tens.format.name,
455 tens.name,
456 )
457 )
458 for idx, tens in enumerate(ps.outputs):
459 print(
460 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
461 % (
462 idx,
463 tens.storage_size() / 1024,
464 tens.storage_shape,
465 tens.mem_area.name,
466 tens.purpose.name,
467 tens.format.name,
468 tens.name,
469 )
470 )
471 print()
472
473 def print_high_level_command_stream(self):
474 for idx, cmd in enumerate(self.high_level_command_stream):
475 print("%3d %s" % (idx, cmd))
476
477
478class Graph:
479 def __init__(self, name="<unnamed>", batch_size=1):
480 self.name = name
481 self.batch_size = batch_size
482 self.subgraphs = []
483
484 self.memory_used = {}
485 self.bits_per_element = {}
486 self.total_size = {}
487 self.total_elements = {}
Louis Verhaard3c07c972020-05-07 08:12:58 +0200488 self.weight_cache = None # See CompressedWeightCache
Tim Hall79d07d22020-04-27 18:20:16 +0100489
490 def get_root_subgraph(self):
491 return self.subgraphs[0]
492
493 def prune_startup_init_pass(self):
494 for sg in self.subgraphs:
495 sg.prune_startup_init_pass()
496
497 def update_consumers(self):
498 for sg in self.subgraphs:
499 sg.update_consumers()
500
501 def refresh_after_modification(self):
502 for sg in self.subgraphs:
503 sg.refresh_after_modification()
504
505 def print_operators(self):
506 for sg in self.subgraphs:
507 sg.print_operators()
508
509 def print_graph(self):
510 for sg in self.subgraphs:
511 sg.print_graph()
512
513 def print_graph_with_tensors(self):
514 for sg in self.subgraphs:
515 sg.print_graph_with_tensors()
516
517 def print_graph_with_tensor_quantization(self):
518 for sg in self.subgraphs:
519 sg.print_graph_with_tensor_quantization()
520
521 def print_passes(self):
522 for sg in self.subgraphs:
523 sg.print_passes()
524
525 def print_passes_with_tensors(self):
526 for sg in self.subgraphs:
527 sg.print_passes_with_tensors()
528
529 def print_cascaded_passes(self):
530 for sg in self.subgraphs:
531 sg.print_cascaded_passes()
532
533 def print_cascaded_passes_with_tensors(self):
534 for sg in self.subgraphs:
535 sg.print_cascaded_passes_with_tensors()
536
537 def print_cascaded_passes_with_tensor_sizes(self):
538 for sg in self.subgraphs:
539 sg.print_cascaded_passes_with_tensor_sizes()
540
541 def print_high_level_command_stream(self):
542 for sg in self.subgraphs:
543 sg.print_high_level_command_stream()