blob: be104b88420b1b12e823a1ef964c956368e2f105 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Tim Hall79d07d22020-04-27 18:20:16 +010027from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010028from .nn_graph import CascadedPass
29from .nn_graph import PassPlacement
30from .nn_graph import SchedulerRewrite
31from .nn_graph import SchedulingStrategy
32from .npu_performance import make_bandwidth_array
33from .npu_performance import make_cycles_array
34from .npu_performance import make_macs_array
35from .npu_performance import make_metrics_arrays
36from .npu_performance import PassCycles
37from .operation import NpuBlockType
38from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
39from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
40from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020041from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010042from .tensor import TensorFormat
43from .tensor import TensorPurpose
44from .tensor import TensorSubPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
47class ParetoMetric(enum.Enum):
48 BwCycMem = 1
49 BwCycMemBlkH = 2
50
51 def __str__(self):
52 return self.name
53
54
55class SchedulerOptions:
56 def __init__(
57 self,
58 use_cascading=True,
59 use_ifm_ofm_overlap=True,
60 verbose_schedule=False,
61 verbose_pareto_frontier_schedules=False,
62 use_ifm_streaming=True,
63 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020064 use_nhcwb16_between_cascaded_passes=True,
Tim Hall79d07d22020-04-27 18:20:16 +010065 ):
66 self.use_cascading = use_cascading
67 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
68 self.verbose_schedule = verbose_schedule
69 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
70 self.use_ifm_streaming = use_ifm_streaming
71 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020072 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Tim Hall79d07d22020-04-27 18:20:16 +010073
74 def __str__(self):
75 return type(self).__name__ + ": " + str(self.__dict__)
76
77 __repr__ = __str__
78
79
80class Strategy:
81 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
82
83 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
84 self.strat = strat
85 self.param = param
86 self.passes = passes
87 self.block_configs = block_configs
88 self.rewrite_list = (
89 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
90 )
91 self.bws = bws
92 self.macs = macs
93 self.cycles = cycles
94 self.sram_used = sram_used
95
96 def __eq__(self, other):
97 if self.strat != other.strat:
98 return False
99 if self.param != other.param:
100 return False
101 if self.block_configs != other.block_configs:
102 return False
103 if self.passes != other.passes:
104 return False
105 if (self.bws != other.bws).any():
106 return False
107 if (self.macs != other.macs).any():
108 return False
109 if (self.cycles != other.cycles).any():
110 return False
111 if self.sram_used != other.sram_used:
112 return False
113 return True
114
115 def empty(self):
116 return not self.passes
117
118 def key(self):
119 return self.passes[-1]
120
121 def clone(self):
122 return Strategy(
123 self.strat,
124 self.param,
125 self.passes,
126 self.block_configs,
127 self.rewrite_list,
128 self.bws,
129 self.macs,
130 self.cycles,
131 self.sram_used,
132 )
133
134 def __str__(self):
135 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
136 self.strat,
137 self.passes,
138 self.rewrite_list,
139 self.bws,
140 self.macs,
141 self.cycles,
142 self.sram_used,
143 )
144
145 __repr__ = __str__
146
147
148class StrategySet:
149 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
150
151 def __init__(self, strats=None):
152 if strats is None:
153 strats = dict()
154 self.strats = strats # final pass in packed pass -> Strategy
155 self.bws, self.macs, self.cycles = make_metrics_arrays()
156 self.max_sram_used = 0
157 self.total_sram_used = 0
158
159 def update_statistics(self):
160 self.bws = make_bandwidth_array()
161 self.max_sram_used = 0
162 for ps, strat in self.strats.items():
163 self.bws += strat.bws
164 self.macs += strat.macs
165 self.cycles += strat.cycles
166 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
167 self.total_sram_used += strat.sram_used
168
169 def clone_add_strategy(self, new_strat):
170 key = new_strat.key()
171 if key in self.strats:
172 assert new_strat == self.strats[key]
173 return self
174 else:
175 new_strats = dict(self.strats)
176 new_strats[key] = new_strat
177 new_set = StrategySet(new_strats)
178 new_set.bws = self.bws + new_strat.bws
179 new_set.macs = self.macs + new_strat.macs
180 new_set.cycles = self.cycles + new_strat.cycles
181 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
182 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
183 return new_set
184
185 def __eq__(self, other):
186 if (self.bws != other.bws).any():
187 return False
188 if (self.macs != other.macs).any():
189 return False
190 if (self.cycles != other.cycles).any():
191 return False
192 if self.max_sram_used != other.max_sram_used:
193 return False
194 if self.total_sram_used != other.total_sram_used:
195 return False
196 if self.strats != other.strats:
197 return False
198 return True
199
200 def __str__(self):
201 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
202 self.max_sram_used,
203 list(ps.name for ps in self.strats),
204 )
205
206 __repr__ = __str__
207
208
209empty_strategy = Strategy(
210 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
211)
212INFINITY = 1e30
213
214ABORT_SEARCH = []
215
216
217def flatten_list_of_lists(lstlst):
218 lst = []
219 for v in lstlst:
220 lst.extend(v)
221 return lst
222
223
224class DynamicProgrammingScheduler:
225 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
226 self.nng = nng
227 self.sg = sg
228 self.arch = arch
229 self.sram_limit = sram_limit
230 self.options = copy.copy(options)
231 self.use_cascading = options.use_cascading
232
233 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
234 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
235 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
236
237 self.verbose_schedule = options.verbose_schedule
238 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
239 self.mem_area = MemArea.Sram
240
241 self.bandwidth_weights = arch.bandwidth_weights
242 self.cycles_weight = arch.cycles_weight
243 self.max_sram_used_weight = arch.max_sram_used_weight
244
245 self.n_combinations_searched = 0
246
247 self.feature_maps_not_in_fast_storage = (
248 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
249 )
250
251 self.pareto_max_candidates = 16
252
253 self.ifm_stream_npu_blocks = set(
254 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
255 )
256
257 num_pareto_metrics = 4
258 view_values = ",".join(["d"] * num_pareto_metrics)
259 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
260
261 def pareto_metric(self, candidate):
262 strat, strat_set = candidate
263 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
264 bws = strat.bws + strat_set.bws
265 last_block_height = 0
266 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
267 last_block_height = strat.block_configs[-1][0]
268
269 return (
270 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
271 strat_set.max_sram_used,
272 strat.sram_used,
273 last_block_height,
274 )
275
276 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
277
278 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
279
280 if len(candidates) <= 1:
281 return candidates
282 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100283 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
284 ids = np.arange(len(candidates), dtype=np.int32)
285 for idx, cand in enumerate(candidates):
286 pareto_vals[idx] = self.pareto_metric(cand)
287
288 sort_order = np.argsort(
289 pareto_vals.view(DynamicProgrammingScheduler.view_values),
290 order=DynamicProgrammingScheduler.order_values,
291 axis=0,
292 kind="stable",
293 ).flatten()
294 pareto_vals = pareto_vals[sort_order]
295 ids = ids[sort_order]
296
297 pareto_frontier = []
298 while len(ids) > 0:
299 pareto_frontier.append(candidates[ids[0]])
300 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
301 ids = ids[not_dominated_by_first]
302 pareto_vals = pareto_vals[not_dominated_by_first]
303
304 if len(pareto_frontier) > self.pareto_max_candidates:
305 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
306 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
307
308 return pareto_frontier
309
310 def candidate_metric(self, candidate):
311 strat, strat_set = candidate
312 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
313 bws = strat.bws + strat_set.bws
314 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
315
316 return (
317 max_sram_used * self.max_sram_used_weight
318 + np.tensordot(bws, self.bandwidth_weights, axes=3)
319 + total_cycles * self.cycles_weight
320 )
321
322 def sort_by_candidate_metric(self, candidate_list):
323 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
324 return sorted_list
325
326 def best_candidate(self, candidate_list):
327 if len(candidate_list) == 0:
328 return ABORT_SEARCH
329 if len(candidate_list) == 1:
330 return candidate_list[0]
331 sorted_list = self.sort_by_candidate_metric(candidate_list)
332 return sorted_list[0]
333
334 def graduate_strat(self, strat_type, sram_used, old_strat_data):
335 res = []
336 for old_strat, old_strat_set in old_strat_data:
337 if old_strat.sram_used + sram_used > self.sram_limit:
338 continue # This strategy is bad, drop it
339 if old_strat_set.max_sram_used > self.sram_limit:
340 continue # This strategy is bad, drop it
341 assert old_strat.strat == SchedulingStrategy.Unknown
342
343 new_strat = old_strat.clone()
344 new_strat.strat = strat_type
345 new_strat.sram_used = old_strat.sram_used + sram_used
346
347 if self.use_ifm_ofm_overlap:
348 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
349 new_strat.strat, new_strat.passes, new_strat.block_configs
350 )
351 new_strat.sram_used -= overlap
352
353 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
354 res.append((empty_strategy, new_strat_set))
355 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
356
357 def append_sram(self, sram_used, old_strat_data):
358 res = []
359 for old_strat, strat_set in old_strat_data:
360 assert old_strat.strat == SchedulingStrategy.Unknown
361 assert old_strat.sram_used == 0
362 new_strat = old_strat.clone()
363 new_strat.sram_used = old_strat.sram_used + sram_used
364
365 res.append((new_strat, strat_set))
366 return res
367
368 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
369 res = []
370 for old_strat, strat_set in old_strat_data:
371 assert old_strat.strat == SchedulingStrategy.Unknown
372 new_strat = old_strat.clone()
373 bws, macs, cycles = metrics[:3]
374
375 new_strat.sram_used = old_strat.sram_used + sram_used
376 new_strat.block_configs = old_strat.block_configs + [block_config]
377 new_strat.bws = old_strat.bws + bws
378 new_strat.macs = old_strat.macs + macs
379 new_strat.cycles = old_strat.cycles + cycles
380 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
381 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
382 )
383
384 res.append((new_strat, strat_set))
385 return res
386
387 def append_sram_pass_block_config_performance_metrics_rewrite_list(
388 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
389 ):
390 res = []
391 for old_strat, strat_set in old_strat_data:
392 assert old_strat.strat == SchedulingStrategy.Unknown
393 new_strat = old_strat.clone()
394 bws, macs, cycles = metrics[:3]
395 new_strat.sram_used = old_strat.sram_used + sram_used
396 new_strat.block_configs = old_strat.block_configs + [block_config]
397 new_strat.bws = old_strat.bws + bws
398 new_strat.macs = old_strat.macs + macs
399 new_strat.cycles = old_strat.cycles + cycles
400 new_strat.passes = old_strat.passes + [new_pass]
401 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
402 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
403 )
404 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
405 res.append((new_strat, strat_set))
406 return res
407
408 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
409 res = []
410 for old_strat, strat_set in old_strat_data:
411 assert old_strat.strat == SchedulingStrategy.Unknown
412 new_strat = old_strat.clone()
413 new_strat.sram_used = old_strat.sram_used + sram_used
414 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
415 res.append((new_strat, strat_set))
416 return res
417
418 def pass_to_strat(self, strat_data):
419 res = {}
420 for strat in strat_data[1].strats.values():
421 for ps in strat.passes:
422 res[ps] = strat
423 return res
424
425 def compatible_strats(self, a, b):
426 intersection = a.keys() & b.keys()
427 for k in intersection:
428 if a[k] != b[k]:
429 return False
430 return True
431
432 def collate_strats_for_passes(self, all_passes):
433 if len(all_passes) == 0:
434 return [(empty_strategy, StrategySet(dict()))]
435 if len(all_passes) == 1:
436 return all_passes[0] # save some space in the common case
437 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
438 prev_combos = [dict()]
439 for j, strand in enumerate(all_strands):
440 new_combos = []
441 for i, alt in enumerate(strand):
442 for prev in prev_combos:
443 if self.compatible_strats(prev, alt):
444 cmb = dict(prev)
445 cmb.update(all_passes[j][i][1].strats)
446 new_combos.append(cmb)
447 prev_combos = new_combos
448
449 res = []
450 for d in prev_combos:
451 s = StrategySet(d)
452 s.update_statistics()
453 res.append((empty_strategy, s))
454 return res
455
456 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
457 # get the rest of the predecessors
458 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
459 other_predecessor_data = self.search_pass_list(other_predecessors)
460
461 # pred strat data has an incomplete strategy, which we need
462 # to continue on, whereas the other ones have completed strategies.
463 # we need to merge these, but keep the incomplete strategy too.
464
465 res = []
466 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
467 all_strats = [
468 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
469 other_predecessor_data, # this one is fine to use as-is
470 ]
471 collated_strat_data = self.collate_strats_for_passes(all_strats)
472 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
473 res.extend(strat_data)
474 return res
475
476 def calc_non_local_mem_usage(self):
477 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
478 range_set = live_range.extract_live_ranges_from_passes(
479 self.sg,
480 self.mem_area,
481 mark_output_tensors_overlapping_with_input_tensors=True,
482 ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
483 )
484 range_dict = range_set.ranges
485
486 # find which ranges overlap passes but aren't input/outputs of the passes.
487 # these won't be counted by the dynamic programming search and must be counted in manually.
488 end_pos = max(ps.time for ps in self.sg.passes) + 2
489 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
490 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
491
492 for tens, rng in range_dict.items():
493 storage_size = tens.storage_size()
494 assert tens.mem_area == self.mem_area
495 mem_usage[rng.start_time : rng.end_time] += storage_size
496
497 for ps in self.sg.passes:
498 local_mem_usage = 0
499 for tens in ps.inputs + ps.outputs + ps.intermediates:
500 if tens.mem_area != self.mem_area:
501 continue
502
503 local_mem_usage += tens.storage_size()
504
505 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
506
507 self.non_local_mem_usage = non_local_mem_usage
508
509 def search(self):
510 self.calc_non_local_mem_usage()
511 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
512 strat_data = self.search_pass_list(starting_passes)
513
514 _, best_set = self.best_candidate(strat_data)
515
516 if self.verbose_pareto_frontier_schedules:
517 print(
518 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
519 % (self.n_combinations_searched, len(strat_data,))
520 )
521 for idx, (_, strat_set) in enumerate(strat_data):
522 extra = ""
523 if strat_set == best_set:
524 extra = "(Best candidate)"
525 print("Candidate", idx, extra)
526 memory_used = {MemArea.Sram: strat_set.max_sram_used}
527 stats_writer.print_performance_metrics_for_strat(
528 self.arch,
529 "",
530 strat_set.cycles,
531 strat_set.macs,
532 strat_set.bws,
533 self.nng.batch_size,
534 memory_used,
535 len(self.sg.passes),
536 len(strat_set.strats),
537 )
538
539 return best_set
540
541 def search_pass_list(self, pass_list):
542 all_strats = []
543 for ps in pass_list:
544 strat = self.search_output(ps)
545 all_strats.append(strat)
546 strat_data = self.collate_strats_for_passes(all_strats)
547 for strd in strat_data:
548 for ps in pass_list:
549 assert ps in strd[1].strats # should have strategies for everything we asked to search
550 return strat_data
551
552 def search_predecessors(self, ps):
553
554 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
555 # we have strats for all passes
556
557 pass_list = ps.dag_predecessors
558 strat_data = self.search_pass_list(pass_list)
559
560 return strat_data
561
562 @lru_cache(maxsize=None)
563 def search_output(self, ps):
564
565 assert ps in self.sg.passes
566 candidate_list = []
567
568 candidate_list.extend(self.search_weight_streaming_output(ps))
569
570 if self.options.use_ifm_streaming:
571 candidate_list.extend(self.search_ifm_streaming_output(ps))
572
573 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
574
575 if not best:
576 print(
577 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
578 % (ps.name,)
579 )
580 return self.search_predecessors(ps)
581
582 return best
583
584 def search_ifm_streaming_output(self, ps):
585 if ps.placement != PassPlacement.Npu:
586 return ABORT_SEARCH
587 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
588 return ABORT_SEARCH
589 strat_data = self.search_ifm_streaming_body(ps, False)
590
591 sram_used = self.non_local_mem_usage[ps.time]
592 for tens in ps.outputs:
593 if tens.mem_area == self.mem_area:
594 sram_used += tens.storage_size()
595
596 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
597
598 @lru_cache(maxsize=None)
599 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
600 if ps.placement != PassPlacement.Npu:
601 return ABORT_SEARCH
602 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
603 return ABORT_SEARCH
604 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
605 res = []
606
607 base_sram_used = 0
608 for tens in ps.intermediates:
609 if tens.mem_area == self.mem_area:
610 base_sram_used += tens.storage_size()
611
612 all_block_configs = self.get_block_configs(ps)
613 for block_config in all_block_configs:
614 all_strats = []
615
616 if self.use_cascading:
617 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
618
619 all_strats.extend(ifm_input_search_resuls)
620
621 rewrite_list = []
622 sram_used = base_sram_used
623
624 metrics = npu_performance.performance_metrics_for_pass(
625 self.arch,
626 ps,
627 block_config,
628 rewrite_list=rewrite_list,
629 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
630 )
631
632 res.extend(
633 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
634 sram_used, ps, block_config, metrics, rewrite_list, all_strats
635 )
636 )
637
638 self.n_combinations_searched += len(res)
639 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
640 return res
641
642 def search_ifm_streaming_partial(self, ps, block_config):
643 if ps.placement != PassPlacement.Npu:
644 return ABORT_SEARCH
645
646 if len(ps.inputs) < 1:
647 return ABORT_SEARCH
648
649 ifm_tensor = ps.ifm_tensor
650
651 if ifm_tensor is None:
652 return ABORT_SEARCH
653 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
654 return ABORT_SEARCH
655 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
656 return ABORT_SEARCH
657
658 pred_pass_list = []
659 for pred_candidate in ps.dag_predecessors:
660 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
661 # we found a predecessor that produces this IFM tensor
662 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
663 # and it only has one successor, namely us
664 if pred_candidate.placement == PassPlacement.Npu:
665 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
666 # and it is on the Npu and fusable - it's a candidate
667 pred_pass_list.append(pred_candidate)
668
669 if not pred_pass_list:
670 return ABORT_SEARCH
671
672 all_candidates = []
673 for pred_pass in pred_pass_list:
674 # recurse into the next pass
675 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
676
677 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
678 for strat_opt in strat_data:
679
680 pred_pass_block_config = strat_opt[0].block_configs[-1]
681 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
682 self.arch, pred_pass, pred_pass_block_config, ps, block_config
683 )
684 if rolling_buffer_dims is None:
685 continue # this does not pack properly, skip it.
686
687 sram_used = 0
688 for tens in ps.inputs:
689 if tens != ifm_tensor:
690 if tens.mem_area == self.mem_area:
691 sram_used += tens.storage_size()
692
693 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
694
695 rewrite_list = [
696 (
697 SchedulerRewrite.ChangeTensorSubPurpose,
698 ifm_tensor,
699 TensorSubPurpose.RollingBufferY,
700 rolling_buffer_y,
701 None,
702 ps,
703 )
704 ]
705 sram_used += ifm_tensor.storage_size_for_sub_purpose(
706 TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
707 )
708
709 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
710
711 self.n_combinations_searched += len(all_candidates)
712 return all_candidates
713
714 def get_block_configs(self, ps):
715 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100716 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100717
718 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
719
720 # Take a limited number of the largest blocks
721 if self.arch.block_config_limit > 0:
722 # Sort by block area, followed by depth
723 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
724 bound = min(len(block_configs), self.arch.block_config_limit)
725 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
726 tmp = block_configs[:bound]
727 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
728 block_configs = tmp
729
730 return block_configs
731
732 def search_ifm_streaming_input(self, ps):
733 sram_used = 0
734 for tens in ps.inputs:
735 if tens.mem_area == self.mem_area:
736 sram_used += tens.storage_size()
737
738 return self.append_sram(sram_used, self.search_predecessors(ps))
739
740 def search_weight_streaming_output(self, ps):
741 strat_data = self.search_weight_streaming_body(ps)
742
743 sram_used = self.non_local_mem_usage[ps.time]
744 for tens in ps.outputs:
745 if tens.mem_area == self.mem_area:
746 sram_used += tens.storage_size()
747
748 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
749
750 @lru_cache(maxsize=None)
751 def search_weight_streaming_body(self, ps):
752
753 strat_data = self.search_weight_streaming_input(ps)
754
755 res = []
756
757 all_block_configs = self.get_block_configs(ps)
758
759 for block_config in all_block_configs:
760
761 sram_used = 0
762 rewrite_list = []
763
764 for tens in ps.intermediates:
765 if tens.mem_area == self.mem_area:
766 if tens.purpose == TensorPurpose.Weights:
Diego Russoea6111a2020-04-14 18:41:58 +0100767 sram_used += tens.storage_size_for_sub_purpose(TensorSubPurpose.DoubleBuffer, block_config[3])
Tim Hall79d07d22020-04-27 18:20:16 +0100768 rewrite_list.append(
769 (
770 SchedulerRewrite.ChangeTensorSubPurpose,
771 tens,
772 TensorSubPurpose.DoubleBuffer,
773 block_config[3],
774 None,
775 ps,
776 )
777 )
778 else:
779 sram_used += tens.storage_size()
780
781 metrics = npu_performance.performance_metrics_for_pass(
782 self.arch, ps, block_config, rewrite_list=rewrite_list
783 )
784
785 res.extend(
786 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
787 sram_used, ps, block_config, metrics, rewrite_list, strat_data
788 )
789 )
790
791 self.n_combinations_searched += len(res)
792 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
793 return res
794
795 def search_weight_streaming_input(self, ps):
796 sram_used = 0
797 for tens in ps.inputs:
798 if tens.mem_area == self.mem_area:
799 sram_used += tens.storage_size()
800
801 return self.append_sram(sram_used, self.search_predecessors(ps))
802
803 def apply_result(self, strat_set, arch):
804 pass_to_cascaded_pass = dict()
805 for _, strat in strat_set.strats.items():
806 # rewrite the tensors that need this first. e.g. make rolling buffers
807 inputs = []
808 intermediates = []
809 outputs = []
810
811 for ps in strat.passes:
812 inputs += ps.inputs
813 intermediates += ps.intermediates
814 outputs += ps.outputs
815
816 for tens in set(inputs) & set(outputs):
817 # tensors that are in both sets are intermediates
818
819 # find pass with input/output tensor, and check if they are both placed on NPU
820 input_placement = None
821 output_placement = None
822 for ps in strat.passes:
823 if tens in ps.inputs:
824 input_placement = ps.placement
825 if tens in ps.outputs:
826 output_placement = ps.placement
827 if input_placement == output_placement == PassPlacement.Npu:
828 tens.set_format(TensorFormat.NHCWB16, arch)
829
830 intermediates.append(tens)
831 inputs.remove(tens)
832 outputs.remove(tens)
833
834 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
835 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
836 tens.mem_area = self.arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200837 tens.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100838 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
839 else:
840 assert 0, "unknown rewrite_op " + str(rewrite_op)
841
842 is_element_wise = True
843 for ps in strat.passes:
844 assert ps.placement == strat.passes[0].placement
845 if not ps.is_element_wise:
846 is_element_wise = False
847 break
848
849 cascaded_pass = CascadedPass(
850 strat.passes[0].name,
851 strat.strat,
852 inputs,
853 intermediates,
854 outputs,
855 strat.passes,
856 strat.passes[0].placement,
857 is_element_wise,
858 )
859 assert strat.sram_used >= 0
860 cascaded_pass.sram_used = strat.sram_used
861
862 for idx, ps in enumerate(strat.passes):
863 assert ps not in pass_to_cascaded_pass
864 pass_to_cascaded_pass[ps] = cascaded_pass
865 ps.cascade = cascaded_pass
866 ps.block_config = strat.block_configs[idx]
867
868 if ps.placement == PassPlacement.Npu:
869 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
870 self.arch, ps, ps.block_config
871 )
872 assert ps.shared_buffer is not None
873
874 for op in ps.ops:
875 subgraph = op.attrs.get("subgraph")
876 if subgraph:
877 subgraph.base_sram_used = cascaded_pass.sram_used
878
879 # all passes should have a cascaded pass now
880 if len(pass_to_cascaded_pass) != len(self.sg.passes):
881 print(
882 "mismatch: we have %d passes, but only %d have cascaded passes associated"
883 % (len(self.sg.passes), len(pass_to_cascaded_pass))
884 )
885 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100886 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100887 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
888
889 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100890
Tim Hall79d07d22020-04-27 18:20:16 +0100891 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200892 if self.sg.placement == PassPlacement.Cpu:
893 # Retain the pass order for CPU subgraph
894 cascaded_passes = [ps.cascade for ps in self.sg.passes]
895 else:
896 # we have all the passes, but we need to put them in order and build predecessor/successor links.
897 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100898
Charles Xu19515e82020-06-10 10:48:33 +0200899 def visit_pass(ps):
900 if ps in visit_pass_set:
901 return
902 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100903
Charles Xu19515e82020-06-10 10:48:33 +0200904 cps = ps.cascade
905 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100906
Charles Xu19515e82020-06-10 10:48:33 +0200907 for ps in cps.passes:
908 for pred in ps.predecessors:
909 if pred in dont_traverse:
910 continue
911 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100912
Charles Xu19515e82020-06-10 10:48:33 +0200913 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100914
Charles Xu19515e82020-06-10 10:48:33 +0200915 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
916 for ps in starting_passes:
917 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100918
919 # reorder so startup init cascaded passes come first
920 def is_startup_cascaded_pass(cps):
921 if not cps.passes:
922 return False
923 return cps.placement == PassPlacement.StartupInit
924
925 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
926 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
927 ]
928
929 self.sg.cascaded_passes = cascaded_passes
930 self.sg.build_cascaded_pass_links()
931
Charles Xu7b8823f2020-05-29 13:53:10 +0200932 if self.options.use_nhcwb16_between_cascaded_passes:
933 # Check if NHCWB16 can be used in between cascaded passes
934 # (NHCWB16 within cascaded passes has been handled earlier in this function)
935 if self.sg.placement == PassPlacement.Npu:
936 for ps in self.sg.cascaded_passes:
937 if ps.placement != PassPlacement.Npu:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200938 continue
Charles Xu7b8823f2020-05-29 13:53:10 +0200939 for output in ps.outputs:
940 if output.purpose != TensorPurpose.FeatureMap:
941 continue
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200942
Charles Xu7b8823f2020-05-29 13:53:10 +0200943 use_NHCWB16 = True
944 for op in output.consumer_list:
Tim Hallc30f4952020-06-15 20:47:35 +0100945 if op is None or op.type == "Reshape":
Charles Xu7b8823f2020-05-29 13:53:10 +0200946 use_NHCWB16 = False
947 else:
948 use_NHCWB16 &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200949
Charles Xu7b8823f2020-05-29 13:53:10 +0200950 if use_NHCWB16:
951 output.set_format(TensorFormat.NHCWB16, arch)
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200952
Tim Hall79d07d22020-04-27 18:20:16 +0100953
954def schedule_passes(nng, arch, options: SchedulerOptions):
955
956 for sg in nng.subgraphs:
957 sg.base_sram_used = 0
958
959 for sg in nng.subgraphs:
960 # re-entering the same nodes from different contexts requires us to
961 # build a simplified directed acyclic (DAG) version of the graph to
962 # use for traversal, rather than using a visit dictionary. this avoids
963 # recursing infinitely due to loops.
964 sg.build_pass_dag_predecessors()
965
966 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
967
968 strat_set = dps.search()
969
970 dps.apply_result(strat_set, arch)
971
972 if options.verbose_schedule:
973 sg.print_cascaded_passes()