blob: 0b594310dd60f011b696351245246bca7fd5ed72 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Tim Hall79d07d22020-04-27 18:20:16 +010027from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010028from .nn_graph import CascadedPass
29from .nn_graph import PassPlacement
30from .nn_graph import SchedulerRewrite
31from .nn_graph import SchedulingStrategy
32from .npu_performance import make_bandwidth_array
33from .npu_performance import make_cycles_array
34from .npu_performance import make_macs_array
35from .npu_performance import make_metrics_arrays
36from .npu_performance import PassCycles
37from .operation import NpuBlockType
38from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
39from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
40from .tensor import MemArea
41from .tensor import TensorFormat
42from .tensor import TensorPurpose
43from .tensor import TensorSubPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010044
45
46class ParetoMetric(enum.Enum):
47 BwCycMem = 1
48 BwCycMemBlkH = 2
49
50 def __str__(self):
51 return self.name
52
53
54class SchedulerOptions:
55 def __init__(
56 self,
57 use_cascading=True,
58 use_ifm_ofm_overlap=True,
59 verbose_schedule=False,
60 verbose_pareto_frontier_schedules=False,
61 use_ifm_streaming=True,
62 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020063 use_nhcwb16_between_cascaded_passes=True,
Tim Hall79d07d22020-04-27 18:20:16 +010064 ):
65 self.use_cascading = use_cascading
66 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
67 self.verbose_schedule = verbose_schedule
68 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
69 self.use_ifm_streaming = use_ifm_streaming
70 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020071 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Tim Hall79d07d22020-04-27 18:20:16 +010072
73 def __str__(self):
74 return type(self).__name__ + ": " + str(self.__dict__)
75
76 __repr__ = __str__
77
78
79class Strategy:
80 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
81
82 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
83 self.strat = strat
84 self.param = param
85 self.passes = passes
86 self.block_configs = block_configs
87 self.rewrite_list = (
88 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
89 )
90 self.bws = bws
91 self.macs = macs
92 self.cycles = cycles
93 self.sram_used = sram_used
94
95 def __eq__(self, other):
96 if self.strat != other.strat:
97 return False
98 if self.param != other.param:
99 return False
100 if self.block_configs != other.block_configs:
101 return False
102 if self.passes != other.passes:
103 return False
104 if (self.bws != other.bws).any():
105 return False
106 if (self.macs != other.macs).any():
107 return False
108 if (self.cycles != other.cycles).any():
109 return False
110 if self.sram_used != other.sram_used:
111 return False
112 return True
113
114 def empty(self):
115 return not self.passes
116
117 def key(self):
118 return self.passes[-1]
119
120 def clone(self):
121 return Strategy(
122 self.strat,
123 self.param,
124 self.passes,
125 self.block_configs,
126 self.rewrite_list,
127 self.bws,
128 self.macs,
129 self.cycles,
130 self.sram_used,
131 )
132
133 def __str__(self):
134 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
135 self.strat,
136 self.passes,
137 self.rewrite_list,
138 self.bws,
139 self.macs,
140 self.cycles,
141 self.sram_used,
142 )
143
144 __repr__ = __str__
145
146
147class StrategySet:
148 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
149
150 def __init__(self, strats=None):
151 if strats is None:
152 strats = dict()
153 self.strats = strats # final pass in packed pass -> Strategy
154 self.bws, self.macs, self.cycles = make_metrics_arrays()
155 self.max_sram_used = 0
156 self.total_sram_used = 0
157
158 def update_statistics(self):
159 self.bws = make_bandwidth_array()
160 self.max_sram_used = 0
161 for ps, strat in self.strats.items():
162 self.bws += strat.bws
163 self.macs += strat.macs
164 self.cycles += strat.cycles
165 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
166 self.total_sram_used += strat.sram_used
167
168 def clone_add_strategy(self, new_strat):
169 key = new_strat.key()
170 if key in self.strats:
171 assert new_strat == self.strats[key]
172 return self
173 else:
174 new_strats = dict(self.strats)
175 new_strats[key] = new_strat
176 new_set = StrategySet(new_strats)
177 new_set.bws = self.bws + new_strat.bws
178 new_set.macs = self.macs + new_strat.macs
179 new_set.cycles = self.cycles + new_strat.cycles
180 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
181 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
182 return new_set
183
184 def __eq__(self, other):
185 if (self.bws != other.bws).any():
186 return False
187 if (self.macs != other.macs).any():
188 return False
189 if (self.cycles != other.cycles).any():
190 return False
191 if self.max_sram_used != other.max_sram_used:
192 return False
193 if self.total_sram_used != other.total_sram_used:
194 return False
195 if self.strats != other.strats:
196 return False
197 return True
198
199 def __str__(self):
200 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
201 self.max_sram_used,
202 list(ps.name for ps in self.strats),
203 )
204
205 __repr__ = __str__
206
207
208empty_strategy = Strategy(
209 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
210)
211INFINITY = 1e30
212
213ABORT_SEARCH = []
214
215
216def flatten_list_of_lists(lstlst):
217 lst = []
218 for v in lstlst:
219 lst.extend(v)
220 return lst
221
222
223class DynamicProgrammingScheduler:
224 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
225 self.nng = nng
226 self.sg = sg
227 self.arch = arch
228 self.sram_limit = sram_limit
229 self.options = copy.copy(options)
230 self.use_cascading = options.use_cascading
231
232 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
233 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
234 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
235
236 self.verbose_schedule = options.verbose_schedule
237 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
238 self.mem_area = MemArea.Sram
239
240 self.bandwidth_weights = arch.bandwidth_weights
241 self.cycles_weight = arch.cycles_weight
242 self.max_sram_used_weight = arch.max_sram_used_weight
243
244 self.n_combinations_searched = 0
245
246 self.feature_maps_not_in_fast_storage = (
247 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
248 )
249
250 self.pareto_max_candidates = 16
251
252 self.ifm_stream_npu_blocks = set(
253 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
254 )
255
256 num_pareto_metrics = 4
257 view_values = ",".join(["d"] * num_pareto_metrics)
258 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
259
260 def pareto_metric(self, candidate):
261 strat, strat_set = candidate
262 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
263 bws = strat.bws + strat_set.bws
264 last_block_height = 0
265 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
266 last_block_height = strat.block_configs[-1][0]
267
268 return (
269 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
270 strat_set.max_sram_used,
271 strat.sram_used,
272 last_block_height,
273 )
274
275 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
276
277 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
278
279 if len(candidates) <= 1:
280 return candidates
281 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100282 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
283 ids = np.arange(len(candidates), dtype=np.int32)
284 for idx, cand in enumerate(candidates):
285 pareto_vals[idx] = self.pareto_metric(cand)
286
287 sort_order = np.argsort(
288 pareto_vals.view(DynamicProgrammingScheduler.view_values),
289 order=DynamicProgrammingScheduler.order_values,
290 axis=0,
291 kind="stable",
292 ).flatten()
293 pareto_vals = pareto_vals[sort_order]
294 ids = ids[sort_order]
295
296 pareto_frontier = []
297 while len(ids) > 0:
298 pareto_frontier.append(candidates[ids[0]])
299 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
300 ids = ids[not_dominated_by_first]
301 pareto_vals = pareto_vals[not_dominated_by_first]
302
303 if len(pareto_frontier) > self.pareto_max_candidates:
304 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
305 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
306
307 return pareto_frontier
308
309 def candidate_metric(self, candidate):
310 strat, strat_set = candidate
311 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
312 bws = strat.bws + strat_set.bws
313 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
314
315 return (
316 max_sram_used * self.max_sram_used_weight
317 + np.tensordot(bws, self.bandwidth_weights, axes=3)
318 + total_cycles * self.cycles_weight
319 )
320
321 def sort_by_candidate_metric(self, candidate_list):
322 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
323 return sorted_list
324
325 def best_candidate(self, candidate_list):
326 if len(candidate_list) == 0:
327 return ABORT_SEARCH
328 if len(candidate_list) == 1:
329 return candidate_list[0]
330 sorted_list = self.sort_by_candidate_metric(candidate_list)
331 return sorted_list[0]
332
333 def graduate_strat(self, strat_type, sram_used, old_strat_data):
334 res = []
335 for old_strat, old_strat_set in old_strat_data:
336 if old_strat.sram_used + sram_used > self.sram_limit:
337 continue # This strategy is bad, drop it
338 if old_strat_set.max_sram_used > self.sram_limit:
339 continue # This strategy is bad, drop it
340 assert old_strat.strat == SchedulingStrategy.Unknown
341
342 new_strat = old_strat.clone()
343 new_strat.strat = strat_type
344 new_strat.sram_used = old_strat.sram_used + sram_used
345
346 if self.use_ifm_ofm_overlap:
347 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
348 new_strat.strat, new_strat.passes, new_strat.block_configs
349 )
350 new_strat.sram_used -= overlap
351
352 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
353 res.append((empty_strategy, new_strat_set))
354 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
355
356 def append_sram(self, sram_used, old_strat_data):
357 res = []
358 for old_strat, strat_set in old_strat_data:
359 assert old_strat.strat == SchedulingStrategy.Unknown
360 assert old_strat.sram_used == 0
361 new_strat = old_strat.clone()
362 new_strat.sram_used = old_strat.sram_used + sram_used
363
364 res.append((new_strat, strat_set))
365 return res
366
367 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
368 res = []
369 for old_strat, strat_set in old_strat_data:
370 assert old_strat.strat == SchedulingStrategy.Unknown
371 new_strat = old_strat.clone()
372 bws, macs, cycles = metrics[:3]
373
374 new_strat.sram_used = old_strat.sram_used + sram_used
375 new_strat.block_configs = old_strat.block_configs + [block_config]
376 new_strat.bws = old_strat.bws + bws
377 new_strat.macs = old_strat.macs + macs
378 new_strat.cycles = old_strat.cycles + cycles
379 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
380 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
381 )
382
383 res.append((new_strat, strat_set))
384 return res
385
386 def append_sram_pass_block_config_performance_metrics_rewrite_list(
387 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
388 ):
389 res = []
390 for old_strat, strat_set in old_strat_data:
391 assert old_strat.strat == SchedulingStrategy.Unknown
392 new_strat = old_strat.clone()
393 bws, macs, cycles = metrics[:3]
394 new_strat.sram_used = old_strat.sram_used + sram_used
395 new_strat.block_configs = old_strat.block_configs + [block_config]
396 new_strat.bws = old_strat.bws + bws
397 new_strat.macs = old_strat.macs + macs
398 new_strat.cycles = old_strat.cycles + cycles
399 new_strat.passes = old_strat.passes + [new_pass]
400 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
401 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
402 )
403 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
404 res.append((new_strat, strat_set))
405 return res
406
407 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
408 res = []
409 for old_strat, strat_set in old_strat_data:
410 assert old_strat.strat == SchedulingStrategy.Unknown
411 new_strat = old_strat.clone()
412 new_strat.sram_used = old_strat.sram_used + sram_used
413 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
414 res.append((new_strat, strat_set))
415 return res
416
417 def pass_to_strat(self, strat_data):
418 res = {}
419 for strat in strat_data[1].strats.values():
420 for ps in strat.passes:
421 res[ps] = strat
422 return res
423
424 def compatible_strats(self, a, b):
425 intersection = a.keys() & b.keys()
426 for k in intersection:
427 if a[k] != b[k]:
428 return False
429 return True
430
431 def collate_strats_for_passes(self, all_passes):
432 if len(all_passes) == 0:
433 return [(empty_strategy, StrategySet(dict()))]
434 if len(all_passes) == 1:
435 return all_passes[0] # save some space in the common case
436 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
437 prev_combos = [dict()]
438 for j, strand in enumerate(all_strands):
439 new_combos = []
440 for i, alt in enumerate(strand):
441 for prev in prev_combos:
442 if self.compatible_strats(prev, alt):
443 cmb = dict(prev)
444 cmb.update(all_passes[j][i][1].strats)
445 new_combos.append(cmb)
446 prev_combos = new_combos
447
448 res = []
449 for d in prev_combos:
450 s = StrategySet(d)
451 s.update_statistics()
452 res.append((empty_strategy, s))
453 return res
454
455 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
456 # get the rest of the predecessors
457 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
458 other_predecessor_data = self.search_pass_list(other_predecessors)
459
460 # pred strat data has an incomplete strategy, which we need
461 # to continue on, whereas the other ones have completed strategies.
462 # we need to merge these, but keep the incomplete strategy too.
463
464 res = []
465 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
466 all_strats = [
467 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
468 other_predecessor_data, # this one is fine to use as-is
469 ]
470 collated_strat_data = self.collate_strats_for_passes(all_strats)
471 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
472 res.extend(strat_data)
473 return res
474
475 def calc_non_local_mem_usage(self):
476 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
477 range_set = live_range.extract_live_ranges_from_passes(
478 self.sg,
479 self.mem_area,
480 mark_output_tensors_overlapping_with_input_tensors=True,
481 ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
482 )
483 range_dict = range_set.ranges
484
485 # find which ranges overlap passes but aren't input/outputs of the passes.
486 # these won't be counted by the dynamic programming search and must be counted in manually.
487 end_pos = max(ps.time for ps in self.sg.passes) + 2
488 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
489 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
490
491 for tens, rng in range_dict.items():
492 storage_size = tens.storage_size()
493 assert tens.mem_area == self.mem_area
494 mem_usage[rng.start_time : rng.end_time] += storage_size
495
496 for ps in self.sg.passes:
497 local_mem_usage = 0
498 for tens in ps.inputs + ps.outputs + ps.intermediates:
499 if tens.mem_area != self.mem_area:
500 continue
501
502 local_mem_usage += tens.storage_size()
503
504 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
505
506 self.non_local_mem_usage = non_local_mem_usage
507
508 def search(self):
509 self.calc_non_local_mem_usage()
510 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
511 strat_data = self.search_pass_list(starting_passes)
512
513 _, best_set = self.best_candidate(strat_data)
514
515 if self.verbose_pareto_frontier_schedules:
516 print(
517 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
518 % (self.n_combinations_searched, len(strat_data,))
519 )
520 for idx, (_, strat_set) in enumerate(strat_data):
521 extra = ""
522 if strat_set == best_set:
523 extra = "(Best candidate)"
524 print("Candidate", idx, extra)
525 memory_used = {MemArea.Sram: strat_set.max_sram_used}
526 stats_writer.print_performance_metrics_for_strat(
527 self.arch,
528 "",
529 strat_set.cycles,
530 strat_set.macs,
531 strat_set.bws,
532 self.nng.batch_size,
533 memory_used,
534 len(self.sg.passes),
535 len(strat_set.strats),
536 )
537
538 return best_set
539
540 def search_pass_list(self, pass_list):
541 all_strats = []
542 for ps in pass_list:
543 strat = self.search_output(ps)
544 all_strats.append(strat)
545 strat_data = self.collate_strats_for_passes(all_strats)
546 for strd in strat_data:
547 for ps in pass_list:
548 assert ps in strd[1].strats # should have strategies for everything we asked to search
549 return strat_data
550
551 def search_predecessors(self, ps):
552
553 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
554 # we have strats for all passes
555
556 pass_list = ps.dag_predecessors
557 strat_data = self.search_pass_list(pass_list)
558
559 return strat_data
560
561 @lru_cache(maxsize=None)
562 def search_output(self, ps):
563
564 assert ps in self.sg.passes
565 candidate_list = []
566
567 candidate_list.extend(self.search_weight_streaming_output(ps))
568
569 if self.options.use_ifm_streaming:
570 candidate_list.extend(self.search_ifm_streaming_output(ps))
571
572 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
573
574 if not best:
575 print(
576 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
577 % (ps.name,)
578 )
579 return self.search_predecessors(ps)
580
581 return best
582
583 def search_ifm_streaming_output(self, ps):
584 if ps.placement != PassPlacement.Npu:
585 return ABORT_SEARCH
586 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
587 return ABORT_SEARCH
588 strat_data = self.search_ifm_streaming_body(ps, False)
589
590 sram_used = self.non_local_mem_usage[ps.time]
591 for tens in ps.outputs:
592 if tens.mem_area == self.mem_area:
593 sram_used += tens.storage_size()
594
595 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
596
597 @lru_cache(maxsize=None)
598 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
599 if ps.placement != PassPlacement.Npu:
600 return ABORT_SEARCH
601 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
602 return ABORT_SEARCH
603 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
604 res = []
605
606 base_sram_used = 0
607 for tens in ps.intermediates:
608 if tens.mem_area == self.mem_area:
609 base_sram_used += tens.storage_size()
610
611 all_block_configs = self.get_block_configs(ps)
612 for block_config in all_block_configs:
613 all_strats = []
614
615 if self.use_cascading:
616 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
617
618 all_strats.extend(ifm_input_search_resuls)
619
620 rewrite_list = []
621 sram_used = base_sram_used
622
623 metrics = npu_performance.performance_metrics_for_pass(
624 self.arch,
625 ps,
626 block_config,
627 rewrite_list=rewrite_list,
628 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
629 )
630
631 res.extend(
632 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
633 sram_used, ps, block_config, metrics, rewrite_list, all_strats
634 )
635 )
636
637 self.n_combinations_searched += len(res)
638 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
639 return res
640
641 def search_ifm_streaming_partial(self, ps, block_config):
642 if ps.placement != PassPlacement.Npu:
643 return ABORT_SEARCH
644
645 if len(ps.inputs) < 1:
646 return ABORT_SEARCH
647
648 ifm_tensor = ps.ifm_tensor
649
650 if ifm_tensor is None:
651 return ABORT_SEARCH
652 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
653 return ABORT_SEARCH
654 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
655 return ABORT_SEARCH
656
657 pred_pass_list = []
658 for pred_candidate in ps.dag_predecessors:
659 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
660 # we found a predecessor that produces this IFM tensor
661 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
662 # and it only has one successor, namely us
663 if pred_candidate.placement == PassPlacement.Npu:
664 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
665 # and it is on the Npu and fusable - it's a candidate
666 pred_pass_list.append(pred_candidate)
667
668 if not pred_pass_list:
669 return ABORT_SEARCH
670
671 all_candidates = []
672 for pred_pass in pred_pass_list:
673 # recurse into the next pass
674 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
675
676 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
677 for strat_opt in strat_data:
678
679 pred_pass_block_config = strat_opt[0].block_configs[-1]
680 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
681 self.arch, pred_pass, pred_pass_block_config, ps, block_config
682 )
683 if rolling_buffer_dims is None:
684 continue # this does not pack properly, skip it.
685
686 sram_used = 0
687 for tens in ps.inputs:
688 if tens != ifm_tensor:
689 if tens.mem_area == self.mem_area:
690 sram_used += tens.storage_size()
691
692 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
693
694 rewrite_list = [
695 (
696 SchedulerRewrite.ChangeTensorSubPurpose,
697 ifm_tensor,
698 TensorSubPurpose.RollingBufferY,
699 rolling_buffer_y,
700 None,
701 ps,
702 )
703 ]
704 sram_used += ifm_tensor.storage_size_for_sub_purpose(
705 TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
706 )
707
708 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
709
710 self.n_combinations_searched += len(all_candidates)
711 return all_candidates
712
713 def get_block_configs(self, ps):
714 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100715 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100716
717 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
718
719 # Take a limited number of the largest blocks
720 if self.arch.block_config_limit > 0:
721 # Sort by block area, followed by depth
722 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
723 bound = min(len(block_configs), self.arch.block_config_limit)
724 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
725 tmp = block_configs[:bound]
726 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
727 block_configs = tmp
728
729 return block_configs
730
731 def search_ifm_streaming_input(self, ps):
732 sram_used = 0
733 for tens in ps.inputs:
734 if tens.mem_area == self.mem_area:
735 sram_used += tens.storage_size()
736
737 return self.append_sram(sram_used, self.search_predecessors(ps))
738
739 def search_weight_streaming_output(self, ps):
740 strat_data = self.search_weight_streaming_body(ps)
741
742 sram_used = self.non_local_mem_usage[ps.time]
743 for tens in ps.outputs:
744 if tens.mem_area == self.mem_area:
745 sram_used += tens.storage_size()
746
747 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
748
749 @lru_cache(maxsize=None)
750 def search_weight_streaming_body(self, ps):
751
752 strat_data = self.search_weight_streaming_input(ps)
753
754 res = []
755
756 all_block_configs = self.get_block_configs(ps)
757
758 for block_config in all_block_configs:
759
760 sram_used = 0
761 rewrite_list = []
762
763 for tens in ps.intermediates:
764 if tens.mem_area == self.mem_area:
765 if tens.purpose == TensorPurpose.Weights:
Diego Russoea6111a2020-04-14 18:41:58 +0100766 sram_used += tens.storage_size_for_sub_purpose(TensorSubPurpose.DoubleBuffer, block_config[3])
Tim Hall79d07d22020-04-27 18:20:16 +0100767 rewrite_list.append(
768 (
769 SchedulerRewrite.ChangeTensorSubPurpose,
770 tens,
771 TensorSubPurpose.DoubleBuffer,
772 block_config[3],
773 None,
774 ps,
775 )
776 )
777 else:
778 sram_used += tens.storage_size()
779
780 metrics = npu_performance.performance_metrics_for_pass(
781 self.arch, ps, block_config, rewrite_list=rewrite_list
782 )
783
784 res.extend(
785 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
786 sram_used, ps, block_config, metrics, rewrite_list, strat_data
787 )
788 )
789
790 self.n_combinations_searched += len(res)
791 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
792 return res
793
794 def search_weight_streaming_input(self, ps):
795 sram_used = 0
796 for tens in ps.inputs:
797 if tens.mem_area == self.mem_area:
798 sram_used += tens.storage_size()
799
800 return self.append_sram(sram_used, self.search_predecessors(ps))
801
802 def apply_result(self, strat_set, arch):
803 pass_to_cascaded_pass = dict()
804 for _, strat in strat_set.strats.items():
805 # rewrite the tensors that need this first. e.g. make rolling buffers
806 inputs = []
807 intermediates = []
808 outputs = []
809
810 for ps in strat.passes:
811 inputs += ps.inputs
812 intermediates += ps.intermediates
813 outputs += ps.outputs
814
815 for tens in set(inputs) & set(outputs):
816 # tensors that are in both sets are intermediates
817
818 # find pass with input/output tensor, and check if they are both placed on NPU
819 input_placement = None
820 output_placement = None
821 for ps in strat.passes:
822 if tens in ps.inputs:
823 input_placement = ps.placement
824 if tens in ps.outputs:
825 output_placement = ps.placement
826 if input_placement == output_placement == PassPlacement.Npu:
827 tens.set_format(TensorFormat.NHCWB16, arch)
828
829 intermediates.append(tens)
830 inputs.remove(tens)
831 outputs.remove(tens)
832
833 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
834 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
835 tens.mem_area = self.arch.fast_storage_mem_area
836 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
837 else:
838 assert 0, "unknown rewrite_op " + str(rewrite_op)
839
840 is_element_wise = True
841 for ps in strat.passes:
842 assert ps.placement == strat.passes[0].placement
843 if not ps.is_element_wise:
844 is_element_wise = False
845 break
846
847 cascaded_pass = CascadedPass(
848 strat.passes[0].name,
849 strat.strat,
850 inputs,
851 intermediates,
852 outputs,
853 strat.passes,
854 strat.passes[0].placement,
855 is_element_wise,
856 )
857 assert strat.sram_used >= 0
858 cascaded_pass.sram_used = strat.sram_used
859
860 for idx, ps in enumerate(strat.passes):
861 assert ps not in pass_to_cascaded_pass
862 pass_to_cascaded_pass[ps] = cascaded_pass
863 ps.cascade = cascaded_pass
864 ps.block_config = strat.block_configs[idx]
865
866 if ps.placement == PassPlacement.Npu:
867 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
868 self.arch, ps, ps.block_config
869 )
870 assert ps.shared_buffer is not None
871
872 for op in ps.ops:
873 subgraph = op.attrs.get("subgraph")
874 if subgraph:
875 subgraph.base_sram_used = cascaded_pass.sram_used
876
877 # all passes should have a cascaded pass now
878 if len(pass_to_cascaded_pass) != len(self.sg.passes):
879 print(
880 "mismatch: we have %d passes, but only %d have cascaded passes associated"
881 % (len(self.sg.passes), len(pass_to_cascaded_pass))
882 )
883 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100884 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100885 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
886
887 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100888
Tim Hall79d07d22020-04-27 18:20:16 +0100889 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200890 if self.sg.placement == PassPlacement.Cpu:
891 # Retain the pass order for CPU subgraph
892 cascaded_passes = [ps.cascade for ps in self.sg.passes]
893 else:
894 # we have all the passes, but we need to put them in order and build predecessor/successor links.
895 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100896
Charles Xu19515e82020-06-10 10:48:33 +0200897 def visit_pass(ps):
898 if ps in visit_pass_set:
899 return
900 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100901
Charles Xu19515e82020-06-10 10:48:33 +0200902 cps = ps.cascade
903 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100904
Charles Xu19515e82020-06-10 10:48:33 +0200905 for ps in cps.passes:
906 for pred in ps.predecessors:
907 if pred in dont_traverse:
908 continue
909 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100910
Charles Xu19515e82020-06-10 10:48:33 +0200911 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100912
Charles Xu19515e82020-06-10 10:48:33 +0200913 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
914 for ps in starting_passes:
915 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100916
917 # reorder so startup init cascaded passes come first
918 def is_startup_cascaded_pass(cps):
919 if not cps.passes:
920 return False
921 return cps.placement == PassPlacement.StartupInit
922
923 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
924 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
925 ]
926
927 self.sg.cascaded_passes = cascaded_passes
928 self.sg.build_cascaded_pass_links()
929
Charles Xu7b8823f2020-05-29 13:53:10 +0200930 if self.options.use_nhcwb16_between_cascaded_passes:
931 # Check if NHCWB16 can be used in between cascaded passes
932 # (NHCWB16 within cascaded passes has been handled earlier in this function)
933 if self.sg.placement == PassPlacement.Npu:
934 for ps in self.sg.cascaded_passes:
935 if ps.placement != PassPlacement.Npu:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200936 continue
Charles Xu7b8823f2020-05-29 13:53:10 +0200937 for output in ps.outputs:
938 if output.purpose != TensorPurpose.FeatureMap:
939 continue
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200940
Charles Xu7b8823f2020-05-29 13:53:10 +0200941 use_NHCWB16 = True
942 for op in output.consumer_list:
Tim Hallc30f4952020-06-15 20:47:35 +0100943 if op is None or op.type == "Reshape":
Charles Xu7b8823f2020-05-29 13:53:10 +0200944 use_NHCWB16 = False
945 else:
946 use_NHCWB16 &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200947
Charles Xu7b8823f2020-05-29 13:53:10 +0200948 if use_NHCWB16:
949 output.set_format(TensorFormat.NHCWB16, arch)
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200950
Tim Hall79d07d22020-04-27 18:20:16 +0100951
952def schedule_passes(nng, arch, options: SchedulerOptions):
953
954 for sg in nng.subgraphs:
955 sg.base_sram_used = 0
956
957 for sg in nng.subgraphs:
958 # re-entering the same nodes from different contexts requires us to
959 # build a simplified directed acyclic (DAG) version of the graph to
960 # use for traversal, rather than using a visit dictionary. this avoids
961 # recursing infinitely due to loops.
962 sg.build_pass_dag_predecessors()
963
964 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
965
966 strat_set = dps.search()
967
968 dps.apply_result(strat_set, arch)
969
970 if options.verbose_schedule:
971 sg.print_cascaded_passes()