blob: 41902d678adc4f140841161d09265483a0010230 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Fredrik Svedberg880e7352020-08-25 11:31:47 +020027from .data_type import DataType
Tim Hall79d07d22020-04-27 18:20:16 +010028from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010029from .nn_graph import CascadedPass
30from .nn_graph import PassPlacement
31from .nn_graph import SchedulerRewrite
32from .nn_graph import SchedulingStrategy
33from .npu_performance import make_bandwidth_array
34from .npu_performance import make_cycles_array
35from .npu_performance import make_macs_array
36from .npu_performance import make_metrics_arrays
37from .npu_performance import PassCycles
38from .operation import NpuBlockType
39from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
40from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
41from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020042from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010043from .tensor import TensorFormat
44from .tensor import TensorPurpose
45from .tensor import TensorSubPurpose
Tim Hallba695182020-08-26 17:27:19 +010046from .numeric_util import full_shape
Tim Hall79d07d22020-04-27 18:20:16 +010047
48class ParetoMetric(enum.Enum):
49 BwCycMem = 1
50 BwCycMemBlkH = 2
51
52 def __str__(self):
53 return self.name
54
55
56class SchedulerOptions:
57 def __init__(
58 self,
59 use_cascading=True,
60 use_ifm_ofm_overlap=True,
61 verbose_schedule=False,
62 verbose_pareto_frontier_schedules=False,
63 use_ifm_streaming=True,
64 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020065 use_nhcwb16_between_cascaded_passes=True,
Tim Hall79d07d22020-04-27 18:20:16 +010066 ):
67 self.use_cascading = use_cascading
68 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
69 self.verbose_schedule = verbose_schedule
70 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
71 self.use_ifm_streaming = use_ifm_streaming
72 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020073 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Tim Hall79d07d22020-04-27 18:20:16 +010074
75 def __str__(self):
76 return type(self).__name__ + ": " + str(self.__dict__)
77
78 __repr__ = __str__
79
80
81class Strategy:
82 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
83
84 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
85 self.strat = strat
86 self.param = param
87 self.passes = passes
88 self.block_configs = block_configs
89 self.rewrite_list = (
90 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
91 )
92 self.bws = bws
93 self.macs = macs
94 self.cycles = cycles
95 self.sram_used = sram_used
96
97 def __eq__(self, other):
98 if self.strat != other.strat:
99 return False
100 if self.param != other.param:
101 return False
102 if self.block_configs != other.block_configs:
103 return False
104 if self.passes != other.passes:
105 return False
106 if (self.bws != other.bws).any():
107 return False
108 if (self.macs != other.macs).any():
109 return False
110 if (self.cycles != other.cycles).any():
111 return False
112 if self.sram_used != other.sram_used:
113 return False
114 return True
115
116 def empty(self):
117 return not self.passes
118
119 def key(self):
120 return self.passes[-1]
121
122 def clone(self):
123 return Strategy(
124 self.strat,
125 self.param,
126 self.passes,
127 self.block_configs,
128 self.rewrite_list,
129 self.bws,
130 self.macs,
131 self.cycles,
132 self.sram_used,
133 )
134
135 def __str__(self):
136 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
137 self.strat,
138 self.passes,
139 self.rewrite_list,
140 self.bws,
141 self.macs,
142 self.cycles,
143 self.sram_used,
144 )
145
146 __repr__ = __str__
147
148
149class StrategySet:
150 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
151
152 def __init__(self, strats=None):
153 if strats is None:
154 strats = dict()
155 self.strats = strats # final pass in packed pass -> Strategy
156 self.bws, self.macs, self.cycles = make_metrics_arrays()
157 self.max_sram_used = 0
158 self.total_sram_used = 0
159
160 def update_statistics(self):
161 self.bws = make_bandwidth_array()
162 self.max_sram_used = 0
163 for ps, strat in self.strats.items():
164 self.bws += strat.bws
165 self.macs += strat.macs
166 self.cycles += strat.cycles
167 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
168 self.total_sram_used += strat.sram_used
169
170 def clone_add_strategy(self, new_strat):
171 key = new_strat.key()
172 if key in self.strats:
173 assert new_strat == self.strats[key]
174 return self
175 else:
176 new_strats = dict(self.strats)
177 new_strats[key] = new_strat
178 new_set = StrategySet(new_strats)
179 new_set.bws = self.bws + new_strat.bws
180 new_set.macs = self.macs + new_strat.macs
181 new_set.cycles = self.cycles + new_strat.cycles
182 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
183 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
184 return new_set
185
186 def __eq__(self, other):
187 if (self.bws != other.bws).any():
188 return False
189 if (self.macs != other.macs).any():
190 return False
191 if (self.cycles != other.cycles).any():
192 return False
193 if self.max_sram_used != other.max_sram_used:
194 return False
195 if self.total_sram_used != other.total_sram_used:
196 return False
197 if self.strats != other.strats:
198 return False
199 return True
200
201 def __str__(self):
202 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
203 self.max_sram_used,
204 list(ps.name for ps in self.strats),
205 )
206
207 __repr__ = __str__
208
209
210empty_strategy = Strategy(
211 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
212)
213INFINITY = 1e30
214
215ABORT_SEARCH = []
216
217
218def flatten_list_of_lists(lstlst):
219 lst = []
220 for v in lstlst:
221 lst.extend(v)
222 return lst
223
224
225class DynamicProgrammingScheduler:
226 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
227 self.nng = nng
228 self.sg = sg
229 self.arch = arch
230 self.sram_limit = sram_limit
231 self.options = copy.copy(options)
232 self.use_cascading = options.use_cascading
233
234 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
235 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200236 else:
237 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
Tim Hall79d07d22020-04-27 18:20:16 +0100238
239 self.verbose_schedule = options.verbose_schedule
240 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
241 self.mem_area = MemArea.Sram
242
243 self.bandwidth_weights = arch.bandwidth_weights
244 self.cycles_weight = arch.cycles_weight
245 self.max_sram_used_weight = arch.max_sram_used_weight
246
247 self.n_combinations_searched = 0
248
249 self.feature_maps_not_in_fast_storage = (
250 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
251 )
252
253 self.pareto_max_candidates = 16
254
255 self.ifm_stream_npu_blocks = set(
256 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
257 )
258
259 num_pareto_metrics = 4
260 view_values = ",".join(["d"] * num_pareto_metrics)
261 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
262
263 def pareto_metric(self, candidate):
264 strat, strat_set = candidate
265 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
266 bws = strat.bws + strat_set.bws
267 last_block_height = 0
268 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
269 last_block_height = strat.block_configs[-1][0]
270
271 return (
272 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
273 strat_set.max_sram_used,
274 strat.sram_used,
275 last_block_height,
276 )
277
278 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
279
280 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
281
282 if len(candidates) <= 1:
283 return candidates
284 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100285 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
286 ids = np.arange(len(candidates), dtype=np.int32)
287 for idx, cand in enumerate(candidates):
288 pareto_vals[idx] = self.pareto_metric(cand)
289
290 sort_order = np.argsort(
291 pareto_vals.view(DynamicProgrammingScheduler.view_values),
292 order=DynamicProgrammingScheduler.order_values,
293 axis=0,
294 kind="stable",
295 ).flatten()
296 pareto_vals = pareto_vals[sort_order]
297 ids = ids[sort_order]
298
299 pareto_frontier = []
300 while len(ids) > 0:
301 pareto_frontier.append(candidates[ids[0]])
302 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
303 ids = ids[not_dominated_by_first]
304 pareto_vals = pareto_vals[not_dominated_by_first]
305
306 if len(pareto_frontier) > self.pareto_max_candidates:
307 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
308 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
309
310 return pareto_frontier
311
312 def candidate_metric(self, candidate):
313 strat, strat_set = candidate
314 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
315 bws = strat.bws + strat_set.bws
316 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
317
318 return (
319 max_sram_used * self.max_sram_used_weight
320 + np.tensordot(bws, self.bandwidth_weights, axes=3)
321 + total_cycles * self.cycles_weight
322 )
323
324 def sort_by_candidate_metric(self, candidate_list):
325 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
326 return sorted_list
327
328 def best_candidate(self, candidate_list):
329 if len(candidate_list) == 0:
330 return ABORT_SEARCH
331 if len(candidate_list) == 1:
332 return candidate_list[0]
333 sorted_list = self.sort_by_candidate_metric(candidate_list)
334 return sorted_list[0]
335
336 def graduate_strat(self, strat_type, sram_used, old_strat_data):
337 res = []
338 for old_strat, old_strat_set in old_strat_data:
339 if old_strat.sram_used + sram_used > self.sram_limit:
340 continue # This strategy is bad, drop it
341 if old_strat_set.max_sram_used > self.sram_limit:
342 continue # This strategy is bad, drop it
343 assert old_strat.strat == SchedulingStrategy.Unknown
344
345 new_strat = old_strat.clone()
346 new_strat.strat = strat_type
347 new_strat.sram_used = old_strat.sram_used + sram_used
348
349 if self.use_ifm_ofm_overlap:
350 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
351 new_strat.strat, new_strat.passes, new_strat.block_configs
352 )
353 new_strat.sram_used -= overlap
354
355 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
356 res.append((empty_strategy, new_strat_set))
357 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
358
359 def append_sram(self, sram_used, old_strat_data):
360 res = []
361 for old_strat, strat_set in old_strat_data:
362 assert old_strat.strat == SchedulingStrategy.Unknown
363 assert old_strat.sram_used == 0
364 new_strat = old_strat.clone()
365 new_strat.sram_used = old_strat.sram_used + sram_used
366
367 res.append((new_strat, strat_set))
368 return res
369
370 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
371 res = []
372 for old_strat, strat_set in old_strat_data:
373 assert old_strat.strat == SchedulingStrategy.Unknown
374 new_strat = old_strat.clone()
375 bws, macs, cycles = metrics[:3]
376
377 new_strat.sram_used = old_strat.sram_used + sram_used
378 new_strat.block_configs = old_strat.block_configs + [block_config]
379 new_strat.bws = old_strat.bws + bws
380 new_strat.macs = old_strat.macs + macs
381 new_strat.cycles = old_strat.cycles + cycles
382 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
383 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
384 )
385
386 res.append((new_strat, strat_set))
387 return res
388
389 def append_sram_pass_block_config_performance_metrics_rewrite_list(
390 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
391 ):
392 res = []
393 for old_strat, strat_set in old_strat_data:
394 assert old_strat.strat == SchedulingStrategy.Unknown
395 new_strat = old_strat.clone()
396 bws, macs, cycles = metrics[:3]
397 new_strat.sram_used = old_strat.sram_used + sram_used
398 new_strat.block_configs = old_strat.block_configs + [block_config]
399 new_strat.bws = old_strat.bws + bws
400 new_strat.macs = old_strat.macs + macs
401 new_strat.cycles = old_strat.cycles + cycles
402 new_strat.passes = old_strat.passes + [new_pass]
403 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
404 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
405 )
406 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
407 res.append((new_strat, strat_set))
408 return res
409
410 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
411 res = []
412 for old_strat, strat_set in old_strat_data:
413 assert old_strat.strat == SchedulingStrategy.Unknown
414 new_strat = old_strat.clone()
415 new_strat.sram_used = old_strat.sram_used + sram_used
416 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
417 res.append((new_strat, strat_set))
418 return res
419
420 def pass_to_strat(self, strat_data):
421 res = {}
422 for strat in strat_data[1].strats.values():
423 for ps in strat.passes:
424 res[ps] = strat
425 return res
426
427 def compatible_strats(self, a, b):
428 intersection = a.keys() & b.keys()
429 for k in intersection:
430 if a[k] != b[k]:
431 return False
432 return True
433
434 def collate_strats_for_passes(self, all_passes):
435 if len(all_passes) == 0:
436 return [(empty_strategy, StrategySet(dict()))]
437 if len(all_passes) == 1:
438 return all_passes[0] # save some space in the common case
439 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
440 prev_combos = [dict()]
441 for j, strand in enumerate(all_strands):
442 new_combos = []
443 for i, alt in enumerate(strand):
444 for prev in prev_combos:
445 if self.compatible_strats(prev, alt):
446 cmb = dict(prev)
447 cmb.update(all_passes[j][i][1].strats)
448 new_combos.append(cmb)
449 prev_combos = new_combos
450
451 res = []
452 for d in prev_combos:
453 s = StrategySet(d)
454 s.update_statistics()
455 res.append((empty_strategy, s))
456 return res
457
458 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
459 # get the rest of the predecessors
460 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
461 other_predecessor_data = self.search_pass_list(other_predecessors)
462
463 # pred strat data has an incomplete strategy, which we need
464 # to continue on, whereas the other ones have completed strategies.
465 # we need to merge these, but keep the incomplete strategy too.
466
467 res = []
468 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
469 all_strats = [
470 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
471 other_predecessor_data, # this one is fine to use as-is
472 ]
473 collated_strat_data = self.collate_strats_for_passes(all_strats)
474 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
475 res.extend(strat_data)
476 return res
477
478 def calc_non_local_mem_usage(self):
479 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
480 range_set = live_range.extract_live_ranges_from_passes(
481 self.sg,
482 self.mem_area,
483 mark_output_tensors_overlapping_with_input_tensors=True,
484 ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
485 )
486 range_dict = range_set.ranges
487
488 # find which ranges overlap passes but aren't input/outputs of the passes.
489 # these won't be counted by the dynamic programming search and must be counted in manually.
490 end_pos = max(ps.time for ps in self.sg.passes) + 2
491 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
492 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
493
494 for tens, rng in range_dict.items():
495 storage_size = tens.storage_size()
496 assert tens.mem_area == self.mem_area
497 mem_usage[rng.start_time : rng.end_time] += storage_size
498
499 for ps in self.sg.passes:
500 local_mem_usage = 0
501 for tens in ps.inputs + ps.outputs + ps.intermediates:
502 if tens.mem_area != self.mem_area:
503 continue
504
505 local_mem_usage += tens.storage_size()
506
507 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
508
509 self.non_local_mem_usage = non_local_mem_usage
510
511 def search(self):
512 self.calc_non_local_mem_usage()
513 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
514 strat_data = self.search_pass_list(starting_passes)
515
516 _, best_set = self.best_candidate(strat_data)
517
518 if self.verbose_pareto_frontier_schedules:
519 print(
520 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
521 % (self.n_combinations_searched, len(strat_data,))
522 )
523 for idx, (_, strat_set) in enumerate(strat_data):
524 extra = ""
525 if strat_set == best_set:
526 extra = "(Best candidate)"
527 print("Candidate", idx, extra)
528 memory_used = {MemArea.Sram: strat_set.max_sram_used}
529 stats_writer.print_performance_metrics_for_strat(
530 self.arch,
531 "",
532 strat_set.cycles,
533 strat_set.macs,
534 strat_set.bws,
535 self.nng.batch_size,
536 memory_used,
537 len(self.sg.passes),
538 len(strat_set.strats),
539 )
540
541 return best_set
542
543 def search_pass_list(self, pass_list):
544 all_strats = []
545 for ps in pass_list:
546 strat = self.search_output(ps)
547 all_strats.append(strat)
548 strat_data = self.collate_strats_for_passes(all_strats)
549 for strd in strat_data:
550 for ps in pass_list:
551 assert ps in strd[1].strats # should have strategies for everything we asked to search
552 return strat_data
553
554 def search_predecessors(self, ps):
555
556 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
557 # we have strats for all passes
558
559 pass_list = ps.dag_predecessors
560 strat_data = self.search_pass_list(pass_list)
561
562 return strat_data
563
564 @lru_cache(maxsize=None)
565 def search_output(self, ps):
566
567 assert ps in self.sg.passes
568 candidate_list = []
569
570 candidate_list.extend(self.search_weight_streaming_output(ps))
571
572 if self.options.use_ifm_streaming:
573 candidate_list.extend(self.search_ifm_streaming_output(ps))
574
575 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
576
577 if not best:
578 print(
579 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
580 % (ps.name,)
581 )
582 return self.search_predecessors(ps)
583
584 return best
585
586 def search_ifm_streaming_output(self, ps):
587 if ps.placement != PassPlacement.Npu:
588 return ABORT_SEARCH
589 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
590 return ABORT_SEARCH
591 strat_data = self.search_ifm_streaming_body(ps, False)
592
593 sram_used = self.non_local_mem_usage[ps.time]
594 for tens in ps.outputs:
595 if tens.mem_area == self.mem_area:
596 sram_used += tens.storage_size()
597
598 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
599
600 @lru_cache(maxsize=None)
601 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
602 if ps.placement != PassPlacement.Npu:
603 return ABORT_SEARCH
604 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
605 return ABORT_SEARCH
606 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
607 res = []
608
609 base_sram_used = 0
610 for tens in ps.intermediates:
611 if tens.mem_area == self.mem_area:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200612 if tens.purpose == TensorPurpose.Weights:
613 base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
614 else:
615 base_sram_used += tens.storage_size()
Tim Hall79d07d22020-04-27 18:20:16 +0100616
617 all_block_configs = self.get_block_configs(ps)
618 for block_config in all_block_configs:
619 all_strats = []
620
621 if self.use_cascading:
622 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
623
624 all_strats.extend(ifm_input_search_resuls)
625
626 rewrite_list = []
627 sram_used = base_sram_used
628
629 metrics = npu_performance.performance_metrics_for_pass(
630 self.arch,
631 ps,
632 block_config,
633 rewrite_list=rewrite_list,
634 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
635 )
636
637 res.extend(
638 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
639 sram_used, ps, block_config, metrics, rewrite_list, all_strats
640 )
641 )
642
643 self.n_combinations_searched += len(res)
644 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
645 return res
646
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200647 def avoid_for_spilling(self, pred_candidate):
648 if self.arch.feature_map_storage_mem_area == self.arch.fast_storage_mem_area:
649 return False
650
651 # For SRAM spilling, concat op is avoided as predecessor
652 for op in pred_candidate.ops:
653 if op.type == "ConcatSliceWrite":
654 return True
655 return False
656
Tim Hall79d07d22020-04-27 18:20:16 +0100657 def search_ifm_streaming_partial(self, ps, block_config):
658 if ps.placement != PassPlacement.Npu:
659 return ABORT_SEARCH
660
661 if len(ps.inputs) < 1:
662 return ABORT_SEARCH
663
664 ifm_tensor = ps.ifm_tensor
665
666 if ifm_tensor is None:
667 return ABORT_SEARCH
668 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
669 return ABORT_SEARCH
670 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
671 return ABORT_SEARCH
672
673 pred_pass_list = []
674 for pred_candidate in ps.dag_predecessors:
675 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
676 # we found a predecessor that produces this IFM tensor
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200677 if not ifm_tensor.avoid_NHCWB16:
678 # and NHCWB16 format is not to be avoided
679 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
680 # and it only has one successor, namely us
681 if pred_candidate.placement == PassPlacement.Npu:
682 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
683 # and it is on the Npu
684 if not self.avoid_for_spilling(pred_candidate):
685 # and fusable - it's a candidate
686 pred_pass_list.append(pred_candidate)
Tim Hall79d07d22020-04-27 18:20:16 +0100687
688 if not pred_pass_list:
689 return ABORT_SEARCH
690
691 all_candidates = []
692 for pred_pass in pred_pass_list:
693 # recurse into the next pass
694 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
695
696 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
697 for strat_opt in strat_data:
698
699 pred_pass_block_config = strat_opt[0].block_configs[-1]
700 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
701 self.arch, pred_pass, pred_pass_block_config, ps, block_config
702 )
703 if rolling_buffer_dims is None:
704 continue # this does not pack properly, skip it.
705
706 sram_used = 0
707 for tens in ps.inputs:
708 if tens != ifm_tensor:
709 if tens.mem_area == self.mem_area:
710 sram_used += tens.storage_size()
711
712 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
713
714 rewrite_list = [
715 (
716 SchedulerRewrite.ChangeTensorSubPurpose,
717 ifm_tensor,
718 TensorSubPurpose.RollingBufferY,
719 rolling_buffer_y,
720 None,
721 ps,
722 )
723 ]
724 sram_used += ifm_tensor.storage_size_for_sub_purpose(
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200725 self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
Tim Hall79d07d22020-04-27 18:20:16 +0100726 )
727
728 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
729
730 self.n_combinations_searched += len(all_candidates)
731 return all_candidates
732
733 def get_block_configs(self, ps):
734 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100735 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100736
737 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
738
739 # Take a limited number of the largest blocks
740 if self.arch.block_config_limit > 0:
741 # Sort by block area, followed by depth
742 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
743 bound = min(len(block_configs), self.arch.block_config_limit)
744 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
745 tmp = block_configs[:bound]
746 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
747 block_configs = tmp
748
749 return block_configs
750
751 def search_ifm_streaming_input(self, ps):
752 sram_used = 0
753 for tens in ps.inputs:
754 if tens.mem_area == self.mem_area:
755 sram_used += tens.storage_size()
756
757 return self.append_sram(sram_used, self.search_predecessors(ps))
758
759 def search_weight_streaming_output(self, ps):
760 strat_data = self.search_weight_streaming_body(ps)
761
762 sram_used = self.non_local_mem_usage[ps.time]
763 for tens in ps.outputs:
764 if tens.mem_area == self.mem_area:
765 sram_used += tens.storage_size()
766
767 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
768
769 @lru_cache(maxsize=None)
770 def search_weight_streaming_body(self, ps):
771
772 strat_data = self.search_weight_streaming_input(ps)
773
774 res = []
775
776 all_block_configs = self.get_block_configs(ps)
777
778 for block_config in all_block_configs:
779
780 sram_used = 0
781 rewrite_list = []
782
783 for tens in ps.intermediates:
784 if tens.mem_area == self.mem_area:
785 if tens.purpose == TensorPurpose.Weights:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200786 sram_used += tens.storage_size_for_sub_purpose(
787 self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
788 )
Tim Hall79d07d22020-04-27 18:20:16 +0100789 rewrite_list.append(
790 (
791 SchedulerRewrite.ChangeTensorSubPurpose,
792 tens,
793 TensorSubPurpose.DoubleBuffer,
794 block_config[3],
795 None,
796 ps,
797 )
798 )
799 else:
800 sram_used += tens.storage_size()
801
802 metrics = npu_performance.performance_metrics_for_pass(
803 self.arch, ps, block_config, rewrite_list=rewrite_list
804 )
805
806 res.extend(
807 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
808 sram_used, ps, block_config, metrics, rewrite_list, strat_data
809 )
810 )
811
812 self.n_combinations_searched += len(res)
813 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
814 return res
815
816 def search_weight_streaming_input(self, ps):
817 sram_used = 0
818 for tens in ps.inputs:
819 if tens.mem_area == self.mem_area:
820 sram_used += tens.storage_size()
821
822 return self.append_sram(sram_used, self.search_predecessors(ps))
823
824 def apply_result(self, strat_set, arch):
825 pass_to_cascaded_pass = dict()
826 for _, strat in strat_set.strats.items():
827 # rewrite the tensors that need this first. e.g. make rolling buffers
828 inputs = []
829 intermediates = []
830 outputs = []
831
832 for ps in strat.passes:
833 inputs += ps.inputs
834 intermediates += ps.intermediates
835 outputs += ps.outputs
836
837 for tens in set(inputs) & set(outputs):
838 # tensors that are in both sets are intermediates
839
840 # find pass with input/output tensor, and check if they are both placed on NPU
841 input_placement = None
842 output_placement = None
843 for ps in strat.passes:
844 if tens in ps.inputs:
845 input_placement = ps.placement
846 if tens in ps.outputs:
847 output_placement = ps.placement
848 if input_placement == output_placement == PassPlacement.Npu:
849 tens.set_format(TensorFormat.NHCWB16, arch)
850
851 intermediates.append(tens)
852 inputs.remove(tens)
853 outputs.remove(tens)
854
855 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
856 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
857 tens.mem_area = self.arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200858 tens.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100859 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
860 else:
861 assert 0, "unknown rewrite_op " + str(rewrite_op)
862
863 is_element_wise = True
864 for ps in strat.passes:
865 assert ps.placement == strat.passes[0].placement
866 if not ps.is_element_wise:
867 is_element_wise = False
868 break
869
870 cascaded_pass = CascadedPass(
871 strat.passes[0].name,
872 strat.strat,
873 inputs,
874 intermediates,
875 outputs,
876 strat.passes,
877 strat.passes[0].placement,
878 is_element_wise,
879 )
880 assert strat.sram_used >= 0
881 cascaded_pass.sram_used = strat.sram_used
882
883 for idx, ps in enumerate(strat.passes):
884 assert ps not in pass_to_cascaded_pass
885 pass_to_cascaded_pass[ps] = cascaded_pass
886 ps.cascade = cascaded_pass
887 ps.block_config = strat.block_configs[idx]
888
889 if ps.placement == PassPlacement.Npu:
890 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
891 self.arch, ps, ps.block_config
892 )
893 assert ps.shared_buffer is not None
894
895 for op in ps.ops:
896 subgraph = op.attrs.get("subgraph")
897 if subgraph:
898 subgraph.base_sram_used = cascaded_pass.sram_used
899
900 # all passes should have a cascaded pass now
901 if len(pass_to_cascaded_pass) != len(self.sg.passes):
902 print(
903 "mismatch: we have %d passes, but only %d have cascaded passes associated"
904 % (len(self.sg.passes), len(pass_to_cascaded_pass))
905 )
906 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100907 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100908 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
909
910 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100911
Tim Hall79d07d22020-04-27 18:20:16 +0100912 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200913 if self.sg.placement == PassPlacement.Cpu:
914 # Retain the pass order for CPU subgraph
915 cascaded_passes = [ps.cascade for ps in self.sg.passes]
916 else:
917 # we have all the passes, but we need to put them in order and build predecessor/successor links.
918 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100919
Charles Xu19515e82020-06-10 10:48:33 +0200920 def visit_pass(ps):
921 if ps in visit_pass_set:
922 return
923 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100924
Charles Xu19515e82020-06-10 10:48:33 +0200925 cps = ps.cascade
926 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100927
Charles Xu19515e82020-06-10 10:48:33 +0200928 for ps in cps.passes:
929 for pred in ps.predecessors:
930 if pred in dont_traverse:
931 continue
932 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100933
Charles Xu19515e82020-06-10 10:48:33 +0200934 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100935
Charles Xu19515e82020-06-10 10:48:33 +0200936 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
937 for ps in starting_passes:
938 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100939
940 # reorder so startup init cascaded passes come first
941 def is_startup_cascaded_pass(cps):
942 if not cps.passes:
943 return False
944 return cps.placement == PassPlacement.StartupInit
945
946 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
947 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
948 ]
949
950 self.sg.cascaded_passes = cascaded_passes
951 self.sg.build_cascaded_pass_links()
952
Charles Xu7b8823f2020-05-29 13:53:10 +0200953 if self.options.use_nhcwb16_between_cascaded_passes:
954 # Check if NHCWB16 can be used in between cascaded passes
955 # (NHCWB16 within cascaded passes has been handled earlier in this function)
956 if self.sg.placement == PassPlacement.Npu:
957 for ps in self.sg.cascaded_passes:
958 if ps.placement != PassPlacement.Npu:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200959 continue
Charles Xu7b8823f2020-05-29 13:53:10 +0200960 for output in ps.outputs:
Tim Hallba695182020-08-26 17:27:19 +0100961 if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
Charles Xu7b8823f2020-05-29 13:53:10 +0200962 continue
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200963
Tim Hallba695182020-08-26 17:27:19 +0100964 use_NHCWB16 = True
965 rewrites = []
966 for op in output.consumer_list:
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200967 if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32):
Tim Hallba695182020-08-26 17:27:19 +0100968 use_NHCWB16 = False
969 elif op.type == "Reshape":
970 # Detect no-op reshapes by comparing their full input and output tensor shapes.
971 inshape = full_shape(4, op.inputs[0].shape, 1)
972 outshape = full_shape(4, op.outputs[0].shape, 1)
973 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
974 # consumers do not also need to perform a reshape or if the OFM is going to
975 # be processed by CPU operations. No-op reshape consumers with empty lists
976 # (those that have no consumers, or null-consumers used as list terminators)
977 # must use normal NHWC output.
978 incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape") for consumer in op.outputs[0].consumer_list
979 if consumer is not None ]
980 if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
981 rewrites.append(op)
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200982 else:
Tim Hallba695182020-08-26 17:27:19 +0100983 use_NHCWB16 = False
984 else:
985 use_NHCWB16 &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200986
Charles Xu7b8823f2020-05-29 13:53:10 +0200987 if use_NHCWB16:
988 output.set_format(TensorFormat.NHCWB16, arch)
Tim Hallba695182020-08-26 17:27:19 +0100989 for rewrite_op in rewrites:
990 rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200991
Tim Hall79d07d22020-04-27 18:20:16 +0100992
993def schedule_passes(nng, arch, options: SchedulerOptions):
994
995 for sg in nng.subgraphs:
996 sg.base_sram_used = 0
997
998 for sg in nng.subgraphs:
999 # re-entering the same nodes from different contexts requires us to
1000 # build a simplified directed acyclic (DAG) version of the graph to
1001 # use for traversal, rather than using a visit dictionary. this avoids
1002 # recursing infinitely due to loops.
1003 sg.build_pass_dag_predecessors()
1004
1005 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
1006
1007 strat_set = dps.search()
1008
1009 dps.apply_result(strat_set, arch)
1010
1011 if options.verbose_schedule:
1012 sg.print_cascaded_passes()