blob: 526cc0e9b4b4a41eace6e01b4603836172c170a5 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
Diego Russoea6111a2020-04-14 18:41:58 +010018import copy
Diego Russoe8a10452020-04-21 17:39:10 +010019import enum
20from functools import lru_cache
Diego Russoea6111a2020-04-14 18:41:58 +010021
Tim Hall79d07d22020-04-27 18:20:16 +010022import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010023
24from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010025from . import npu_performance
26from . import stats_writer
Fredrik Svedberg880e7352020-08-25 11:31:47 +020027from .data_type import DataType
Tim Hall79d07d22020-04-27 18:20:16 +010028from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
Diego Russoe8a10452020-04-21 17:39:10 +010029from .nn_graph import CascadedPass
30from .nn_graph import PassPlacement
31from .nn_graph import SchedulerRewrite
32from .nn_graph import SchedulingStrategy
33from .npu_performance import make_bandwidth_array
34from .npu_performance import make_cycles_array
35from .npu_performance import make_macs_array
36from .npu_performance import make_metrics_arrays
37from .npu_performance import PassCycles
Jacob Bohlin1a666972020-09-11 10:04:15 +020038from .numeric_util import full_shape
Diego Russoe8a10452020-04-21 17:39:10 +010039from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Andreas Nevalainen897cc142020-10-28 15:42:08 +010041from .operation import Operation
Diego Russoe8a10452020-04-21 17:39:10 +010042from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
43from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
44from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020045from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010046from .tensor import TensorFormat
47from .tensor import TensorPurpose
48from .tensor import TensorSubPurpose
Jacob Bohlin1a666972020-09-11 10:04:15 +020049
Tim Hall79d07d22020-04-27 18:20:16 +010050
51class ParetoMetric(enum.Enum):
52 BwCycMem = 1
53 BwCycMemBlkH = 2
54
55 def __str__(self):
56 return self.name
57
58
59class SchedulerOptions:
60 def __init__(
61 self,
62 use_cascading=True,
Tim Hall79d07d22020-04-27 18:20:16 +010063 verbose_schedule=False,
64 verbose_pareto_frontier_schedules=False,
65 use_ifm_streaming=True,
66 pareto_metric=ParetoMetric.BwCycMem,
Charles Xu7b8823f2020-05-29 13:53:10 +020067 use_nhcwb16_between_cascaded_passes=True,
Andreas Nevalainen897cc142020-10-28 15:42:08 +010068 keep_scale_placement=False,
Tim Hall79d07d22020-04-27 18:20:16 +010069 ):
70 self.use_cascading = use_cascading
Tim Hall79d07d22020-04-27 18:20:16 +010071 self.verbose_schedule = verbose_schedule
72 self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
73 self.use_ifm_streaming = use_ifm_streaming
74 self.pareto_metric = pareto_metric
Charles Xu7b8823f2020-05-29 13:53:10 +020075 self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
Andreas Nevalainen897cc142020-10-28 15:42:08 +010076 self.keep_scale_placement = keep_scale_placement
Tim Hall79d07d22020-04-27 18:20:16 +010077
78 def __str__(self):
79 return type(self).__name__ + ": " + str(self.__dict__)
80
81 __repr__ = __str__
82
83
84class Strategy:
85 __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
86
87 def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
88 self.strat = strat
89 self.param = param
90 self.passes = passes
91 self.block_configs = block_configs
92 self.rewrite_list = (
93 rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
94 )
95 self.bws = bws
96 self.macs = macs
97 self.cycles = cycles
98 self.sram_used = sram_used
99
100 def __eq__(self, other):
101 if self.strat != other.strat:
102 return False
103 if self.param != other.param:
104 return False
105 if self.block_configs != other.block_configs:
106 return False
107 if self.passes != other.passes:
108 return False
109 if (self.bws != other.bws).any():
110 return False
111 if (self.macs != other.macs).any():
112 return False
113 if (self.cycles != other.cycles).any():
114 return False
115 if self.sram_used != other.sram_used:
116 return False
117 return True
118
119 def empty(self):
120 return not self.passes
121
122 def key(self):
123 return self.passes[-1]
124
125 def clone(self):
126 return Strategy(
127 self.strat,
128 self.param,
129 self.passes,
130 self.block_configs,
131 self.rewrite_list,
132 self.bws,
133 self.macs,
134 self.cycles,
135 self.sram_used,
136 )
137
138 def __str__(self):
139 return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
140 self.strat,
141 self.passes,
142 self.rewrite_list,
143 self.bws,
144 self.macs,
145 self.cycles,
146 self.sram_used,
147 )
148
149 __repr__ = __str__
150
151
152class StrategySet:
153 __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
154
155 def __init__(self, strats=None):
156 if strats is None:
157 strats = dict()
158 self.strats = strats # final pass in packed pass -> Strategy
159 self.bws, self.macs, self.cycles = make_metrics_arrays()
160 self.max_sram_used = 0
161 self.total_sram_used = 0
162
163 def update_statistics(self):
164 self.bws = make_bandwidth_array()
165 self.max_sram_used = 0
166 for ps, strat in self.strats.items():
167 self.bws += strat.bws
168 self.macs += strat.macs
169 self.cycles += strat.cycles
170 self.max_sram_used = max(self.max_sram_used, strat.sram_used)
171 self.total_sram_used += strat.sram_used
172
173 def clone_add_strategy(self, new_strat):
174 key = new_strat.key()
175 if key in self.strats:
176 assert new_strat == self.strats[key]
177 return self
178 else:
179 new_strats = dict(self.strats)
180 new_strats[key] = new_strat
181 new_set = StrategySet(new_strats)
182 new_set.bws = self.bws + new_strat.bws
183 new_set.macs = self.macs + new_strat.macs
184 new_set.cycles = self.cycles + new_strat.cycles
185 new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
186 new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
187 return new_set
188
189 def __eq__(self, other):
190 if (self.bws != other.bws).any():
191 return False
192 if (self.macs != other.macs).any():
193 return False
194 if (self.cycles != other.cycles).any():
195 return False
196 if self.max_sram_used != other.max_sram_used:
197 return False
198 if self.total_sram_used != other.total_sram_used:
199 return False
200 if self.strats != other.strats:
201 return False
202 return True
203
204 def __str__(self):
205 return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
206 self.max_sram_used,
207 list(ps.name for ps in self.strats),
208 )
209
210 __repr__ = __str__
211
212
213empty_strategy = Strategy(
214 SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
215)
216INFINITY = 1e30
217
218ABORT_SEARCH = []
219
220
221def flatten_list_of_lists(lstlst):
222 lst = []
223 for v in lstlst:
224 lst.extend(v)
225 return lst
226
227
228class DynamicProgrammingScheduler:
229 def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
230 self.nng = nng
231 self.sg = sg
232 self.arch = arch
233 self.sram_limit = sram_limit
234 self.options = copy.copy(options)
235 self.use_cascading = options.use_cascading
236
237 if self.arch.feature_map_storage_mem_area != MemArea.Sram:
238 self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
Patrik Gustavsson3ab94522020-06-29 17:36:55 +0200239 else:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100240 self.use_ifm_ofm_overlap = True
Tim Hall79d07d22020-04-27 18:20:16 +0100241
242 self.verbose_schedule = options.verbose_schedule
243 self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
244 self.mem_area = MemArea.Sram
245
246 self.bandwidth_weights = arch.bandwidth_weights
247 self.cycles_weight = arch.cycles_weight
248 self.max_sram_used_weight = arch.max_sram_used_weight
249
250 self.n_combinations_searched = 0
251
252 self.feature_maps_not_in_fast_storage = (
253 arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
254 )
255
256 self.pareto_max_candidates = 16
257
258 self.ifm_stream_npu_blocks = set(
Louis Verhaardaee5d752020-09-30 09:01:52 +0200259 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
Tim Hall79d07d22020-04-27 18:20:16 +0100260 )
261
262 num_pareto_metrics = 4
263 view_values = ",".join(["d"] * num_pareto_metrics)
264 order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
265
266 def pareto_metric(self, candidate):
267 strat, strat_set = candidate
268 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
269 bws = strat.bws + strat_set.bws
270 last_block_height = 0
271 if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
272 last_block_height = strat.block_configs[-1][0]
273
274 return (
275 np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
276 strat_set.max_sram_used,
277 strat.sram_used,
278 last_block_height,
279 )
280
281 def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
282
283 candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
284
285 if len(candidates) <= 1:
286 return candidates
287 assert remove_equally_good_candidates
Tim Hall79d07d22020-04-27 18:20:16 +0100288 pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
289 ids = np.arange(len(candidates), dtype=np.int32)
290 for idx, cand in enumerate(candidates):
291 pareto_vals[idx] = self.pareto_metric(cand)
292
293 sort_order = np.argsort(
294 pareto_vals.view(DynamicProgrammingScheduler.view_values),
295 order=DynamicProgrammingScheduler.order_values,
296 axis=0,
297 kind="stable",
298 ).flatten()
299 pareto_vals = pareto_vals[sort_order]
300 ids = ids[sort_order]
301
302 pareto_frontier = []
303 while len(ids) > 0:
304 pareto_frontier.append(candidates[ids[0]])
305 not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
306 ids = ids[not_dominated_by_first]
307 pareto_vals = pareto_vals[not_dominated_by_first]
308
309 if len(pareto_frontier) > self.pareto_max_candidates:
310 pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
311 pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
312
313 return pareto_frontier
314
315 def candidate_metric(self, candidate):
316 strat, strat_set = candidate
317 max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
318 bws = strat.bws + strat_set.bws
319 total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
320
321 return (
322 max_sram_used * self.max_sram_used_weight
323 + np.tensordot(bws, self.bandwidth_weights, axes=3)
324 + total_cycles * self.cycles_weight
325 )
326
327 def sort_by_candidate_metric(self, candidate_list):
328 sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
329 return sorted_list
330
331 def best_candidate(self, candidate_list):
332 if len(candidate_list) == 0:
333 return ABORT_SEARCH
334 if len(candidate_list) == 1:
335 return candidate_list[0]
336 sorted_list = self.sort_by_candidate_metric(candidate_list)
337 return sorted_list[0]
338
339 def graduate_strat(self, strat_type, sram_used, old_strat_data):
340 res = []
341 for old_strat, old_strat_set in old_strat_data:
342 if old_strat.sram_used + sram_used > self.sram_limit:
343 continue # This strategy is bad, drop it
344 if old_strat_set.max_sram_used > self.sram_limit:
345 continue # This strategy is bad, drop it
346 assert old_strat.strat == SchedulingStrategy.Unknown
347
348 new_strat = old_strat.clone()
349 new_strat.strat = strat_type
350 new_strat.sram_used = old_strat.sram_used + sram_used
351
352 if self.use_ifm_ofm_overlap:
353 overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
354 new_strat.strat, new_strat.passes, new_strat.block_configs
355 )
356 new_strat.sram_used -= overlap
357
358 new_strat_set = old_strat_set.clone_add_strategy(new_strat)
359 res.append((empty_strategy, new_strat_set))
360 return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
361
362 def append_sram(self, sram_used, old_strat_data):
363 res = []
364 for old_strat, strat_set in old_strat_data:
365 assert old_strat.strat == SchedulingStrategy.Unknown
366 assert old_strat.sram_used == 0
367 new_strat = old_strat.clone()
368 new_strat.sram_used = old_strat.sram_used + sram_used
369
370 res.append((new_strat, strat_set))
371 return res
372
373 def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
374 res = []
375 for old_strat, strat_set in old_strat_data:
376 assert old_strat.strat == SchedulingStrategy.Unknown
377 new_strat = old_strat.clone()
378 bws, macs, cycles = metrics[:3]
379
380 new_strat.sram_used = old_strat.sram_used + sram_used
381 new_strat.block_configs = old_strat.block_configs + [block_config]
382 new_strat.bws = old_strat.bws + bws
383 new_strat.macs = old_strat.macs + macs
384 new_strat.cycles = old_strat.cycles + cycles
385 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
386 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
387 )
388
389 res.append((new_strat, strat_set))
390 return res
391
392 def append_sram_pass_block_config_performance_metrics_rewrite_list(
393 self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
394 ):
395 res = []
396 for old_strat, strat_set in old_strat_data:
397 assert old_strat.strat == SchedulingStrategy.Unknown
398 new_strat = old_strat.clone()
399 bws, macs, cycles = metrics[:3]
400 new_strat.sram_used = old_strat.sram_used + sram_used
401 new_strat.block_configs = old_strat.block_configs + [block_config]
402 new_strat.bws = old_strat.bws + bws
403 new_strat.macs = old_strat.macs + macs
404 new_strat.cycles = old_strat.cycles + cycles
405 new_strat.passes = old_strat.passes + [new_pass]
406 new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
407 self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
408 )
409 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
410 res.append((new_strat, strat_set))
411 return res
412
413 def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
414 res = []
415 for old_strat, strat_set in old_strat_data:
416 assert old_strat.strat == SchedulingStrategy.Unknown
417 new_strat = old_strat.clone()
418 new_strat.sram_used = old_strat.sram_used + sram_used
419 new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
420 res.append((new_strat, strat_set))
421 return res
422
423 def pass_to_strat(self, strat_data):
424 res = {}
425 for strat in strat_data[1].strats.values():
426 for ps in strat.passes:
427 res[ps] = strat
428 return res
429
430 def compatible_strats(self, a, b):
431 intersection = a.keys() & b.keys()
432 for k in intersection:
433 if a[k] != b[k]:
434 return False
435 return True
436
437 def collate_strats_for_passes(self, all_passes):
438 if len(all_passes) == 0:
439 return [(empty_strategy, StrategySet(dict()))]
440 if len(all_passes) == 1:
441 return all_passes[0] # save some space in the common case
442 all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
443 prev_combos = [dict()]
444 for j, strand in enumerate(all_strands):
445 new_combos = []
446 for i, alt in enumerate(strand):
447 for prev in prev_combos:
448 if self.compatible_strats(prev, alt):
449 cmb = dict(prev)
450 cmb.update(all_passes[j][i][1].strats)
451 new_combos.append(cmb)
452 prev_combos = new_combos
453
454 res = []
455 for d in prev_combos:
456 s = StrategySet(d)
457 s.update_statistics()
458 res.append((empty_strategy, s))
459 return res
460
461 def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
462 # get the rest of the predecessors
463 other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
464 other_predecessor_data = self.search_pass_list(other_predecessors)
465
466 # pred strat data has an incomplete strategy, which we need
467 # to continue on, whereas the other ones have completed strategies.
468 # we need to merge these, but keep the incomplete strategy too.
469
470 res = []
471 for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
472 all_strats = [
473 [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
474 other_predecessor_data, # this one is fine to use as-is
475 ]
476 collated_strat_data = self.collate_strats_for_passes(all_strats)
477 strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
478 res.extend(strat_data)
479 return res
480
481 def calc_non_local_mem_usage(self):
482 ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
483 range_set = live_range.extract_live_ranges_from_passes(
Patrik Gustavssona151f592020-10-16 13:59:52 +0200484 self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
Tim Hall79d07d22020-04-27 18:20:16 +0100485 )
486 range_dict = range_set.ranges
487
488 # find which ranges overlap passes but aren't input/outputs of the passes.
489 # these won't be counted by the dynamic programming search and must be counted in manually.
490 end_pos = max(ps.time for ps in self.sg.passes) + 2
491 mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
492 non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
493
494 for tens, rng in range_dict.items():
495 storage_size = tens.storage_size()
496 assert tens.mem_area == self.mem_area
497 mem_usage[rng.start_time : rng.end_time] += storage_size
498
499 for ps in self.sg.passes:
500 local_mem_usage = 0
501 for tens in ps.inputs + ps.outputs + ps.intermediates:
502 if tens.mem_area != self.mem_area:
503 continue
504
505 local_mem_usage += tens.storage_size()
506
507 non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
508
509 self.non_local_mem_usage = non_local_mem_usage
510
511 def search(self):
512 self.calc_non_local_mem_usage()
513 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
514 strat_data = self.search_pass_list(starting_passes)
515
516 _, best_set = self.best_candidate(strat_data)
517
518 if self.verbose_pareto_frontier_schedules:
519 print(
520 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
Diqing Zhong504d6b62020-09-17 12:21:10 +0200521 % (self.n_combinations_searched, len(strat_data))
Tim Hall79d07d22020-04-27 18:20:16 +0100522 )
523 for idx, (_, strat_set) in enumerate(strat_data):
524 extra = ""
525 if strat_set == best_set:
526 extra = "(Best candidate)"
527 print("Candidate", idx, extra)
528 memory_used = {MemArea.Sram: strat_set.max_sram_used}
529 stats_writer.print_performance_metrics_for_strat(
530 self.arch,
531 "",
532 strat_set.cycles,
533 strat_set.macs,
534 strat_set.bws,
535 self.nng.batch_size,
536 memory_used,
537 len(self.sg.passes),
538 len(strat_set.strats),
539 )
540
541 return best_set
542
543 def search_pass_list(self, pass_list):
544 all_strats = []
545 for ps in pass_list:
546 strat = self.search_output(ps)
547 all_strats.append(strat)
548 strat_data = self.collate_strats_for_passes(all_strats)
549 for strd in strat_data:
550 for ps in pass_list:
551 assert ps in strd[1].strats # should have strategies for everything we asked to search
552 return strat_data
553
554 def search_predecessors(self, ps):
555
556 # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
557 # we have strats for all passes
558
559 pass_list = ps.dag_predecessors
560 strat_data = self.search_pass_list(pass_list)
561
562 return strat_data
563
564 @lru_cache(maxsize=None)
565 def search_output(self, ps):
566
567 assert ps in self.sg.passes
568 candidate_list = []
569
570 candidate_list.extend(self.search_weight_streaming_output(ps))
571
572 if self.options.use_ifm_streaming:
573 candidate_list.extend(self.search_ifm_streaming_output(ps))
574
575 best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
576
577 if not best:
578 print(
579 "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
580 % (ps.name,)
581 )
582 return self.search_predecessors(ps)
583
584 return best
585
586 def search_ifm_streaming_output(self, ps):
587 if ps.placement != PassPlacement.Npu:
588 return ABORT_SEARCH
589 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
590 return ABORT_SEARCH
591 strat_data = self.search_ifm_streaming_body(ps, False)
592
593 sram_used = self.non_local_mem_usage[ps.time]
594 for tens in ps.outputs:
595 if tens.mem_area == self.mem_area:
596 sram_used += tens.storage_size()
597
598 return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
599
600 @lru_cache(maxsize=None)
601 def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
602 if ps.placement != PassPlacement.Npu:
603 return ABORT_SEARCH
604 if ps.npu_block_type not in self.ifm_stream_npu_blocks:
605 return ABORT_SEARCH
606 ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
607 res = []
608
609 base_sram_used = 0
610 for tens in ps.intermediates:
611 if tens.mem_area == self.mem_area:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200612 if tens.purpose == TensorPurpose.Weights:
613 base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
614 else:
615 base_sram_used += tens.storage_size()
Tim Hall79d07d22020-04-27 18:20:16 +0100616
617 all_block_configs = self.get_block_configs(ps)
618 for block_config in all_block_configs:
619 all_strats = []
620
621 if self.use_cascading:
622 all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
623
624 all_strats.extend(ifm_input_search_resuls)
625
626 rewrite_list = []
627 sram_used = base_sram_used
628
629 metrics = npu_performance.performance_metrics_for_pass(
630 self.arch,
631 ps,
632 block_config,
633 rewrite_list=rewrite_list,
634 force_outputs_to_fast_storage=force_outputs_to_fast_storage,
635 )
636
637 res.extend(
638 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
639 sram_used, ps, block_config, metrics, rewrite_list, all_strats
640 )
641 )
642
643 self.n_combinations_searched += len(res)
644 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
645 return res
646
Diqing Zhong504d6b62020-09-17 12:21:10 +0200647 def avoid_for_cascading(self, pred_candidate):
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200648 for op in pred_candidate.ops:
Diqing Zhong504d6b62020-09-17 12:21:10 +0200649 if (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200650 op.type == Op.ConcatSliceWrite
Diqing Zhong504d6b62020-09-17 12:21:10 +0200651 and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
652 ):
653 # For SRAM spilling, concat op is avoided as predecessor
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200654 return True
Jacob Bohlin1a666972020-09-11 10:04:15 +0200655 if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
656 # The op has consumers in other subgraphs
657 return True
Patrik Gustavssonc0bb8992020-08-11 16:45:35 +0200658 return False
659
Tim Hall79d07d22020-04-27 18:20:16 +0100660 def search_ifm_streaming_partial(self, ps, block_config):
661 if ps.placement != PassPlacement.Npu:
662 return ABORT_SEARCH
663
664 if len(ps.inputs) < 1:
665 return ABORT_SEARCH
666
667 ifm_tensor = ps.ifm_tensor
668
669 if ifm_tensor is None:
670 return ABORT_SEARCH
671 if ifm_tensor.purpose != TensorPurpose.FeatureMap:
672 return ABORT_SEARCH
673 if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
674 return ABORT_SEARCH
675
676 pred_pass_list = []
677 for pred_candidate in ps.dag_predecessors:
678 if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
679 # we found a predecessor that produces this IFM tensor
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200680 if not ifm_tensor.avoid_NHCWB16:
681 # and NHCWB16 format is not to be avoided
682 if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
683 # and it only has one successor, namely us
684 if pred_candidate.placement == PassPlacement.Npu:
685 if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
686 # and it is on the Npu
Diqing Zhong504d6b62020-09-17 12:21:10 +0200687 if not self.avoid_for_cascading(pred_candidate):
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200688 # and fusable - it's a candidate
689 pred_pass_list.append(pred_candidate)
Tim Hall79d07d22020-04-27 18:20:16 +0100690
691 if not pred_pass_list:
692 return ABORT_SEARCH
693
694 all_candidates = []
695 for pred_pass in pred_pass_list:
696 # recurse into the next pass
697 ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
698
699 strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
700 for strat_opt in strat_data:
701
702 pred_pass_block_config = strat_opt[0].block_configs[-1]
703 rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
704 self.arch, pred_pass, pred_pass_block_config, ps, block_config
705 )
706 if rolling_buffer_dims is None:
707 continue # this does not pack properly, skip it.
708
709 sram_used = 0
710 for tens in ps.inputs:
711 if tens != ifm_tensor:
712 if tens.mem_area == self.mem_area:
713 sram_used += tens.storage_size()
714
715 rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
716
717 rewrite_list = [
718 (
719 SchedulerRewrite.ChangeTensorSubPurpose,
720 ifm_tensor,
721 TensorSubPurpose.RollingBufferY,
722 rolling_buffer_y,
723 None,
724 ps,
725 )
726 ]
727 sram_used += ifm_tensor.storage_size_for_sub_purpose(
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200728 self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
Tim Hall79d07d22020-04-27 18:20:16 +0100729 )
730
731 all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
732
733 self.n_combinations_searched += len(all_candidates)
734 return all_candidates
735
736 def get_block_configs(self, ps):
737 if ps.placement != PassPlacement.Npu:
Diego Russoea6111a2020-04-14 18:41:58 +0100738 return [(1, 1, 1, 1)] # default
Tim Hall79d07d22020-04-27 18:20:16 +0100739
740 block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
741
742 # Take a limited number of the largest blocks
743 if self.arch.block_config_limit > 0:
744 # Sort by block area, followed by depth
745 block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
746 bound = min(len(block_configs), self.arch.block_config_limit)
747 # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
748 tmp = block_configs[:bound]
749 tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
750 block_configs = tmp
751
752 return block_configs
753
754 def search_ifm_streaming_input(self, ps):
755 sram_used = 0
756 for tens in ps.inputs:
757 if tens.mem_area == self.mem_area:
758 sram_used += tens.storage_size()
759
760 return self.append_sram(sram_used, self.search_predecessors(ps))
761
762 def search_weight_streaming_output(self, ps):
763 strat_data = self.search_weight_streaming_body(ps)
764
765 sram_used = self.non_local_mem_usage[ps.time]
766 for tens in ps.outputs:
767 if tens.mem_area == self.mem_area:
768 sram_used += tens.storage_size()
769
770 return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
771
772 @lru_cache(maxsize=None)
773 def search_weight_streaming_body(self, ps):
774
775 strat_data = self.search_weight_streaming_input(ps)
776
777 res = []
778
779 all_block_configs = self.get_block_configs(ps)
780
781 for block_config in all_block_configs:
782
783 sram_used = 0
784 rewrite_list = []
785
786 for tens in ps.intermediates:
787 if tens.mem_area == self.mem_area:
788 if tens.purpose == TensorPurpose.Weights:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200789 sram_used += tens.storage_size_for_sub_purpose(
790 self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
791 )
Tim Hall79d07d22020-04-27 18:20:16 +0100792 rewrite_list.append(
793 (
794 SchedulerRewrite.ChangeTensorSubPurpose,
795 tens,
796 TensorSubPurpose.DoubleBuffer,
797 block_config[3],
798 None,
799 ps,
800 )
801 )
802 else:
803 sram_used += tens.storage_size()
804
805 metrics = npu_performance.performance_metrics_for_pass(
806 self.arch, ps, block_config, rewrite_list=rewrite_list
807 )
808
809 res.extend(
810 self.append_sram_pass_block_config_performance_metrics_rewrite_list(
811 sram_used, ps, block_config, metrics, rewrite_list, strat_data
812 )
813 )
814
815 self.n_combinations_searched += len(res)
816 res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
817 return res
818
819 def search_weight_streaming_input(self, ps):
820 sram_used = 0
821 for tens in ps.inputs:
822 if tens.mem_area == self.mem_area:
823 sram_used += tens.storage_size()
824
825 return self.append_sram(sram_used, self.search_predecessors(ps))
826
827 def apply_result(self, strat_set, arch):
828 pass_to_cascaded_pass = dict()
829 for _, strat in strat_set.strats.items():
830 # rewrite the tensors that need this first. e.g. make rolling buffers
831 inputs = []
832 intermediates = []
833 outputs = []
834
835 for ps in strat.passes:
836 inputs += ps.inputs
837 intermediates += ps.intermediates
838 outputs += ps.outputs
839
840 for tens in set(inputs) & set(outputs):
841 # tensors that are in both sets are intermediates
842
843 # find pass with input/output tensor, and check if they are both placed on NPU
844 input_placement = None
845 output_placement = None
846 for ps in strat.passes:
847 if tens in ps.inputs:
848 input_placement = ps.placement
849 if tens in ps.outputs:
850 output_placement = ps.placement
851 if input_placement == output_placement == PassPlacement.Npu:
852 tens.set_format(TensorFormat.NHCWB16, arch)
853
854 intermediates.append(tens)
855 inputs.remove(tens)
856 outputs.remove(tens)
857
858 for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
859 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
860 tens.mem_area = self.arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200861 tens.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100862 tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
863 else:
864 assert 0, "unknown rewrite_op " + str(rewrite_op)
865
866 is_element_wise = True
867 for ps in strat.passes:
868 assert ps.placement == strat.passes[0].placement
869 if not ps.is_element_wise:
870 is_element_wise = False
871 break
872
873 cascaded_pass = CascadedPass(
874 strat.passes[0].name,
875 strat.strat,
876 inputs,
877 intermediates,
878 outputs,
879 strat.passes,
880 strat.passes[0].placement,
881 is_element_wise,
882 )
883 assert strat.sram_used >= 0
884 cascaded_pass.sram_used = strat.sram_used
885
886 for idx, ps in enumerate(strat.passes):
887 assert ps not in pass_to_cascaded_pass
888 pass_to_cascaded_pass[ps] = cascaded_pass
889 ps.cascade = cascaded_pass
890 ps.block_config = strat.block_configs[idx]
891
892 if ps.placement == PassPlacement.Npu:
893 ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
894 self.arch, ps, ps.block_config
895 )
896 assert ps.shared_buffer is not None
897
Diqing Zhong504d6b62020-09-17 12:21:10 +0200898 sram_used = max(self.non_local_mem_usage[ps.time], 0)
Tim Hall79d07d22020-04-27 18:20:16 +0100899 for op in ps.ops:
900 subgraph = op.attrs.get("subgraph")
901 if subgraph:
Diqing Zhong504d6b62020-09-17 12:21:10 +0200902 subgraph.base_sram_used = sram_used
Tim Hall79d07d22020-04-27 18:20:16 +0100903
904 # all passes should have a cascaded pass now
905 if len(pass_to_cascaded_pass) != len(self.sg.passes):
906 print(
907 "mismatch: we have %d passes, but only %d have cascaded passes associated"
908 % (len(self.sg.passes), len(pass_to_cascaded_pass))
909 )
910 for ps in self.sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100911 if ps not in pass_to_cascaded_pass:
Tim Hall79d07d22020-04-27 18:20:16 +0100912 print("%3d pass missing cascaded pass %s" % (ps.time, ps))
913
914 assert len(pass_to_cascaded_pass) == len(self.sg.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100915
Tim Hall79d07d22020-04-27 18:20:16 +0100916 cascaded_passes = []
Charles Xu19515e82020-06-10 10:48:33 +0200917 if self.sg.placement == PassPlacement.Cpu:
918 # Retain the pass order for CPU subgraph
919 cascaded_passes = [ps.cascade for ps in self.sg.passes]
920 else:
921 # we have all the passes, but we need to put them in order and build predecessor/successor links.
922 visit_pass_set = set()
Tim Hall79d07d22020-04-27 18:20:16 +0100923
Charles Xu19515e82020-06-10 10:48:33 +0200924 def visit_pass(ps):
925 if ps in visit_pass_set:
926 return
927 visit_pass_set.add(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100928
Charles Xu19515e82020-06-10 10:48:33 +0200929 cps = ps.cascade
930 dont_traverse = set(cps.passes)
Tim Hall79d07d22020-04-27 18:20:16 +0100931
Charles Xu19515e82020-06-10 10:48:33 +0200932 for ps in cps.passes:
933 for pred in ps.predecessors:
934 if pred in dont_traverse:
935 continue
936 visit_pass(pred)
Tim Hall79d07d22020-04-27 18:20:16 +0100937
Charles Xu19515e82020-06-10 10:48:33 +0200938 cascaded_passes.append(cps)
Tim Hall79d07d22020-04-27 18:20:16 +0100939
Charles Xu19515e82020-06-10 10:48:33 +0200940 starting_passes = [ps for ps in self.sg.passes if not ps.successors]
941 for ps in starting_passes:
942 visit_pass(ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100943
944 # reorder so startup init cascaded passes come first
945 def is_startup_cascaded_pass(cps):
946 if not cps.passes:
947 return False
948 return cps.placement == PassPlacement.StartupInit
949
950 cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
951 cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
952 ]
953
954 self.sg.cascaded_passes = cascaded_passes
955 self.sg.build_cascaded_pass_links()
956
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200957 # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
958 # (NHCWB16 within cascaded passes has been handled earlier in this function)
959 if self.sg.placement == PassPlacement.Npu:
960 # Dictionary tensor -> list of ops, containing feature maps that can be attempted
961 # to be moved to fast storage
962 fast_storage_tensor_rewrites = {}
963 last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
Fredrik Svedbergfd314282020-11-06 13:48:15 +0100964 # Memory only passes have no primary_op, so use the last op in ops
965 if last_op_in_subgraph is None:
966 last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].ops[-1]
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200967 for ps in self.sg.cascaded_passes:
968 if ps.placement != PassPlacement.Npu:
969 continue
970 for output in ps.outputs:
971 if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200972 continue
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200973
974 use_NHCWB16 = True
975 use_fast_storage = True
976 rewrites = []
977 for op in output.consumer_list:
978 if op is None:
979 use_NHCWB16 = False
980 use_fast_storage = False
Charles Xu7b8823f2020-05-29 13:53:10 +0200981 continue
Louis Verhaardaee5d752020-09-30 09:01:52 +0200982 if op.type == Op.ReduceSum and output.dtype == DataType.int32:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200983 use_NHCWB16 = False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200984 elif op.type == Op.Reshape:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200985 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
986 # consumers do not also need to perform a reshape or if the OFM is going to
987 # be processed by CPU operations. No-op reshape consumers with empty lists
988 # (those that have no consumers, or null-consumers used as list terminators)
989 # must use normal NHWC output.
Fredrik Svedbergfd314282020-11-06 13:48:15 +0100990 def incompatible_consumers(oper):
991 if oper and oper.type == Op.Reshape:
992 for consumer in oper.outputs[0].consumer_list:
993 yield from incompatible_consumers(consumer)
994 yield not oper or not oper.run_on_npu or oper is last_op_in_subgraph
995
996 if not any(incompatible_consumers(op)):
997
998 def get_rewrites(oper):
999 if oper and oper.type == Op.Reshape:
1000 for consumer in oper.outputs[0].consumer_list:
1001 yield from get_rewrites(consumer)
1002 yield oper
1003
1004 rewrites.extend(get_rewrites(op))
1005 # Detect no-op reshapes by comparing their full input and output tensor shapes.
1006 inshape = full_shape(4, op.inputs[0].shape, 1)
1007 compatible_shape = [
1008 (inshape == full_shape(4, oper.outputs[0].shape, 1)) for oper in get_rewrites(op)
1009 ]
1010 use_NHCWB16 = compatible_shape and all(compatible_shape)
Tim Hallba695182020-08-26 17:27:19 +01001011 else:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001012 use_NHCWB16 = False
1013 use_fast_storage = False
1014 use_NHCWB16 &= op.run_on_npu
1015 use_fast_storage &= op.run_on_npu
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001016
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001017 if use_fast_storage:
1018 fast_storage_tensor_rewrites[output] = rewrites
1019 if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
1020 output.set_format(TensorFormat.NHCWB16, arch)
1021 for rewrite_op in rewrites:
1022 rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
1023 if self.feature_maps_not_in_fast_storage:
1024 # Remember feature maps that can be moved to fast storage for later use
1025 # in use_fast_storage_for_feature_maps
1026 self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001027
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001028 def move_scales_to_fast_storage(self, sg, arch):
1029 # IFM streamed ops reads bias tensors several times, move these to fast storage
1030 for cp in sg.cascaded_passes:
1031 if cp.strategy == SchedulingStrategy.IfmStream:
1032 for ps in cp.passes:
1033 if ps.scale_tensor and (cp.sram_used + ps.scale_tensor.storage_size()) <= self.sram_limit:
1034 tens = ps.scale_tensor
1035
1036 # Find op using scale tensor
1037 op = next((op for op in ps.ops if tens in op.inputs), None)
1038 assert op
1039
1040 # Create fast storage tensor
1041 new_tens = tens.clone_into_fast_storage(arch)
1042 new_tens.consumer_list = tens.consumer_list.copy()
1043 new_tens.purpose = TensorPurpose.FSBias
1044
1045 # Create DMA cmd
1046 dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
1047 dma_cmd.inputs = [tens]
1048 dma_cmd.set_output_tensor(new_tens)
1049 dma_cmd.attrs["source"] = tens.mem_area
1050 dma_cmd.attrs["destination"] = new_tens.mem_area
1051 dma_cmd.run_on_npu = True
1052
1053 tens.consumer_list.clear()
1054 tens.consumer_list.append(dma_cmd)
1055
1056 # Replace tensor and op
1057 idx = op.inputs.index(tens)
1058 op.inputs[idx] = new_tens
1059
1060 ps.ops.insert(0, dma_cmd)
1061 ps.scale_tensor = new_tens
1062 ps.intermediates.append(new_tens)
1063 ps.cascade.intermediates.append(new_tens)
1064
1065 cp.sram_used += tens.storage_size()
1066
Tim Hall79d07d22020-04-27 18:20:16 +01001067
1068def schedule_passes(nng, arch, options: SchedulerOptions):
1069
1070 for sg in nng.subgraphs:
1071 sg.base_sram_used = 0
1072
1073 for sg in nng.subgraphs:
1074 # re-entering the same nodes from different contexts requires us to
1075 # build a simplified directed acyclic (DAG) version of the graph to
1076 # use for traversal, rather than using a visit dictionary. this avoids
1077 # recursing infinitely due to loops.
1078 sg.build_pass_dag_predecessors()
1079
1080 dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
1081
1082 strat_set = dps.search()
1083
1084 dps.apply_result(strat_set, arch)
1085
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001086 if not options.keep_scale_placement:
1087 dps.move_scales_to_fast_storage(sg, arch)
1088
Tim Hall79d07d22020-04-27 18:20:16 +01001089 if options.verbose_schedule:
1090 sg.print_cascaded_passes()
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001091
1092
1093def _calc_tens_to_cps(sg, tensor_rewrites):
1094 # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
1095 # Returns dictionary tensor -> list of cascaded passes
1096 # Note: if cascaded passes are A, B, C, D, and a tensor is output
1097 # of A and input to D, then it also consumes SRAM in passes B and C.
1098 if "tens_to_cps" in sg.scheduling_info:
1099 return sg.scheduling_info["tens_to_cps"]
1100 # Determine life-time of tensors
1101 min_index = {}
1102 max_index = {}
1103 index = 0
1104 cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
1105 for cps in cps_list:
1106 for tens in cps.inputs + cps.outputs:
1107 if tens in tensor_rewrites:
1108 min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
1109 max_index[tens] = index
1110 index += 1
1111 # Convert to affected cps-es
1112 tens_to_cps = {}
1113 for tens in min_index:
1114 tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
1115 sg.scheduling_info["tens_to_cps"] = tens_to_cps
1116 return tens_to_cps
1117
1118
1119def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
1120 # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
1121 tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
1122 tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
1123 # Sort tensors first on life-time (smallest first), then on size (biggest first)
1124 tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
1125 for _, _, _, tens in tens_list:
1126 cps_list = tens_to_cps[tens]
Fredrik Svedbergfd314282020-11-06 13:48:15 +01001127 if len(cps_list) < 1:
Louis Verhaard0b9c9a32020-09-15 14:05:38 +02001128 continue
1129 sz = tens.storage_size()
1130 fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
1131 if fits_in_fast_storage:
1132 tens.mem_area = arch.fast_storage_mem_area
1133 tens.mem_type = MemType.Scratch_fast
1134 tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
1135 assert tens in tensor_rewrites
1136 # Also rewrite reshapes
1137 for rewrite_op in tensor_rewrites[tens]:
1138 tens2 = rewrite_op.outputs[0]
1139 tens2.mem_area = arch.fast_storage_mem_area
1140 tens2.mem_type = MemType.Scratch_fast
1141 tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
1142 for cps in cps_list:
1143 cps.sram_used += sz
1144
1145
1146def undo_use_fast_storage(sg, arch):
1147 # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
1148 tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
1149 tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
1150 mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
1151 for tens, cps_list in tens_to_cps.items():
1152 if tens.mem_type == MemType.Scratch_fast:
1153 sz = tens.storage_size()
1154 tens.mem_area = mem_area
1155 tens.mem_type = MemType.Scratch
1156 # Also undo reshapes
1157 for rewrite_op in tensor_rewrites[tens]:
1158 tens2 = rewrite_op.outputs[0]
1159 tens2.mem_area = mem_area
1160 tens2.mem_type = MemType.Scratch
1161 for cps in cps_list:
1162 cps.sram_used -= sz