blob: fe31a4639c711f4fd040a16884aa90272e2492b2 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
20
21import enum
Diego Russoea6111a2020-04-14 18:41:58 +010022import copy
23
Tim Hall79d07d22020-04-27 18:20:16 +010024import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010025
26from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010027from . import npu_performance
28from . import stats_writer
Diego Russoea6111a2020-04-14 18:41:58 +010029from .tensor import TensorPurpose, TensorSubPurpose, TensorFormat, MemArea
30from .operation import NpuBlockType
31from .nn_graph import SchedulingStrategy, CascadedPass, PassPlacement, SchedulerRewrite
Tim Hall79d07d22020-04-27 18:20:16 +010032from .npu_performance import make_bandwidth_array, make_macs_array, make_cycles_array, make_metrics_arrays, PassCycles
Tim Hall79d07d22020-04-27 18:20:16 +010033from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
34from .shared_buffer_allocation import (
35 find_block_configs_suitable_for_pass_and_shared_buffer,
36 shared_buffer_allocation_for_pass_and_block_config,
37)
38from functools import lru_cache
39
40
41class ParetoMetric(enum.Enum):
42 BwCycMem = 1
43 BwCycMemBlkH = 2
44
45 def __str__(self):
46 return self.name
47
48
49class SchedulerOptions:
50 def __init__(
51 self,
52 use_cascading=True,
53 use_ifm_ofm_overlap=True,
54 verbose_schedule=False,
55 verbose_pareto_frontier_schedules=False,
56 use_ifm_streaming=True,
57 pareto_metric=ParetoMetric.BwCycMem,
58 ):
59 self.use_cascading = use_cascading
60 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
61 self.verbose_schedule = verbose_schedule
62 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
63 self.use_ifm_streaming = use_ifm_streaming
64 self.pareto_metric = pareto_metric
65
66 def __str__(self):
67 return type(self).__name__ + ": " + str(self.__dict__)
68
69 __repr__ = __str__
70
71
72class Strategy:
73 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
74
75 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
76 self.strat = strat
77 self.param = param
78 self.passes = passes
79 self.block_configs = block_configs
80 self.rewrite_list = (
81 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
82 )
83 self.bws = bws
84 self.macs = macs
85 self.cycles = cycles
86 self.sram_used = sram_used
87
88 def __eq__(self, other):
89 if self.strat != other.strat:
90 return False
91 if self.param != other.param:
92 return False
93 if self.block_configs != other.block_configs:
94 return False
95 if self.passes != other.passes:
96 return False
97 if (self.bws != other.bws).any():
98 return False
99 if (self.macs != other.macs).any():
100 return False
101 if (self.cycles != other.cycles).any():
102 return False
103 if self.sram_used != other.sram_used:
104 return False
105 return True
106
107 def empty(self):
108 return not self.passes
109
110 def key(self):
111 return self.passes[-1]
112
113 def clone(self):
114 return Strategy(
115 self.strat,
116 self.param,
117 self.passes,
118 self.block_configs,
119 self.rewrite_list,
120 self.bws,
121 self.macs,
122 self.cycles,
123 self.sram_used,
124 )
125
126 def __str__(self):
127 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
128 self.strat,
129 self.passes,
130 self.rewrite_list,
131 self.bws,
132 self.macs,
133 self.cycles,
134 self.sram_used,
135 )
136
137 __repr__ = __str__
138
139
140class StrategySet:
141 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
142
143 def __init__(self, strats=None):
144 if strats is None:
145 strats = dict()
146 self.strats = strats # final pass in packed pass -> Strategy
147 self.bws, self.macs, self.cycles = make_metrics_arrays()
148 self.max_sram_used = 0
149 self.total_sram_used = 0
150
151 def update_statistics(self):
152 self.bws = make_bandwidth_array()
153 self.max_sram_used = 0
154 for ps, strat in self.strats.items():
155 self.bws += strat.bws
156 self.macs += strat.macs
157 self.cycles += strat.cycles
158 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
159 self.total_sram_used += strat.sram_used
160
161 def clone_add_strategy(self, new_strat):
162 key = new_strat.key()
163 if key in self.strats:
164 assert new_strat == self.strats[key]
165 return self
166 else:
167 new_strats = dict(self.strats)
168 new_strats[key] = new_strat
169 new_set = StrategySet(new_strats)
170 new_set.bws = self.bws + new_strat.bws
171 new_set.macs = self.macs + new_strat.macs
172 new_set.cycles = self.cycles + new_strat.cycles
173 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
174 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
175 return new_set
176
177 def __eq__(self, other):
178 if (self.bws != other.bws).any():
179 return False
180 if (self.macs != other.macs).any():
181 return False
182 if (self.cycles != other.cycles).any():
183 return False
184 if self.max_sram_used != other.max_sram_used:
185 return False
186 if self.total_sram_used != other.total_sram_used:
187 return False
188 if self.strats != other.strats:
189 return False
190 return True
191
192 def __str__(self):
193 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
194 self.max_sram_used,
195 list(ps.name for ps in self.strats),
196 )
197
198 __repr__ = __str__
199
200
201empty_strategy = Strategy(
202 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
203)
204INFINITY = 1e30
205
206ABORT_SEARCH = []
207
208
209def flatten_list_of_lists(lstlst):
210 lst = []
211 for v in lstlst:
212 lst.extend(v)
213 return lst
214
215
216class DynamicProgrammingScheduler:
217 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
218 self.nng = nng
219 self.sg = sg
220 self.arch = arch
221 self.sram_limit = sram_limit
222 self.options = copy.copy(options)
223 self.use_cascading = options.use_cascading
224
225 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
226 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
227 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
228
229 self.verbose_schedule = options.verbose_schedule
230 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
231 self.mem_area = MemArea.Sram
232
233 self.bandwidth_weights = arch.bandwidth_weights
234 self.cycles_weight = arch.cycles_weight
235 self.max_sram_used_weight = arch.max_sram_used_weight
236
237 self.n_combinations_searched = 0
238
239 self.feature_maps_not_in_fast_storage = (
240 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
241 )
242
243 self.pareto_max_candidates = 16
244
245 self.ifm_stream_npu_blocks = set(
246 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
247 )
248
249 num_pareto_metrics = 4
250 view_values = ",".join(["d"] * num_pareto_metrics)
251 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
252
253 def pareto_metric(self, candidate):
254 strat, strat_set = candidate
255 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
256 bws = strat.bws + strat_set.bws
257 last_block_height = 0
258 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
259 last_block_height = strat.block_configs[-1][0]
260
261 return (
262 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
263 strat_set.max_sram_used,
264 strat.sram_used,
265 last_block_height,
266 )
267
268 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
269
270 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
271
272 if len(candidates) <= 1:
273 return candidates
274 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100275 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
276 ids = np.arange(len(candidates), dtype=np.int32)
277 for idx, cand in enumerate(candidates):
278 pareto_vals[idx] = self.pareto_metric(cand)
279
280 sort_order = np.argsort(
281 pareto_vals.view(DynamicProgrammingScheduler.view_values),
282 order=DynamicProgrammingScheduler.order_values,
283 axis=0,
284 kind="stable",
285 ).flatten()
286 pareto_vals = pareto_vals[sort_order]
287 ids = ids[sort_order]
288
289 pareto_frontier = []
290 while len(ids) > 0:
291 pareto_frontier.append(candidates[ids[0]])
292 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
293 ids = ids[not_dominated_by_first]
294 pareto_vals = pareto_vals[not_dominated_by_first]
295
296 if len(pareto_frontier) > self.pareto_max_candidates:
297 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
298 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
299
300 return pareto_frontier
301
302 def candidate_metric(self, candidate):
303 strat, strat_set = candidate
304 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
305 bws = strat.bws + strat_set.bws
306 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
307
308 return (
309 max_sram_used * self.max_sram_used_weight
310 + np.tensordot(bws, self.bandwidth_weights, axes=3)
311 + total_cycles * self.cycles_weight
312 )
313
314 def sort_by_candidate_metric(self, candidate_list):
315 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
316 return sorted_list
317
318 def best_candidate(self, candidate_list):
319 if len(candidate_list) == 0:
320 return ABORT_SEARCH
321 if len(candidate_list) == 1:
322 return candidate_list[0]
323 sorted_list = self.sort_by_candidate_metric(candidate_list)
324 return sorted_list[0]
325
326 def graduate_strat(self, strat_type, sram_used, old_strat_data):
327 res = []
328 for old_strat, old_strat_set in old_strat_data:
329 if old_strat.sram_used + sram_used > self.sram_limit:
330 continue # This strategy is bad, drop it
331 if old_strat_set.max_sram_used > self.sram_limit:
332 continue # This strategy is bad, drop it
333 assert old_strat.strat == SchedulingStrategy.Unknown
334
335 new_strat = old_strat.clone()
336 new_strat.strat = strat_type
337 new_strat.sram_used = old_strat.sram_used + sram_used
338
339 if self.use_ifm_ofm_overlap:
340 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
341 new_strat.strat, new_strat.passes, new_strat.block_configs
342 )
343 new_strat.sram_used -= overlap
344
345 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
346 res.append((empty_strategy, new_strat_set))
347 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
348
349 def append_sram(self, sram_used, old_strat_data):
350 res = []
351 for old_strat, strat_set in old_strat_data:
352 assert old_strat.strat == SchedulingStrategy.Unknown
353 assert old_strat.sram_used == 0
354 new_strat = old_strat.clone()
355 new_strat.sram_used = old_strat.sram_used + sram_used
356
357 res.append((new_strat, strat_set))
358 return res
359
360 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
361 res = []
362 for old_strat, strat_set in old_strat_data:
363 assert old_strat.strat == SchedulingStrategy.Unknown
364 new_strat = old_strat.clone()
365 bws, macs, cycles = metrics[:3]
366
367 new_strat.sram_used = old_strat.sram_used + sram_used
368 new_strat.block_configs = old_strat.block_configs + [block_config]
369 new_strat.bws = old_strat.bws + bws
370 new_strat.macs = old_strat.macs + macs
371 new_strat.cycles = old_strat.cycles + cycles
372 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
373 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
374 )
375
376 res.append((new_strat, strat_set))
377 return res
378
379 def append_sram_pass_block_config_performance_metrics_rewrite_list(
380 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
381 ):
382 res = []
383 for old_strat, strat_set in old_strat_data:
384 assert old_strat.strat == SchedulingStrategy.Unknown
385 new_strat = old_strat.clone()
386 bws, macs, cycles = metrics[:3]
387 new_strat.sram_used = old_strat.sram_used + sram_used
388 new_strat.block_configs = old_strat.block_configs + [block_config]
389 new_strat.bws = old_strat.bws + bws
390 new_strat.macs = old_strat.macs + macs
391 new_strat.cycles = old_strat.cycles + cycles
392 new_strat.passes = old_strat.passes + [new_pass]
393 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
394 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
395 )
396 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
397 res.append((new_strat, strat_set))
398 return res
399
400 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
401 res = []
402 for old_strat, strat_set in old_strat_data:
403 assert old_strat.strat == SchedulingStrategy.Unknown
404 new_strat = old_strat.clone()
405 new_strat.sram_used = old_strat.sram_used + sram_used
406 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
407 res.append((new_strat, strat_set))
408 return res
409
410 def pass_to_strat(self, strat_data):
411 res = {}
412 for strat in strat_data[1].strats.values():
413 for ps in strat.passes:
414 res[ps] = strat
415 return res
416
417 def compatible_strats(self, a, b):
418 intersection = a.keys() & b.keys()
419 for k in intersection:
420 if a[k] != b[k]:
421 return False
422 return True
423
424 def collate_strats_for_passes(self, all_passes):
425 if len(all_passes) == 0:
426 return [(empty_strategy, StrategySet(dict()))]
427 if len(all_passes) == 1:
428 return all_passes[0] # save some space in the common case
429 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
430 prev_combos = [dict()]
431 for j, strand in enumerate(all_strands):
432 new_combos = []
433 for i, alt in enumerate(strand):
434 for prev in prev_combos:
435 if self.compatible_strats(prev, alt):
436 cmb = dict(prev)
437 cmb.update(all_passes[j][i][1].strats)
438 new_combos.append(cmb)
439 prev_combos = new_combos
440
441 res = []
442 for d in prev_combos:
443 s = StrategySet(d)
444 s.update_statistics()
445 res.append((empty_strategy, s))
446 return res
447
448 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
449 # get the rest of the predecessors
450 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
451 other_predecessor_data = self.search_pass_list(other_predecessors)
452
453 # pred strat data has an incomplete strategy, which we need
454 # to continue on, whereas the other ones have completed strategies.
455 # we need to merge these, but keep the incomplete strategy too.
456
457 res = []
458 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
459 all_strats = [
460 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
461 other_predecessor_data, # this one is fine to use as-is
462 ]
463 collated_strat_data = self.collate_strats_for_passes(all_strats)
464 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
465 res.extend(strat_data)
466 return res
467
468 def calc_non_local_mem_usage(self):
469 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
470 range_set = live_range.extract_live_ranges_from_passes(
471 self.sg,
472 self.mem_area,
473 mark_output_tensors_overlapping_with_input_tensors=True,
474 ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
475 )
476 range_dict = range_set.ranges
477
478 # find which ranges overlap passes but aren't input/outputs of the passes.
479 # these won't be counted by the dynamic programming search and must be counted in manually.
480 end_pos = max(ps.time for ps in self.sg.passes) + 2
481 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
482 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
483
484 for tens, rng in range_dict.items():
485 storage_size = tens.storage_size()
486 assert tens.mem_area == self.mem_area
487 mem_usage[rng.start_time : rng.end_time] += storage_size
488
489 for ps in self.sg.passes:
490 local_mem_usage = 0
491 for tens in ps.inputs + ps.outputs + ps.intermediates:
492 if tens.mem_area != self.mem_area:
493 continue
494
495 local_mem_usage += tens.storage_size()
496
497 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
498
499 self.non_local_mem_usage = non_local_mem_usage
500
501 def search(self):
502 self.calc_non_local_mem_usage()
503 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
504 strat_data = self.search_pass_list(starting_passes)
505
506 _, best_set = self.best_candidate(strat_data)
507
508 if self.verbose_pareto_frontier_schedules:
509 print(
510 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
511 % (self.n_combinations_searched, len(strat_data,))
512 )
513 for idx, (_, strat_set) in enumerate(strat_data):
514 extra = ""
515 if strat_set == best_set:
516 extra = "(Best candidate)"
517 print("Candidate", idx, extra)
518 memory_used = {MemArea.Sram: strat_set.max_sram_used}
519 stats_writer.print_performance_metrics_for_strat(
520 self.arch,
521 "",
522 strat_set.cycles,
523 strat_set.macs,
524 strat_set.bws,
525 self.nng.batch_size,
526 memory_used,
527 len(self.sg.passes),
528 len(strat_set.strats),
529 )
530
531 return best_set
532
533 def search_pass_list(self, pass_list):
534 all_strats = []
535 for ps in pass_list:
536 strat = self.search_output(ps)
537 all_strats.append(strat)
538 strat_data = self.collate_strats_for_passes(all_strats)
539 for strd in strat_data:
540 for ps in pass_list:
541 assert ps in strd[1].strats # should have strategies for everything we asked to search
542 return strat_data
543
544 def search_predecessors(self, ps):
545
546 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
547 # we have strats for all passes
548
549 pass_list = ps.dag_predecessors
550 strat_data = self.search_pass_list(pass_list)
551
552 return strat_data
553
554 @lru_cache(maxsize=None)
555 def search_output(self, ps):
556
557 assert ps in self.sg.passes
558 candidate_list = []
559
560 candidate_list.extend(self.search_weight_streaming_output(ps))
561
562 if self.options.use_ifm_streaming:
563 candidate_list.extend(self.search_ifm_streaming_output(ps))
564
565 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
566
567 if not best:
568 print(
569 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
570 % (ps.name,)
571 )
572 return self.search_predecessors(ps)
573
574 return best
575
576 def search_ifm_streaming_output(self, ps):
577 if ps.placement != PassPlacement.Npu:
578 return ABORT_SEARCH
579 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
580 return ABORT_SEARCH
581 strat_data = self.search_ifm_streaming_body(ps, False)
582
583 sram_used = self.non_local_mem_usage[ps.time]
584 for tens in ps.outputs:
585 if tens.mem_area == self.mem_area:
586 sram_used += tens.storage_size()
587
588 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
589
590 @lru_cache(maxsize=None)
591 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
592 if ps.placement != PassPlacement.Npu:
593 return ABORT_SEARCH
594 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
595 return ABORT_SEARCH
596 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
597 res = []
598
599 base_sram_used = 0
600 for tens in ps.intermediates:
601 if tens.mem_area == self.mem_area:
602 base_sram_used += tens.storage_size()
603
604 all_block_configs = self.get_block_configs(ps)
605 for block_config in all_block_configs:
606 all_strats = []
607
608 if self.use_cascading:
609 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
610
611 all_strats.extend(ifm_input_search_resuls)
612
613 rewrite_list = []
614 sram_used = base_sram_used
615
616 metrics = npu_performance.performance_metrics_for_pass(
617 self.arch,
618 ps,
619 block_config,
620 rewrite_list=rewrite_list,
621 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
622 )
623
624 res.extend(
625 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
626 sram_used, ps, block_config, metrics, rewrite_list, all_strats
627 )
628 )
629
630 self.n_combinations_searched += len(res)
631 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
632 return res
633
634 def search_ifm_streaming_partial(self, ps, block_config):
635 if ps.placement != PassPlacement.Npu:
636 return ABORT_SEARCH
637
638 if len(ps.inputs) < 1:
639 return ABORT_SEARCH
640
641 ifm_tensor = ps.ifm_tensor
642
643 if ifm_tensor is None:
644 return ABORT_SEARCH
645 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
646 return ABORT_SEARCH
647 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
648 return ABORT_SEARCH
649
650 pred_pass_list = []
651 for pred_candidate in ps.dag_predecessors:
652 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
653 # we found a predecessor that produces this IFM tensor
654 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
655 # and it only has one successor, namely us
656 if pred_candidate.placement == PassPlacement.Npu:
657 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
658 # and it is on the Npu and fusable - it's a candidate
659 pred_pass_list.append(pred_candidate)
660
661 if not pred_pass_list:
662 return ABORT_SEARCH
663
664 all_candidates = []
665 for pred_pass in pred_pass_list:
666 # recurse into the next pass
667 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
668
669 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
670 for strat_opt in strat_data:
671
672 pred_pass_block_config = strat_opt[0].block_configs[-1]
673 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
674 self.arch, pred_pass, pred_pass_block_config, ps, block_config
675 )
676 if rolling_buffer_dims is None:
677 continue # this does not pack properly, skip it.
678
679 sram_used = 0
680 for tens in ps.inputs:
681 if tens != ifm_tensor:
682 if tens.mem_area == self.mem_area:
683 sram_used += tens.storage_size()
684
685 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
686
687 rewrite_list = [
688 (
689 SchedulerRewrite.ChangeTensorSubPurpose,
690 ifm_tensor,
691 TensorSubPurpose.RollingBufferY,
692 rolling_buffer_y,
693 None,
694 ps,
695 )
696 ]
697 sram_used += ifm_tensor.storage_size_for_sub_purpose(
698 TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
699 )
700
701 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
702
703 self.n_combinations_searched += len(all_candidates)
704 return all_candidates
705
706 def get_block_configs(self, ps):
707 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100708 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100709
710 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
711
712 # Take a limited number of the largest blocks
713 if self.arch.block_config_limit > 0:
714 # Sort by block area, followed by depth
715 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
716 bound = min(len(block_configs), self.arch.block_config_limit)
717 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
718 tmp = block_configs[:bound]
719 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
720 block_configs = tmp
721
722 return block_configs
723
724 def search_ifm_streaming_input(self, ps):
725 sram_used = 0
726 for tens in ps.inputs:
727 if tens.mem_area == self.mem_area:
728 sram_used += tens.storage_size()
729
730 return self.append_sram(sram_used, self.search_predecessors(ps))
731
732 def search_weight_streaming_output(self, ps):
733 strat_data = self.search_weight_streaming_body(ps)
734
735 sram_used = self.non_local_mem_usage[ps.time]
736 for tens in ps.outputs:
737 if tens.mem_area == self.mem_area:
738 sram_used += tens.storage_size()
739
740 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
741
742 @lru_cache(maxsize=None)
743 def search_weight_streaming_body(self, ps):
744
745 strat_data = self.search_weight_streaming_input(ps)
746
747 res = []
748
749 all_block_configs = self.get_block_configs(ps)
750
751 for block_config in all_block_configs:
752
753 sram_used = 0
754 rewrite_list = []
755
756 for tens in ps.intermediates:
757 if tens.mem_area == self.mem_area:
758 if tens.purpose == TensorPurpose.Weights:
Diego Russoea6111a2020-04-14 18:41:58 +0100759 sram_used += tens.storage_size_for_sub_purpose(TensorSubPurpose.DoubleBuffer, block_config[3])
Tim Hall79d07d22020-04-27 18:20:16 +0100760 rewrite_list.append(
761 (
762 SchedulerRewrite.ChangeTensorSubPurpose,
763 tens,
764 TensorSubPurpose.DoubleBuffer,
765 block_config[3],
766 None,
767 ps,
768 )
769 )
770 else:
771 sram_used += tens.storage_size()
772
773 metrics = npu_performance.performance_metrics_for_pass(
774 self.arch, ps, block_config, rewrite_list=rewrite_list
775 )
776
777 res.extend(
778 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
779 sram_used, ps, block_config, metrics, rewrite_list, strat_data
780 )
781 )
782
783 self.n_combinations_searched += len(res)
784 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
785 return res
786
787 def search_weight_streaming_input(self, ps):
788 sram_used = 0
789 for tens in ps.inputs:
790 if tens.mem_area == self.mem_area:
791 sram_used += tens.storage_size()
792
793 return self.append_sram(sram_used, self.search_predecessors(ps))
794
795 def apply_result(self, strat_set, arch):
796 pass_to_cascaded_pass = dict()
797 for _, strat in strat_set.strats.items():
798 # rewrite the tensors that need this first. e.g. make rolling buffers
799 inputs = []
800 intermediates = []
801 outputs = []
802
803 for ps in strat.passes:
804 inputs += ps.inputs
805 intermediates += ps.intermediates
806 outputs += ps.outputs
807
808 for tens in set(inputs) & set(outputs):
809 # tensors that are in both sets are intermediates
810
811 # find pass with input/output tensor, and check if they are both placed on NPU
812 input_placement = None
813 output_placement = None
814 for ps in strat.passes:
815 if tens in ps.inputs:
816 input_placement = ps.placement
817 if tens in ps.outputs:
818 output_placement = ps.placement
819 if input_placement == output_placement == PassPlacement.Npu:
820 tens.set_format(TensorFormat.NHCWB16, arch)
821
822 intermediates.append(tens)
823 inputs.remove(tens)
824 outputs.remove(tens)
825
826 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
827 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
828 tens.mem_area = self.arch.fast_storage_mem_area
829 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
830 else:
831 assert 0, "unknown rewrite_op " + str(rewrite_op)
832
833 is_element_wise = True
834 for ps in strat.passes:
835 assert ps.placement == strat.passes[0].placement
836 if not ps.is_element_wise:
837 is_element_wise = False
838 break
839
840 cascaded_pass = CascadedPass(
841 strat.passes[0].name,
842 strat.strat,
843 inputs,
844 intermediates,
845 outputs,
846 strat.passes,
847 strat.passes[0].placement,
848 is_element_wise,
849 )
850 assert strat.sram_used >= 0
851 cascaded_pass.sram_used = strat.sram_used
852
853 for idx, ps in enumerate(strat.passes):
854 assert ps not in pass_to_cascaded_pass
855 pass_to_cascaded_pass[ps] = cascaded_pass
856 ps.cascade = cascaded_pass
857 ps.block_config = strat.block_configs[idx]
858
859 if ps.placement == PassPlacement.Npu:
860 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
861 self.arch, ps, ps.block_config
862 )
863 assert ps.shared_buffer is not None
864
865 for op in ps.ops:
866 subgraph = op.attrs.get("subgraph")
867 if subgraph:
868 subgraph.base_sram_used = cascaded_pass.sram_used
869
870 # all passes should have a cascaded pass now
871 if len(pass_to_cascaded_pass) != len(self.sg.passes):
872 print(
873 "mismatch: we have %d passes, but only %d have cascaded passes associated"
874 % (len(self.sg.passes), len(pass_to_cascaded_pass))
875 )
876 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100877 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100878 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
879
880 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
881 # we have all the passes, but we need to put them in order and build predecessor/successor links.
882
883 visit_pass_set = set()
884 cascaded_passes = []
885
886 def visit_pass(ps):
887 if ps in visit_pass_set:
888 return
889 visit_pass_set.add(ps)
890
891 cps = ps.cascade
892 dont_traverse = set(cps.passes)
893
894 for ps in cps.passes:
895 for pred in ps.predecessors:
896 if pred in dont_traverse:
897 continue
898 visit_pass(pred)
899
900 cascaded_passes.append(cps)
901
902 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
903 for ps in starting_passes:
904 visit_pass(ps)
905
906 # reorder so startup init cascaded passes come first
907 def is_startup_cascaded_pass(cps):
908 if not cps.passes:
909 return False
910 return cps.placement == PassPlacement.StartupInit
911
912 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
913 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
914 ]
915
916 self.sg.cascaded_passes = cascaded_passes
917 self.sg.build_cascaded_pass_links()
918
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200919 # Check if NHCWB16 can be used in between cascaded passes
920 # (NHCWB16 within cascaded passes has been handled earlier in this function)
921 if self.sg.placement == PassPlacement.Npu:
922 for ps in self.sg.cascaded_passes:
923 if ps.placement != PassPlacement.Npu:
924 continue
925 for output in ps.outputs:
926 if output.purpose != TensorPurpose.FeatureMap:
927 continue
928
929 use_NHCWB16 = True
930 for op in output.consumer_list:
931 if op == None or op.type == 'Reshape':
932 use_NHCWB16 = False
933 else:
934 use_NHCWB16 &= op.run_on_npu
935
936 if use_NHCWB16:
937 output.set_format(TensorFormat.NHCWB16, arch)
938
Tim Hall79d07d22020-04-27 18:20:16 +0100939
940def schedule_passes(nng, arch, options: SchedulerOptions):
941
942 for sg in nng.subgraphs:
943 sg.base_sram_used = 0
944
945 for sg in nng.subgraphs:
946 # re-entering the same nodes from different contexts requires us to
947 # build a simplified directed acyclic (DAG) version of the graph to
948 # use for traversal, rather than using a visit dictionary. this avoids
949 # recursing infinitely due to loops.
950 sg.build_pass_dag_predecessors()
951
952 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
953
954 strat_set = dps.search()
955
956 dps.apply_result(strat_set, arch)
957
958 if options.verbose_schedule:
959 sg.print_cascaded_passes()