blob: ca018d2eb9e34d8e515480e07a20ef377cc16684 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Tim Hall79d07d22020-04-27 18:20:16 +010027from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010028from .nn_graph import CascadedPass
29from .nn_graph import PassPlacement
30from .nn_graph import SchedulerRewrite
31from .nn_graph import SchedulingStrategy
32from .npu_performance import make_bandwidth_array
33from .npu_performance import make_cycles_array
34from .npu_performance import make_macs_array
35from .npu_performance import make_metrics_arrays
36from .npu_performance import PassCycles
37from .operation import NpuBlockType
38from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
39from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
40from .tensor import MemArea
41from .tensor import TensorFormat
42from .tensor import TensorPurpose
43from .tensor import TensorSubPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010044
45
46class ParetoMetric(enum.Enum):
47 BwCycMem = 1
48 BwCycMemBlkH = 2
49
50 def __str__(self):
51 return self.name
52
53
54class SchedulerOptions:
55 def __init__(
56 self,
57 use_cascading=True,
58 use_ifm_ofm_overlap=True,
59 verbose_schedule=False,
60 verbose_pareto_frontier_schedules=False,
61 use_ifm_streaming=True,
62 pareto_metric=ParetoMetric.BwCycMem,
63 ):
64 self.use_cascading = use_cascading
65 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
66 self.verbose_schedule = verbose_schedule
67 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
68 self.use_ifm_streaming = use_ifm_streaming
69 self.pareto_metric = pareto_metric
70
71 def __str__(self):
72 return type(self).__name__ + ": " + str(self.__dict__)
73
74 __repr__ = __str__
75
76
77class Strategy:
78 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
79
80 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
81 self.strat = strat
82 self.param = param
83 self.passes = passes
84 self.block_configs = block_configs
85 self.rewrite_list = (
86 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
87 )
88 self.bws = bws
89 self.macs = macs
90 self.cycles = cycles
91 self.sram_used = sram_used
92
93 def __eq__(self, other):
94 if self.strat != other.strat:
95 return False
96 if self.param != other.param:
97 return False
98 if self.block_configs != other.block_configs:
99 return False
100 if self.passes != other.passes:
101 return False
102 if (self.bws != other.bws).any():
103 return False
104 if (self.macs != other.macs).any():
105 return False
106 if (self.cycles != other.cycles).any():
107 return False
108 if self.sram_used != other.sram_used:
109 return False
110 return True
111
112 def empty(self):
113 return not self.passes
114
115 def key(self):
116 return self.passes[-1]
117
118 def clone(self):
119 return Strategy(
120 self.strat,
121 self.param,
122 self.passes,
123 self.block_configs,
124 self.rewrite_list,
125 self.bws,
126 self.macs,
127 self.cycles,
128 self.sram_used,
129 )
130
131 def __str__(self):
132 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
133 self.strat,
134 self.passes,
135 self.rewrite_list,
136 self.bws,
137 self.macs,
138 self.cycles,
139 self.sram_used,
140 )
141
142 __repr__ = __str__
143
144
145class StrategySet:
146 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
147
148 def __init__(self, strats=None):
149 if strats is None:
150 strats = dict()
151 self.strats = strats # final pass in packed pass -> Strategy
152 self.bws, self.macs, self.cycles = make_metrics_arrays()
153 self.max_sram_used = 0
154 self.total_sram_used = 0
155
156 def update_statistics(self):
157 self.bws = make_bandwidth_array()
158 self.max_sram_used = 0
159 for ps, strat in self.strats.items():
160 self.bws += strat.bws
161 self.macs += strat.macs
162 self.cycles += strat.cycles
163 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
164 self.total_sram_used += strat.sram_used
165
166 def clone_add_strategy(self, new_strat):
167 key = new_strat.key()
168 if key in self.strats:
169 assert new_strat == self.strats[key]
170 return self
171 else:
172 new_strats = dict(self.strats)
173 new_strats[key] = new_strat
174 new_set = StrategySet(new_strats)
175 new_set.bws = self.bws + new_strat.bws
176 new_set.macs = self.macs + new_strat.macs
177 new_set.cycles = self.cycles + new_strat.cycles
178 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
179 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
180 return new_set
181
182 def __eq__(self, other):
183 if (self.bws != other.bws).any():
184 return False
185 if (self.macs != other.macs).any():
186 return False
187 if (self.cycles != other.cycles).any():
188 return False
189 if self.max_sram_used != other.max_sram_used:
190 return False
191 if self.total_sram_used != other.total_sram_used:
192 return False
193 if self.strats != other.strats:
194 return False
195 return True
196
197 def __str__(self):
198 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
199 self.max_sram_used,
200 list(ps.name for ps in self.strats),
201 )
202
203 __repr__ = __str__
204
205
206empty_strategy = Strategy(
207 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
208)
209INFINITY = 1e30
210
211ABORT_SEARCH = []
212
213
214def flatten_list_of_lists(lstlst):
215 lst = []
216 for v in lstlst:
217 lst.extend(v)
218 return lst
219
220
221class DynamicProgrammingScheduler:
222 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
223 self.nng = nng
224 self.sg = sg
225 self.arch = arch
226 self.sram_limit = sram_limit
227 self.options = copy.copy(options)
228 self.use_cascading = options.use_cascading
229
230 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
231 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
232 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
233
234 self.verbose_schedule = options.verbose_schedule
235 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
236 self.mem_area = MemArea.Sram
237
238 self.bandwidth_weights = arch.bandwidth_weights
239 self.cycles_weight = arch.cycles_weight
240 self.max_sram_used_weight = arch.max_sram_used_weight
241
242 self.n_combinations_searched = 0
243
244 self.feature_maps_not_in_fast_storage = (
245 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
246 )
247
248 self.pareto_max_candidates = 16
249
250 self.ifm_stream_npu_blocks = set(
251 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
252 )
253
254 num_pareto_metrics = 4
255 view_values = ",".join(["d"] * num_pareto_metrics)
256 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
257
258 def pareto_metric(self, candidate):
259 strat, strat_set = candidate
260 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
261 bws = strat.bws + strat_set.bws
262 last_block_height = 0
263 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
264 last_block_height = strat.block_configs[-1][0]
265
266 return (
267 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
268 strat_set.max_sram_used,
269 strat.sram_used,
270 last_block_height,
271 )
272
273 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
274
275 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
276
277 if len(candidates) <= 1:
278 return candidates
279 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100280 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
281 ids = np.arange(len(candidates), dtype=np.int32)
282 for idx, cand in enumerate(candidates):
283 pareto_vals[idx] = self.pareto_metric(cand)
284
285 sort_order = np.argsort(
286 pareto_vals.view(DynamicProgrammingScheduler.view_values),
287 order=DynamicProgrammingScheduler.order_values,
288 axis=0,
289 kind="stable",
290 ).flatten()
291 pareto_vals = pareto_vals[sort_order]
292 ids = ids[sort_order]
293
294 pareto_frontier = []
295 while len(ids) > 0:
296 pareto_frontier.append(candidates[ids[0]])
297 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
298 ids = ids[not_dominated_by_first]
299 pareto_vals = pareto_vals[not_dominated_by_first]
300
301 if len(pareto_frontier) > self.pareto_max_candidates:
302 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
303 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
304
305 return pareto_frontier
306
307 def candidate_metric(self, candidate):
308 strat, strat_set = candidate
309 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
310 bws = strat.bws + strat_set.bws
311 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
312
313 return (
314 max_sram_used * self.max_sram_used_weight
315 + np.tensordot(bws, self.bandwidth_weights, axes=3)
316 + total_cycles * self.cycles_weight
317 )
318
319 def sort_by_candidate_metric(self, candidate_list):
320 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
321 return sorted_list
322
323 def best_candidate(self, candidate_list):
324 if len(candidate_list) == 0:
325 return ABORT_SEARCH
326 if len(candidate_list) == 1:
327 return candidate_list[0]
328 sorted_list = self.sort_by_candidate_metric(candidate_list)
329 return sorted_list[0]
330
331 def graduate_strat(self, strat_type, sram_used, old_strat_data):
332 res = []
333 for old_strat, old_strat_set in old_strat_data:
334 if old_strat.sram_used + sram_used > self.sram_limit:
335 continue # This strategy is bad, drop it
336 if old_strat_set.max_sram_used > self.sram_limit:
337 continue # This strategy is bad, drop it
338 assert old_strat.strat == SchedulingStrategy.Unknown
339
340 new_strat = old_strat.clone()
341 new_strat.strat = strat_type
342 new_strat.sram_used = old_strat.sram_used + sram_used
343
344 if self.use_ifm_ofm_overlap:
345 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
346 new_strat.strat, new_strat.passes, new_strat.block_configs
347 )
348 new_strat.sram_used -= overlap
349
350 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
351 res.append((empty_strategy, new_strat_set))
352 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
353
354 def append_sram(self, sram_used, old_strat_data):
355 res = []
356 for old_strat, strat_set in old_strat_data:
357 assert old_strat.strat == SchedulingStrategy.Unknown
358 assert old_strat.sram_used == 0
359 new_strat = old_strat.clone()
360 new_strat.sram_used = old_strat.sram_used + sram_used
361
362 res.append((new_strat, strat_set))
363 return res
364
365 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
366 res = []
367 for old_strat, strat_set in old_strat_data:
368 assert old_strat.strat == SchedulingStrategy.Unknown
369 new_strat = old_strat.clone()
370 bws, macs, cycles = metrics[:3]
371
372 new_strat.sram_used = old_strat.sram_used + sram_used
373 new_strat.block_configs = old_strat.block_configs + [block_config]
374 new_strat.bws = old_strat.bws + bws
375 new_strat.macs = old_strat.macs + macs
376 new_strat.cycles = old_strat.cycles + cycles
377 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
378 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
379 )
380
381 res.append((new_strat, strat_set))
382 return res
383
384 def append_sram_pass_block_config_performance_metrics_rewrite_list(
385 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
386 ):
387 res = []
388 for old_strat, strat_set in old_strat_data:
389 assert old_strat.strat == SchedulingStrategy.Unknown
390 new_strat = old_strat.clone()
391 bws, macs, cycles = metrics[:3]
392 new_strat.sram_used = old_strat.sram_used + sram_used
393 new_strat.block_configs = old_strat.block_configs + [block_config]
394 new_strat.bws = old_strat.bws + bws
395 new_strat.macs = old_strat.macs + macs
396 new_strat.cycles = old_strat.cycles + cycles
397 new_strat.passes = old_strat.passes + [new_pass]
398 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
399 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
400 )
401 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
402 res.append((new_strat, strat_set))
403 return res
404
405 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
406 res = []
407 for old_strat, strat_set in old_strat_data:
408 assert old_strat.strat == SchedulingStrategy.Unknown
409 new_strat = old_strat.clone()
410 new_strat.sram_used = old_strat.sram_used + sram_used
411 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
412 res.append((new_strat, strat_set))
413 return res
414
415 def pass_to_strat(self, strat_data):
416 res = {}
417 for strat in strat_data[1].strats.values():
418 for ps in strat.passes:
419 res[ps] = strat
420 return res
421
422 def compatible_strats(self, a, b):
423 intersection = a.keys() & b.keys()
424 for k in intersection:
425 if a[k] != b[k]:
426 return False
427 return True
428
429 def collate_strats_for_passes(self, all_passes):
430 if len(all_passes) == 0:
431 return [(empty_strategy, StrategySet(dict()))]
432 if len(all_passes) == 1:
433 return all_passes[0] # save some space in the common case
434 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
435 prev_combos = [dict()]
436 for j, strand in enumerate(all_strands):
437 new_combos = []
438 for i, alt in enumerate(strand):
439 for prev in prev_combos:
440 if self.compatible_strats(prev, alt):
441 cmb = dict(prev)
442 cmb.update(all_passes[j][i][1].strats)
443 new_combos.append(cmb)
444 prev_combos = new_combos
445
446 res = []
447 for d in prev_combos:
448 s = StrategySet(d)
449 s.update_statistics()
450 res.append((empty_strategy, s))
451 return res
452
453 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
454 # get the rest of the predecessors
455 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
456 other_predecessor_data = self.search_pass_list(other_predecessors)
457
458 # pred strat data has an incomplete strategy, which we need
459 # to continue on, whereas the other ones have completed strategies.
460 # we need to merge these, but keep the incomplete strategy too.
461
462 res = []
463 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
464 all_strats = [
465 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
466 other_predecessor_data, # this one is fine to use as-is
467 ]
468 collated_strat_data = self.collate_strats_for_passes(all_strats)
469 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
470 res.extend(strat_data)
471 return res
472
473 def calc_non_local_mem_usage(self):
474 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
475 range_set = live_range.extract_live_ranges_from_passes(
476 self.sg,
477 self.mem_area,
478 mark_output_tensors_overlapping_with_input_tensors=True,
479 ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
480 )
481 range_dict = range_set.ranges
482
483 # find which ranges overlap passes but aren't input/outputs of the passes.
484 # these won't be counted by the dynamic programming search and must be counted in manually.
485 end_pos = max(ps.time for ps in self.sg.passes) + 2
486 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
487 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
488
489 for tens, rng in range_dict.items():
490 storage_size = tens.storage_size()
491 assert tens.mem_area == self.mem_area
492 mem_usage[rng.start_time : rng.end_time] += storage_size
493
494 for ps in self.sg.passes:
495 local_mem_usage = 0
496 for tens in ps.inputs + ps.outputs + ps.intermediates:
497 if tens.mem_area != self.mem_area:
498 continue
499
500 local_mem_usage += tens.storage_size()
501
502 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
503
504 self.non_local_mem_usage = non_local_mem_usage
505
506 def search(self):
507 self.calc_non_local_mem_usage()
508 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
509 strat_data = self.search_pass_list(starting_passes)
510
511 _, best_set = self.best_candidate(strat_data)
512
513 if self.verbose_pareto_frontier_schedules:
514 print(
515 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
516 % (self.n_combinations_searched, len(strat_data,))
517 )
518 for idx, (_, strat_set) in enumerate(strat_data):
519 extra = ""
520 if strat_set == best_set:
521 extra = "(Best candidate)"
522 print("Candidate", idx, extra)
523 memory_used = {MemArea.Sram: strat_set.max_sram_used}
524 stats_writer.print_performance_metrics_for_strat(
525 self.arch,
526 "",
527 strat_set.cycles,
528 strat_set.macs,
529 strat_set.bws,
530 self.nng.batch_size,
531 memory_used,
532 len(self.sg.passes),
533 len(strat_set.strats),
534 )
535
536 return best_set
537
538 def search_pass_list(self, pass_list):
539 all_strats = []
540 for ps in pass_list:
541 strat = self.search_output(ps)
542 all_strats.append(strat)
543 strat_data = self.collate_strats_for_passes(all_strats)
544 for strd in strat_data:
545 for ps in pass_list:
546 assert ps in strd[1].strats # should have strategies for everything we asked to search
547 return strat_data
548
549 def search_predecessors(self, ps):
550
551 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
552 # we have strats for all passes
553
554 pass_list = ps.dag_predecessors
555 strat_data = self.search_pass_list(pass_list)
556
557 return strat_data
558
559 @lru_cache(maxsize=None)
560 def search_output(self, ps):
561
562 assert ps in self.sg.passes
563 candidate_list = []
564
565 candidate_list.extend(self.search_weight_streaming_output(ps))
566
567 if self.options.use_ifm_streaming:
568 candidate_list.extend(self.search_ifm_streaming_output(ps))
569
570 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
571
572 if not best:
573 print(
574 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
575 % (ps.name,)
576 )
577 return self.search_predecessors(ps)
578
579 return best
580
581 def search_ifm_streaming_output(self, ps):
582 if ps.placement != PassPlacement.Npu:
583 return ABORT_SEARCH
584 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
585 return ABORT_SEARCH
586 strat_data = self.search_ifm_streaming_body(ps, False)
587
588 sram_used = self.non_local_mem_usage[ps.time]
589 for tens in ps.outputs:
590 if tens.mem_area == self.mem_area:
591 sram_used += tens.storage_size()
592
593 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
594
595 @lru_cache(maxsize=None)
596 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
597 if ps.placement != PassPlacement.Npu:
598 return ABORT_SEARCH
599 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
600 return ABORT_SEARCH
601 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
602 res = []
603
604 base_sram_used = 0
605 for tens in ps.intermediates:
606 if tens.mem_area == self.mem_area:
607 base_sram_used += tens.storage_size()
608
609 all_block_configs = self.get_block_configs(ps)
610 for block_config in all_block_configs:
611 all_strats = []
612
613 if self.use_cascading:
614 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
615
616 all_strats.extend(ifm_input_search_resuls)
617
618 rewrite_list = []
619 sram_used = base_sram_used
620
621 metrics = npu_performance.performance_metrics_for_pass(
622 self.arch,
623 ps,
624 block_config,
625 rewrite_list=rewrite_list,
626 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
627 )
628
629 res.extend(
630 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
631 sram_used, ps, block_config, metrics, rewrite_list, all_strats
632 )
633 )
634
635 self.n_combinations_searched += len(res)
636 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
637 return res
638
639 def search_ifm_streaming_partial(self, ps, block_config):
640 if ps.placement != PassPlacement.Npu:
641 return ABORT_SEARCH
642
643 if len(ps.inputs) < 1:
644 return ABORT_SEARCH
645
646 ifm_tensor = ps.ifm_tensor
647
648 if ifm_tensor is None:
649 return ABORT_SEARCH
650 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
651 return ABORT_SEARCH
652 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
653 return ABORT_SEARCH
654
655 pred_pass_list = []
656 for pred_candidate in ps.dag_predecessors:
657 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
658 # we found a predecessor that produces this IFM tensor
659 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
660 # and it only has one successor, namely us
661 if pred_candidate.placement == PassPlacement.Npu:
662 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
663 # and it is on the Npu and fusable - it's a candidate
664 pred_pass_list.append(pred_candidate)
665
666 if not pred_pass_list:
667 return ABORT_SEARCH
668
669 all_candidates = []
670 for pred_pass in pred_pass_list:
671 # recurse into the next pass
672 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
673
674 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
675 for strat_opt in strat_data:
676
677 pred_pass_block_config = strat_opt[0].block_configs[-1]
678 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
679 self.arch, pred_pass, pred_pass_block_config, ps, block_config
680 )
681 if rolling_buffer_dims is None:
682 continue # this does not pack properly, skip it.
683
684 sram_used = 0
685 for tens in ps.inputs:
686 if tens != ifm_tensor:
687 if tens.mem_area == self.mem_area:
688 sram_used += tens.storage_size()
689
690 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
691
692 rewrite_list = [
693 (
694 SchedulerRewrite.ChangeTensorSubPurpose,
695 ifm_tensor,
696 TensorSubPurpose.RollingBufferY,
697 rolling_buffer_y,
698 None,
699 ps,
700 )
701 ]
702 sram_used += ifm_tensor.storage_size_for_sub_purpose(
703 TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
704 )
705
706 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
707
708 self.n_combinations_searched += len(all_candidates)
709 return all_candidates
710
711 def get_block_configs(self, ps):
712 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100713 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100714
715 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
716
717 # Take a limited number of the largest blocks
718 if self.arch.block_config_limit > 0:
719 # Sort by block area, followed by depth
720 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
721 bound = min(len(block_configs), self.arch.block_config_limit)
722 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
723 tmp = block_configs[:bound]
724 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
725 block_configs = tmp
726
727 return block_configs
728
729 def search_ifm_streaming_input(self, ps):
730 sram_used = 0
731 for tens in ps.inputs:
732 if tens.mem_area == self.mem_area:
733 sram_used += tens.storage_size()
734
735 return self.append_sram(sram_used, self.search_predecessors(ps))
736
737 def search_weight_streaming_output(self, ps):
738 strat_data = self.search_weight_streaming_body(ps)
739
740 sram_used = self.non_local_mem_usage[ps.time]
741 for tens in ps.outputs:
742 if tens.mem_area == self.mem_area:
743 sram_used += tens.storage_size()
744
745 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
746
747 @lru_cache(maxsize=None)
748 def search_weight_streaming_body(self, ps):
749
750 strat_data = self.search_weight_streaming_input(ps)
751
752 res = []
753
754 all_block_configs = self.get_block_configs(ps)
755
756 for block_config in all_block_configs:
757
758 sram_used = 0
759 rewrite_list = []
760
761 for tens in ps.intermediates:
762 if tens.mem_area == self.mem_area:
763 if tens.purpose == TensorPurpose.Weights:
Diego Russoea6111a2020-04-14 18:41:58 +0100764 sram_used += tens.storage_size_for_sub_purpose(TensorSubPurpose.DoubleBuffer, block_config[3])
Tim Hall79d07d22020-04-27 18:20:16 +0100765 rewrite_list.append(
766 (
767 SchedulerRewrite.ChangeTensorSubPurpose,
768 tens,
769 TensorSubPurpose.DoubleBuffer,
770 block_config[3],
771 None,
772 ps,
773 )
774 )
775 else:
776 sram_used += tens.storage_size()
777
778 metrics = npu_performance.performance_metrics_for_pass(
779 self.arch, ps, block_config, rewrite_list=rewrite_list
780 )
781
782 res.extend(
783 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
784 sram_used, ps, block_config, metrics, rewrite_list, strat_data
785 )
786 )
787
788 self.n_combinations_searched += len(res)
789 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
790 return res
791
792 def search_weight_streaming_input(self, ps):
793 sram_used = 0
794 for tens in ps.inputs:
795 if tens.mem_area == self.mem_area:
796 sram_used += tens.storage_size()
797
798 return self.append_sram(sram_used, self.search_predecessors(ps))
799
800 def apply_result(self, strat_set, arch):
801 pass_to_cascaded_pass = dict()
802 for _, strat in strat_set.strats.items():
803 # rewrite the tensors that need this first. e.g. make rolling buffers
804 inputs = []
805 intermediates = []
806 outputs = []
807
808 for ps in strat.passes:
809 inputs += ps.inputs
810 intermediates += ps.intermediates
811 outputs += ps.outputs
812
813 for tens in set(inputs) & set(outputs):
814 # tensors that are in both sets are intermediates
815
816 # find pass with input/output tensor, and check if they are both placed on NPU
817 input_placement = None
818 output_placement = None
819 for ps in strat.passes:
820 if tens in ps.inputs:
821 input_placement = ps.placement
822 if tens in ps.outputs:
823 output_placement = ps.placement
824 if input_placement == output_placement == PassPlacement.Npu:
825 tens.set_format(TensorFormat.NHCWB16, arch)
826
827 intermediates.append(tens)
828 inputs.remove(tens)
829 outputs.remove(tens)
830
831 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
832 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
833 tens.mem_area = self.arch.fast_storage_mem_area
834 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
835 else:
836 assert 0, "unknown rewrite_op " + str(rewrite_op)
837
838 is_element_wise = True
839 for ps in strat.passes:
840 assert ps.placement == strat.passes[0].placement
841 if not ps.is_element_wise:
842 is_element_wise = False
843 break
844
845 cascaded_pass = CascadedPass(
846 strat.passes[0].name,
847 strat.strat,
848 inputs,
849 intermediates,
850 outputs,
851 strat.passes,
852 strat.passes[0].placement,
853 is_element_wise,
854 )
855 assert strat.sram_used >= 0
856 cascaded_pass.sram_used = strat.sram_used
857
858 for idx, ps in enumerate(strat.passes):
859 assert ps not in pass_to_cascaded_pass
860 pass_to_cascaded_pass[ps] = cascaded_pass
861 ps.cascade = cascaded_pass
862 ps.block_config = strat.block_configs[idx]
863
864 if ps.placement == PassPlacement.Npu:
865 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
866 self.arch, ps, ps.block_config
867 )
868 assert ps.shared_buffer is not None
869
870 for op in ps.ops:
871 subgraph = op.attrs.get("subgraph")
872 if subgraph:
873 subgraph.base_sram_used = cascaded_pass.sram_used
874
875 # all passes should have a cascaded pass now
876 if len(pass_to_cascaded_pass) != len(self.sg.passes):
877 print(
878 "mismatch: we have %d passes, but only %d have cascaded passes associated"
879 % (len(self.sg.passes), len(pass_to_cascaded_pass))
880 )
881 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100882 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100883 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
884
885 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
886 # we have all the passes, but we need to put them in order and build predecessor/successor links.
887
888 visit_pass_set = set()
889 cascaded_passes = []
890
891 def visit_pass(ps):
892 if ps in visit_pass_set:
893 return
894 visit_pass_set.add(ps)
895
896 cps = ps.cascade
897 dont_traverse = set(cps.passes)
898
899 for ps in cps.passes:
900 for pred in ps.predecessors:
901 if pred in dont_traverse:
902 continue
903 visit_pass(pred)
904
905 cascaded_passes.append(cps)
906
907 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
908 for ps in starting_passes:
909 visit_pass(ps)
910
911 # reorder so startup init cascaded passes come first
912 def is_startup_cascaded_pass(cps):
913 if not cps.passes:
914 return False
915 return cps.placement == PassPlacement.StartupInit
916
917 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
918 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
919 ]
920
921 self.sg.cascaded_passes = cascaded_passes
922 self.sg.build_cascaded_pass_links()
923
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200924 # Check if NHCWB16 can be used in between cascaded passes
925 # (NHCWB16 within cascaded passes has been handled earlier in this function)
926 if self.sg.placement == PassPlacement.Npu:
927 for ps in self.sg.cascaded_passes:
928 if ps.placement != PassPlacement.Npu:
929 continue
930 for output in ps.outputs:
931 if output.purpose != TensorPurpose.FeatureMap:
932 continue
933
934 use_NHCWB16 = True
935 for op in output.consumer_list:
936 if op == None or op.type == 'Reshape':
937 use_NHCWB16 = False
938 else:
939 use_NHCWB16 &= op.run_on_npu
940
941 if use_NHCWB16:
942 output.set_format(TensorFormat.NHCWB16, arch)
943
Tim Hall79d07d22020-04-27 18:20:16 +0100944
945def schedule_passes(nng, arch, options: SchedulerOptions):
946
947 for sg in nng.subgraphs:
948 sg.base_sram_used = 0
949
950 for sg in nng.subgraphs:
951 # re-entering the same nodes from different contexts requires us to
952 # build a simplified directed acyclic (DAG) version of the graph to
953 # use for traversal, rather than using a visit dictionary. this avoids
954 # recursing infinitely due to loops.
955 sg.build_pass_dag_predecessors()
956
957 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
958
959 strat_set = dps.search()
960
961 dps.apply_result(strat_set, arch)
962
963 if options.verbose_schedule:
964 sg.print_cascaded_passes()