blob: 8d335bd85bc7b10116e0607e2bf04c116f72e3c4 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Neural network graph classes and enums.
20# Pass - A packed pass containing one or more Operations.
21# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
22# configurations.
23# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
24# Graph - A full neural network graph with one or more Subgraphs.
25
26import enum
27from .data_type import BaseType, DataType
28from .tensor import MemArea, TensorPurpose, TensorSubPurpose, TensorFormat, Tensor
29from .operation import Operation, NpuBlockType
30
31
32class PassPlacement(enum.Enum):
33 Unknown = 0
34 Cpu = 1
35 Npu = 2
36 MemoryOnly = 3
37 StartupInit = 4
38
39
40class TensorAllocator(enum.Enum):
41 LinearAlloc = 1
42 Greedy = 2
43
44 def __str__(self):
45 return self.name
46
47
48class Pass:
49 def __init__(self, name, placement, is_element_wise, npu_block_type):
50 self.inputs = []
51 self.intermediates = []
52 self.outputs = []
53 self.ops = []
54 self.primary_op = None
55 self.ifm_tensor = None
56 self.ifm2_tensor = None
57 self.ofm_tensor = None
58 self.weight_tensor = None
59 self.scale_tensor = None
60 self.name = name
61 self.cascade = None
62 self.placement = placement
63
64 # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
65 # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
66 self.is_element_wise = is_element_wise
67 self.npu_block_type = npu_block_type
68 self.block_config = None # will be filled in by scheduler
69 self.shared_buffer = None # will be filled in by scheduler
70
71 self.predecessors = []
72 self.successors = []
73
74 def __str__(self):
75 return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
76
77 __repr__ = __str__
78
79 def get_primary_op_ifm_weights(self):
80 if not self.primary_op:
81 return None, None
82 return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
83
84 def get_primary_op_ifm_ifm2_weights_ofm(self):
85 if not self.primary_op:
86 return None, None, None, None
87 return self.primary_op.get_ifm_ifm2_weights_ofm()
88
89 def get_primary_op_ifm_weights_biases_ofm(self):
90 if not self.primary_op:
91 return None, None, None, None
92 return self.primary_op.get_ifm_weights_biases_ofm()
93
94
95class SchedulingStrategy(enum.Enum):
96 Unknown = -1
97 IfmStream = 0
98 WeightStream = 1
99
100
101class SchedulerRewrite(enum.Enum):
102 Nop = 0
103 ChangeTensorSubPurpose = 1
104
105
106class CascadedPass:
107 def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
108 self.name = name
109 self.strategy = strat
110 self.inputs = inputs
111 self.intermediates = intermediates
112 self.outputs = outputs
113 self.passes = passes
114 self.placement = placement
115 self.is_element_wise = is_element_wise
116
117 self.predecessors = []
118 self.successors = []
119
120 def __str__(self):
121 return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
122 self.strategy,
123 len(self.passes),
124 self.name,
125 [ps.name for ps in self.passes],
126 [ps.block_config for ps in self.passes],
127 )
128
129 __repr__ = __str__
130
131
132class Subgraph:
133 def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
134 self.output_tensors = []
135 self.input_tensors = []
136 self.original_inputs = [] # Preserve the original input order
137 self.passes = []
138 self.cascaded_passes = []
139 self.name = name
140 self.high_level_command_stream = []
141 self.placement = placement
142 self.command_stream_tensor = None
143 self.flash_tensor = None
144
145 self.memory_used = {}
146
147 def __str__(self):
148 return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
149 self.name,
150 len(self.passes),
151 len(self.cascaded_passes),
152 )
153
154 __repr__ = __str__
155
156 def update_consumers(self):
157 visit_op_set = set()
158 visit_tensor_set = set()
159 self.input_tensors = []
160
161 print_visit = False
162
163 def visit_op(op):
164 if op in visit_op_set:
165 return
166
167 visit_op_set.add(op)
168 for inp in op.inputs:
169 if print_visit:
170 print(inp, "adding consumer", op)
171 visit_tensor(inp)
172 inp.consumer_list.append(op)
173
174 if op.type in set(("Placeholder", "SubgraphInput")):
175 assert len(op.outputs) == 1
176 self.input_tensors.append(op.outputs[0])
177
178 for out in op.outputs:
179 if out not in visit_tensor_set:
180 out.consumer_list = [] # reset unvisited output, just in case
181
182 def visit_tensor(tens):
183 if tens in visit_tensor_set:
184 return
185 visit_tensor_set.add(tens)
186 tens.consumer_list = []
187 for op in tens.ops:
188 visit_op(op)
189
190 for ps in self.passes:
191 for tens in ps.outputs + ps.inputs:
192 tens.consumer_list = [] # reset unvisited tensors to start with
193
194 for tens in self.output_tensors:
195 visit_tensor(tens)
196 tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
197
198 print_visit = True
199 for ps in self.passes:
200 for op in ps.ops:
201 visit_op(op)
202 for tens in ps.inputs:
203 visit_tensor(tens)
204
205 def build_pass_links(self):
206 for idx, ps in enumerate(self.passes):
207 ps.time = 2 * idx
208 ps.predecessors = []
209 ps.successors = []
210
211 for ps in self.passes:
212 for tens in ps.inputs:
213 for op in tens.ops:
214 pred_pass = op.scheduled_pass
215 assert pred_pass.time < ps.time
216 if ps not in pred_pass.successors:
217 pred_pass.successors.append(ps)
218
219 if pred_pass not in ps.predecessors:
220 ps.predecessors.append(pred_pass)
221
222 assert tens in pred_pass.outputs
223
224 def build_pass_dag_predecessors(self):
225 for ps in self.passes:
226 ps.dag_predecessors = []
227
228 class State(enum.Enum):
229 NotVisited = 0
230 BeingVisited = 1
231 Visited = 2
232
233 pass_visit_dict = {}
234
235 def visit_pass(ps):
236 state = pass_visit_dict.get(ps, State.NotVisited)
237 if state == State.Visited:
238 return True
239 elif state == State.BeingVisited:
240 return False # this is a loop, need to remove this link
241 elif state == State.NotVisited:
242 pass_visit_dict[ps] = State.BeingVisited
243
244 ps.dag_predecessors = []
245 for pred in ps.predecessors:
246 if visit_pass(pred):
247 ps.dag_predecessors.append(pred)
248
249 pass_visit_dict[ps] = State.Visited
250 return True
251
252 for ps in self.passes:
253 if not ps.successors:
254 visit_pass(ps)
255
256 def build_cascaded_pass_links(self):
257 for cps in self.cascaded_passes:
258 cps.predecessors = []
259 cps.successors = []
260
261 for cps in self.cascaded_passes:
262 for tens in cps.inputs:
263 for op in tens.ops:
264 pred_cpass = op.scheduled_pass.cascade
265 if cps not in pred_cpass.successors:
266 pred_cpass.successors.append(cps)
267
268 if pred_cpass not in cps.predecessors:
269 cps.predecessors.append(pred_cpass)
270
271 assert tens in pred_cpass.outputs
272
273 def refresh_after_modification(self):
274 self.update_consumers()
275
276 def prune_startup_init_pass(self):
277 assert len(self.passes) >= 1
278 ps = self.passes[0]
279 assert ps.placement == PassPlacement.StartupInit
280
281 ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
282 ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
283
284 def get_all_ops(self):
285 all_ops = []
286 visit_op_set = set()
287 visit_tensor_set = set()
288
289 def visit_op(op):
290 if op in visit_op_set:
291 return
292 visit_op_set.add(op)
293 for inp in op.inputs:
294 visit_tensor(inp)
295
296 all_ops.append(op)
297
298 def visit_tensor(tens):
299 if tens in visit_tensor_set:
300 return
301 visit_tensor_set.add(tens)
302 for op in tens.ops:
303 visit_op(op)
304
305 for tens in self.output_tensors:
306 visit_tensor(tens)
307
308 return all_ops
309
310 def print_operators(self):
311 all_ops = self.get_all_ops()
312 unique_ops = []
313 print("print_operators")
314 for op in all_ops:
315 if op.type in set(("Const", "Identity", "Placeholder")):
316 continue
317
318 attrs = op.attrs
319 if (
320 op.type == "Conv2D"
321 or op.type == "DepthwiseConv2dNative"
322 or op.type == "Conv2DBiasAct"
323 or op.type == "DepthwiseConv2dBiasAct"
324 ):
325 kshape = op.inputs[1].shape
326 attrs["kshape"] = [kshape[0], kshape[1]]
327 attrs["type"] = op.type
328 attrs.pop("use_cudnn_on_gpu", None)
329 if attrs not in unique_ops:
330 unique_ops.append(attrs)
331 # print attributes in human readable format
332 a = attrs.copy()
333 s = a.pop("type")
334 data_format = a.pop("data_format", None)
335 if data_format and data_format != b"NHWC":
336 s += " " + str(data_format)
337 t = a.pop("T", None)
338 if t:
339 s += " " + str(t)[9:-2]
340 srct = a.pop("SrcT", None)
341 if srct:
342 s += " " + str(srct)[9:-2]
343 dstt = a.pop("DstT", None)
344 if dstt:
345 s += "->" + str(dstt)[9:-2]
346 print(s + " " + str(a))
347
348 def print_graph(self):
349 all_ops = self.get_all_ops()
350 for idx, op in enumerate(all_ops):
351 print(idx, op.type, op.name)
352
353 def print_graph_with_tensors(self):
354 all_ops = self.get_all_ops()
355 for idx, op in enumerate(all_ops):
356 print(idx, op.type, op.name)
357 for idx, tens in enumerate(op.inputs):
358 print(" Input %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
359 for idx, tens in enumerate(op.outputs):
360 print(" Output %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
361 print()
362
363 def print_graph_with_tensor_quantization(self):
364 all_ops = self.get_all_ops()
365 for idx, op in enumerate(all_ops):
366 print(idx, op.type, op.name)
367 for idx, tens in enumerate(op.inputs):
368 q = tens.quantization
369 if q is None:
370 print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
371 else:
372 print(
373 " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
374 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
375 )
376 for idx, tens in enumerate(op.outputs):
377 q = tens.quantization
378 if q is None:
379 print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
380 else:
381 print(
382 " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
383 % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
384 )
385 print()
386
387 def print_passes(self):
388 for idx, ps in enumerate(self.passes):
389 print("%03d %s" % (idx * 2, ps))
390
391 def print_passes_with_tensors(self):
392 for idx, ps in enumerate(self.passes):
393 print("%3d %s" % (idx * 2, ps))
394 for idx, tens in enumerate(ps.inputs):
395 print(
396 " Input %2d %-15s %-15s %-15s %s"
397 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
398 )
399 for idx, tens in enumerate(ps.intermediates):
400 print(
401 " Intermediate %2d %-15s %-15s %-15s %s"
402 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
403 )
404 for idx, tens in enumerate(ps.outputs):
405 print(
406 " Output %2d %-15s %-15s %-15s %s"
407 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
408 )
409 print()
410
411 def print_cascaded_passes(self):
412 for idx, ps in enumerate(self.cascaded_passes):
413 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
414
415 def print_cascaded_passes_with_tensors(self):
416 for idx, ps in enumerate(self.cascaded_passes):
417 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
418 for idx, tens in enumerate(ps.inputs):
419 print(
420 " Input %2d %-15s %-15s %-15s %s"
421 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
422 )
423 for idx, tens in enumerate(ps.intermediates):
424 print(
425 " Intermediate %2d %-15s %-15s %-15s %s"
426 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
427 )
428 for idx, tens in enumerate(ps.outputs):
429 print(
430 " Output %2d %-15s %-15s %-15s %s"
431 % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
432 )
433 print()
434
435 def print_cascaded_passes_with_tensor_sizes(self):
436 for idx, ps in enumerate(self.cascaded_passes):
437 print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
438 for idx, tens in enumerate(ps.inputs):
439 print(
440 " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
441 % (
442 idx,
443 tens.storage_size() / 1024,
444 tens.storage_shape,
445 tens.mem_area.name,
446 tens.purpose.name,
447 tens.format.name,
448 tens.name,
449 )
450 )
451 for idx, tens in enumerate(ps.intermediates):
452 print(
453 " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
454 % (
455 idx,
456 tens.storage_size() / 1024,
457 tens.storage_shape,
458 tens.mem_area.name,
459 tens.purpose.name,
460 tens.format.name,
461 tens.name,
462 )
463 )
464 for idx, tens in enumerate(ps.outputs):
465 print(
466 " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
467 % (
468 idx,
469 tens.storage_size() / 1024,
470 tens.storage_shape,
471 tens.mem_area.name,
472 tens.purpose.name,
473 tens.format.name,
474 tens.name,
475 )
476 )
477 print()
478
479 def print_high_level_command_stream(self):
480 for idx, cmd in enumerate(self.high_level_command_stream):
481 print("%3d %s" % (idx, cmd))
482
483
484class Graph:
485 def __init__(self, name="<unnamed>", batch_size=1):
486 self.name = name
487 self.batch_size = batch_size
488 self.subgraphs = []
489
490 self.memory_used = {}
491 self.bits_per_element = {}
492 self.total_size = {}
493 self.total_elements = {}
494
495 def get_root_subgraph(self):
496 return self.subgraphs[0]
497
498 def prune_startup_init_pass(self):
499 for sg in self.subgraphs:
500 sg.prune_startup_init_pass()
501
502 def update_consumers(self):
503 for sg in self.subgraphs:
504 sg.update_consumers()
505
506 def refresh_after_modification(self):
507 for sg in self.subgraphs:
508 sg.refresh_after_modification()
509
510 def print_operators(self):
511 for sg in self.subgraphs:
512 sg.print_operators()
513
514 def print_graph(self):
515 for sg in self.subgraphs:
516 sg.print_graph()
517
518 def print_graph_with_tensors(self):
519 for sg in self.subgraphs:
520 sg.print_graph_with_tensors()
521
522 def print_graph_with_tensor_quantization(self):
523 for sg in self.subgraphs:
524 sg.print_graph_with_tensor_quantization()
525
526 def print_passes(self):
527 for sg in self.subgraphs:
528 sg.print_passes()
529
530 def print_passes_with_tensors(self):
531 for sg in self.subgraphs:
532 sg.print_passes_with_tensors()
533
534 def print_cascaded_passes(self):
535 for sg in self.subgraphs:
536 sg.print_cascaded_passes()
537
538 def print_cascaded_passes_with_tensors(self):
539 for sg in self.subgraphs:
540 sg.print_cascaded_passes_with_tensors()
541
542 def print_cascaded_passes_with_tensor_sizes(self):
543 for sg in self.subgraphs:
544 sg.print_cascaded_passes_with_tensor_sizes()
545
546 def print_high_level_command_stream(self):
547 for sg in self.subgraphs:
548 sg.print_high_level_command_stream()