blob: 31e6383a25fa8f50afb2b31aab0955613a00fdf1 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Fredrik Svedberg880e7352020-08-25 11:31:47 +020027from .data_type import DataType
Tim Hall79d07d22020-04-27 18:20:16 +010028from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010029from .nn_graph import CascadedPass
30from .nn_graph import PassPlacement
31from .nn_graph import SchedulerRewrite
32from .nn_graph import SchedulingStrategy
33from .npu_performance import make_bandwidth_array
34from .npu_performance import make_cycles_array
35from .npu_performance import make_macs_array
36from .npu_performance import make_metrics_arrays
37from .npu_performance import PassCycles
Jacob Bohlin1a666972020-09-11 10:04:15 +020038from .numeric_util import full_shape
Diego Russoe8a10452020-04-21 17:39:10 +010039from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Diego Russoe8a10452020-04-21 17:39:10 +010041from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
42from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
43from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020044from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010045from .tensor import TensorFormat
46from .tensor import TensorPurpose
47from .tensor import TensorSubPurpose
Jacob Bohlin1a666972020-09-11 10:04:15 +020048
Tim Hall79d07d22020-04-27 18:20:16 +010049
50class ParetoMetric(enum.Enum):
51 BwCycMem = 1
52 BwCycMemBlkH = 2
53
54 def __str__(self):
55 return self.name
56
57
58class SchedulerOptions:
59 def __init__(
60 self,
61 use_cascading=True,
62 use_ifm_ofm_overlap=True,
63 verbose_schedule=False,
64 verbose_pareto_frontier_schedules=False,
65 use_ifm_streaming=True,
66 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020067 use_nhcwb16_between_cascaded_passes=True,
Tim Hall79d07d22020-04-27 18:20:16 +010068 ):
69 self.use_cascading = use_cascading
70 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
71 self.verbose_schedule = verbose_schedule
72 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
73 self.use_ifm_streaming = use_ifm_streaming
74 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020075 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Tim Hall79d07d22020-04-27 18:20:16 +010076
77 def __str__(self):
78 return type(self).__name__ + ": " + str(self.__dict__)
79
80 __repr__ = __str__
81
82
83class Strategy:
84 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
85
86 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
87 self.strat = strat
88 self.param = param
89 self.passes = passes
90 self.block_configs = block_configs
91 self.rewrite_list = (
92 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
93 )
94 self.bws = bws
95 self.macs = macs
96 self.cycles = cycles
97 self.sram_used = sram_used
98
99 def __eq__(self, other):
100 if self.strat != other.strat:
101 return False
102 if self.param != other.param:
103 return False
104 if self.block_configs != other.block_configs:
105 return False
106 if self.passes != other.passes:
107 return False
108 if (self.bws != other.bws).any():
109 return False
110 if (self.macs != other.macs).any():
111 return False
112 if (self.cycles != other.cycles).any():
113 return False
114 if self.sram_used != other.sram_used:
115 return False
116 return True
117
118 def empty(self):
119 return not self.passes
120
121 def key(self):
122 return self.passes[-1]
123
124 def clone(self):
125 return Strategy(
126 self.strat,
127 self.param,
128 self.passes,
129 self.block_configs,
130 self.rewrite_list,
131 self.bws,
132 self.macs,
133 self.cycles,
134 self.sram_used,
135 )
136
137 def __str__(self):
138 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
139 self.strat,
140 self.passes,
141 self.rewrite_list,
142 self.bws,
143 self.macs,
144 self.cycles,
145 self.sram_used,
146 )
147
148 __repr__ = __str__
149
150
151class StrategySet:
152 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
153
154 def __init__(self, strats=None):
155 if strats is None:
156 strats = dict()
157 self.strats = strats # final pass in packed pass -> Strategy
158 self.bws, self.macs, self.cycles = make_metrics_arrays()
159 self.max_sram_used = 0
160 self.total_sram_used = 0
161
162 def update_statistics(self):
163 self.bws = make_bandwidth_array()
164 self.max_sram_used = 0
165 for ps, strat in self.strats.items():
166 self.bws += strat.bws
167 self.macs += strat.macs
168 self.cycles += strat.cycles
169 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
170 self.total_sram_used += strat.sram_used
171
172 def clone_add_strategy(self, new_strat):
173 key = new_strat.key()
174 if key in self.strats:
175 assert new_strat == self.strats[key]
176 return self
177 else:
178 new_strats = dict(self.strats)
179 new_strats[key] = new_strat
180 new_set = StrategySet(new_strats)
181 new_set.bws = self.bws + new_strat.bws
182 new_set.macs = self.macs + new_strat.macs
183 new_set.cycles = self.cycles + new_strat.cycles
184 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
185 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
186 return new_set
187
188 def __eq__(self, other):
189 if (self.bws != other.bws).any():
190 return False
191 if (self.macs != other.macs).any():
192 return False
193 if (self.cycles != other.cycles).any():
194 return False
195 if self.max_sram_used != other.max_sram_used:
196 return False
197 if self.total_sram_used != other.total_sram_used:
198 return False
199 if self.strats != other.strats:
200 return False
201 return True
202
203 def __str__(self):
204 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
205 self.max_sram_used,
206 list(ps.name for ps in self.strats),
207 )
208
209 __repr__ = __str__
210
211
212empty_strategy = Strategy(
213 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
214)
215INFINITY = 1e30
216
217ABORT_SEARCH = []
218
219
220def flatten_list_of_lists(lstlst):
221 lst = []
222 for v in lstlst:
223 lst.extend(v)
224 return lst
225
226
227class DynamicProgrammingScheduler:
228 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
229 self.nng = nng
230 self.sg = sg
231 self.arch = arch
232 self.sram_limit = sram_limit
233 self.options = copy.copy(options)
234 self.use_cascading = options.use_cascading
235
236 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
237 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200238 else:
239 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
Tim Hall79d07d22020-04-27 18:20:16 +0100240
241 self.verbose_schedule = options.verbose_schedule
242 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
243 self.mem_area = MemArea.Sram
244
245 self.bandwidth_weights = arch.bandwidth_weights
246 self.cycles_weight = arch.cycles_weight
247 self.max_sram_used_weight = arch.max_sram_used_weight
248
249 self.n_combinations_searched = 0
250
251 self.feature_maps_not_in_fast_storage = (
252 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
253 )
254
255 self.pareto_max_candidates = 16
256
257 self.ifm_stream_npu_blocks = set(
Louis Verhaardaee5d752020-09-30 09:01:52 +0200258 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
Tim Hall79d07d22020-04-27 18:20:16 +0100259 )
260
261 num_pareto_metrics = 4
262 view_values = ",".join(["d"] * num_pareto_metrics)
263 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
264
265 def pareto_metric(self, candidate):
266 strat, strat_set = candidate
267 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
268 bws = strat.bws + strat_set.bws
269 last_block_height = 0
270 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
271 last_block_height = strat.block_configs[-1][0]
272
273 return (
274 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
275 strat_set.max_sram_used,
276 strat.sram_used,
277 last_block_height,
278 )
279
280 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
281
282 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
283
284 if len(candidates) <= 1:
285 return candidates
286 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100287 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
288 ids = np.arange(len(candidates), dtype=np.int32)
289 for idx, cand in enumerate(candidates):
290 pareto_vals[idx] = self.pareto_metric(cand)
291
292 sort_order = np.argsort(
293 pareto_vals.view(DynamicProgrammingScheduler.view_values),
294 order=DynamicProgrammingScheduler.order_values,
295 axis=0,
296 kind="stable",
297 ).flatten()
298 pareto_vals = pareto_vals[sort_order]
299 ids = ids[sort_order]
300
301 pareto_frontier = []
302 while len(ids) > 0:
303 pareto_frontier.append(candidates[ids[0]])
304 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
305 ids = ids[not_dominated_by_first]
306 pareto_vals = pareto_vals[not_dominated_by_first]
307
308 if len(pareto_frontier) > self.pareto_max_candidates:
309 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
310 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
311
312 return pareto_frontier
313
314 def candidate_metric(self, candidate):
315 strat, strat_set = candidate
316 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
317 bws = strat.bws + strat_set.bws
318 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
319
320 return (
321 max_sram_used * self.max_sram_used_weight
322 + np.tensordot(bws, self.bandwidth_weights, axes=3)
323 + total_cycles * self.cycles_weight
324 )
325
326 def sort_by_candidate_metric(self, candidate_list):
327 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
328 return sorted_list
329
330 def best_candidate(self, candidate_list):
331 if len(candidate_list) == 0:
332 return ABORT_SEARCH
333 if len(candidate_list) == 1:
334 return candidate_list[0]
335 sorted_list = self.sort_by_candidate_metric(candidate_list)
336 return sorted_list[0]
337
338 def graduate_strat(self, strat_type, sram_used, old_strat_data):
339 res = []
340 for old_strat, old_strat_set in old_strat_data:
341 if old_strat.sram_used + sram_used > self.sram_limit:
342 continue # This strategy is bad, drop it
343 if old_strat_set.max_sram_used > self.sram_limit:
344 continue # This strategy is bad, drop it
345 assert old_strat.strat == SchedulingStrategy.Unknown
346
347 new_strat = old_strat.clone()
348 new_strat.strat = strat_type
349 new_strat.sram_used = old_strat.sram_used + sram_used
350
351 if self.use_ifm_ofm_overlap:
352 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
353 new_strat.strat, new_strat.passes, new_strat.block_configs
354 )
355 new_strat.sram_used -= overlap
356
357 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
358 res.append((empty_strategy, new_strat_set))
359 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
360
361 def append_sram(self, sram_used, old_strat_data):
362 res = []
363 for old_strat, strat_set in old_strat_data:
364 assert old_strat.strat == SchedulingStrategy.Unknown
365 assert old_strat.sram_used == 0
366 new_strat = old_strat.clone()
367 new_strat.sram_used = old_strat.sram_used + sram_used
368
369 res.append((new_strat, strat_set))
370 return res
371
372 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
373 res = []
374 for old_strat, strat_set in old_strat_data:
375 assert old_strat.strat == SchedulingStrategy.Unknown
376 new_strat = old_strat.clone()
377 bws, macs, cycles = metrics[:3]
378
379 new_strat.sram_used = old_strat.sram_used + sram_used
380 new_strat.block_configs = old_strat.block_configs + [block_config]
381 new_strat.bws = old_strat.bws + bws
382 new_strat.macs = old_strat.macs + macs
383 new_strat.cycles = old_strat.cycles + cycles
384 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
385 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
386 )
387
388 res.append((new_strat, strat_set))
389 return res
390
391 def append_sram_pass_block_config_performance_metrics_rewrite_list(
392 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
393 ):
394 res = []
395 for old_strat, strat_set in old_strat_data:
396 assert old_strat.strat == SchedulingStrategy.Unknown
397 new_strat = old_strat.clone()
398 bws, macs, cycles = metrics[:3]
399 new_strat.sram_used = old_strat.sram_used + sram_used
400 new_strat.block_configs = old_strat.block_configs + [block_config]
401 new_strat.bws = old_strat.bws + bws
402 new_strat.macs = old_strat.macs + macs
403 new_strat.cycles = old_strat.cycles + cycles
404 new_strat.passes = old_strat.passes + [new_pass]
405 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
406 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
407 )
408 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
409 res.append((new_strat, strat_set))
410 return res
411
412 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
413 res = []
414 for old_strat, strat_set in old_strat_data:
415 assert old_strat.strat == SchedulingStrategy.Unknown
416 new_strat = old_strat.clone()
417 new_strat.sram_used = old_strat.sram_used + sram_used
418 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
419 res.append((new_strat, strat_set))
420 return res
421
422 def pass_to_strat(self, strat_data):
423 res = {}
424 for strat in strat_data[1].strats.values():
425 for ps in strat.passes:
426 res[ps] = strat
427 return res
428
429 def compatible_strats(self, a, b):
430 intersection = a.keys() & b.keys()
431 for k in intersection:
432 if a[k] != b[k]:
433 return False
434 return True
435
436 def collate_strats_for_passes(self, all_passes):
437 if len(all_passes) == 0:
438 return [(empty_strategy, StrategySet(dict()))]
439 if len(all_passes) == 1:
440 return all_passes[0] # save some space in the common case
441 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
442 prev_combos = [dict()]
443 for j, strand in enumerate(all_strands):
444 new_combos = []
445 for i, alt in enumerate(strand):
446 for prev in prev_combos:
447 if self.compatible_strats(prev, alt):
448 cmb = dict(prev)
449 cmb.update(all_passes[j][i][1].strats)
450 new_combos.append(cmb)
451 prev_combos = new_combos
452
453 res = []
454 for d in prev_combos:
455 s = StrategySet(d)
456 s.update_statistics()
457 res.append((empty_strategy, s))
458 return res
459
460 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
461 # get the rest of the predecessors
462 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
463 other_predecessor_data = self.search_pass_list(other_predecessors)
464
465 # pred strat data has an incomplete strategy, which we need
466 # to continue on, whereas the other ones have completed strategies.
467 # we need to merge these, but keep the incomplete strategy too.
468
469 res = []
470 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
471 all_strats = [
472 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
473 other_predecessor_data, # this one is fine to use as-is
474 ]
475 collated_strat_data = self.collate_strats_for_passes(all_strats)
476 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
477 res.extend(strat_data)
478 return res
479
480 def calc_non_local_mem_usage(self):
481 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
482 range_set = live_range.extract_live_ranges_from_passes(
Patrik Gustavssona151f592020-10-16 13:59:52 +0200483 self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
Tim Hall79d07d22020-04-27 18:20:16 +0100484 )
485 range_dict = range_set.ranges
486
487 # find which ranges overlap passes but aren't input/outputs of the passes.
488 # these won't be counted by the dynamic programming search and must be counted in manually.
489 end_pos = max(ps.time for ps in self.sg.passes) + 2
490 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
491 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
492
493 for tens, rng in range_dict.items():
494 storage_size = tens.storage_size()
495 assert tens.mem_area == self.mem_area
496 mem_usage[rng.start_time : rng.end_time] += storage_size
497
498 for ps in self.sg.passes:
499 local_mem_usage = 0
500 for tens in ps.inputs + ps.outputs + ps.intermediates:
501 if tens.mem_area != self.mem_area:
502 continue
503
504 local_mem_usage += tens.storage_size()
505
506 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
507
508 self.non_local_mem_usage = non_local_mem_usage
509
510 def search(self):
511 self.calc_non_local_mem_usage()
512 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
513 strat_data = self.search_pass_list(starting_passes)
514
515 _, best_set = self.best_candidate(strat_data)
516
517 if self.verbose_pareto_frontier_schedules:
518 print(
519 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
Diqing Zhong504d6b62020-09-17 12:21:10 +0200520 % (self.n_combinations_searched, len(strat_data))
Tim Hall79d07d22020-04-27 18:20:16 +0100521 )
522 for idx, (_, strat_set) in enumerate(strat_data):
523 extra = ""
524 if strat_set == best_set:
525 extra = "(Best candidate)"
526 print("Candidate", idx, extra)
527 memory_used = {MemArea.Sram: strat_set.max_sram_used}
528 stats_writer.print_performance_metrics_for_strat(
529 self.arch,
530 "",
531 strat_set.cycles,
532 strat_set.macs,
533 strat_set.bws,
534 self.nng.batch_size,
535 memory_used,
536 len(self.sg.passes),
537 len(strat_set.strats),
538 )
539
540 return best_set
541
542 def search_pass_list(self, pass_list):
543 all_strats = []
544 for ps in pass_list:
545 strat = self.search_output(ps)
546 all_strats.append(strat)
547 strat_data = self.collate_strats_for_passes(all_strats)
548 for strd in strat_data:
549 for ps in pass_list:
550 assert ps in strd[1].strats # should have strategies for everything we asked to search
551 return strat_data
552
553 def search_predecessors(self, ps):
554
555 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
556 # we have strats for all passes
557
558 pass_list = ps.dag_predecessors
559 strat_data = self.search_pass_list(pass_list)
560
561 return strat_data
562
563 @lru_cache(maxsize=None)
564 def search_output(self, ps):
565
566 assert ps in self.sg.passes
567 candidate_list = []
568
569 candidate_list.extend(self.search_weight_streaming_output(ps))
570
571 if self.options.use_ifm_streaming:
572 candidate_list.extend(self.search_ifm_streaming_output(ps))
573
574 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
575
576 if not best:
577 print(
578 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
579 % (ps.name,)
580 )
581 return self.search_predecessors(ps)
582
583 return best
584
585 def search_ifm_streaming_output(self, ps):
586 if ps.placement != PassPlacement.Npu:
587 return ABORT_SEARCH
588 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
589 return ABORT_SEARCH
590 strat_data = self.search_ifm_streaming_body(ps, False)
591
592 sram_used = self.non_local_mem_usage[ps.time]
593 for tens in ps.outputs:
594 if tens.mem_area == self.mem_area:
595 sram_used += tens.storage_size()
596
597 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
598
599 @lru_cache(maxsize=None)
600 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
601 if ps.placement != PassPlacement.Npu:
602 return ABORT_SEARCH
603 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
604 return ABORT_SEARCH
605 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
606 res = []
607
608 base_sram_used = 0
609 for tens in ps.intermediates:
610 if tens.mem_area == self.mem_area:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200611 if tens.purpose == TensorPurpose.Weights:
612 base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
613 else:
614 base_sram_used += tens.storage_size()
Tim Hall79d07d22020-04-27 18:20:16 +0100615
616 all_block_configs = self.get_block_configs(ps)
617 for block_config in all_block_configs:
618 all_strats = []
619
620 if self.use_cascading:
621 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
622
623 all_strats.extend(ifm_input_search_resuls)
624
625 rewrite_list = []
626 sram_used = base_sram_used
627
628 metrics = npu_performance.performance_metrics_for_pass(
629 self.arch,
630 ps,
631 block_config,
632 rewrite_list=rewrite_list,
633 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
634 )
635
636 res.extend(
637 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
638 sram_used, ps, block_config, metrics, rewrite_list, all_strats
639 )
640 )
641
642 self.n_combinations_searched += len(res)
643 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
644 return res
645
Diqing Zhong504d6b62020-09-17 12:21:10 +0200646 def avoid_for_cascading(self, pred_candidate):
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200647 for op in pred_candidate.ops:
Diqing Zhong504d6b62020-09-17 12:21:10 +0200648 if (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200649 op.type == Op.ConcatSliceWrite
Diqing Zhong504d6b62020-09-17 12:21:10 +0200650 and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
651 ):
652 # For SRAM spilling, concat op is avoided as predecessor
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200653 return True
Jacob Bohlin1a666972020-09-11 10:04:15 +0200654 if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
655 # The op has consumers in other subgraphs
656 return True
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200657 return False
658
Tim Hall79d07d22020-04-27 18:20:16 +0100659 def search_ifm_streaming_partial(self, ps, block_config):
660 if ps.placement != PassPlacement.Npu:
661 return ABORT_SEARCH
662
663 if len(ps.inputs) < 1:
664 return ABORT_SEARCH
665
666 ifm_tensor = ps.ifm_tensor
667
668 if ifm_tensor is None:
669 return ABORT_SEARCH
670 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
671 return ABORT_SEARCH
672 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
673 return ABORT_SEARCH
674
675 pred_pass_list = []
676 for pred_candidate in ps.dag_predecessors:
677 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
678 # we found a predecessor that produces this IFM tensor
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200679 if not ifm_tensor.avoid_NHCWB16:
680 # and NHCWB16 format is not to be avoided
681 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
682 # and it only has one successor, namely us
683 if pred_candidate.placement == PassPlacement.Npu:
684 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
685 # and it is on the Npu
Diqing Zhong504d6b62020-09-17 12:21:10 +0200686 if not self.avoid_for_cascading(pred_candidate):
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200687 # and fusable - it's a candidate
688 pred_pass_list.append(pred_candidate)
Tim Hall79d07d22020-04-27 18:20:16 +0100689
690 if not pred_pass_list:
691 return ABORT_SEARCH
692
693 all_candidates = []
694 for pred_pass in pred_pass_list:
695 # recurse into the next pass
696 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
697
698 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
699 for strat_opt in strat_data:
700
701 pred_pass_block_config = strat_opt[0].block_configs[-1]
702 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
703 self.arch, pred_pass, pred_pass_block_config, ps, block_config
704 )
705 if rolling_buffer_dims is None:
706 continue # this does not pack properly, skip it.
707
708 sram_used = 0
709 for tens in ps.inputs:
710 if tens != ifm_tensor:
711 if tens.mem_area == self.mem_area:
712 sram_used += tens.storage_size()
713
714 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
715
716 rewrite_list = [
717 (
718 SchedulerRewrite.ChangeTensorSubPurpose,
719 ifm_tensor,
720 TensorSubPurpose.RollingBufferY,
721 rolling_buffer_y,
722 None,
723 ps,
724 )
725 ]
726 sram_used += ifm_tensor.storage_size_for_sub_purpose(
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200727 self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
Tim Hall79d07d22020-04-27 18:20:16 +0100728 )
729
730 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
731
732 self.n_combinations_searched += len(all_candidates)
733 return all_candidates
734
735 def get_block_configs(self, ps):
736 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100737 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100738
739 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
740
741 # Take a limited number of the largest blocks
742 if self.arch.block_config_limit > 0:
743 # Sort by block area, followed by depth
744 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
745 bound = min(len(block_configs), self.arch.block_config_limit)
746 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
747 tmp = block_configs[:bound]
748 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
749 block_configs = tmp
750
751 return block_configs
752
753 def search_ifm_streaming_input(self, ps):
754 sram_used = 0
755 for tens in ps.inputs:
756 if tens.mem_area == self.mem_area:
757 sram_used += tens.storage_size()
758
759 return self.append_sram(sram_used, self.search_predecessors(ps))
760
761 def search_weight_streaming_output(self, ps):
762 strat_data = self.search_weight_streaming_body(ps)
763
764 sram_used = self.non_local_mem_usage[ps.time]
765 for tens in ps.outputs:
766 if tens.mem_area == self.mem_area:
767 sram_used += tens.storage_size()
768
769 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
770
771 @lru_cache(maxsize=None)
772 def search_weight_streaming_body(self, ps):
773
774 strat_data = self.search_weight_streaming_input(ps)
775
776 res = []
777
778 all_block_configs = self.get_block_configs(ps)
779
780 for block_config in all_block_configs:
781
782 sram_used = 0
783 rewrite_list = []
784
785 for tens in ps.intermediates:
786 if tens.mem_area == self.mem_area:
787 if tens.purpose == TensorPurpose.Weights:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200788 sram_used += tens.storage_size_for_sub_purpose(
789 self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
790 )
Tim Hall79d07d22020-04-27 18:20:16 +0100791 rewrite_list.append(
792 (
793 SchedulerRewrite.ChangeTensorSubPurpose,
794 tens,
795 TensorSubPurpose.DoubleBuffer,
796 block_config[3],
797 None,
798 ps,
799 )
800 )
801 else:
802 sram_used += tens.storage_size()
803
804 metrics = npu_performance.performance_metrics_for_pass(
805 self.arch, ps, block_config, rewrite_list=rewrite_list
806 )
807
808 res.extend(
809 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
810 sram_used, ps, block_config, metrics, rewrite_list, strat_data
811 )
812 )
813
814 self.n_combinations_searched += len(res)
815 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
816 return res
817
818 def search_weight_streaming_input(self, ps):
819 sram_used = 0
820 for tens in ps.inputs:
821 if tens.mem_area == self.mem_area:
822 sram_used += tens.storage_size()
823
824 return self.append_sram(sram_used, self.search_predecessors(ps))
825
826 def apply_result(self, strat_set, arch):
827 pass_to_cascaded_pass = dict()
828 for _, strat in strat_set.strats.items():
829 # rewrite the tensors that need this first. e.g. make rolling buffers
830 inputs = []
831 intermediates = []
832 outputs = []
833
834 for ps in strat.passes:
835 inputs += ps.inputs
836 intermediates += ps.intermediates
837 outputs += ps.outputs
838
839 for tens in set(inputs) & set(outputs):
840 # tensors that are in both sets are intermediates
841
842 # find pass with input/output tensor, and check if they are both placed on NPU
843 input_placement = None
844 output_placement = None
845 for ps in strat.passes:
846 if tens in ps.inputs:
847 input_placement = ps.placement
848 if tens in ps.outputs:
849 output_placement = ps.placement
850 if input_placement == output_placement == PassPlacement.Npu:
851 tens.set_format(TensorFormat.NHCWB16, arch)
852
853 intermediates.append(tens)
854 inputs.remove(tens)
855 outputs.remove(tens)
856
857 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
858 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
859 tens.mem_area = self.arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200860 tens.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100861 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
862 else:
863 assert 0, "unknown rewrite_op " + str(rewrite_op)
864
865 is_element_wise = True
866 for ps in strat.passes:
867 assert ps.placement == strat.passes[0].placement
868 if not ps.is_element_wise:
869 is_element_wise = False
870 break
871
872 cascaded_pass = CascadedPass(
873 strat.passes[0].name,
874 strat.strat,
875 inputs,
876 intermediates,
877 outputs,
878 strat.passes,
879 strat.passes[0].placement,
880 is_element_wise,
881 )
882 assert strat.sram_used >= 0
883 cascaded_pass.sram_used = strat.sram_used
884
885 for idx, ps in enumerate(strat.passes):
886 assert ps not in pass_to_cascaded_pass
887 pass_to_cascaded_pass[ps] = cascaded_pass
888 ps.cascade = cascaded_pass
889 ps.block_config = strat.block_configs[idx]
890
891 if ps.placement == PassPlacement.Npu:
892 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
893 self.arch, ps, ps.block_config
894 )
895 assert ps.shared_buffer is not None
896
Diqing Zhong504d6b62020-09-17 12:21:10 +0200897 sram_used = max(self.non_local_mem_usage[ps.time], 0)
Tim Hall79d07d22020-04-27 18:20:16 +0100898 for op in ps.ops:
899 subgraph = op.attrs.get("subgraph")
900 if subgraph:
Diqing Zhong504d6b62020-09-17 12:21:10 +0200901 subgraph.base_sram_used = sram_used
Tim Hall79d07d22020-04-27 18:20:16 +0100902
903 # all passes should have a cascaded pass now
904 if len(pass_to_cascaded_pass) != len(self.sg.passes):
905 print(
906 "mismatch: we have %d passes, but only %d have cascaded passes associated"
907 % (len(self.sg.passes), len(pass_to_cascaded_pass))
908 )
909 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100910 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100911 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
912
913 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100914
Tim Hall79d07d22020-04-27 18:20:16 +0100915 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200916 if self.sg.placement == PassPlacement.Cpu:
917 # Retain the pass order for CPU subgraph
918 cascaded_passes = [ps.cascade for ps in self.sg.passes]
919 else:
920 # we have all the passes, but we need to put them in order and build predecessor/successor links.
921 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100922
Charles Xu19515e82020-06-10 10:48:33 +0200923 def visit_pass(ps):
924 if ps in visit_pass_set:
925 return
926 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100927
Charles Xu19515e82020-06-10 10:48:33 +0200928 cps = ps.cascade
929 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100930
Charles Xu19515e82020-06-10 10:48:33 +0200931 for ps in cps.passes:
932 for pred in ps.predecessors:
933 if pred in dont_traverse:
934 continue
935 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100936
Charles Xu19515e82020-06-10 10:48:33 +0200937 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100938
Charles Xu19515e82020-06-10 10:48:33 +0200939 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
940 for ps in starting_passes:
941 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100942
943 # reorder so startup init cascaded passes come first
944 def is_startup_cascaded_pass(cps):
945 if not cps.passes:
946 return False
947 return cps.placement == PassPlacement.StartupInit
948
949 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
950 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
951 ]
952
953 self.sg.cascaded_passes = cascaded_passes
954 self.sg.build_cascaded_pass_links()
955
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200956 # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
957 # (NHCWB16 within cascaded passes has been handled earlier in this function)
958 if self.sg.placement == PassPlacement.Npu:
959 # Dictionary tensor -> list of ops, containing feature maps that can be attempted
960 # to be moved to fast storage
961 fast_storage_tensor_rewrites = {}
962 last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
963 for ps in self.sg.cascaded_passes:
964 if ps.placement != PassPlacement.Npu:
965 continue
966 for output in ps.outputs:
967 if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200968 continue
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200969
970 use_NHCWB16 = True
971 use_fast_storage = True
972 rewrites = []
973 for op in output.consumer_list:
974 if op is None:
975 use_NHCWB16 = False
976 use_fast_storage = False
Charles Xu7b8823f2020-05-29 13:53:10 +0200977 continue
Louis Verhaardaee5d752020-09-30 09:01:52 +0200978 if op.type == Op.ReduceSum and output.dtype == DataType.int32:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200979 use_NHCWB16 = False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200980 elif op.type == Op.Reshape:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200981 # Detect no-op reshapes by comparing their full input and output tensor shapes.
982 inshape = full_shape(4, op.inputs[0].shape, 1)
983 outshape = full_shape(4, op.outputs[0].shape, 1)
984 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
985 # consumers do not also need to perform a reshape or if the OFM is going to
986 # be processed by CPU operations. No-op reshape consumers with empty lists
987 # (those that have no consumers, or null-consumers used as list terminators)
988 # must use normal NHWC output.
989 incompatible_consumers = [
990 (
991 not consumer.run_on_npu
Louis Verhaardaee5d752020-09-30 09:01:52 +0200992 or consumer.type == Op.Reshape
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200993 or (consumer is last_op_in_subgraph)
994 )
995 for consumer in op.outputs[0].consumer_list
996 if consumer is not None
997 ]
998 if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
999 rewrites.append(op)
Tim Hallba695182020-08-26 17:27:19 +01001000 else:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001001 use_NHCWB16 = False
1002 use_fast_storage = False
1003 use_NHCWB16 &= op.run_on_npu
1004 use_fast_storage &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001005
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001006 if use_fast_storage:
1007 fast_storage_tensor_rewrites[output] = rewrites
1008 if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
1009 output.set_format(TensorFormat.NHCWB16, arch)
1010 for rewrite_op in rewrites:
1011 rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
1012 if self.feature_maps_not_in_fast_storage:
1013 # Remember feature maps that can be moved to fast storage for later use
1014 # in use_fast_storage_for_feature_maps
1015 self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001016
Tim Hall79d07d22020-04-27 18:20:16 +01001017
1018def schedule_passes(nng, arch, options: SchedulerOptions):
1019
1020 for sg in nng.subgraphs:
1021 sg.base_sram_used = 0
1022
1023 for sg in nng.subgraphs:
1024 # re-entering the same nodes from different contexts requires us to
1025 # build a simplified directed acyclic (DAG) version of the graph to
1026 # use for traversal, rather than using a visit dictionary. this avoids
1027 # recursing infinitely due to loops.
1028 sg.build_pass_dag_predecessors()
1029
1030 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
1031
1032 strat_set = dps.search()
1033
1034 dps.apply_result(strat_set, arch)
1035
1036 if options.verbose_schedule:
1037 sg.print_cascaded_passes()
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001038
1039
1040def _calc_tens_to_cps(sg, tensor_rewrites):
1041 # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
1042 # Returns dictionary tensor -> list of cascaded passes
1043 # Note: if cascaded passes are A, B, C, D, and a tensor is output
1044 # of A and input to D, then it also consumes SRAM in passes B and C.
1045 if "tens_to_cps" in sg.scheduling_info:
1046 return sg.scheduling_info["tens_to_cps"]
1047 # Determine life-time of tensors
1048 min_index = {}
1049 max_index = {}
1050 index = 0
1051 cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
1052 for cps in cps_list:
1053 for tens in cps.inputs + cps.outputs:
1054 if tens in tensor_rewrites:
1055 min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
1056 max_index[tens] = index
1057 index += 1
1058 # Convert to affected cps-es
1059 tens_to_cps = {}
1060 for tens in min_index:
1061 tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
1062 sg.scheduling_info["tens_to_cps"] = tens_to_cps
1063 return tens_to_cps
1064
1065
1066def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
1067 # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
1068 tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
1069 tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
1070 # Sort tensors first on life-time (smallest first), then on size (biggest first)
1071 tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
1072 for _, _, _, tens in tens_list:
1073 cps_list = tens_to_cps[tens]
1074 if len(cps_list) <= 1:
1075 continue
1076 sz = tens.storage_size()
1077 fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
1078 if fits_in_fast_storage:
1079 tens.mem_area = arch.fast_storage_mem_area
1080 tens.mem_type = MemType.Scratch_fast
1081 tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
1082 assert tens in tensor_rewrites
1083 # Also rewrite reshapes
1084 for rewrite_op in tensor_rewrites[tens]:
1085 tens2 = rewrite_op.outputs[0]
1086 tens2.mem_area = arch.fast_storage_mem_area
1087 tens2.mem_type = MemType.Scratch_fast
1088 tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
1089 for cps in cps_list:
1090 cps.sram_used += sz
1091
1092
1093def undo_use_fast_storage(sg, arch):
1094 # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
1095 tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
1096 tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
1097 mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
1098 for tens, cps_list in tens_to_cps.items():
1099 if tens.mem_type == MemType.Scratch_fast:
1100 sz = tens.storage_size()
1101 tens.mem_area = mem_area
1102 tens.mem_type = MemType.Scratch
1103 # Also undo reshapes
1104 for rewrite_op in tensor_rewrites[tens]:
1105 tens2 = rewrite_op.outputs[0]
1106 tens2.mem_area = mem_area
1107 tens2.mem_type = MemType.Scratch
1108 for cps in cps_list:
1109 cps.sram_used -= sz