blob: 47f8a47f50eebe5b165c0bc578b835faf3382ae1 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Fredrik Svedberg880e7352020-08-25 11:31:47 +020027from .data_type import DataType
Tim Hall79d07d22020-04-27 18:20:16 +010028from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010029from .nn_graph import CascadedPass
30from .nn_graph import PassPlacement
31from .nn_graph import SchedulerRewrite
32from .nn_graph import SchedulingStrategy
33from .npu_performance import make_bandwidth_array
34from .npu_performance import make_cycles_array
35from .npu_performance import make_macs_array
36from .npu_performance import make_metrics_arrays
37from .npu_performance import PassCycles
Jacob Bohlin1a666972020-09-11 10:04:15 +020038from .numeric_util import full_shape
Diego Russoe8a10452020-04-21 17:39:10 +010039from .operation import NpuBlockType
40from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
41from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
42from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020043from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010044from .tensor import TensorFormat
45from .tensor import TensorPurpose
46from .tensor import TensorSubPurpose
Jacob Bohlin1a666972020-09-11 10:04:15 +020047
Tim Hall79d07d22020-04-27 18:20:16 +010048
49class ParetoMetric(enum.Enum):
50 BwCycMem = 1
51 BwCycMemBlkH = 2
52
53 def __str__(self):
54 return self.name
55
56
57class SchedulerOptions:
58 def __init__(
59 self,
60 use_cascading=True,
61 use_ifm_ofm_overlap=True,
62 verbose_schedule=False,
63 verbose_pareto_frontier_schedules=False,
64 use_ifm_streaming=True,
65 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020066 use_nhcwb16_between_cascaded_passes=True,
Tim Hall79d07d22020-04-27 18:20:16 +010067 ):
68 self.use_cascading = use_cascading
69 self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
70 self.verbose_schedule = verbose_schedule
71 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
72 self.use_ifm_streaming = use_ifm_streaming
73 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020074 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Tim Hall79d07d22020-04-27 18:20:16 +010075
76 def __str__(self):
77 return type(self).__name__ + ": " + str(self.__dict__)
78
79 __repr__ = __str__
80
81
82class Strategy:
83 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
84
85 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
86 self.strat = strat
87 self.param = param
88 self.passes = passes
89 self.block_configs = block_configs
90 self.rewrite_list = (
91 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
92 )
93 self.bws = bws
94 self.macs = macs
95 self.cycles = cycles
96 self.sram_used = sram_used
97
98 def __eq__(self, other):
99 if self.strat != other.strat:
100 return False
101 if self.param != other.param:
102 return False
103 if self.block_configs != other.block_configs:
104 return False
105 if self.passes != other.passes:
106 return False
107 if (self.bws != other.bws).any():
108 return False
109 if (self.macs != other.macs).any():
110 return False
111 if (self.cycles != other.cycles).any():
112 return False
113 if self.sram_used != other.sram_used:
114 return False
115 return True
116
117 def empty(self):
118 return not self.passes
119
120 def key(self):
121 return self.passes[-1]
122
123 def clone(self):
124 return Strategy(
125 self.strat,
126 self.param,
127 self.passes,
128 self.block_configs,
129 self.rewrite_list,
130 self.bws,
131 self.macs,
132 self.cycles,
133 self.sram_used,
134 )
135
136 def __str__(self):
137 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
138 self.strat,
139 self.passes,
140 self.rewrite_list,
141 self.bws,
142 self.macs,
143 self.cycles,
144 self.sram_used,
145 )
146
147 __repr__ = __str__
148
149
150class StrategySet:
151 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
152
153 def __init__(self, strats=None):
154 if strats is None:
155 strats = dict()
156 self.strats = strats # final pass in packed pass -> Strategy
157 self.bws, self.macs, self.cycles = make_metrics_arrays()
158 self.max_sram_used = 0
159 self.total_sram_used = 0
160
161 def update_statistics(self):
162 self.bws = make_bandwidth_array()
163 self.max_sram_used = 0
164 for ps, strat in self.strats.items():
165 self.bws += strat.bws
166 self.macs += strat.macs
167 self.cycles += strat.cycles
168 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
169 self.total_sram_used += strat.sram_used
170
171 def clone_add_strategy(self, new_strat):
172 key = new_strat.key()
173 if key in self.strats:
174 assert new_strat == self.strats[key]
175 return self
176 else:
177 new_strats = dict(self.strats)
178 new_strats[key] = new_strat
179 new_set = StrategySet(new_strats)
180 new_set.bws = self.bws + new_strat.bws
181 new_set.macs = self.macs + new_strat.macs
182 new_set.cycles = self.cycles + new_strat.cycles
183 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
184 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
185 return new_set
186
187 def __eq__(self, other):
188 if (self.bws != other.bws).any():
189 return False
190 if (self.macs != other.macs).any():
191 return False
192 if (self.cycles != other.cycles).any():
193 return False
194 if self.max_sram_used != other.max_sram_used:
195 return False
196 if self.total_sram_used != other.total_sram_used:
197 return False
198 if self.strats != other.strats:
199 return False
200 return True
201
202 def __str__(self):
203 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
204 self.max_sram_used,
205 list(ps.name for ps in self.strats),
206 )
207
208 __repr__ = __str__
209
210
211empty_strategy = Strategy(
212 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
213)
214INFINITY = 1e30
215
216ABORT_SEARCH = []
217
218
219def flatten_list_of_lists(lstlst):
220 lst = []
221 for v in lstlst:
222 lst.extend(v)
223 return lst
224
225
226class DynamicProgrammingScheduler:
227 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
228 self.nng = nng
229 self.sg = sg
230 self.arch = arch
231 self.sram_limit = sram_limit
232 self.options = copy.copy(options)
233 self.use_cascading = options.use_cascading
234
235 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
236 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200237 else:
238 self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
Tim Hall79d07d22020-04-27 18:20:16 +0100239
240 self.verbose_schedule = options.verbose_schedule
241 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
242 self.mem_area = MemArea.Sram
243
244 self.bandwidth_weights = arch.bandwidth_weights
245 self.cycles_weight = arch.cycles_weight
246 self.max_sram_used_weight = arch.max_sram_used_weight
247
248 self.n_combinations_searched = 0
249
250 self.feature_maps_not_in_fast_storage = (
251 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
252 )
253
254 self.pareto_max_candidates = 16
255
256 self.ifm_stream_npu_blocks = set(
257 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
258 )
259
260 num_pareto_metrics = 4
261 view_values = ",".join(["d"] * num_pareto_metrics)
262 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
263
264 def pareto_metric(self, candidate):
265 strat, strat_set = candidate
266 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
267 bws = strat.bws + strat_set.bws
268 last_block_height = 0
269 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
270 last_block_height = strat.block_configs[-1][0]
271
272 return (
273 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
274 strat_set.max_sram_used,
275 strat.sram_used,
276 last_block_height,
277 )
278
279 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
280
281 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
282
283 if len(candidates) <= 1:
284 return candidates
285 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100286 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
287 ids = np.arange(len(candidates), dtype=np.int32)
288 for idx, cand in enumerate(candidates):
289 pareto_vals[idx] = self.pareto_metric(cand)
290
291 sort_order = np.argsort(
292 pareto_vals.view(DynamicProgrammingScheduler.view_values),
293 order=DynamicProgrammingScheduler.order_values,
294 axis=0,
295 kind="stable",
296 ).flatten()
297 pareto_vals = pareto_vals[sort_order]
298 ids = ids[sort_order]
299
300 pareto_frontier = []
301 while len(ids) > 0:
302 pareto_frontier.append(candidates[ids[0]])
303 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
304 ids = ids[not_dominated_by_first]
305 pareto_vals = pareto_vals[not_dominated_by_first]
306
307 if len(pareto_frontier) > self.pareto_max_candidates:
308 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
309 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
310
311 return pareto_frontier
312
313 def candidate_metric(self, candidate):
314 strat, strat_set = candidate
315 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
316 bws = strat.bws + strat_set.bws
317 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
318
319 return (
320 max_sram_used * self.max_sram_used_weight
321 + np.tensordot(bws, self.bandwidth_weights, axes=3)
322 + total_cycles * self.cycles_weight
323 )
324
325 def sort_by_candidate_metric(self, candidate_list):
326 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
327 return sorted_list
328
329 def best_candidate(self, candidate_list):
330 if len(candidate_list) == 0:
331 return ABORT_SEARCH
332 if len(candidate_list) == 1:
333 return candidate_list[0]
334 sorted_list = self.sort_by_candidate_metric(candidate_list)
335 return sorted_list[0]
336
337 def graduate_strat(self, strat_type, sram_used, old_strat_data):
338 res = []
339 for old_strat, old_strat_set in old_strat_data:
340 if old_strat.sram_used + sram_used > self.sram_limit:
341 continue # This strategy is bad, drop it
342 if old_strat_set.max_sram_used > self.sram_limit:
343 continue # This strategy is bad, drop it
344 assert old_strat.strat == SchedulingStrategy.Unknown
345
346 new_strat = old_strat.clone()
347 new_strat.strat = strat_type
348 new_strat.sram_used = old_strat.sram_used + sram_used
349
350 if self.use_ifm_ofm_overlap:
351 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
352 new_strat.strat, new_strat.passes, new_strat.block_configs
353 )
354 new_strat.sram_used -= overlap
355
356 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
357 res.append((empty_strategy, new_strat_set))
358 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
359
360 def append_sram(self, sram_used, old_strat_data):
361 res = []
362 for old_strat, strat_set in old_strat_data:
363 assert old_strat.strat == SchedulingStrategy.Unknown
364 assert old_strat.sram_used == 0
365 new_strat = old_strat.clone()
366 new_strat.sram_used = old_strat.sram_used + sram_used
367
368 res.append((new_strat, strat_set))
369 return res
370
371 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
372 res = []
373 for old_strat, strat_set in old_strat_data:
374 assert old_strat.strat == SchedulingStrategy.Unknown
375 new_strat = old_strat.clone()
376 bws, macs, cycles = metrics[:3]
377
378 new_strat.sram_used = old_strat.sram_used + sram_used
379 new_strat.block_configs = old_strat.block_configs + [block_config]
380 new_strat.bws = old_strat.bws + bws
381 new_strat.macs = old_strat.macs + macs
382 new_strat.cycles = old_strat.cycles + cycles
383 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
384 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
385 )
386
387 res.append((new_strat, strat_set))
388 return res
389
390 def append_sram_pass_block_config_performance_metrics_rewrite_list(
391 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
392 ):
393 res = []
394 for old_strat, strat_set in old_strat_data:
395 assert old_strat.strat == SchedulingStrategy.Unknown
396 new_strat = old_strat.clone()
397 bws, macs, cycles = metrics[:3]
398 new_strat.sram_used = old_strat.sram_used + sram_used
399 new_strat.block_configs = old_strat.block_configs + [block_config]
400 new_strat.bws = old_strat.bws + bws
401 new_strat.macs = old_strat.macs + macs
402 new_strat.cycles = old_strat.cycles + cycles
403 new_strat.passes = old_strat.passes + [new_pass]
404 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
405 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
406 )
407 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
408 res.append((new_strat, strat_set))
409 return res
410
411 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
412 res = []
413 for old_strat, strat_set in old_strat_data:
414 assert old_strat.strat == SchedulingStrategy.Unknown
415 new_strat = old_strat.clone()
416 new_strat.sram_used = old_strat.sram_used + sram_used
417 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
418 res.append((new_strat, strat_set))
419 return res
420
421 def pass_to_strat(self, strat_data):
422 res = {}
423 for strat in strat_data[1].strats.values():
424 for ps in strat.passes:
425 res[ps] = strat
426 return res
427
428 def compatible_strats(self, a, b):
429 intersection = a.keys() & b.keys()
430 for k in intersection:
431 if a[k] != b[k]:
432 return False
433 return True
434
435 def collate_strats_for_passes(self, all_passes):
436 if len(all_passes) == 0:
437 return [(empty_strategy, StrategySet(dict()))]
438 if len(all_passes) == 1:
439 return all_passes[0] # save some space in the common case
440 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
441 prev_combos = [dict()]
442 for j, strand in enumerate(all_strands):
443 new_combos = []
444 for i, alt in enumerate(strand):
445 for prev in prev_combos:
446 if self.compatible_strats(prev, alt):
447 cmb = dict(prev)
448 cmb.update(all_passes[j][i][1].strats)
449 new_combos.append(cmb)
450 prev_combos = new_combos
451
452 res = []
453 for d in prev_combos:
454 s = StrategySet(d)
455 s.update_statistics()
456 res.append((empty_strategy, s))
457 return res
458
459 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
460 # get the rest of the predecessors
461 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
462 other_predecessor_data = self.search_pass_list(other_predecessors)
463
464 # pred strat data has an incomplete strategy, which we need
465 # to continue on, whereas the other ones have completed strategies.
466 # we need to merge these, but keep the incomplete strategy too.
467
468 res = []
469 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
470 all_strats = [
471 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
472 other_predecessor_data, # this one is fine to use as-is
473 ]
474 collated_strat_data = self.collate_strats_for_passes(all_strats)
475 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
476 res.extend(strat_data)
477 return res
478
479 def calc_non_local_mem_usage(self):
480 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
481 range_set = live_range.extract_live_ranges_from_passes(
482 self.sg,
483 self.mem_area,
484 mark_output_tensors_overlapping_with_input_tensors=True,
485 ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
486 )
487 range_dict = range_set.ranges
488
489 # find which ranges overlap passes but aren't input/outputs of the passes.
490 # these won't be counted by the dynamic programming search and must be counted in manually.
491 end_pos = max(ps.time for ps in self.sg.passes) + 2
492 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
493 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
494
495 for tens, rng in range_dict.items():
496 storage_size = tens.storage_size()
497 assert tens.mem_area == self.mem_area
498 mem_usage[rng.start_time : rng.end_time] += storage_size
499
500 for ps in self.sg.passes:
501 local_mem_usage = 0
502 for tens in ps.inputs + ps.outputs + ps.intermediates:
503 if tens.mem_area != self.mem_area:
504 continue
505
506 local_mem_usage += tens.storage_size()
507
508 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
509
510 self.non_local_mem_usage = non_local_mem_usage
511
512 def search(self):
513 self.calc_non_local_mem_usage()
514 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
515 strat_data = self.search_pass_list(starting_passes)
516
517 _, best_set = self.best_candidate(strat_data)
518
519 if self.verbose_pareto_frontier_schedules:
520 print(
521 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
522 % (self.n_combinations_searched, len(strat_data,))
523 )
524 for idx, (_, strat_set) in enumerate(strat_data):
525 extra = ""
526 if strat_set == best_set:
527 extra = "(Best candidate)"
528 print("Candidate", idx, extra)
529 memory_used = {MemArea.Sram: strat_set.max_sram_used}
530 stats_writer.print_performance_metrics_for_strat(
531 self.arch,
532 "",
533 strat_set.cycles,
534 strat_set.macs,
535 strat_set.bws,
536 self.nng.batch_size,
537 memory_used,
538 len(self.sg.passes),
539 len(strat_set.strats),
540 )
541
542 return best_set
543
544 def search_pass_list(self, pass_list):
545 all_strats = []
546 for ps in pass_list:
547 strat = self.search_output(ps)
548 all_strats.append(strat)
549 strat_data = self.collate_strats_for_passes(all_strats)
550 for strd in strat_data:
551 for ps in pass_list:
552 assert ps in strd[1].strats # should have strategies for everything we asked to search
553 return strat_data
554
555 def search_predecessors(self, ps):
556
557 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
558 # we have strats for all passes
559
560 pass_list = ps.dag_predecessors
561 strat_data = self.search_pass_list(pass_list)
562
563 return strat_data
564
565 @lru_cache(maxsize=None)
566 def search_output(self, ps):
567
568 assert ps in self.sg.passes
569 candidate_list = []
570
571 candidate_list.extend(self.search_weight_streaming_output(ps))
572
573 if self.options.use_ifm_streaming:
574 candidate_list.extend(self.search_ifm_streaming_output(ps))
575
576 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
577
578 if not best:
579 print(
580 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
581 % (ps.name,)
582 )
583 return self.search_predecessors(ps)
584
585 return best
586
587 def search_ifm_streaming_output(self, ps):
588 if ps.placement != PassPlacement.Npu:
589 return ABORT_SEARCH
590 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
591 return ABORT_SEARCH
592 strat_data = self.search_ifm_streaming_body(ps, False)
593
594 sram_used = self.non_local_mem_usage[ps.time]
595 for tens in ps.outputs:
596 if tens.mem_area == self.mem_area:
597 sram_used += tens.storage_size()
598
599 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
600
601 @lru_cache(maxsize=None)
602 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
603 if ps.placement != PassPlacement.Npu:
604 return ABORT_SEARCH
605 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
606 return ABORT_SEARCH
607 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
608 res = []
609
610 base_sram_used = 0
611 for tens in ps.intermediates:
612 if tens.mem_area == self.mem_area:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200613 if tens.purpose == TensorPurpose.Weights:
614 base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
615 else:
616 base_sram_used += tens.storage_size()
Tim Hall79d07d22020-04-27 18:20:16 +0100617
618 all_block_configs = self.get_block_configs(ps)
619 for block_config in all_block_configs:
620 all_strats = []
621
622 if self.use_cascading:
623 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
624
625 all_strats.extend(ifm_input_search_resuls)
626
627 rewrite_list = []
628 sram_used = base_sram_used
629
630 metrics = npu_performance.performance_metrics_for_pass(
631 self.arch,
632 ps,
633 block_config,
634 rewrite_list=rewrite_list,
635 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
636 )
637
638 res.extend(
639 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
640 sram_used, ps, block_config, metrics, rewrite_list, all_strats
641 )
642 )
643
644 self.n_combinations_searched += len(res)
645 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
646 return res
647
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200648 def avoid_for_spilling(self, pred_candidate):
649 if self.arch.feature_map_storage_mem_area == self.arch.fast_storage_mem_area:
650 return False
651
652 # For SRAM spilling, concat op is avoided as predecessor
653 for op in pred_candidate.ops:
654 if op.type == "ConcatSliceWrite":
655 return True
Jacob Bohlin1a666972020-09-11 10:04:15 +0200656 if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
657 # The op has consumers in other subgraphs
658 return True
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200659 return False
660
Tim Hall79d07d22020-04-27 18:20:16 +0100661 def search_ifm_streaming_partial(self, ps, block_config):
662 if ps.placement != PassPlacement.Npu:
663 return ABORT_SEARCH
664
665 if len(ps.inputs) < 1:
666 return ABORT_SEARCH
667
668 ifm_tensor = ps.ifm_tensor
669
670 if ifm_tensor is None:
671 return ABORT_SEARCH
672 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
673 return ABORT_SEARCH
674 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
675 return ABORT_SEARCH
676
677 pred_pass_list = []
678 for pred_candidate in ps.dag_predecessors:
679 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
680 # we found a predecessor that produces this IFM tensor
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200681 if not ifm_tensor.avoid_NHCWB16:
682 # and NHCWB16 format is not to be avoided
683 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
684 # and it only has one successor, namely us
685 if pred_candidate.placement == PassPlacement.Npu:
686 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
687 # and it is on the Npu
688 if not self.avoid_for_spilling(pred_candidate):
689 # and fusable - it's a candidate
690 pred_pass_list.append(pred_candidate)
Tim Hall79d07d22020-04-27 18:20:16 +0100691
692 if not pred_pass_list:
693 return ABORT_SEARCH
694
695 all_candidates = []
696 for pred_pass in pred_pass_list:
697 # recurse into the next pass
698 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
699
700 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
701 for strat_opt in strat_data:
702
703 pred_pass_block_config = strat_opt[0].block_configs[-1]
704 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
705 self.arch, pred_pass, pred_pass_block_config, ps, block_config
706 )
707 if rolling_buffer_dims is None:
708 continue # this does not pack properly, skip it.
709
710 sram_used = 0
711 for tens in ps.inputs:
712 if tens != ifm_tensor:
713 if tens.mem_area == self.mem_area:
714 sram_used += tens.storage_size()
715
716 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
717
718 rewrite_list = [
719 (
720 SchedulerRewrite.ChangeTensorSubPurpose,
721 ifm_tensor,
722 TensorSubPurpose.RollingBufferY,
723 rolling_buffer_y,
724 None,
725 ps,
726 )
727 ]
728 sram_used += ifm_tensor.storage_size_for_sub_purpose(
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200729 self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
Tim Hall79d07d22020-04-27 18:20:16 +0100730 )
731
732 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
733
734 self.n_combinations_searched += len(all_candidates)
735 return all_candidates
736
737 def get_block_configs(self, ps):
738 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100739 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100740
741 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
742
743 # Take a limited number of the largest blocks
744 if self.arch.block_config_limit > 0:
745 # Sort by block area, followed by depth
746 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
747 bound = min(len(block_configs), self.arch.block_config_limit)
748 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
749 tmp = block_configs[:bound]
750 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
751 block_configs = tmp
752
753 return block_configs
754
755 def search_ifm_streaming_input(self, ps):
756 sram_used = 0
757 for tens in ps.inputs:
758 if tens.mem_area == self.mem_area:
759 sram_used += tens.storage_size()
760
761 return self.append_sram(sram_used, self.search_predecessors(ps))
762
763 def search_weight_streaming_output(self, ps):
764 strat_data = self.search_weight_streaming_body(ps)
765
766 sram_used = self.non_local_mem_usage[ps.time]
767 for tens in ps.outputs:
768 if tens.mem_area == self.mem_area:
769 sram_used += tens.storage_size()
770
771 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
772
773 @lru_cache(maxsize=None)
774 def search_weight_streaming_body(self, ps):
775
776 strat_data = self.search_weight_streaming_input(ps)
777
778 res = []
779
780 all_block_configs = self.get_block_configs(ps)
781
782 for block_config in all_block_configs:
783
784 sram_used = 0
785 rewrite_list = []
786
787 for tens in ps.intermediates:
788 if tens.mem_area == self.mem_area:
789 if tens.purpose == TensorPurpose.Weights:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200790 sram_used += tens.storage_size_for_sub_purpose(
791 self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
792 )
Tim Hall79d07d22020-04-27 18:20:16 +0100793 rewrite_list.append(
794 (
795 SchedulerRewrite.ChangeTensorSubPurpose,
796 tens,
797 TensorSubPurpose.DoubleBuffer,
798 block_config[3],
799 None,
800 ps,
801 )
802 )
803 else:
804 sram_used += tens.storage_size()
805
806 metrics = npu_performance.performance_metrics_for_pass(
807 self.arch, ps, block_config, rewrite_list=rewrite_list
808 )
809
810 res.extend(
811 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
812 sram_used, ps, block_config, metrics, rewrite_list, strat_data
813 )
814 )
815
816 self.n_combinations_searched += len(res)
817 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
818 return res
819
820 def search_weight_streaming_input(self, ps):
821 sram_used = 0
822 for tens in ps.inputs:
823 if tens.mem_area == self.mem_area:
824 sram_used += tens.storage_size()
825
826 return self.append_sram(sram_used, self.search_predecessors(ps))
827
828 def apply_result(self, strat_set, arch):
829 pass_to_cascaded_pass = dict()
830 for _, strat in strat_set.strats.items():
831 # rewrite the tensors that need this first. e.g. make rolling buffers
832 inputs = []
833 intermediates = []
834 outputs = []
835
836 for ps in strat.passes:
837 inputs += ps.inputs
838 intermediates += ps.intermediates
839 outputs += ps.outputs
840
841 for tens in set(inputs) & set(outputs):
842 # tensors that are in both sets are intermediates
843
844 # find pass with input/output tensor, and check if they are both placed on NPU
845 input_placement = None
846 output_placement = None
847 for ps in strat.passes:
848 if tens in ps.inputs:
849 input_placement = ps.placement
850 if tens in ps.outputs:
851 output_placement = ps.placement
852 if input_placement == output_placement == PassPlacement.Npu:
853 tens.set_format(TensorFormat.NHCWB16, arch)
854
855 intermediates.append(tens)
856 inputs.remove(tens)
857 outputs.remove(tens)
858
859 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
860 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
861 tens.mem_area = self.arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200862 tens.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100863 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
864 else:
865 assert 0, "unknown rewrite_op " + str(rewrite_op)
866
867 is_element_wise = True
868 for ps in strat.passes:
869 assert ps.placement == strat.passes[0].placement
870 if not ps.is_element_wise:
871 is_element_wise = False
872 break
873
874 cascaded_pass = CascadedPass(
875 strat.passes[0].name,
876 strat.strat,
877 inputs,
878 intermediates,
879 outputs,
880 strat.passes,
881 strat.passes[0].placement,
882 is_element_wise,
883 )
884 assert strat.sram_used >= 0
885 cascaded_pass.sram_used = strat.sram_used
886
887 for idx, ps in enumerate(strat.passes):
888 assert ps not in pass_to_cascaded_pass
889 pass_to_cascaded_pass[ps] = cascaded_pass
890 ps.cascade = cascaded_pass
891 ps.block_config = strat.block_configs[idx]
892
893 if ps.placement == PassPlacement.Npu:
894 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
895 self.arch, ps, ps.block_config
896 )
897 assert ps.shared_buffer is not None
898
899 for op in ps.ops:
900 subgraph = op.attrs.get("subgraph")
901 if subgraph:
902 subgraph.base_sram_used = cascaded_pass.sram_used
903
904 # all passes should have a cascaded pass now
905 if len(pass_to_cascaded_pass) != len(self.sg.passes):
906 print(
907 "mismatch: we have %d passes, but only %d have cascaded passes associated"
908 % (len(self.sg.passes), len(pass_to_cascaded_pass))
909 )
910 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100911 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100912 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
913
914 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100915
Tim Hall79d07d22020-04-27 18:20:16 +0100916 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200917 if self.sg.placement == PassPlacement.Cpu:
918 # Retain the pass order for CPU subgraph
919 cascaded_passes = [ps.cascade for ps in self.sg.passes]
920 else:
921 # we have all the passes, but we need to put them in order and build predecessor/successor links.
922 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100923
Charles Xu19515e82020-06-10 10:48:33 +0200924 def visit_pass(ps):
925 if ps in visit_pass_set:
926 return
927 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100928
Charles Xu19515e82020-06-10 10:48:33 +0200929 cps = ps.cascade
930 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100931
Charles Xu19515e82020-06-10 10:48:33 +0200932 for ps in cps.passes:
933 for pred in ps.predecessors:
934 if pred in dont_traverse:
935 continue
936 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100937
Charles Xu19515e82020-06-10 10:48:33 +0200938 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100939
Charles Xu19515e82020-06-10 10:48:33 +0200940 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
941 for ps in starting_passes:
942 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100943
944 # reorder so startup init cascaded passes come first
945 def is_startup_cascaded_pass(cps):
946 if not cps.passes:
947 return False
948 return cps.placement == PassPlacement.StartupInit
949
950 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
951 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
952 ]
953
954 self.sg.cascaded_passes = cascaded_passes
955 self.sg.build_cascaded_pass_links()
956
Charles Xu7b8823f2020-05-29 13:53:10 +0200957 if self.options.use_nhcwb16_between_cascaded_passes:
958 # Check if NHCWB16 can be used in between cascaded passes
959 # (NHCWB16 within cascaded passes has been handled earlier in this function)
960 if self.sg.placement == PassPlacement.Npu:
Tim Hall5e990c82020-08-27 18:36:05 +0100961 last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
Charles Xu7b8823f2020-05-29 13:53:10 +0200962 for ps in self.sg.cascaded_passes:
963 if ps.placement != PassPlacement.Npu:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200964 continue
Charles Xu7b8823f2020-05-29 13:53:10 +0200965 for output in ps.outputs:
Tim Hallba695182020-08-26 17:27:19 +0100966 if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
Charles Xu7b8823f2020-05-29 13:53:10 +0200967 continue
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200968
Tim Hallba695182020-08-26 17:27:19 +0100969 use_NHCWB16 = True
970 rewrites = []
971 for op in output.consumer_list:
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200972 if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32):
Tim Hallba695182020-08-26 17:27:19 +0100973 use_NHCWB16 = False
974 elif op.type == "Reshape":
975 # Detect no-op reshapes by comparing their full input and output tensor shapes.
976 inshape = full_shape(4, op.inputs[0].shape, 1)
977 outshape = full_shape(4, op.outputs[0].shape, 1)
978 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
979 # consumers do not also need to perform a reshape or if the OFM is going to
980 # be processed by CPU operations. No-op reshape consumers with empty lists
981 # (those that have no consumers, or null-consumers used as list terminators)
982 # must use normal NHWC output.
Jacob Bohlin1a666972020-09-11 10:04:15 +0200983 incompatible_consumers = [
984 (
985 not consumer.run_on_npu
986 or consumer.type == "Reshape"
987 or (consumer is last_op_in_subgraph)
988 )
989 for consumer in op.outputs[0].consumer_list
990 if consumer is not None
991 ]
Tim Hallba695182020-08-26 17:27:19 +0100992 if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
993 rewrites.append(op)
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200994 else:
Tim Hallba695182020-08-26 17:27:19 +0100995 use_NHCWB16 = False
996 else:
997 use_NHCWB16 &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200998
Charles Xu7b8823f2020-05-29 13:53:10 +0200999 if use_NHCWB16:
1000 output.set_format(TensorFormat.NHCWB16, arch)
Tim Hallba695182020-08-26 17:27:19 +01001001 for rewrite_op in rewrites:
1002 rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001003
Tim Hall79d07d22020-04-27 18:20:16 +01001004
1005def schedule_passes(nng, arch, options: SchedulerOptions):
1006
1007 for sg in nng.subgraphs:
1008 sg.base_sram_used = 0
1009
1010 for sg in nng.subgraphs:
1011 # re-entering the same nodes from different contexts requires us to
1012 # build a simplified directed acyclic (DAG) version of the graph to
1013 # use for traversal, rather than using a visit dictionary. this avoids
1014 # recursing infinitely due to loops.
1015 sg.build_pass_dag_predecessors()
1016
1017 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
1018
1019 strat_set = dps.search()
1020
1021 dps.apply_result(strat_set, arch)
1022
1023 if options.verbose_schedule:
1024 sg.print_cascaded_passes()