blob: 91bad46c5466a0ba5a8d18dad44a98eff3b6d89f [file] [log] [blame]
Patrik Gustavsson2446e592021-02-11 08:36:12 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Fredrik Svedberg880e7352020-08-25 11:31:47 +020027from .data_type import DataType
Tim Hall79d07d22020-04-27 18:20:16 +010028from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010029from .nn_graph import CascadedPass
30from .nn_graph import PassPlacement
31from .nn_graph import SchedulerRewrite
32from .nn_graph import SchedulingStrategy
33from .npu_performance import make_bandwidth_array
34from .npu_performance import make_cycles_array
Diego Russoe8a10452020-04-21 17:39:10 +010035from .npu_performance import make_metrics_arrays
36from .npu_performance import PassCycles
37from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020038from .operation import Op
Andreas Nevalainen897cc142020-10-28 15:42:08 +010039from .operation import Operation
Diego Russoe8a10452020-04-21 17:39:10 +010040from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
41from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
42from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020043from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010044from .tensor import TensorFormat
45from .tensor import TensorPurpose
46from .tensor import TensorSubPurpose
Jacob Bohlin1a666972020-09-11 10:04:15 +020047
Tim Hall79d07d22020-04-27 18:20:16 +010048
49class ParetoMetric(enum.Enum):
50 BwCycMem = 1
51 BwCycMemBlkH = 2
52
53 def __str__(self):
54 return self.name
55
56
57class SchedulerOptions:
58 def __init__(
59 self,
60 use_cascading=True,
Tim Hall79d07d22020-04-27 18:20:16 +010061 verbose_schedule=False,
62 verbose_pareto_frontier_schedules=False,
63 use_ifm_streaming=True,
64 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020065 use_nhcwb16_between_cascaded_passes=True,
Tim Hall14e8a202020-11-27 12:23:42 +000066 cache_bias_scale_tensor=True,
Tim Hall79d07d22020-04-27 18:20:16 +010067 ):
68 self.use_cascading = use_cascading
Tim Hall79d07d22020-04-27 18:20:16 +010069 self.verbose_schedule = verbose_schedule
70 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
71 self.use_ifm_streaming = use_ifm_streaming
72 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020073 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Tim Hall14e8a202020-11-27 12:23:42 +000074 self.cache_bias_scale_tensor = cache_bias_scale_tensor
Tim Hall79d07d22020-04-27 18:20:16 +010075
76 def __str__(self):
77 return type(self).__name__ + ": " + str(self.__dict__)
78
79 __repr__ = __str__
80
81
82class Strategy:
83 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
84
85 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
86 self.strat = strat
87 self.param = param
88 self.passes = passes
89 self.block_configs = block_configs
90 self.rewrite_list = (
91 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
92 )
93 self.bws = bws
94 self.macs = macs
95 self.cycles = cycles
96 self.sram_used = sram_used
97
98 def __eq__(self, other):
99 if self.strat != other.strat:
100 return False
101 if self.param != other.param:
102 return False
103 if self.block_configs != other.block_configs:
104 return False
105 if self.passes != other.passes:
106 return False
107 if (self.bws != other.bws).any():
108 return False
Diqing Zhong69aadd02020-12-08 13:08:48 +0100109 if self.macs != other.macs:
Tim Hall79d07d22020-04-27 18:20:16 +0100110 return False
111 if (self.cycles != other.cycles).any():
112 return False
113 if self.sram_used != other.sram_used:
114 return False
115 return True
116
117 def empty(self):
118 return not self.passes
119
120 def key(self):
121 return self.passes[-1]
122
123 def clone(self):
124 return Strategy(
125 self.strat,
126 self.param,
127 self.passes,
128 self.block_configs,
129 self.rewrite_list,
130 self.bws,
131 self.macs,
132 self.cycles,
133 self.sram_used,
134 )
135
136 def __str__(self):
137 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
138 self.strat,
139 self.passes,
140 self.rewrite_list,
141 self.bws,
142 self.macs,
143 self.cycles,
144 self.sram_used,
145 )
146
147 __repr__ = __str__
148
149
150class StrategySet:
151 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
152
153 def __init__(self, strats=None):
154 if strats is None:
155 strats = dict()
156 self.strats = strats # final pass in packed pass -> Strategy
157 self.bws, self.macs, self.cycles = make_metrics_arrays()
158 self.max_sram_used = 0
159 self.total_sram_used = 0
160
161 def update_statistics(self):
162 self.bws = make_bandwidth_array()
163 self.max_sram_used = 0
164 for ps, strat in self.strats.items():
165 self.bws += strat.bws
166 self.macs += strat.macs
167 self.cycles += strat.cycles
168 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
169 self.total_sram_used += strat.sram_used
170
171 def clone_add_strategy(self, new_strat):
172 key = new_strat.key()
173 if key in self.strats:
174 assert new_strat == self.strats[key]
175 return self
176 else:
177 new_strats = dict(self.strats)
178 new_strats[key] = new_strat
179 new_set = StrategySet(new_strats)
180 new_set.bws = self.bws + new_strat.bws
181 new_set.macs = self.macs + new_strat.macs
182 new_set.cycles = self.cycles + new_strat.cycles
183 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
184 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
185 return new_set
186
187 def __eq__(self, other):
188 if (self.bws != other.bws).any():
189 return False
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100190 if self.macs != other.macs:
Tim Hall79d07d22020-04-27 18:20:16 +0100191 return False
192 if (self.cycles != other.cycles).any():
193 return False
194 if self.max_sram_used != other.max_sram_used:
195 return False
196 if self.total_sram_used != other.total_sram_used:
197 return False
198 if self.strats != other.strats:
199 return False
200 return True
201
202 def __str__(self):
203 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
204 self.max_sram_used,
205 list(ps.name for ps in self.strats),
206 )
207
208 __repr__ = __str__
209
210
211empty_strategy = Strategy(
Diqing Zhong69aadd02020-12-08 13:08:48 +0100212 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), 0, make_cycles_array(), 0
Tim Hall79d07d22020-04-27 18:20:16 +0100213)
214INFINITY = 1e30
215
216ABORT_SEARCH = []
217
218
219def flatten_list_of_lists(lstlst):
220 lst = []
221 for v in lstlst:
222 lst.extend(v)
223 return lst
224
225
226class DynamicProgrammingScheduler:
227 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
228 self.nng = nng
229 self.sg = sg
230 self.arch = arch
231 self.sram_limit = sram_limit
232 self.options = copy.copy(options)
233 self.use_cascading = options.use_cascading
234
235 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
236 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200237 else:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100238 self.use_ifm_ofm_overlap = True
Tim Hall79d07d22020-04-27 18:20:16 +0100239
240 self.verbose_schedule = options.verbose_schedule
241 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
242 self.mem_area = MemArea.Sram
243
244 self.bandwidth_weights = arch.bandwidth_weights
245 self.cycles_weight = arch.cycles_weight
246 self.max_sram_used_weight = arch.max_sram_used_weight
247
248 self.n_combinations_searched = 0
249
Tim Hall79d07d22020-04-27 18:20:16 +0100250 self.pareto_max_candidates = 16
251
252 self.ifm_stream_npu_blocks = set(
Louis Verhaardaee5d752020-09-30 09:01:52 +0200253 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
Tim Hall79d07d22020-04-27 18:20:16 +0100254 )
255
256 num_pareto_metrics = 4
257 view_values = ",".join(["d"] * num_pareto_metrics)
258 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
259
260 def pareto_metric(self, candidate):
261 strat, strat_set = candidate
262 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
263 bws = strat.bws + strat_set.bws
264 last_block_height = 0
265 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
266 last_block_height = strat.block_configs[-1][0]
267
268 return (
269 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
270 strat_set.max_sram_used,
271 strat.sram_used,
272 last_block_height,
273 )
274
275 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
276
277 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
278
279 if len(candidates) <= 1:
280 return candidates
281 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100282 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
283 ids = np.arange(len(candidates), dtype=np.int32)
284 for idx, cand in enumerate(candidates):
285 pareto_vals[idx] = self.pareto_metric(cand)
286
287 sort_order = np.argsort(
288 pareto_vals.view(DynamicProgrammingScheduler.view_values),
289 order=DynamicProgrammingScheduler.order_values,
290 axis=0,
291 kind="stable",
292 ).flatten()
293 pareto_vals = pareto_vals[sort_order]
294 ids = ids[sort_order]
295
296 pareto_frontier = []
297 while len(ids) > 0:
298 pareto_frontier.append(candidates[ids[0]])
299 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
300 ids = ids[not_dominated_by_first]
301 pareto_vals = pareto_vals[not_dominated_by_first]
302
303 if len(pareto_frontier) > self.pareto_max_candidates:
304 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
305 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
306
307 return pareto_frontier
308
309 def candidate_metric(self, candidate):
310 strat, strat_set = candidate
311 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
312 bws = strat.bws + strat_set.bws
313 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
314
315 return (
316 max_sram_used * self.max_sram_used_weight
317 + np.tensordot(bws, self.bandwidth_weights, axes=3)
318 + total_cycles * self.cycles_weight
319 )
320
321 def sort_by_candidate_metric(self, candidate_list):
322 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
323 return sorted_list
324
325 def best_candidate(self, candidate_list):
326 if len(candidate_list) == 0:
327 return ABORT_SEARCH
328 if len(candidate_list) == 1:
329 return candidate_list[0]
330 sorted_list = self.sort_by_candidate_metric(candidate_list)
331 return sorted_list[0]
332
333 def graduate_strat(self, strat_type, sram_used, old_strat_data):
334 res = []
335 for old_strat, old_strat_set in old_strat_data:
336 if old_strat.sram_used + sram_used > self.sram_limit:
337 continue # This strategy is bad, drop it
338 if old_strat_set.max_sram_used > self.sram_limit:
339 continue # This strategy is bad, drop it
340 assert old_strat.strat == SchedulingStrategy.Unknown
341
342 new_strat = old_strat.clone()
343 new_strat.strat = strat_type
344 new_strat.sram_used = old_strat.sram_used + sram_used
345
346 if self.use_ifm_ofm_overlap:
347 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
348 new_strat.strat, new_strat.passes, new_strat.block_configs
349 )
350 new_strat.sram_used -= overlap
351
352 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
353 res.append((empty_strategy, new_strat_set))
354 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
355
356 def append_sram(self, sram_used, old_strat_data):
357 res = []
358 for old_strat, strat_set in old_strat_data:
359 assert old_strat.strat == SchedulingStrategy.Unknown
360 assert old_strat.sram_used == 0
361 new_strat = old_strat.clone()
362 new_strat.sram_used = old_strat.sram_used + sram_used
363
364 res.append((new_strat, strat_set))
365 return res
366
367 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
368 res = []
369 for old_strat, strat_set in old_strat_data:
370 assert old_strat.strat == SchedulingStrategy.Unknown
371 new_strat = old_strat.clone()
372 bws, macs, cycles = metrics[:3]
373
374 new_strat.sram_used = old_strat.sram_used + sram_used
375 new_strat.block_configs = old_strat.block_configs + [block_config]
376 new_strat.bws = old_strat.bws + bws
377 new_strat.macs = old_strat.macs + macs
378 new_strat.cycles = old_strat.cycles + cycles
379 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
380 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
381 )
382
383 res.append((new_strat, strat_set))
384 return res
385
386 def append_sram_pass_block_config_performance_metrics_rewrite_list(
387 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
388 ):
389 res = []
390 for old_strat, strat_set in old_strat_data:
391 assert old_strat.strat == SchedulingStrategy.Unknown
392 new_strat = old_strat.clone()
393 bws, macs, cycles = metrics[:3]
394 new_strat.sram_used = old_strat.sram_used + sram_used
395 new_strat.block_configs = old_strat.block_configs + [block_config]
396 new_strat.bws = old_strat.bws + bws
397 new_strat.macs = old_strat.macs + macs
398 new_strat.cycles = old_strat.cycles + cycles
399 new_strat.passes = old_strat.passes + [new_pass]
400 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
401 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
402 )
403 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
404 res.append((new_strat, strat_set))
405 return res
406
407 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
408 res = []
409 for old_strat, strat_set in old_strat_data:
410 assert old_strat.strat == SchedulingStrategy.Unknown
411 new_strat = old_strat.clone()
412 new_strat.sram_used = old_strat.sram_used + sram_used
413 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
414 res.append((new_strat, strat_set))
415 return res
416
417 def pass_to_strat(self, strat_data):
418 res = {}
419 for strat in strat_data[1].strats.values():
420 for ps in strat.passes:
421 res[ps] = strat
422 return res
423
424 def compatible_strats(self, a, b):
425 intersection = a.keys() & b.keys()
426 for k in intersection:
427 if a[k] != b[k]:
428 return False
429 return True
430
431 def collate_strats_for_passes(self, all_passes):
432 if len(all_passes) == 0:
433 return [(empty_strategy, StrategySet(dict()))]
434 if len(all_passes) == 1:
435 return all_passes[0] # save some space in the common case
436 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
437 prev_combos = [dict()]
438 for j, strand in enumerate(all_strands):
439 new_combos = []
440 for i, alt in enumerate(strand):
441 for prev in prev_combos:
442 if self.compatible_strats(prev, alt):
443 cmb = dict(prev)
444 cmb.update(all_passes[j][i][1].strats)
445 new_combos.append(cmb)
446 prev_combos = new_combos
447
448 res = []
449 for d in prev_combos:
450 s = StrategySet(d)
451 s.update_statistics()
452 res.append((empty_strategy, s))
453 return res
454
455 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
456 # get the rest of the predecessors
457 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
458 other_predecessor_data = self.search_pass_list(other_predecessors)
459
460 # pred strat data has an incomplete strategy, which we need
461 # to continue on, whereas the other ones have completed strategies.
462 # we need to merge these, but keep the incomplete strategy too.
463
464 res = []
465 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
466 all_strats = [
467 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
468 other_predecessor_data, # this one is fine to use as-is
469 ]
470 collated_strat_data = self.collate_strats_for_passes(all_strats)
471 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
472 res.extend(strat_data)
473 return res
474
475 def calc_non_local_mem_usage(self):
476 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
477 range_set = live_range.extract_live_ranges_from_passes(
Patrik Gustavssona151f592020-10-16 13:59:52 +0200478 self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
Tim Hall79d07d22020-04-27 18:20:16 +0100479 )
480 range_dict = range_set.ranges
481
482 # find which ranges overlap passes but aren't input/outputs of the passes.
483 # these won't be counted by the dynamic programming search and must be counted in manually.
484 end_pos = max(ps.time for ps in self.sg.passes) + 2
485 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
486 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
487
488 for tens, rng in range_dict.items():
489 storage_size = tens.storage_size()
490 assert tens.mem_area == self.mem_area
491 mem_usage[rng.start_time : rng.end_time] += storage_size
492
493 for ps in self.sg.passes:
494 local_mem_usage = 0
495 for tens in ps.inputs + ps.outputs + ps.intermediates:
496 if tens.mem_area != self.mem_area:
497 continue
498
499 local_mem_usage += tens.storage_size()
500
501 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
502
503 self.non_local_mem_usage = non_local_mem_usage
504
505 def search(self):
506 self.calc_non_local_mem_usage()
507 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
508 strat_data = self.search_pass_list(starting_passes)
509
510 _, best_set = self.best_candidate(strat_data)
511
512 if self.verbose_pareto_frontier_schedules:
513 print(
514 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
Diqing Zhong504d6b62020-09-17 12:21:10 +0200515 % (self.n_combinations_searched, len(strat_data))
Tim Hall79d07d22020-04-27 18:20:16 +0100516 )
517 for idx, (_, strat_set) in enumerate(strat_data):
518 extra = ""
519 if strat_set == best_set:
520 extra = "(Best candidate)"
521 print("Candidate", idx, extra)
522 memory_used = {MemArea.Sram: strat_set.max_sram_used}
523 stats_writer.print_performance_metrics_for_strat(
524 self.arch,
525 "",
526 strat_set.cycles,
527 strat_set.macs,
528 strat_set.bws,
529 self.nng.batch_size,
530 memory_used,
Fredrik Svedberga5dd60b2021-04-19 13:56:59 +0200531 self.sg.min_mem_usage,
Tim Hall79d07d22020-04-27 18:20:16 +0100532 len(self.sg.passes),
533 len(strat_set.strats),
534 )
535
536 return best_set
537
538 def search_pass_list(self, pass_list):
539 all_strats = []
540 for ps in pass_list:
541 strat = self.search_output(ps)
542 all_strats.append(strat)
543 strat_data = self.collate_strats_for_passes(all_strats)
544 for strd in strat_data:
545 for ps in pass_list:
546 assert ps in strd[1].strats # should have strategies for everything we asked to search
547 return strat_data
548
549 def search_predecessors(self, ps):
550
551 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
552 # we have strats for all passes
553
554 pass_list = ps.dag_predecessors
555 strat_data = self.search_pass_list(pass_list)
556
557 return strat_data
558
559 @lru_cache(maxsize=None)
560 def search_output(self, ps):
561
562 assert ps in self.sg.passes
563 candidate_list = []
564
565 candidate_list.extend(self.search_weight_streaming_output(ps))
566
Patrik Gustavsson34b9dc12020-11-25 13:41:22 +0100567 if self.options.use_ifm_streaming:
Tim Hall79d07d22020-04-27 18:20:16 +0100568 candidate_list.extend(self.search_ifm_streaming_output(ps))
569
570 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
571
572 if not best:
573 print(
574 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
575 % (ps.name,)
576 )
577 return self.search_predecessors(ps)
578
579 return best
580
581 def search_ifm_streaming_output(self, ps):
582 if ps.placement != PassPlacement.Npu:
583 return ABORT_SEARCH
584 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
585 return ABORT_SEARCH
586 strat_data = self.search_ifm_streaming_body(ps, False)
587
588 sram_used = self.non_local_mem_usage[ps.time]
589 for tens in ps.outputs:
590 if tens.mem_area == self.mem_area:
591 sram_used += tens.storage_size()
592
593 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
594
595 @lru_cache(maxsize=None)
596 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
597 if ps.placement != PassPlacement.Npu:
598 return ABORT_SEARCH
599 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
600 return ABORT_SEARCH
601 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
602 res = []
603
604 base_sram_used = 0
605 for tens in ps.intermediates:
606 if tens.mem_area == self.mem_area:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200607 if tens.purpose == TensorPurpose.Weights:
608 base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
609 else:
610 base_sram_used += tens.storage_size()
Tim Hall79d07d22020-04-27 18:20:16 +0100611
612 all_block_configs = self.get_block_configs(ps)
613 for block_config in all_block_configs:
614 all_strats = []
615
616 if self.use_cascading:
617 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
618
619 all_strats.extend(ifm_input_search_resuls)
620
621 rewrite_list = []
622 sram_used = base_sram_used
623
624 metrics = npu_performance.performance_metrics_for_pass(
625 self.arch,
626 ps,
627 block_config,
628 rewrite_list=rewrite_list,
629 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
630 )
631
632 res.extend(
633 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
634 sram_used, ps, block_config, metrics, rewrite_list, all_strats
635 )
636 )
637
638 self.n_combinations_searched += len(res)
639 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
640 return res
641
Diqing Zhong504d6b62020-09-17 12:21:10 +0200642 def avoid_for_cascading(self, pred_candidate):
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200643 for op in pred_candidate.ops:
Diqing Zhong504d6b62020-09-17 12:21:10 +0200644 if (
Patrik Gustavsson2446e592021-02-11 08:36:12 +0100645 op.memory_function == Op.ConcatSliceWrite
Diqing Zhong504d6b62020-09-17 12:21:10 +0200646 and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
647 ):
648 # For SRAM spilling, concat op is avoided as predecessor
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200649 return True
Jacob Bohlin1a666972020-09-11 10:04:15 +0200650 if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
651 # The op has consumers in other subgraphs
652 return True
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200653 return False
654
Tim Hall79d07d22020-04-27 18:20:16 +0100655 def search_ifm_streaming_partial(self, ps, block_config):
656 if ps.placement != PassPlacement.Npu:
657 return ABORT_SEARCH
658
659 if len(ps.inputs) < 1:
660 return ABORT_SEARCH
661
662 ifm_tensor = ps.ifm_tensor
663
664 if ifm_tensor is None:
665 return ABORT_SEARCH
666 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
667 return ABORT_SEARCH
668 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
669 return ABORT_SEARCH
670
671 pred_pass_list = []
672 for pred_candidate in ps.dag_predecessors:
673 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
674 # we found a predecessor that produces this IFM tensor
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200675 if not ifm_tensor.needs_linear_format:
676 # and NHCWB16 can be used
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200677 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
678 # and it only has one successor, namely us
679 if pred_candidate.placement == PassPlacement.Npu:
680 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
681 # and it is on the Npu
Diqing Zhong504d6b62020-09-17 12:21:10 +0200682 if not self.avoid_for_cascading(pred_candidate):
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200683 # and fusable - it's a candidate
684 pred_pass_list.append(pred_candidate)
Tim Hall79d07d22020-04-27 18:20:16 +0100685
686 if not pred_pass_list:
687 return ABORT_SEARCH
688
689 all_candidates = []
690 for pred_pass in pred_pass_list:
691 # recurse into the next pass
Tim Hall1bd531d2020-11-01 20:59:36 +0000692 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.arch.is_spilling_enabled())
Tim Hall79d07d22020-04-27 18:20:16 +0100693
694 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
695 for strat_opt in strat_data:
696
697 pred_pass_block_config = strat_opt[0].block_configs[-1]
698 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
699 self.arch, pred_pass, pred_pass_block_config, ps, block_config
700 )
701 if rolling_buffer_dims is None:
702 continue # this does not pack properly, skip it.
703
704 sram_used = 0
705 for tens in ps.inputs:
706 if tens != ifm_tensor:
707 if tens.mem_area == self.mem_area:
708 sram_used += tens.storage_size()
709
710 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
711
712 rewrite_list = [
713 (
714 SchedulerRewrite.ChangeTensorSubPurpose,
715 ifm_tensor,
716 TensorSubPurpose.RollingBufferY,
717 rolling_buffer_y,
718 None,
719 ps,
720 )
721 ]
722 sram_used += ifm_tensor.storage_size_for_sub_purpose(
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200723 self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
Tim Hall79d07d22020-04-27 18:20:16 +0100724 )
725
726 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
727
728 self.n_combinations_searched += len(all_candidates)
729 return all_candidates
730
731 def get_block_configs(self, ps):
732 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100733 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100734
735 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
736
737 # Take a limited number of the largest blocks
738 if self.arch.block_config_limit > 0:
739 # Sort by block area, followed by depth
740 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
741 bound = min(len(block_configs), self.arch.block_config_limit)
742 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
743 tmp = block_configs[:bound]
744 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
745 block_configs = tmp
746
747 return block_configs
748
749 def search_ifm_streaming_input(self, ps):
750 sram_used = 0
751 for tens in ps.inputs:
752 if tens.mem_area == self.mem_area:
753 sram_used += tens.storage_size()
754
755 return self.append_sram(sram_used, self.search_predecessors(ps))
756
757 def search_weight_streaming_output(self, ps):
758 strat_data = self.search_weight_streaming_body(ps)
759
760 sram_used = self.non_local_mem_usage[ps.time]
761 for tens in ps.outputs:
762 if tens.mem_area == self.mem_area:
763 sram_used += tens.storage_size()
764
765 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
766
767 @lru_cache(maxsize=None)
768 def search_weight_streaming_body(self, ps):
769
770 strat_data = self.search_weight_streaming_input(ps)
771
772 res = []
773
774 all_block_configs = self.get_block_configs(ps)
775
776 for block_config in all_block_configs:
777
778 sram_used = 0
779 rewrite_list = []
780
781 for tens in ps.intermediates:
782 if tens.mem_area == self.mem_area:
783 if tens.purpose == TensorPurpose.Weights:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200784 sram_used += tens.storage_size_for_sub_purpose(
785 self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
786 )
Tim Hall79d07d22020-04-27 18:20:16 +0100787 rewrite_list.append(
788 (
789 SchedulerRewrite.ChangeTensorSubPurpose,
790 tens,
791 TensorSubPurpose.DoubleBuffer,
792 block_config[3],
793 None,
794 ps,
795 )
796 )
797 else:
798 sram_used += tens.storage_size()
799
800 metrics = npu_performance.performance_metrics_for_pass(
801 self.arch, ps, block_config, rewrite_list=rewrite_list
802 )
803
804 res.extend(
805 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
806 sram_used, ps, block_config, metrics, rewrite_list, strat_data
807 )
808 )
809
810 self.n_combinations_searched += len(res)
811 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
812 return res
813
814 def search_weight_streaming_input(self, ps):
815 sram_used = 0
816 for tens in ps.inputs:
817 if tens.mem_area == self.mem_area:
818 sram_used += tens.storage_size()
819
820 return self.append_sram(sram_used, self.search_predecessors(ps))
821
822 def apply_result(self, strat_set, arch):
823 pass_to_cascaded_pass = dict()
824 for _, strat in strat_set.strats.items():
825 # rewrite the tensors that need this first. e.g. make rolling buffers
826 inputs = []
827 intermediates = []
828 outputs = []
829
830 for ps in strat.passes:
831 inputs += ps.inputs
832 intermediates += ps.intermediates
833 outputs += ps.outputs
834
835 for tens in set(inputs) & set(outputs):
836 # tensors that are in both sets are intermediates
837
838 # find pass with input/output tensor, and check if they are both placed on NPU
839 input_placement = None
840 output_placement = None
841 for ps in strat.passes:
842 if tens in ps.inputs:
843 input_placement = ps.placement
844 if tens in ps.outputs:
845 output_placement = ps.placement
846 if input_placement == output_placement == PassPlacement.Npu:
847 tens.set_format(TensorFormat.NHCWB16, arch)
848
849 intermediates.append(tens)
850 inputs.remove(tens)
851 outputs.remove(tens)
852
853 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
854 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
855 tens.mem_area = self.arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200856 tens.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100857 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
858 else:
859 assert 0, "unknown rewrite_op " + str(rewrite_op)
860
861 is_element_wise = True
862 for ps in strat.passes:
863 assert ps.placement == strat.passes[0].placement
864 if not ps.is_element_wise:
865 is_element_wise = False
866 break
867
868 cascaded_pass = CascadedPass(
869 strat.passes[0].name,
870 strat.strat,
871 inputs,
872 intermediates,
873 outputs,
874 strat.passes,
875 strat.passes[0].placement,
876 is_element_wise,
877 )
878 assert strat.sram_used >= 0
879 cascaded_pass.sram_used = strat.sram_used
880
881 for idx, ps in enumerate(strat.passes):
882 assert ps not in pass_to_cascaded_pass
883 pass_to_cascaded_pass[ps] = cascaded_pass
884 ps.cascade = cascaded_pass
885 ps.block_config = strat.block_configs[idx]
886
887 if ps.placement == PassPlacement.Npu:
888 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
889 self.arch, ps, ps.block_config
890 )
891 assert ps.shared_buffer is not None
892
Diqing Zhong504d6b62020-09-17 12:21:10 +0200893 sram_used = max(self.non_local_mem_usage[ps.time], 0)
Tim Hall79d07d22020-04-27 18:20:16 +0100894 for op in ps.ops:
895 subgraph = op.attrs.get("subgraph")
896 if subgraph:
Diqing Zhong504d6b62020-09-17 12:21:10 +0200897 subgraph.base_sram_used = sram_used
Tim Hall79d07d22020-04-27 18:20:16 +0100898
899 # all passes should have a cascaded pass now
900 if len(pass_to_cascaded_pass) != len(self.sg.passes):
901 print(
902 "mismatch: we have %d passes, but only %d have cascaded passes associated"
903 % (len(self.sg.passes), len(pass_to_cascaded_pass))
904 )
905 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100906 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100907 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
908
909 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100910
Tim Hall79d07d22020-04-27 18:20:16 +0100911 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200912 if self.sg.placement == PassPlacement.Cpu:
913 # Retain the pass order for CPU subgraph
914 cascaded_passes = [ps.cascade for ps in self.sg.passes]
915 else:
916 # we have all the passes, but we need to put them in order and build predecessor/successor links.
917 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100918
Charles Xu19515e82020-06-10 10:48:33 +0200919 def visit_pass(ps):
920 if ps in visit_pass_set:
921 return
922 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100923
Charles Xu19515e82020-06-10 10:48:33 +0200924 cps = ps.cascade
925 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100926
Charles Xu19515e82020-06-10 10:48:33 +0200927 for ps in cps.passes:
928 for pred in ps.predecessors:
929 if pred in dont_traverse:
930 continue
931 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100932
Charles Xu19515e82020-06-10 10:48:33 +0200933 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100934
Charles Xu19515e82020-06-10 10:48:33 +0200935 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
936 for ps in starting_passes:
937 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100938
939 # reorder so startup init cascaded passes come first
940 def is_startup_cascaded_pass(cps):
941 if not cps.passes:
942 return False
943 return cps.placement == PassPlacement.StartupInit
944
945 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
946 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
947 ]
948
949 self.sg.cascaded_passes = cascaded_passes
950 self.sg.build_cascaded_pass_links()
951
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200952 # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
953 # (NHCWB16 within cascaded passes has been handled earlier in this function)
954 if self.sg.placement == PassPlacement.Npu:
955 # Dictionary tensor -> list of ops, containing feature maps that can be attempted
956 # to be moved to fast storage
957 fast_storage_tensor_rewrites = {}
958 last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
Fredrik Svedbergfd314282020-11-06 13:48:15 +0100959 # Memory only passes have no primary_op, so use the last op in ops
960 if last_op_in_subgraph is None:
961 last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].ops[-1]
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200962 for ps in self.sg.cascaded_passes:
963 if ps.placement != PassPlacement.Npu:
964 continue
965 for output in ps.outputs:
Patrik Gustavssond1836c72021-02-04 08:22:18 +0100966 if output.purpose != TensorPurpose.FeatureMap:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200967 continue
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200968
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200969 use_NHCWB16 = not output.needs_linear_format
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200970 use_fast_storage = True
971 rewrites = []
972 for op in output.consumer_list:
973 if op is None:
974 use_NHCWB16 = False
975 use_fast_storage = False
Charles Xu7b8823f2020-05-29 13:53:10 +0200976 continue
Louis Verhaardaee5d752020-09-30 09:01:52 +0200977 if op.type == Op.ReduceSum and output.dtype == DataType.int32:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200978 use_NHCWB16 = False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200979 elif op.type == Op.Reshape:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200980 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
981 # consumers do not also need to perform a reshape or if the OFM is going to
982 # be processed by CPU operations. No-op reshape consumers with empty lists
983 # (those that have no consumers, or null-consumers used as list terminators)
984 # must use normal NHWC output.
Fredrik Svedbergfd314282020-11-06 13:48:15 +0100985 def incompatible_consumers(oper):
986 if oper and oper.type == Op.Reshape:
987 for consumer in oper.outputs[0].consumer_list:
988 yield from incompatible_consumers(consumer)
989 yield not oper or not oper.run_on_npu or oper is last_op_in_subgraph
990
991 if not any(incompatible_consumers(op)):
992
993 def get_rewrites(oper):
994 if oper and oper.type == Op.Reshape:
995 for consumer in oper.outputs[0].consumer_list:
996 yield from get_rewrites(consumer)
997 yield oper
998
999 rewrites.extend(get_rewrites(op))
1000 # Detect no-op reshapes by comparing their full input and output tensor shapes.
Patrik Gustavsson2349d422020-12-01 16:02:29 +01001001 inshape = op.ifm_shapes[0]
1002 compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
Patrik Gustavssond1836c72021-02-04 08:22:18 +01001003 use_NHCWB16 &= compatible_shape and all(compatible_shape)
Tim Hallba695182020-08-26 17:27:19 +01001004 else:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001005 use_NHCWB16 = False
1006 use_fast_storage = False
1007 use_NHCWB16 &= op.run_on_npu
1008 use_fast_storage &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001009
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001010 if use_fast_storage:
1011 fast_storage_tensor_rewrites[output] = rewrites
1012 if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
1013 output.set_format(TensorFormat.NHCWB16, arch)
1014 for rewrite_op in rewrites:
1015 rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
Tim Hall1bd531d2020-11-01 20:59:36 +00001016 if arch.is_spilling_enabled():
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001017 # Remember feature maps that can be moved to fast storage for later use
1018 # in use_fast_storage_for_feature_maps
1019 self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001020
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001021
1022def move_scales_to_fast_storage(nng, arch):
1023 for sg in nng.subgraphs:
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001024 # IFM streamed ops reads bias tensors several times, move these to fast storage
1025 for cp in sg.cascaded_passes:
1026 if cp.strategy == SchedulingStrategy.IfmStream:
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001027 # Calculate SRAM usage
1028 new_size = 0
1029 all_tens = []
1030 for ps in cp.passes:
1031 pass_tens = np.array([ps.ifm_tensor, ps.ifm2_tensor, ps.ofm_tensor, ps.weight_tensor])
1032 pass_tens = np.append(pass_tens, ps.intermediates)
1033 for tens in pass_tens:
1034 if tens and tens.mem_area == MemArea.Sram and tens not in all_tens:
1035 all_tens.append(tens)
1036 new_size += tens.storage_size()
1037
1038 cp.sram_used = new_size
1039
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001040 for ps in cp.passes:
Andreas Nevalainened67b882020-11-17 09:16:11 +01001041 if ps.scale_tensor:
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001042 tens = ps.scale_tensor
1043
1044 # Find op using scale tensor
1045 op = next((op for op in ps.ops if tens in op.inputs), None)
1046 assert op
1047
1048 # Create fast storage tensor
1049 new_tens = tens.clone_into_fast_storage(arch)
1050 new_tens.consumer_list = tens.consumer_list.copy()
1051 new_tens.purpose = TensorPurpose.FSBias
Andreas Nevalainened67b882020-11-17 09:16:11 +01001052 new_tens_size = new_tens.storage_size()
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001053
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001054 if (cp.sram_used + new_tens_size) <= arch.sram_size:
Andreas Nevalainened67b882020-11-17 09:16:11 +01001055 # Create DMA cmd
1056 dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
1057 dma_cmd.inputs = [tens]
1058 dma_cmd.set_output_tensor(new_tens)
1059 dma_cmd.attrs["source"] = tens.mem_area
1060 dma_cmd.attrs["destination"] = new_tens.mem_area
1061 dma_cmd.run_on_npu = True
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001062
Andreas Nevalainened67b882020-11-17 09:16:11 +01001063 tens.consumer_list.clear()
1064 tens.consumer_list.append(dma_cmd)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001065
Andreas Nevalainened67b882020-11-17 09:16:11 +01001066 # Replace tensor and op
1067 idx = op.inputs.index(tens)
1068 op.inputs[idx] = new_tens
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001069
Andreas Nevalainened67b882020-11-17 09:16:11 +01001070 ps.ops.insert(0, dma_cmd)
1071 ps.scale_tensor = new_tens
1072 ps.intermediates.append(new_tens)
1073 ps.cascade.intermediates.append(new_tens)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001074
Andreas Nevalainened67b882020-11-17 09:16:11 +01001075 cp.sram_used += new_tens_size
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001076
Tim Hall79d07d22020-04-27 18:20:16 +01001077
1078def schedule_passes(nng, arch, options: SchedulerOptions):
1079
1080 for sg in nng.subgraphs:
1081 sg.base_sram_used = 0
1082
1083 for sg in nng.subgraphs:
1084 # re-entering the same nodes from different contexts requires us to
1085 # build a simplified directed acyclic (DAG) version of the graph to
1086 # use for traversal, rather than using a visit dictionary. this avoids
1087 # recursing infinitely due to loops.
1088 sg.build_pass_dag_predecessors()
1089
1090 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
1091
1092 strat_set = dps.search()
1093
1094 dps.apply_result(strat_set, arch)
1095
1096 if options.verbose_schedule:
1097 sg.print_cascaded_passes()
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001098
1099
1100def _calc_tens_to_cps(sg, tensor_rewrites):
1101 # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
1102 # Returns dictionary tensor -> list of cascaded passes
1103 # Note: if cascaded passes are A, B, C, D, and a tensor is output
1104 # of A and input to D, then it also consumes SRAM in passes B and C.
1105 if "tens_to_cps" in sg.scheduling_info:
1106 return sg.scheduling_info["tens_to_cps"]
1107 # Determine life-time of tensors
1108 min_index = {}
1109 max_index = {}
1110 index = 0
1111 cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
1112 for cps in cps_list:
1113 for tens in cps.inputs + cps.outputs:
1114 if tens in tensor_rewrites:
1115 min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
1116 max_index[tens] = index
1117 index += 1
1118 # Convert to affected cps-es
1119 tens_to_cps = {}
1120 for tens in min_index:
1121 tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
1122 sg.scheduling_info["tens_to_cps"] = tens_to_cps
1123 return tens_to_cps
1124
1125
1126def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
1127 # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
1128 tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
1129 tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
1130 # Sort tensors first on life-time (smallest first), then on size (biggest first)
1131 tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
1132 for _, _, _, tens in tens_list:
1133 cps_list = tens_to_cps[tens]
Fredrik Svedbergfd314282020-11-06 13:48:15 +01001134 if len(cps_list) < 1:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001135 continue
1136 sz = tens.storage_size()
1137 fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
1138 if fits_in_fast_storage:
1139 tens.mem_area = arch.fast_storage_mem_area
1140 tens.mem_type = MemType.Scratch_fast
1141 tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
1142 assert tens in tensor_rewrites
1143 # Also rewrite reshapes
1144 for rewrite_op in tensor_rewrites[tens]:
1145 tens2 = rewrite_op.outputs[0]
1146 tens2.mem_area = arch.fast_storage_mem_area
1147 tens2.mem_type = MemType.Scratch_fast
1148 tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
1149 for cps in cps_list:
1150 cps.sram_used += sz
1151
1152
1153def undo_use_fast_storage(sg, arch):
1154 # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
1155 tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
1156 tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
1157 mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
1158 for tens, cps_list in tens_to_cps.items():
1159 if tens.mem_type == MemType.Scratch_fast:
1160 sz = tens.storage_size()
1161 tens.mem_area = mem_area
1162 tens.mem_type = MemType.Scratch
1163 # Also undo reshapes
1164 for rewrite_op in tensor_rewrites[tens]:
1165 tens2 = rewrite_op.outputs[0]
1166 tens2.mem_area = mem_area
1167 tens2.mem_type = MemType.Scratch
1168 for cps in cps_list:
1169 cps.sram_used -= sz