blob: 83e19bc64ef82baf15f99d5ee56f14ba6e484531 [file] [log] [blame]
Johan Alfvenf3490472023-01-13 08:46:27 +01001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Halld8339a72021-05-27 18:49:40 +010016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
Tim Halld8339a72021-05-27 18:49:40 +010018# The scheduler creates and searches for an optimal plan for the network, selecting block configurations and
19# subdivisions for the Operators
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020# For Class name forward references for the type annotations. (see PEP 563).
21from __future__ import annotations
22
Diego Russoea6111a2020-04-14 18:41:58 +010023import copy
Johan Alfvén5e0ae552022-02-09 21:20:10 +010024from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010025from enum import auto
26from enum import IntEnum
Johan Alfvén6f4cb032022-05-05 08:42:46 +020027from typing import Any
Tim Halld8339a72021-05-27 18:49:40 +010028from typing import Dict
29from typing import List
30from typing import Optional
31from typing import Tuple
Jonas Ohlsson845e2322022-03-01 12:39:55 +010032from typing import TYPE_CHECKING
33
34# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
35if TYPE_CHECKING:
36 from .npu_performance import CycleCost
Diego Russoea6111a2020-04-14 18:41:58 +010037
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +010038import numpy as np
39
Diego Russoea6111a2020-04-14 18:41:58 +010040from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010041from . import npu_performance
Tim Halld8339a72021-05-27 18:49:40 +010042from . import tensor_allocation
43from . import weight_compressor
44from .architecture_allocator import ArchitectureBlockConfig
45from .architecture_allocator import find_block_config
46from .architecture_allocator import get_ifm_area_required
Fredrik Svedbergd03dc502022-06-30 10:44:12 +020047from .architecture_allocator import to_upscale
erik.andersson@arm.com8912f3a2022-08-16 11:08:57 +020048from .architecture_allocator import is_nearest
Tim Halld8339a72021-05-27 18:49:40 +010049from .architecture_features import ArchitectureFeatures
50from .architecture_features import Block
51from .cascade_builder import CascadeBuilder
52from .cascade_builder import CascadeInfo
Fredrik Svedberg880e7352020-08-25 11:31:47 +020053from .data_type import DataType
Diego Russoe8a10452020-04-21 17:39:10 +010054from .nn_graph import CascadedPass
Tim Halld8339a72021-05-27 18:49:40 +010055from .nn_graph import Graph
56from .nn_graph import Pass
Diego Russoe8a10452020-04-21 17:39:10 +010057from .nn_graph import PassPlacement
Diego Russoe8a10452020-04-21 17:39:10 +010058from .nn_graph import SchedulingStrategy
Tim Halld8339a72021-05-27 18:49:40 +010059from .nn_graph import Subgraph
Johan Alfvén92689d52022-12-06 11:16:19 +010060from .live_range import ofm_can_reuse_ifm
Tim Halld8339a72021-05-27 18:49:40 +010061from .numeric_util import round_down
62from .numeric_util import round_up
Louis Verhaardaee5d752020-09-30 09:01:52 +020063from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010064from .shape4d import Shape4D
Diego Russoe8a10452020-04-21 17:39:10 +010065from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020066from .tensor import MemType
Tim Halld8339a72021-05-27 18:49:40 +010067from .tensor import Tensor
Diego Russoe8a10452020-04-21 17:39:10 +010068from .tensor import TensorFormat
69from .tensor import TensorPurpose
70from .tensor import TensorSubPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010071from .weight_compressor import NpuWeightTensor
Jacob Bohlin1a666972020-09-11 10:04:15 +020072
Tim Hall79d07d22020-04-27 18:20:16 +010073
Tim Halld8339a72021-05-27 18:49:40 +010074def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D:
75 if tensor_format == TensorFormat.NHCWB16:
76 return shape.with_depth(round_up(shape.depth, 16))
77
78 return shape
79
80
81class OptimizationStrategy(IntEnum):
82 """Enum defining the different optimization strategies for the Scheduler"""
83
84 Size = auto()
85 Performance = auto()
Tim Hall79d07d22020-04-27 18:20:16 +010086
87 def __str__(self):
88 return self.name
89
90
Tim Halld8339a72021-05-27 18:49:40 +010091class SchedulerOpInfo:
92 """Contains metadata about a SchedulerOperation that is unique to one Schedule"""
93
Tim Hall79d07d22020-04-27 18:20:16 +010094 def __init__(
95 self,
Tim Halld8339a72021-05-27 18:49:40 +010096 block_config: ArchitectureBlockConfig,
97 weights_size: int,
98 stripe_input: Shape4D,
99 stripe_input2: Optional[Shape4D],
100 stripe: Shape4D,
Tim Hall79d07d22020-04-27 18:20:16 +0100101 ):
Tim Halld8339a72021-05-27 18:49:40 +0100102 self.block_config = block_config
103 self.weights_size = weights_size
104 self.stripe_input = stripe_input
105 self.stripe_input2 = stripe_input2
106 self.stripe = stripe
107 self.cascade = 0 # Assigned by CascadeBuilder. 0 means not part of a cascade
108 self.time_index = None # Set by update_op_memory_snapshot
109 self.ofm_depth_slices: List[int] = [0, stripe.depth]
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100110 self.npu_weights_tensor: Optional[NpuWeightTensor] = None
111 self.npu_scales_tensor: Optional[NpuWeightTensor] = None
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000112 self.buffered_weight_tensors: List[Tensor] = []
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100113 self.cycles: Optional[CycleCost] = None
Tim Halld8339a72021-05-27 18:49:40 +0100114 self.slack_buffering_cycles = 0
115 self.slack_buffering_memory = 0
116 self.full_weight_transfer_cycles = 0
117
118 def copy(self):
Jonas Ohlssond8575072022-03-30 10:30:25 +0200119 res = SchedulerOpInfo(
120 self.block_config,
121 self.weights_size,
122 self.stripe_input,
123 self.stripe_input2,
124 self.stripe,
125 )
Tim Halld8339a72021-05-27 18:49:40 +0100126 res.cascade = self.cascade
127 return res
128
129 def __str__(self):
130 res = f"\t\tBlock Config = {self.block_config}\n"
131 res += f"\t\tOFM Block = {self.block_config.ofm_block}\n"
132 res += f"\t\tIFM Stripe = {self.stripe_input}\n"
133 res += f"\t\tIFM2 Stripe = {self.stripe_input2}\n"
134 res += f"\t\tOFM Stripe = {self.stripe}\n"
135 res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000136 for idx, tens in enumerate(self.buffered_weight_tensors):
137 res += f"\t\tWeight buffer{idx + 1} = {tens.storage_size()} bytes\n"
Tim Halld8339a72021-05-27 18:49:40 +0100138 res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
139 res += f"\t\tAssigned Cascade = {self.cascade}"
140 return res
141
142
143class SchedulerOptions:
144 """Contains options for the Scheduler"""
145
146 def __init__(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200147 self,
148 optimization_strategy,
149 sram_target,
150 verbose_schedule,
Tim Halld8339a72021-05-27 18:49:40 +0100151 ):
152 self.optimization_strategy = optimization_strategy
153 self.optimization_sram_limit = sram_target
Tim Hall79d07d22020-04-27 18:20:16 +0100154 self.verbose_schedule = verbose_schedule
Tim Hall79d07d22020-04-27 18:20:16 +0100155
Tim Halld8339a72021-05-27 18:49:40 +0100156 def __str__(self) -> str:
157 return f"{type(self).__name__}: {str(self.__dict__)}"
Tim Hall79d07d22020-04-27 18:20:16 +0100158
159 __repr__ = __str__
160
161
Tim Halld8339a72021-05-27 18:49:40 +0100162class SchedulerTensor:
163 def __init__(self, shape, dt, mem_area, _format):
164 self.dtype = dt
165 self.mem_area = mem_area
166 self.shape = shape
167 self.format = _format
168 self.connection = None
Tim Hall79d07d22020-04-27 18:20:16 +0100169
Tim Halld8339a72021-05-27 18:49:40 +0100170
171class SchedulerOperation:
172 """Scheduler internal representation of 'Operation'
173 This class can be seen as a node within the Scheduler Graph representation
174 """
175
176 def __init__(self, ps: Pass, arch: ArchitectureFeatures, nng: Graph):
177 self.arch = arch
178 self.parent_ps = ps
179 self.parent_op = ps.primary_op
180 self.name = ps.primary_op.name
181 self.op_type = ps.primary_op.type
182 self.activation = ps.primary_op.activation
183 self.kernel = ps.primary_op.kernel
Tim Hall3c5cfe92022-03-16 16:31:57 +0000184 self.resampling_mode = ps.primary_op.ifm_resampling_mode
Fredrik Svedbergb81e1bb2022-10-11 21:50:51 +0200185 self.reversed_operands = False
Tim Halld8339a72021-05-27 18:49:40 +0100186 self.uses_scalar = ps.primary_op.ifm2 is not None and (
187 ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == []
Tim Hall79d07d22020-04-27 18:20:16 +0100188 )
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100189
Tim Halld8339a72021-05-27 18:49:40 +0100190 self.ifm_ublock = arch.ifm_ublock
Tim Hall79d07d22020-04-27 18:20:16 +0100191
Jonas Ohlssond8575072022-03-30 10:30:25 +0200192 self.ifm = SchedulerTensor(
193 ps.ifm_shapes[0],
194 ps.ifm_tensor.dtype,
195 ps.ifm_tensor.mem_area,
196 ps.ifm_tensor.format,
197 )
Tim Hall79d07d22020-04-27 18:20:16 +0100198
Tim Halld8339a72021-05-27 18:49:40 +0100199 self.ifm2 = None
200 if ps.ifm2_tensor:
201 self.ifm2 = SchedulerTensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200202 ps.ifm_shapes[1],
203 ps.ifm2_tensor.dtype,
204 ps.ifm2_tensor.mem_area,
205 ps.ifm2_tensor.format,
Tim Halld8339a72021-05-27 18:49:40 +0100206 )
Tim Hall79d07d22020-04-27 18:20:16 +0100207
Jonas Ohlssond8575072022-03-30 10:30:25 +0200208 self.ofm = SchedulerTensor(
209 ps.ofm_shapes[0],
210 ps.ofm_tensor.dtype,
211 ps.ofm_tensor.mem_area,
212 ps.ofm_tensor.format,
213 )
Tim Hall79d07d22020-04-27 18:20:16 +0100214
Johan Alfven126558e2023-03-09 08:36:10 +0100215 # LUT must be placed in shram area. The copy is done by DMA
216 # generated by the high level command stream generator.
217 for idx, tens in enumerate(self.parent_op.inputs):
218 if tens.purpose == TensorPurpose.LUT:
219 new_tens = tens.clone_into_shram(self.arch)
220 new_tens.consumer_list.append(self.parent_op)
221 self.parent_op.inputs[idx] = new_tens
222
Tim Halld8339a72021-05-27 18:49:40 +0100223 # Input volume width and height required to produce the smallest possible stripe
224 self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input()
Tim Hall79d07d22020-04-27 18:20:16 +0100225
Tim Halld8339a72021-05-27 18:49:40 +0100226 # Flags that marks whether this SchedulerOperation requires full IFM/OFM
227 self.requires_full_ifm = False
228 self.requires_full_ifm2 = False
229 self.requires_full_ofm = False
Tim Hall79d07d22020-04-27 18:20:16 +0100230
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200231 self.evicted_fms_size = 0
232
Tim Halld8339a72021-05-27 18:49:40 +0100233 self.index = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100234
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100235 # Perform an IFM swap for certain binary elementwise operators
236 # in order to enable cascading, if the SchedOp conforms to
237 # Elementwise cascading rules.
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200238 # The non-constant/non-scalar/non-broadcast IFM should be the primary input
239 if self.op_type.is_binary_elementwise_op():
240 ifm = self.parent_op.ifm
241 ifm2 = self.parent_op.ifm2
242 ofm = self.parent_op.ofm
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100243
Johan Alfvén993ea532022-10-26 10:20:01 +0200244 ifm_can_swap = ifm.is_const or ifm.is_scalar
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200245 ifm2_can_be_primary = not (ifm2.is_const or ifm2.is_scalar or ifm2.is_broadcast(ofm))
246
Johan Alfvén993ea532022-10-26 10:20:01 +0200247 if ifm_can_swap and ifm2_can_be_primary:
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200248 # IFM2 is the primary input
Fredrik Svedbergb81e1bb2022-10-11 21:50:51 +0200249 self.reversed_operands = True
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100250 self.ifm, self.ifm2 = self.ifm2, self.ifm
251
252 self.parent_ps.ifm_shapes = self.parent_ps.ifm_shapes[::-1]
253 self.parent_ps.inputs = self.parent_ps.inputs[::-1]
254 self.parent_ps.ifm_tensor, self.parent_ps.ifm2_tensor = (
255 self.parent_ps.ifm2_tensor,
256 self.parent_ps.ifm_tensor,
257 )
258
Tim Halld8339a72021-05-27 18:49:40 +0100259 def add_ifm_connection(self, conn: "Connection"):
260 """Add input connection to another SchedulerOperation or Subgraph Input"""
261 conn.consumers.append(self)
262 self.ifm.connection = conn
Tim Hall79d07d22020-04-27 18:20:16 +0100263
Tim Halld8339a72021-05-27 18:49:40 +0100264 def add_ifm2_connection(self, conn: "Connection"):
265 """Add input connection to another SchedulerOperation or Subgraph Input"""
266 if self.ifm2:
267 conn.consumers.append(self)
268 self.ifm2.connection = conn
Tim Hall79d07d22020-04-27 18:20:16 +0100269 else:
Tim Halld8339a72021-05-27 18:49:40 +0100270 assert False, f"Trying to set an IFM2 Connection to {self} which has no IFM2"
Tim Hall79d07d22020-04-27 18:20:16 +0100271
Tim Halld8339a72021-05-27 18:49:40 +0100272 def add_ofm_connection(self, conn: "Connection"):
273 """Add output connection to another SchedulerOperation or Subgraph Output"""
274 conn.producers.append(self)
275 self.ofm.connection = conn
276
277 def get_dependants(self):
278 """Returns a list of the Ops that depend on this Operation's OFM"""
279 return self.ofm.connection.consumers
280
281 def ifm_size_in_bytes(self) -> int:
282 """Returns size of the IFM in bytes"""
283 ifm_storage_shape = shape_for_format(self.ifm.shape, self.ifm.format)
284 return round_up(ifm_storage_shape.elements() * self.ifm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
285
286 def ifm2_size_in_bytes(self) -> int:
287 """Returns size of the IFM2 in bytes"""
288 if self.ifm2:
289 ifm2_storage_shape = shape_for_format(self.ifm2.shape, self.ifm2.format)
290 return round_up(ifm2_storage_shape.elements() * self.ifm2.dtype.size_in_bytes(), Tensor.AllocationQuantum)
291
292 return 0
293
294 def ofm_size_in_bytes(self) -> int:
295 """Returns size of the OFM in bytes"""
296 ofm_storage_shape = shape_for_format(self.ofm.shape, self.ofm.format)
297 return round_up(ofm_storage_shape.elements() * self.ofm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
298
299 def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo:
300 """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce"""
301 ifm_shape = self.ifm.shape
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100302 ifm2_shape = self.ifm2.shape if self.ifm2 is not None else None
Tim Halld8339a72021-05-27 18:49:40 +0100303 ofm_shape = stripe
304
305 if ofm_shape != self.ofm.shape:
306 # Striped Op - Need to calculate stripe input volume
307 stripe_input_w, stripe_input_h = self._get_stripe_input_requirement(stripe)
308 # Ensure stripe input volume is within the full IFM volume
309 stripe_input_h = min(stripe_input_h, self.ifm.shape.height)
310 stripe_input_w = min(stripe_input_w, self.ifm.shape.width)
311 ifm_shape = ifm_shape.with_hw(stripe_input_h, stripe_input_w)
312
313 if self.ifm2:
314 stripe_input2_h = min(stripe_input_h, self.ifm2.shape.height)
315 stripe_input2_w = min(stripe_input_w, self.ifm2.shape.width)
316 ifm2_shape = ifm2_shape.with_hw(stripe_input2_h, stripe_input2_w)
317
318 block_config = self._get_block_config(ifm_shape, ifm2_shape, self.uses_scalar, ofm_shape)
319
320 scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape)
321 if self.parent_op.weights:
322 # Default full-depth weight encoding with no buffering
Tim Halld784af72021-06-08 21:25:57 +0100323 (
324 scheduler_op_info.npu_weights_tensor,
325 scheduler_op_info.npu_scales_tensor,
326 ) = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100327 self.arch,
328 self.parent_op,
329 self.parent_op.weights,
330 self.parent_op.bias,
331 self.kernel,
332 block_config,
333 [0, self.ofm.shape.depth],
334 )
335
336 self.parent_ps.block_config = block_config.old_style_representation()
337 return scheduler_op_info
338
339 def _get_stripe_input_requirement(self, stripe_shape: Shape4D) -> Tuple[int, int]:
340 """Returns the amount of IFM required to produce the stripe with shape:'stripe_shape'"""
341 ofm_shape_to_produce = Block.from_shape(stripe_shape.as_list())
342
Fredrik Svedberg3ff7a4a2021-09-29 10:08:04 +0200343 return get_ifm_area_required(ofm_shape_to_produce, self.kernel, self.resampling_mode)
Tim Halld8339a72021-05-27 18:49:40 +0100344
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100345 def _calculate_min_stripe_input(self) -> Tuple[int, int]:
Tim Halld8339a72021-05-27 18:49:40 +0100346 # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1)
347 min_stripe = self.ofm.shape.with_hw(1, 1)
348 return self._get_stripe_input_requirement(min_stripe)
349
350 def _get_block_config(
351 self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100352 ) -> Optional[ArchitectureBlockConfig]:
Tim Halld8339a72021-05-27 18:49:40 +0100353 # Returns a block config and SHRAM layout
354 lut_banks = 2 if self.parent_op.activation_lut else 0
355 return find_block_config(
356 self.arch,
357 self.op_type.npu_block_type,
358 ofm_shape,
359 ifm_shape,
360 ifm2_shape,
361 uses_scalar,
362 self.ifm.dtype.size_in_bits(),
363 self.kernel,
364 lut_banks,
365 self.parent_op.has_scaling(),
366 self.resampling_mode,
367 )
368
369
370class Connection:
371 """Scheduler internal representation of a Tensor that connects two SchedulerOperations
372 This class can be seen as an edge within the Scheduler Graph representation
373 """
374
375 def __init__(self, tensor: Tensor):
376 self.parent_tens = tensor
377
378 # SchedulerOperation relationships
379 self.producers: List[SchedulerOperation] = []
380 self.consumers: List[SchedulerOperation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100381
382 def __str__(self):
Tim Halld8339a72021-05-27 18:49:40 +0100383 return f"<Connection {self.parent_tens.name}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100384
385 __repr__ = __str__
386
387
Tim Halld8339a72021-05-27 18:49:40 +0100388class Schedule:
389 """Class that contains a solution of how to schedule an NPU subgraph and its cost"""
Tim Hall79d07d22020-04-27 18:20:16 +0100390
Tim Halld8339a72021-05-27 18:49:40 +0100391 def __init__(self, sg: Subgraph, label: str):
392 self.sg = sg
393 self.label = label
394 self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {}
395 self.cascades: Dict[int, CascadeInfo] = {}
396 self.fast_storage_peak_usage = 0
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100397 self.memory_snapshot: Optional[List[int]] = None
Tim Halld8339a72021-05-27 18:49:40 +0100398
399 @property
400 def name(self):
401 return f"{self.sg.name}_{self.label}"
Tim Hall79d07d22020-04-27 18:20:16 +0100402
403
Tim Halld8339a72021-05-27 18:49:40 +0100404class Scheduler:
405 """Main class of the Vela Scheduling"""
Tim Hall79d07d22020-04-27 18:20:16 +0100406
Tim Halld8339a72021-05-27 18:49:40 +0100407 def __init__(self, nng: Graph, sg: Subgraph, arch: ArchitectureFeatures, options: SchedulerOptions):
Tim Hall79d07d22020-04-27 18:20:16 +0100408 self.nng = nng
409 self.sg = sg
410 self.arch = arch
Ayaan Masoodb801dda2022-02-22 11:28:55 +0000411 self.sched_ops: List[SchedulerOperation] = []
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100412 self.max_schedule: Optional[Schedule] = None
Tim Halld8339a72021-05-27 18:49:40 +0100413 self.scheduler_options = options
Tim Hall79d07d22020-04-27 18:20:16 +0100414
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200415 self.scratched_fms: Dict[Tensor, Any] = {}
416 self.evicted_fms: List[live_range.LiveRange] = []
417
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100418 def avoid_nhcwb16_for_ofm(self, tens, ps, arch):
Johan Alfvenf3490472023-01-13 08:46:27 +0100419 """For elementwise ops when ifm is in nhwc format and not brick format aligned (16),
420 then if the ofm can overwrite the ifm it is better to enforce ofm format to nhwc in
421 order to reduce memory transactions"""
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100422
423 op = ps.primary_op
424 if not op.type.is_elementwise_op():
425 return False
426
427 depth = op.ofm_shapes[0][-1]
428 if (depth % 16) == 0:
429 return False
430
431 # Check if overwriting the inputs can be allowed
432 OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
433 outp = OpShapeTens(op.ofm_shapes[0], op.ofm)
434 inps = []
435 if op.ifm is not None:
436 inps.append(OpShapeTens(op.ifm_shapes[0], op.ifm))
437 if op.ifm2 is not None:
438 inps.append(OpShapeTens(op.ifm_shapes[1], op.ifm2))
439
440 # Find an input tensor that can be overwritten by the output
441 for inp in inps:
442 if (
443 # check op input and output shapes allow overlapping
444 inp.op_shape == outp.op_shape
445 # check input tensor is valid
446 and inp.tens is not None
447 and inp.tens.shape != []
448 # check input and output tensors are compatible
449 and inp.tens.format == outp.tens.format
450 and inp.tens.dtype == outp.tens.dtype
451 ):
452 if inp.tens.format == TensorFormat.NHWC:
453 return True
454
455 return False
456
Tim Halld8339a72021-05-27 18:49:40 +0100457 def create_scheduler_representation(self, arch: ArchitectureFeatures):
458 """Creates a Scheduler Graph representation"""
459 # Temporary dict for creating connections between the Operations
460 connections: Dict[Tensor, Connection] = {}
461 # Memory required for the largest FeatureMap that has to be full
462 min_memory_req = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100463 for ps in self.sg.passes:
Tim Halld8339a72021-05-27 18:49:40 +0100464 if ps.primary_op:
465 # Set tensor format to NHCWB16 for output FeatureMaps, if possible
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200466 for output in ps.outputs:
Jacob Bohlina5e8c1c2021-06-14 13:33:39 +0200467 if output in self.sg.output_tensors or output.purpose != TensorPurpose.FeatureMap:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200468 continue
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100469
470 if output.needs_linear_format:
471 continue
472
473 if self.avoid_nhcwb16_for_ofm(output, ps, arch):
474 output.needs_linear_format = True
475 continue
476
477 output.set_format(TensorFormat.NHCWB16, arch)
Tim Halld8339a72021-05-27 18:49:40 +0100478
479 # Create SchedulerOperations
480 op = SchedulerOperation(ps, arch, self.nng)
481 op.index = len(self.sched_ops)
482
483 # Make connections
484 if ps.ifm_tensor not in connections:
485 connections[ps.ifm_tensor] = Connection(ps.ifm_tensor)
486 if ps.ifm2_tensor and ps.ifm2_tensor not in connections:
487 connections[ps.ifm2_tensor] = Connection(ps.ifm2_tensor)
488 if ps.ofm_tensor not in connections:
489 connections[ps.ofm_tensor] = Connection(ps.ofm_tensor)
490
491 op.add_ifm_connection(connections[ps.ifm_tensor])
492 if ps.ifm2_tensor:
493 op.add_ifm2_connection(connections[ps.ifm2_tensor])
494 op.add_ofm_connection(connections[ps.ofm_tensor])
495
496 # Set requirements on the ifm/ofm buffers
497 self.sched_ops.append(op)
498 if ps.ifm_tensor in self.sg.input_tensors:
499 # This Op consumes a subgraph input
500 op.requires_full_ifm = True
501 if ps.ifm2_tensor and ps.ifm2_tensor in self.sg.input_tensors:
502 # This Op consumes a subgraph input
503 op.requires_full_ifm2 = True
504 if ps.ofm_tensor in self.sg.output_tensors:
505 # This Op produces a subgraph output
506 op.requires_full_ofm = True
507 if ps.ifm_tensor.needs_linear_format:
508 op.requires_full_ifm = True
509 if ps.ifm2_tensor and ps.ifm2_tensor.needs_linear_format:
510 op.requires_full_ifm2 = True
511 if ps.ofm_tensor.needs_linear_format or ps.primary_op.memory_function == Op.ConcatSliceWrite:
512 op.requires_full_ofm = True
513 if len(ps.primary_op.outputs) > 1 or len(ps.primary_op.outputs[0].consumer_list) > 1:
514 # Op has multiple outputs or consumers - requires full OFM
515 op.requires_full_ofm = True
516
517 # Check memory requirements if this Op requires any full FeatureMaps
518 op_memory_req = 0
519 if op.requires_full_ifm:
520 op_memory_req += op.ifm_size_in_bytes()
521 if op.requires_full_ifm2:
522 op_memory_req += op.ifm2_size_in_bytes()
523 if op.requires_full_ofm:
524 op_memory_req += op.ofm_size_in_bytes()
525
526 min_memory_req = max(op_memory_req, min_memory_req)
527
528 # Theoretical minimum required memory - used to guide the cascade building
529 self.min_memory_req = min_memory_req
530
531 def create_initial_schedule(self) -> Schedule:
532 """Creates an initial schedule with no cascading or buffering of any kind"""
533 schedule = Schedule(self.sg, "MAX")
Tim Halld8339a72021-05-27 18:49:40 +0100534 for op in self.sched_ops:
535 cost = op.create_scheduler_info(self.nng, op.ofm.shape)
536 cost.cycles = self.estimate_op_performance(op, cost.block_config, op.ofm.shape.depth)
537 schedule.cost_map[op] = cost
538
539 return schedule
540
541 def update_op_memory_snapshot(self, schedule: Schedule):
542 memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
543
544 # Collect live ranges from tensors
545 lr_graph = live_range.LiveRangeGraph()
546 for mem_area, mem_type_set in memories_list:
Johan Alfven6e281af2023-02-28 09:03:03 +0100547 live_range.extract_live_ranges_from_schedule(
548 self.sg,
Jonas Ohlssond8575072022-03-30 10:30:25 +0200549 mem_area,
550 mem_type_set,
551 lr_graph,
Tim Halld8339a72021-05-27 18:49:40 +0100552 )
553
554 # Populate time-array with memory used by live ranges
555 temporal_usage = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
556 schedule.memory_snapshot = temporal_usage
557
558 # Set the peak memory usage
559 schedule.fast_storage_peak_usage = max(temporal_usage, default=0)
560
561 def estimate_op_performance(self, op: SchedulerOperation, block_config, ofm_depth):
562 query = npu_performance.PerformanceQuery(op.op_type.npu_block_type)
563 query.ifm_shape = op.ifm.shape
564 query.ifm_memory_area = op.ifm.mem_area
565 query.ifm_bits = op.ifm.dtype.size_in_bits()
566 query.ifm_format = op.ifm.format
567 query.ifm2_shape = op.ifm2 and op.ifm2.shape
568 query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
569 query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
570 query.ifm2_format = op.ifm2 and op.ifm2.format
571 query.ofm_shape = op.ofm.shape.with_depth(ofm_depth)
572 query.ofm_memory_area = op.ofm.mem_area
573 query.ofm_bits = op.ofm.dtype.size_in_bits()
574 query.ofm_format = op.ofm.format
575 if op.parent_op.bias:
576 query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
577 query.const_memory_area = self.arch.fast_storage_mem_area
578
579 query.kernel = op.kernel
580 query.config = block_config
581
582 return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query)
583
Johan Alfvén5c309712022-06-10 15:40:58 +0200584 def estimate_element_access(self, op: SchedulerOperation, block_config, ofm_depth):
585 query = npu_performance.PerformanceQuery(op.op_type.npu_block_type)
586 query.ifm_shape = op.ifm.shape
587 query.ifm_memory_area = op.ifm.mem_area
588 query.ifm_bits = op.ifm.dtype.size_in_bits()
589 query.ifm_format = op.ifm.format
590 query.ifm2_shape = op.ifm2 and op.ifm2.shape
591 query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
592 query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
593 query.ifm2_format = op.ifm2 and op.ifm2.format
594 query.ofm_shape = op.ofm.shape.with_depth(ofm_depth)
595 query.ofm_memory_area = op.ofm.mem_area
596 query.ofm_bits = op.ofm.dtype.size_in_bits()
597 query.ofm_format = op.ofm.format
598 if op.parent_op.bias:
599 query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
600 query.const_memory_area = self.arch.fast_storage_mem_area
601
602 query.kernel = op.kernel
603 query.config = block_config
604
605 return npu_performance.measure_element_access(self.arch, query)
606
Tim Hall789e6f32021-06-17 17:02:31 +0100607 def propose_schedule_buffering(self, ref_schedule: Schedule, staging_limit_bytes):
Tim Halld8339a72021-05-27 18:49:40 +0100608 """Create a buffered schedule"""
609 buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED")
Tim Halld8339a72021-05-27 18:49:40 +0100610
611 prev_op = None
612 for sched_op in self.sched_ops:
613 if sched_op not in ref_schedule.cost_map:
614 # sched_op is not part of this sub-schedule - skip
615 continue
616
617 self.propose_operator_buffering(sched_op, prev_op, buffered_schedule, ref_schedule, staging_limit_bytes)
618 prev_op = sched_op
619
620 return buffered_schedule
621
622 def propose_operator_buffering(
623 self,
624 sched_op: SchedulerOperation,
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100625 prev_op: Optional[SchedulerOperation],
Tim Halld8339a72021-05-27 18:49:40 +0100626 buffered_schedule: Schedule,
627 ref_schedule: Schedule,
628 staging_limit_bytes,
629 ):
630 # Mild recursion might mean this Op has already been seen
631 if sched_op in buffered_schedule.cost_map:
632 return
633
634 # Take the reference schedule as default costings for this schedule
635 ref_cost = ref_schedule.cost_map[sched_op]
636 cost = copy.copy(ref_cost)
637 cost.slack_buffering_cycles = ref_cost.cycles.op_cycles
638 memory_snapshot = ref_schedule.memory_snapshot
639 ref_memory_usage = memory_snapshot[ref_cost.time_index] if ref_cost.time_index < len(memory_snapshot) else 0
640 cost.slack_buffering_memory = staging_limit_bytes - ref_memory_usage
641 buffered_schedule.cost_map[sched_op] = cost
642
643 # Attempt weight buffering on anything with a weights tensor
644 if sched_op.parent_op.weights:
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200645 buffer_limit_bytes = cost.slack_buffering_memory
646
647 # If applicable apply size limitation, but keep it within reason (ratio 1.5).
648 # Size limitation is used when use_fast_storage_for_feature_maps have
649 # detected that there are fms that do not fit in fast storage.
650 if sched_op.evicted_fms_size and ((buffer_limit_bytes / sched_op.evicted_fms_size) >= 1.5):
651 buffer_limit_bytes -= sched_op.evicted_fms_size
652
Tim Halld8339a72021-05-27 18:49:40 +0100653 self.propose_weight_buffering(
654 sched_op.parent_op.weights,
655 sched_op.parent_op.bias,
656 sched_op,
657 prev_op,
658 buffered_schedule,
659 ref_schedule,
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200660 buffer_limit_bytes,
Tim Halld8339a72021-05-27 18:49:40 +0100661 )
662
663 return cost
664
665 def weights_needs_dma(self, weight_tensor):
666 if weight_tensor and weight_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
667 # Weights are in permanent storage
668 # Only when permanent storage differs from feature map storage, there is a point moving the data
669 if (
670 weight_tensor.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
671 and self.arch.permanent_storage_mem_area != self.arch.fast_storage_mem_area
672 ):
673 return True
674 return False
675
676 def propose_weight_buffering(
677 self,
678 weight_tensor,
679 scale_tensor,
680 sched_op: SchedulerOperation,
681 prev_op: SchedulerOperation,
682 buffered_schedule: Schedule,
683 ref_schedule: Schedule,
684 buffer_limit_bytes,
685 ):
686 cost = buffered_schedule.cost_map[sched_op]
687 prev_cost = buffered_schedule.cost_map.get(prev_op)
688 ref_cost = ref_schedule.cost_map[sched_op]
689 assert cost and ref_cost
690
691 needs_dma = self.weights_needs_dma(weight_tensor)
692
693 ofm_full_depth_slices = [0, ref_cost.stripe.depth]
694
695 # Encode weights for the full depth
Tim Halld784af72021-06-08 21:25:57 +0100696 full_weights, full_scales = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100697 self.arch,
698 sched_op.parent_op,
699 weight_tensor,
700 scale_tensor,
701 sched_op.kernel,
702 cost.block_config,
703 ofm_full_depth_slices,
704 )
705 full_weights_bytes = len(full_weights.buffer)
706 cost.ofm_depth_slices = ofm_full_depth_slices
707
708 # No buffering required - take all the weights from permanent storage
709 if sched_op.op_type == Op.FullyConnected or not needs_dma:
710 cost.npu_weights_tensor = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100711 cost.npu_scales_tensor = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100712 return
713
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100714 encoded_weights: Optional[NpuWeightTensor] = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100715 encoded_scales = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100716
717 # How many NPU cycles are available under the previously executing
718 # operator and SRAM unused for performing buffered DMA transfers
719 slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0
720 slack_memory = prev_cost.slack_buffering_memory if prev_cost else 0
721
722 # Force full depth for cascaded Ops
723 if ref_cost.cascade != 0:
724 weight_tensor_purpose = TensorSubPurpose.Standard
725 weight_buffer_size = full_weights_bytes
726 # Update the memory snapshot to reflect the added size of the weights
727 ref_schedule.memory_snapshot[ref_cost.time_index] += weight_buffer_size
728 else:
729 # Estimate the buffering cycle time for the full set of weights
730 full_transfer_cycles = npu_performance.measure_mem2mem_cycles(
731 self.arch, weight_tensor.mem_area, self.arch.fast_storage_mem_area, full_weights_bytes
732 )
733 cost.full_weight_transfer_cycles = full_transfer_cycles
734
735 # Calculate the amount of prebuffering necessary (or what is possible with limited
736 # double buffer buffer size)
737 half_buffer_limit = buffer_limit_bytes // 2
738 if full_transfer_cycles > slack_cycles:
739 prebuffer_ratio = slack_cycles / full_transfer_cycles
740 prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit)
741 else:
742 prebuffer_bytes = min(full_weights_bytes, half_buffer_limit)
Tim Hall789e6f32021-06-17 17:02:31 +0100743
744 prebuffer_ratio = prebuffer_bytes / full_weights_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100745
746 # Have to split the weights if the initial buffering can't store
747 # all of the compressed weights
748 if prebuffer_bytes < full_weights_bytes:
Tim Hall789e6f32021-06-17 17:02:31 +0100749 block_depth = cost.block_config.ofm_block.depth
Tim Halld8339a72021-05-27 18:49:40 +0100750
Tim Hall789e6f32021-06-17 17:02:31 +0100751 # Choose initial prebuffering depth (already buffer clamped)
752 prebuffer_depth = ref_cost.stripe.depth * prebuffer_ratio
Tim Halld8339a72021-05-27 18:49:40 +0100753 prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
754
Tim Hall789e6f32021-06-17 17:02:31 +0100755 # Calculate cycles executed during the prebuffer
756 pre_op_cycles = self.estimate_op_performance(sched_op, cost.block_config, prebuffer_depth)
757 buffering_depth = ref_cost.stripe.depth * (pre_op_cycles.op_cycles / full_transfer_cycles)
Tim Halld8339a72021-05-27 18:49:40 +0100758
Tim Hall789e6f32021-06-17 17:02:31 +0100759 # Choose initial buffering depth and clamp to the double buffering limit
760 buffering_depth = round_up(buffering_depth, block_depth)
761 buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
762 if buffering_bytes > half_buffer_limit:
763 buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
764
765 while True:
766 # Attempt to buffer whole blocks
Johan Alfvéncce7f2d2022-04-08 10:47:09 +0200767 if buffering_depth > block_depth:
Tim Hall789e6f32021-06-17 17:02:31 +0100768 buffering_depth = round_down(buffering_depth, block_depth)
769 else:
770 buffering_depth = round_down(buffering_depth, ArchitectureFeatures.OFMSplitDepth)
771 buffering_depth = int(max(buffering_depth, ArchitectureFeatures.OFMSplitDepth))
Tim Halld8339a72021-05-27 18:49:40 +0100772
773 # Create list of depth slices
774 depth_slices = [0]
775 if prebuffer_depth < ref_cost.stripe.depth:
776 depth_slices += list(range(prebuffer_depth, ref_cost.stripe.depth, buffering_depth))
777 depth_slices.append(ref_cost.stripe.depth)
778
779 # Encode weights based depth slices
780 cost.ofm_depth_slices = depth_slices
Tim Halld784af72021-06-08 21:25:57 +0100781 encoded_weights, encoded_scales = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100782 self.arch,
783 sched_op.parent_op,
784 weight_tensor,
785 scale_tensor,
786 sched_op.kernel,
787 cost.block_config,
788 cost.ofm_depth_slices,
789 )
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100790 assert encoded_weights is not None
Tim Halld8339a72021-05-27 18:49:40 +0100791 # Chosen buffering might not fit at all, iterate until it does
792 # or until the minimum usable slice size is reached
793 if (
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000794 encoded_weights.double_buffer_size() <= buffer_limit_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100795 or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
796 ):
797 break
798
Tim Hall789e6f32021-06-17 17:02:31 +0100799 if buffering_depth > prebuffer_depth:
800 buffering_depth = round_up(buffering_depth // 2, ArchitectureFeatures.OFMSplitDepth)
801 else:
802 prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
Tim Halld8339a72021-05-27 18:49:40 +0100803
804 # Calculate cycles required to run the last op for use as future slack
805 tail_cycles = self.estimate_op_performance(
806 sched_op, cost.block_config, depth_slices[-1] - depth_slices[-2]
807 )
808 cost.slack_buffering_cycles = tail_cycles.op_cycles
809
810 # Determine whether the weights need to be double buffered
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000811 weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes())
Tim Halld8339a72021-05-27 18:49:40 +0100812
813 # Only buffer weights if there's still space left for the buffer
814 if weight_buffer_size <= buffer_limit_bytes:
815 assert weight_buffer_size % 16 == 0
816 # Determine whether to double buffer or single buffer
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000817 double_buffer_size = encoded_weights.double_buffer_size()
818 if (double_buffer_size <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
Tim Halld8339a72021-05-27 18:49:40 +0100819 weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
820 else:
821 weight_tensor_purpose = TensorSubPurpose.Standard
822
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000823 cost.buffered_weight_tensors = [
824 self.buffer_tensor(
825 encoded_weights,
826 weight_tensor_purpose,
827 encoded_weights.double_buffer_sizes[0],
828 weight_tensor.name + "_buffer",
829 )
830 ]
831 if weight_tensor_purpose == TensorSubPurpose.DoubleBuffer:
832 buf2 = self.buffer_tensor(
833 encoded_weights,
834 weight_tensor_purpose,
835 encoded_weights.double_buffer_sizes[1],
836 weight_tensor.name + "_buffer2",
837 )
838 cost.buffered_weight_tensors.append(buf2)
839
Rickard Bolin1bd20f22022-12-07 08:54:53 +0000840 # Note! OFM depth slices define slices as [0, s1, ... sn]. For example, [0, 70, 140] describes two slices
841 # (0-70 and 70-140) but has a length of 3, which would result in idx = 3 % 2 = 1 if two buffers were used.
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000842 last_used_buffer_idx = len(cost.ofm_depth_slices) % len(cost.buffered_weight_tensors)
843 weight_buffer_size = encoded_weights.double_buffer_sizes[last_used_buffer_idx]
844
Tim Halld8339a72021-05-27 18:49:40 +0100845 if ref_cost.cascade == 0:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000846 # Determine if the lifetime can be extended and pre-buffer the first weight buffer
847 # under the previous operation
848 cost.buffered_weight_tensors[0].pre_buffer = encoded_weights.double_buffer_size() < slack_memory
Tim Halld8339a72021-05-27 18:49:40 +0100849
850 cost.slack_buffering_memory -= weight_buffer_size
851 else:
852 # Don't slice or buffer - use the whole depth from persistent storage
853 cost.ofm_depth_slices = ofm_full_depth_slices
854 encoded_weights = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100855 encoded_scales = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100856
857 cost.npu_weights_tensor = encoded_weights
Tim Halld784af72021-06-08 21:25:57 +0100858 cost.npu_scales_tensor = encoded_scales
Tim Halld8339a72021-05-27 18:49:40 +0100859
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200860 def buffer_tensor(self, src_tensor: Tensor, sub_purpose: TensorSubPurpose, buffer_size: int, name: str) -> Tensor:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000861 buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name)
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200862 buffered_weight_tensor.src_tensor = src_tensor
863 buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
864 buffered_weight_tensor.mem_type = MemType.Scratch_fast
865 buffered_weight_tensor.purpose = TensorPurpose.Weights
866 buffered_weight_tensor.sub_purpose = sub_purpose
867 return buffered_weight_tensor
868
Tim Halld8339a72021-05-27 18:49:40 +0100869 def propose_minimal_schedule(self) -> Schedule:
870 """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the
871 next operators stride"""
872 min_schedule = Schedule(self.sg, "MIN")
873 cost_map = min_schedule.cost_map
874
875 # Keep track of the previous Op - which consumes the current Op's OFM
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100876 prev_op: Optional[SchedulerOperation] = None
Tim Halld8339a72021-05-27 18:49:40 +0100877 for sched_op in reversed(self.sched_ops):
878 min_stripe_height = prev_op.kernel.stride.y if prev_op else 1
879 min_stripe = sched_op.ofm.shape.with_height(min_stripe_height)
880
881 cost = sched_op.create_scheduler_info(self.nng, min_stripe)
882 cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
883 cost_map[sched_op] = cost
884
885 prev_op = sched_op
886
887 return min_schedule
888
889 def propose_schedule_striping(self, final_stripe: Shape4D, label: str, ref_schedule: Schedule) -> Schedule:
890 """Proposes new striping for a schedule. The stripe is derived from the ifm requirements of the next Op down"""
891 ref_cost = ref_schedule.cost_map
892
893 striped_schedule = Schedule(self.sg, label)
894 stripe = final_stripe
895 for sched_op in reversed(self.sched_ops):
896 if sched_op not in ref_cost:
897 # sched_op is not part of the sub-schedule - skip
898 continue
899
900 # Create a cost entry with the new stripe
901 cost = sched_op.create_scheduler_info(self.nng, stripe)
902
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000903 weight_tensor = cost.npu_weights_tensor
904 for idx, buffered_tens in enumerate(ref_cost[sched_op].buffered_weight_tensors):
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200905 # If the weights are buffered in the reference schedule they should be in the new proposal
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000906 cost.buffered_weight_tensors.append(
907 self.buffer_tensor(
908 weight_tensor,
909 buffered_tens.sub_purpose,
910 weight_tensor.double_buffer_sizes[idx],
911 buffered_tens.name,
912 )
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200913 )
Tim Halld8339a72021-05-27 18:49:40 +0100914
915 # Estimate performance
916 cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
917 striped_schedule.cost_map[sched_op] = cost
918
erik.andersson@arm.com8912f3a2022-08-16 11:08:57 +0200919 # Calculate the preceeding Op's stripe.
920
921 # In certain cases where an upscaling Op is cascaded,
922 # it may get instructed to produce an odd stripe height.
923 # Thus we need to force it back to even heights.
924 force_even_stripe_heights = False
925 for op in self.sched_ops:
926 # Check if the cascade has a Nearest Neighbor-op.
927 # If that is the case, force the stripes to be even.
928 if (
929 ref_cost.get(op, None)
930 and ref_cost.get(sched_op, None)
931 and ref_cost[op].cascade == ref_cost[sched_op].cascade
932 and is_nearest(op.resampling_mode)
933 ):
934 force_even_stripe_heights = True
935 break
936 upscaling_remainder = stripe.height % to_upscale(sched_op.resampling_mode)
937 height = stripe.height + (stripe.height % 2 if force_even_stripe_heights else upscaling_remainder)
Fredrik Svedbergd03dc502022-06-30 10:44:12 +0200938 stripe = sched_op.ifm.shape.with_height(height * sched_op.kernel.stride.y)
Tim Halld8339a72021-05-27 18:49:40 +0100939
940 return striped_schedule
941
942 def estimate_schedule_memory_usage(self, schedule: Schedule, non_local_mem_usage: dict):
943 """Estimates the memory usage of a schedule"""
944 cost = schedule.cost_map
945 cascades = schedule.cascades
946 peak_mem_usage = 0
947 for sched_op in self.sched_ops:
948 if sched_op not in cost:
949 # sched_op is not part of the sub-schedule - skip
950 continue
951
952 if cost[sched_op].cascade:
953 # This Op is part of a cascade - use the cascade's memory usage
954 cascade_info = cascades[cost[sched_op].cascade]
Johan Alfven3340a882023-03-16 11:04:31 +0100955 op_mem_usage = cascade_info.mem_usage + non_local_mem_usage.get(sched_op, 0)
Tim Halld8339a72021-05-27 18:49:40 +0100956 else:
957 # This Op is not part of a cascade - calculate the memory usage
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000958 op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100959
960 op_mem_usage = (
961 sched_op.ifm_size_in_bytes()
962 + sched_op.ofm_size_in_bytes()
963 + op_weight_buffer
964 + non_local_mem_usage.get(sched_op, 0)
965 )
Johan Alfven3340a882023-03-16 11:04:31 +0100966 peak_mem_usage = max(op_mem_usage, peak_mem_usage)
Tim Halld8339a72021-05-27 18:49:40 +0100967
968 return peak_mem_usage
969
Johan Alfvén255dad72022-07-16 18:27:05 +0200970 def build_cascades_for_min_schedule(self, min_schedule: Schedule, max_template: Schedule, memory_limit: int):
971 # Update memory snapshot
972 self.sg.schedule = min_schedule
973 self.update_op_memory_snapshot(min_schedule)
974
975 # Calculate residual memory for Min schedule
976 non_local_mem_usage = {}
977 for sched_op in self.sched_ops:
978 time_index = min_schedule.cost_map[sched_op].time_index
979
980 if self.arch.is_spilling_enabled():
981 # For Dedicated SRAM only the intermediate buffers are in SRAM, hence op_mem_usage is 0
982 op_mem_usage = 0
983 else:
984 # Min schedule only have ifm and ofm in SRAM (no buffered weigth tensors)
Johan Alfvéneb332a32022-12-09 17:50:48 +0100985 # Only include IFM's that are in the scratch area
986 ifm = sched_op.ifm.connection.parent_tens
987 ifm_size = (
988 0 if ifm.mem_type not in (MemType.Scratch, MemType.Scratch_fast) else sched_op.ifm_size_in_bytes()
989 )
Johan Alfvén92689d52022-12-06 11:16:19 +0100990 ofm_size = 0 if ofm_can_reuse_ifm(sched_op) else sched_op.ofm_size_in_bytes()
Johan Alfvéneb332a32022-12-09 17:50:48 +0100991 op_mem_usage = ifm_size + ofm_size
Johan Alfvén255dad72022-07-16 18:27:05 +0200992
993 non_local_mem_usage[sched_op] = min_schedule.memory_snapshot[time_index] - op_mem_usage
Johan Alfvén92689d52022-12-06 11:16:19 +0100994 assert non_local_mem_usage[sched_op] >= 0
Johan Alfvén255dad72022-07-16 18:27:05 +0200995
Rickard Bolin1bd20f22022-12-07 08:54:53 +0000996 # Create cascades for Min schedule
Johan Alfvén255dad72022-07-16 18:27:05 +0200997 cascade_builder = CascadeBuilder(self.sched_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
998 cascade_builder.build_cascades(min_schedule, max_template, memory_limit)
999
Tim Halld8339a72021-05-27 18:49:40 +01001000 def optimize_sub_schedule(
1001 self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int
1002 ) -> Schedule:
1003 """Extracts the Ops covered by the given cascade and creates a sub-schedule. The sub-schedule is optimized by
1004 proposing weight buffering and then continously proposing new stripe sizes"""
1005 ref_cost = ref_schedule.cost_map
1006 # Extract the ops that are part of this sub-schedule
1007 start = cascade_info.start
1008 end = cascade_info.end
1009 sub_schedule_ops = self.sched_ops[start : end + 1]
1010 # Create a sub-schedule that contains only the costs for the Ops that are part of the sub-schedule
1011 sub_schedule = Schedule(self.sg, f"SUB_{start}_{end}")
1012 for sched_op in sub_schedule_ops:
1013 sub_schedule.cost_map[sched_op] = ref_cost[sched_op]
1014
1015 sub_schedule.cascades[end] = cascade_info
1016 # Use the memory snapshot from the reference schedule
1017 sub_schedule.memory_snapshot = ref_schedule.memory_snapshot
1018
1019 # Calculate memory usage that is live during the sub-schedule but not part of it
1020 time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
1021 mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
1022 # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
Johan Alfven3340a882023-03-16 11:04:31 +01001023 # included in a cascade or not. Not valid in Dedicated SRAM mode (spilling enabled).
Tim Halld8339a72021-05-27 18:49:40 +01001024 persistent_initial_ifm = (
Johan Alfven3340a882023-03-16 11:04:31 +01001025 sub_schedule_ops[0].ifm_size_in_bytes()
1026 if not self.arch.is_spilling_enabled() and len(sub_schedule_ops[0].ifm.connection.consumers) > 1
1027 else 0
Tim Halld8339a72021-05-27 18:49:40 +01001028 )
1029 # Calculate non-local-mem-usage per Operator
1030 non_local_mem_usage = {}
1031 for idx, sched_op in enumerate(sub_schedule_ops):
1032 non_local_mem_usage[sched_op] = mem_usage_parallel_to_sub_schedule
1033 if idx != 0:
1034 non_local_mem_usage[sched_op] += persistent_initial_ifm
1035
1036 cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
1037
1038 # Start by adding buffering
Tim Hall789e6f32021-06-17 17:02:31 +01001039 buffered_sub_schedule = self.propose_schedule_buffering(
1040 sub_schedule, self.scheduler_options.optimization_sram_limit
1041 )
Tim Halld8339a72021-05-27 18:49:40 +01001042 # Copy the cascades over from the unbuffered-schedule
1043 buffered_sub_schedule.cascades = sub_schedule.cascades
1044
1045 # Generate the possible stripings for the final Op in the sub-schedule
1046 final_ofm_shape = sub_schedule_ops[-1].ofm.shape
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001047
1048 # Skip testing the min stripe used in the MIN schedule since that will be used
1049 # anyway if no new cascades are created below
1050 last_op = sub_schedule_ops[-1]
1051 min_stripe_h = sub_schedule.cost_map[last_op].stripe.height + 1
1052
Tim Halld8339a72021-05-27 18:49:40 +01001053 possible_stripes = [
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001054 final_ofm_shape.with_height(stripe_h) for stripe_h in range(min_stripe_h, final_ofm_shape.height // 2 + 1)
Tim Halld8339a72021-05-27 18:49:40 +01001055 ]
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001056 # Propose different striping
Jacob Bohlinfad72042021-08-24 21:51:41 +02001057 best_schedule = None
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001058 max_nbr_of_cascades = 0
1059 for iteration, proposed_stripe in enumerate(possible_stripes):
Tim Halld8339a72021-05-27 18:49:40 +01001060 proposed_schedule = self.propose_schedule_striping(
1061 proposed_stripe, f"OPTIMIZED_{iteration}", buffered_sub_schedule
1062 )
1063
1064 cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit)
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001065 nbr_of_cascades = len(proposed_schedule.cascades)
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001066 if iteration == 0:
1067 # First iteration - used as limit to prevent splitting up the cascades
1068 # Long cascades are better in order to reduce IFM/IFM dram bandwidth
1069 max_nbr_of_cascades = nbr_of_cascades
1070
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001071 # Check if proposal fits
1072 proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001073 if (proposed_schedule_mem_usage) <= memory_limit and nbr_of_cascades <= max_nbr_of_cascades:
Tim Halld8339a72021-05-27 18:49:40 +01001074 best_schedule = proposed_schedule
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001075
Tim Halld8339a72021-05-27 18:49:40 +01001076 if not proposed_schedule.cascades:
1077 # No cascading required - early exit
1078 break
1079 else:
Johan Alfvén2a285fc2022-08-17 14:59:58 +02001080 break
Tim Halld8339a72021-05-27 18:49:40 +01001081
1082 return best_schedule
1083
1084 def optimize_schedule(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001085 self,
1086 schedule: Schedule,
1087 max_sched: Schedule,
1088 max_template: Schedule,
1089 options: SchedulerOptions,
Tim Halld8339a72021-05-27 18:49:40 +01001090 ) -> Schedule:
1091 """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule"""
1092 sram_limit = options.optimization_sram_limit
1093 if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled():
1094 # Maximum performance schedule fits within the SRAM target
1095 return max_sched
1096
Jacob Bohlinfad72042021-08-24 21:51:41 +02001097 # Iterate over a copy of the cascades since they may change during the loop
1098 for cascade_info in list(schedule.cascades.values()):
Tim Halld8339a72021-05-27 18:49:40 +01001099 # Optimize the sub-schedule in this cascade
1100 opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
Jacob Bohlinfad72042021-08-24 21:51:41 +02001101 if opt_sub_schedule:
1102 # Remove the existing cascade
1103 del schedule.cascades[cascade_info.end]
1104 # Update the sub-schedule Op and cascade costs to the full schedule
1105 schedule.cost_map.update(opt_sub_schedule.cost_map)
1106 schedule.cascades.update(opt_sub_schedule.cascades)
Tim Halld8339a72021-05-27 18:49:40 +01001107
1108 # Update memory snapshot
1109 self.sg.schedule = schedule
1110 self.update_op_memory_snapshot(schedule)
1111 # Propose schedule buffering to the optimized schedule
Tim Hall789e6f32021-06-17 17:02:31 +01001112 optimized_sched = self.propose_schedule_buffering(schedule, self.scheduler_options.optimization_sram_limit)
Tim Halld8339a72021-05-27 18:49:40 +01001113 # Copy the cascade's metadata from the unbuffered schedule
1114 optimized_sched.cascades = schedule.cascades
1115 return optimized_sched
1116
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001117 def optimize_weight_buffering_size(
1118 self,
1119 min_schedule: Schedule,
1120 options: SchedulerOptions,
1121 ):
1122 default_schedule = self.sg.schedule
Tim Hallc1be0872022-03-03 17:50:52 +00001123 npu_performance.calc_new_performance_for_network(self.nng, self.arch, None, False)
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001124 default_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
1125 default_dram_cycles = self.nng.cycles[npu_performance.PassCycles.DramAccess]
1126
1127 # Restore mem/type for scratched_fms
1128 for tens in self.scratched_fms:
1129 tens.mem_area = self.scratched_fms[tens][0]
1130 tens.mem_type = self.scratched_fms[tens][1]
1131
1132 self.update_op_memory_snapshot(self.sg.schedule)
1133
1134 # Collect live ranges from tensors
1135 memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
1136 lr_graph = live_range.LiveRangeGraph()
1137 for mem_area, mem_type_set in memories_list:
Johan Alfven6e281af2023-02-28 09:03:03 +01001138 live_range.extract_live_ranges_from_schedule(
1139 self.sg,
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001140 mem_area,
1141 mem_type_set,
1142 lr_graph,
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001143 )
1144
1145 # Find the relation between the sched_op and the buffering tensor
1146 weight_ops = {}
1147 for sched_op in self.sched_ops:
1148 cost = self.sg.schedule.cost_map[sched_op]
Rickard Bolinfd8b5002022-05-16 09:11:06 +00001149 for tens in cost.buffered_weight_tensors:
1150 weight_ops[tens] = sched_op
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001151
1152 # Filter out weight buffer live ranges
1153 weight_lrs = []
1154 for lr in lr_graph.lrs:
1155 for tens in lr.tensors:
1156 if weight_ops.get(tens):
1157 weight_lrs.append(lr)
1158 break
1159
1160 # See if any evicted fm overlaps with a weight buffering op.
1161 # If this is the case add a size limitation to the buffering op
1162 for lr in self.evicted_fms:
1163 for weight_lr in weight_lrs:
1164 if lr.overlaps_ranges(weight_lr):
1165 for tens in weight_lr.tensors:
1166 sched_op = weight_ops.get(tens)
1167 if sched_op:
1168 # Add size reduction to the op
1169 sched_op.evicted_fms_size += lr.size
1170 break
1171
1172 self.sg.schedule = min_schedule
1173 self.update_op_memory_snapshot(self.sg.schedule)
1174
1175 # Run schedule buffering - with weight buffer size reduction
1176 schedule = self.propose_schedule_buffering(self.sg.schedule, options.optimization_sram_limit)
1177 schedule.cascades = self.sg.schedule.cascades
1178 self.sg.schedule = schedule
1179
1180 # Apply new buffer schdule and calc new performance
1181 self.update_op_memory_snapshot(self.sg.schedule)
1182 self.apply_schedule(self.sg.schedule)
1183 self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
1184
Tim Hallc1be0872022-03-03 17:50:52 +00001185 npu_performance.calc_new_performance_for_network(self.nng, self.arch, None, False)
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001186 new_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
1187 new_dram_cycles = self.nng.cycles[npu_performance.PassCycles.DramAccess]
1188
Tim Hall8bc7a652022-05-19 15:29:23 +01001189 improvement_tot = (
1190 round((default_tot_cycles - new_tot_cycles) / default_tot_cycles, 2) if default_tot_cycles != 0 else 0
1191 )
1192 improvement_dram = (
1193 round((default_dram_cycles - new_dram_cycles) / default_dram_cycles, 2) if default_dram_cycles != 0 else 0
1194 )
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001195
1196 # Compare both total and dram improvement
Johan Alfvén3dae1b62022-05-17 10:26:48 +02001197 if not (improvement_tot >= 0 and improvement_dram > 0):
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001198 # No improvement, restore the default schedule
1199 for sched_op in self.sched_ops:
1200 sched_op.evicted_fms_size = 0
1201
1202 for tens in self.scratched_fms:
1203 tens.mem_area = self.scratched_fms[tens][0]
1204 tens.mem_type = self.scratched_fms[tens][1]
1205
1206 self.sg.schedule = default_schedule
1207 self.update_op_memory_snapshot(self.sg.schedule)
1208 self.apply_schedule(self.sg.schedule)
1209 self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
1210
Tim Halld8339a72021-05-27 18:49:40 +01001211 def apply_schedule(self, sched: Schedule):
1212 """Applies the given schedule as a final solution"""
1213 for sched_op in self.sched_ops:
1214 op_info = sched.cost_map[sched_op]
1215 cascade_info = sched.cascades.get(op_info.cascade, None)
1216 if cascade_info and sched_op in cascade_info.buffers:
1217 buffer_tens = sched_op.ifm.connection.parent_tens
1218 # Apply memory area and type
1219 buffer_tens.mem_area = self.arch.fast_storage_mem_area
1220 buffer_tens.mem_type = MemType.Scratch_fast
1221 # Apply Rolling buffer
1222 buffer_tens.set_format(TensorFormat.NHCWB16, self.arch)
1223 buffer_tens.set_new_sub_purpose(TensorSubPurpose.RollingBufferY, cascade_info.buffers[sched_op].height)
1224
1225 sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
1226
1227 # Ensure that the src_tensor reference is set correctly
Rickard Bolinfd8b5002022-05-16 09:11:06 +00001228 for tens in op_info.buffered_weight_tensors:
1229 tens.src_tensor = op_info.npu_weights_tensor
Tim Halld8339a72021-05-27 18:49:40 +01001230
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001231 def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001232 """Finds the set of feature maps that fits within the staging limit which combined has the largest amount of
1233 access cycles and moves those feature map into fast storage"""
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001234 max_mem_usage = []
1235 base_mem_usage = []
1236 fast_storage_type = MemType.Scratch_fast
1237 fast_storage_mem_area = self.arch.fast_storage_mem_area
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001238 self.evicted_fms = []
Tim Halld8339a72021-05-27 18:49:40 +01001239
1240 # Force all OFMs to fast-storage
1241 for sched_op in self.sched_ops:
1242 cost = schedule.cost_map[sched_op]
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001243 if cost.cascade == 0 and sched_op.get_dependants():
1244 ofm_tens = sched_op.ofm.connection.parent_tens
1245 if not any(cons is None for cons in ofm_tens.consumer_list):
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001246 if ofm_tens not in self.scratched_fms:
1247 # Remember default mem area and mem type, only done once
1248 self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
1249
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001250 ofm_tens.mem_area = fast_storage_mem_area
1251 ofm_tens.mem_type = fast_storage_type
Tim Halld8339a72021-05-27 18:49:40 +01001252
1253 # Collect live ranges from tensors
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001254 memories_list = [(fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
Tim Halld8339a72021-05-27 18:49:40 +01001255 lr_graph = live_range.LiveRangeGraph()
1256 for mem_area, mem_type_set in memories_list:
Johan Alfven6e281af2023-02-28 09:03:03 +01001257 live_range.extract_live_ranges_from_schedule(
1258 self.sg,
Jonas Ohlssond8575072022-03-30 10:30:25 +02001259 mem_area,
1260 mem_type_set,
1261 lr_graph,
Tim Halld8339a72021-05-27 18:49:40 +01001262 )
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001263 max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
Tim Halld8339a72021-05-27 18:49:40 +01001264
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001265 # If max_mem_usage does not exceed staging limit at any point all lrs fit and can stay in fast storage
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001266 if max(max_mem_usage) <= staging_limit:
1267 return
1268
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001269 # Build up the base memory usage by removing the mem_usage of the lrs we previously moved to fast-storage
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001270 base_mem_usage = np.array(max_mem_usage)
1271 curr_lrs = []
Tim Halld8339a72021-05-27 18:49:40 +01001272 for lr in lr_graph.lrs:
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001273 for tens in lr.tensors:
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001274 if self.scratched_fms.get(tens):
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001275 curr_lrs.append(lr)
1276 base_mem_usage[lr.start_time : lr.end_time + 1] -= lr.size
1277 break
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001278 competing_lrs = []
Johan Alfvén5c309712022-06-10 15:40:58 +02001279 competing_tens_access = {}
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001280
1281 # Evict live ranges that will never fit
1282 for lr in curr_lrs.copy():
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001283 base_usage = max(base_mem_usage[lr.start_time : lr.end_time + 1])
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001284 if base_usage + lr.size > staging_limit:
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001285 # Lr will never fit and may thus be evicted
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001286 self.evicted_fms.append(lr)
1287 FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001288 curr_lrs.remove(lr)
1289
1290 # Keep live ranges that will always fit in fast storage and let the remaining ones compete
1291 for lr in curr_lrs:
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001292 # Since max_mem_usage is the memory usage with all FMs still in fast-storage,
1293 # the memory limit cannot be exceeded if max_mem_usage does not.
1294 # Thus, the affected lrs can remain in fast-storage if the following is true
1295 if max(max_mem_usage[lr.start_time : lr.end_time + 1]) <= staging_limit:
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001296 FastStorageComponentAllocator.keep(lr, base_mem_usage)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001297 else:
1298 competing_lrs.append(lr)
Johan Alfvén5c309712022-06-10 15:40:58 +02001299 for tens in lr.tensors:
1300 competing_tens_access[tens] = 0
1301
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001302 # All lrs and their tensors have been handled if competing_lrs_sz is zero, we may thus return
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001303 if len(competing_lrs) == 0:
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001304 return
1305
Johan Alfvén5c309712022-06-10 15:40:58 +02001306 # Estimate element access for all tensors that are competing for a place in fast-storage.
1307 # This number is used when deciding which tensor that stays in fast-storage
1308 for sched_op in self.sched_ops:
1309 cost = schedule.cost_map[sched_op]
1310
1311 if competing_tens_access.get(sched_op.ifm.connection.parent_tens) is not None:
1312 tens = sched_op.ifm.connection.parent_tens
1313 access = self.estimate_element_access(sched_op, cost.block_config, sched_op.ofm.shape.depth)
1314 competing_tens_access[tens] += access.ifm_read[0]
1315
1316 if sched_op.ifm2 and competing_tens_access.get(sched_op.ifm2.connection.parent_tens) is not None:
1317 tens = sched_op.ifm2.connection.parent_tens
1318 access = self.estimate_element_access(sched_op, cost.block_config, sched_op.ofm.shape.depth)
1319 competing_tens_access[tens] += access.ifm_read[1]
1320
1321 if competing_tens_access.get(sched_op.ofm.connection.parent_tens) is not None:
1322 tens = sched_op.ofm.connection.parent_tens
1323 access = self.estimate_element_access(sched_op, cost.block_config, sched_op.ofm.shape.depth)
1324 competing_tens_access[tens] += access.ofm_write
1325
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001326 # Sort live ranges "from left to right" on the time axis to simplify checking overlapping ranges
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001327 competing_lrs = sorted(competing_lrs, key=lambda lr: (lr.start_time, lr.end_time + 1, lr.size))
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001328
1329 # Remove lrs that have a live range that is too long compared to others.
1330 # They are causing problems for the HillClimb Allocator when it has to
1331 # change the allocation indices, in order to fit all the allocations into SRAM.
1332 # This problem only occur in larger networks with complex graphs.
1333 #
1334 # Limit the number of items for allocate_component to work with max MAX_EXHAUSTIVE_ITEMS
1335 # at the time. Too many will give too long compilation time
1336 #
1337 # Too long is currently decided to be (based on experience, analyzing many networks):
1338 # Compare lr at postion i with lr at position i + MAX_EXHAUSTIVE_ITEMS.
1339 # If end time differs by at least MAX_EXHAUSTIVE_LIFE_RANGE then do not include lr at position i.
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001340 if len(competing_lrs) > FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS:
1341 # Create a copy of the original list to iterate over because the original version is modified in-loop
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001342 competing_lrs_copy = competing_lrs.copy()
1343 for i, lr in enumerate(competing_lrs_copy):
1344 lr_time = lr.end_time - lr.start_time
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001345 # Only check live ranges longer than MAX_EXHAUSTIVE_LIFE_RANGE
1346 if lr_time >= FastStorageComponentAllocator.MAX_EXHAUSTIVE_LIFE_RANGE:
1347 # Compare current lr with lr at position lr + MAX_EXHAUSTIVE_ITEMS
1348 cmp_pos = min(i + FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS, len(competing_lrs) - 1)
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001349
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001350 # Compare end times + plus a margin by MAX_EXHAUSTIVE_LIFE_RANGE
1351 if (
1352 lr.end_time
1353 > competing_lrs_copy[cmp_pos].end_time + FastStorageComponentAllocator.MAX_EXHAUSTIVE_LIFE_RANGE
1354 ):
1355 # Current lr live time stands out, remove it. No use adding it to the
1356 # evicted_fms list since the lr should not be included in the fast storage allocation
1357 FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
1358 competing_lrs.remove(lr)
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001359
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001360 # Split competing live ranges into components by finding disconnected groups of live ranges or components of
1361 # max size MAX_EXHAUSTIVE_ITEMS
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001362 start = 0
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001363 end_time = competing_lrs[0].end_time
1364 component_allocator = FastStorageComponentAllocator(base_mem_usage, max_mem_usage, staging_limit)
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001365 component_ranges = []
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001366 for i, lr in enumerate(competing_lrs):
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001367 nbr_items = i - start
1368 if lr.start_time <= end_time and (nbr_items < FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS):
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001369 end_time = max(end_time, lr.end_time)
1370 else:
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001371 # Number items reached max items or current lr's start time
1372 # does not overlap with previous lr's end time
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001373 component_ranges.append((start, i))
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001374 start = i
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001375 end_time = lr.end_time
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001376 component_ranges.append((start, len(competing_lrs)))
1377
1378 # Allocate each component separately
1379 for start, end in component_ranges:
1380 component_allocator.allocate_component(
1381 competing_lrs[start:end],
1382 max_mem_usage,
1383 base_mem_usage,
1384 self.scratched_fms,
1385 competing_tens_access,
1386 self.evicted_fms,
1387 )
1388 assert max(max_mem_usage) <= staging_limit, "Allocation exceeds staging limit"
Tim Halld8339a72021-05-27 18:49:40 +01001389
Tim Halld8339a72021-05-27 18:49:40 +01001390 def print_schedule(self, schedule: Schedule):
1391 print(f"Schedule: '{schedule.name}'")
1392 for sched_op in self.sched_ops:
1393 if sched_op not in schedule.cost_map:
1394 # Sub-schedule printing
1395 continue
1396
1397 op_info = schedule.cost_map[sched_op]
1398 print(f"\t{sched_op.index}: Operation {sched_op.name} - OFM {sched_op.ofm.shape}")
1399 print(f"\t\tType: {sched_op.op_type}")
1400 print(f"\t\tKernel: {sched_op.kernel}")
1401 print(f"{op_info}")
1402 mem_usage = (
1403 schedule.memory_snapshot[op_info.time_index]
1404 if op_info.time_index < len(schedule.memory_snapshot)
1405 else 0
1406 )
1407 print(f"\t\tSRAM Used: {mem_usage} bytes")
1408
Jonas Ohlsson25e700c2022-03-04 14:58:56 +01001409 print("\tCascades:")
Tim Halld8339a72021-05-27 18:49:40 +01001410 for i, cascade in enumerate(schedule.cascades.values()):
1411 print(f"\t\t{i}: {cascade.start} -> {cascade.end}, size: {cascade.mem_usage}")
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001412
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001413
Johan Alfven6e281af2023-02-28 09:03:03 +01001414def _update_memory_snapshot_for_all_npu_graphs(nng: Graph, arch: ArchitectureFeatures, schedulers):
1415 mem_area = arch.fast_storage_mem_area
1416 mem_type_set = set((MemType.Scratch, MemType.Scratch_fast))
1417
1418 # Collect live ranges for the full graph
1419 # extract_live_ranges_from_cascaded_passes will start from the root sg and
1420 # all sub graphs/cascaded passes will be visited and the correct time_index
1421 # will be set for all the tensors.
1422 lr_graph = live_range.LiveRangeGraph()
1423 live_range.extract_live_ranges_from_cascaded_passes(
1424 nng.get_root_subgraph(),
1425 mem_area,
1426 mem_type_set,
1427 lr_graph,
1428 Tensor.AllocationQuantum,
1429 )
1430 # Populate time-array with memory used by live ranges
1431 temporal_usage = lr_graph.get_temporal_memory_usage(arch.fast_storage_mem_area)
1432
1433 # Update snapshot for all the npu sub graphs
1434 # Not needed for the scheduler any longer but npu_performance
1435 # is using this information so it must have the correct state
1436 for sg in schedulers:
1437 sg.schedule.memory_snapshot = temporal_usage
1438 sg.schedule.fast_storage_peak_usage = max(temporal_usage, default=0)
1439
1440
Tim Halld8339a72021-05-27 18:49:40 +01001441def _update_tensor_allocation(nng: Graph, arch: ArchitectureFeatures, options):
1442 """
1443 Creates live ranges and runs tensor allocator for the current schedule
1444 (i.e. sg.schedule for all subgraphs), returns the maximum memory usage
1445 and updates SchedulerOpInfo.mem_usage for all operations in the schedule.
1446 """
1447 root_sg = nng.get_root_subgraph()
1448
1449 alloc_list = []
1450 if arch.is_spilling_enabled():
1451 mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
1452 mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
1453 # Order is important
1454 alloc_list.append(mem_alloc_scratch_fast)
1455 alloc_list.append(mem_alloc_scratch)
1456 else:
1457 mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
1458 alloc_list.append(mem_alloc_scratch)
1459
1460 for mem_area, mem_type_set in alloc_list:
1461 tensor_allocation.allocate_tensors(
1462 nng,
1463 root_sg,
1464 arch,
1465 mem_area,
1466 mem_type_set,
1467 tensor_allocator=options.tensor_allocator,
1468 verbose_allocation=options.verbose_allocation,
1469 cpu_tensor_alignment=options.cpu_tensor_alignment,
Tim Hallcda4fcb2022-05-19 12:36:58 +01001470 hillclimb_max_iterations=options.hillclimb_max_iterations,
Tim Halld8339a72021-05-27 18:49:40 +01001471 )
1472
1473
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001474class FastStorageComponentAllocator:
Johan Alfvén5c309712022-06-10 15:40:58 +02001475 MAX_EXHAUSTIVE_LIFE_RANGE = 20
Johan Alfvén3a6325f2022-10-07 18:03:48 +02001476 MAX_EXHAUSTIVE_ITEMS = 20
Johan Alfvén5c309712022-06-10 15:40:58 +02001477
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001478 def __init__(self, base_mem_usage, max_mem_usage, staging_limit):
1479 self.base_mem_usage = base_mem_usage
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001480 self.max_mem_usage = max_mem_usage
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001481 self.staging_limit = staging_limit
1482 self.lrs = []
1483 self.evicted = []
1484 self.curr_evicted = []
1485 self.remaining_total_size = []
Johan Alfvén5c309712022-06-10 15:40:58 +02001486 self.best_score = 0
1487 self.competing_tens_access = {}
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001488
Johan Alfvén5c309712022-06-10 15:40:58 +02001489 def allocate_exhaustive(self, ix, score):
1490 # Favour tensors with highest element access (score)
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001491 if ix >= self.num_lrs:
Johan Alfvén5c309712022-06-10 15:40:58 +02001492 if score > self.best_score:
1493 self.best_score = score
Louis Verhaard5c8f1e52022-02-23 14:13:07 +01001494 self.evicted = self.curr_evicted.copy()
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001495 return
1496
1497 lr = self.lrs[ix]
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001498
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001499 # If adding the tensor size to the base mem usage doesn't exceed the staging limit anywhere on the lr time
1500 # range, it can fit and the case where the tensor is included needs to be checked
1501 can_fit = max(self.base_mem_usage[lr.start_time : lr.end_time + 1]) + lr.size <= self.staging_limit
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001502 if can_fit:
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001503 # Tensor can fit, add tensor element access to the score and check case where tensor is included
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001504 self.curr_evicted[ix] = False
1505 self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, True)
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001506 self.allocate_exhaustive(ix + 1, score + self.competing_tens_access[lr.tensors[0]])
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001507 self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, False)
1508
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001509 # If the max mem usage doesn't exceed the staging limit anywhere on the lr time range, it always fits and the
1510 # case where the tensor is not included can be skipped
1511 always_fits = max(self.max_mem_usage[lr.start_time : lr.end_time + 1]) <= self.staging_limit
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001512 if not always_fits:
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001513 # Tensor doesn't always fit, check case when tensor is not included
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001514 self.curr_evicted[ix] = True
1515 self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, False)
Johan Alfvén5c309712022-06-10 15:40:58 +02001516 self.allocate_exhaustive(ix + 1, score)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001517 self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, True)
1518
1519 @staticmethod
1520 def update_mem_usage(mem_usage, lr, increase):
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001521 size = lr.size if increase else -lr.size
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001522 for t in range(lr.start_time, lr.end_time + 1):
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001523 mem_usage[t] += size
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001524 return mem_usage
1525
1526 @staticmethod
1527 def evict(lr, max_mem_usage, scratched_fms):
1528 for t in range(lr.start_time, lr.end_time + 1):
1529 max_mem_usage[t] -= lr.size
1530 for tens in lr.tensors:
1531 if tens in scratched_fms:
1532 tens.mem_area = scratched_fms[tens][0]
1533 tens.mem_type = scratched_fms[tens][1]
1534
1535 @staticmethod
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001536 def keep(lr, base_mem_usage):
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001537 for t in range(lr.start_time, lr.end_time + 1):
1538 base_mem_usage[t] += lr.size
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001539
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001540 def allocate_component(self, lrs, max_mem, min_mem, scratched_fms, competing_tens_access, evicted_fms):
1541 self.lrs = lrs
1542 self.num_lrs = len(lrs)
1543 self.evicted = [0] * self.num_lrs
1544 self.curr_evicted = [0] * self.num_lrs
1545 self.best_score = -1
1546 self.competing_tens_access = competing_tens_access
Johan Alfvén5c309712022-06-10 15:40:58 +02001547 # Recursively evaluate all permutations of allocations of the lrs found in the component.
1548 # For every permutation that fits within the staging_limit there is a score calculated.
1549 # The permutation with the highest score will then be chosen. The score is calculated
1550 # as the sum of the actual element access (ifm read and ofm write) for all the
1551 # including tensors. So it is not necessary the tensor with the biggest size that ends up
1552 # being included in the result.
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001553 self.allocate_exhaustive(0, 0)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001554 # Optimal allocation has been found, move lrs accordingly
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001555 for i, lr in enumerate(self.lrs):
1556 if self.evicted[i]:
1557 self.evict(lr, max_mem, scratched_fms)
1558 if lr not in evicted_fms:
1559 evicted_fms.append(lr)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001560 else:
Rickard Bolin1bd20f22022-12-07 08:54:53 +00001561 self.keep(lr, min_mem)
1562 if lr in evicted_fms:
1563 evicted_fms.remove(lr)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001564
1565
Tim Halld8339a72021-05-27 18:49:40 +01001566def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):
1567 """Entry point for the Scheduler"""
1568 # Initialize CPU subgraphs
1569 schedulers = dict()
1570 # Initialize schedulers with max schedule. Only schedule NPU subgraphs
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001571 for sg in nng.subgraphs:
Tim Halld8339a72021-05-27 18:49:40 +01001572 if sg.placement != PassPlacement.Npu:
1573 # Create cascaded passes for CPU Ops
1574 cascaded_passes = []
1575 for idx, ps in enumerate(sg.passes):
1576 cps = CascadedPass(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001577 ps.name,
1578 SchedulingStrategy.WeightStream,
1579 ps.inputs,
1580 [],
1581 ps.outputs,
1582 [ps],
1583 ps.placement,
1584 False,
Tim Halld8339a72021-05-27 18:49:40 +01001585 )
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001586
Tim Halld8339a72021-05-27 18:49:40 +01001587 cps.time = idx
1588 ps.cascade = cps
1589 cascaded_passes.append(cps)
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001590
Tim Halld8339a72021-05-27 18:49:40 +01001591 sg.cascaded_passes = cascaded_passes
1592 else:
1593 # Npu subgraph - create schedule
1594 scheduler = Scheduler(nng, sg, arch, scheduler_options)
1595 schedulers[sg] = scheduler
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001596
Tim Halld8339a72021-05-27 18:49:40 +01001597 scheduler.create_scheduler_representation(arch)
1598 sg.sched_ops = scheduler.sched_ops
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001599
Tim Halld8339a72021-05-27 18:49:40 +01001600 # Create the Max schedule template
1601 max_schedule_template = scheduler.create_initial_schedule()
1602 scheduler.max_schedule = max_schedule_template
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001603
Tim Halld8339a72021-05-27 18:49:40 +01001604 # Create the optimimised Max schedule
1605 sg.schedule = max_schedule_template
1606 scheduler.update_op_memory_snapshot(max_schedule_template)
Tim Hall789e6f32021-06-17 17:02:31 +01001607 opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template, 1 << 32)
Tim Halld8339a72021-05-27 18:49:40 +01001608 sg.schedule = opt_max_schedule
1609 scheduler.update_op_memory_snapshot(opt_max_schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001610
Tim Halld8339a72021-05-27 18:49:40 +01001611 # Create Min schedule
1612 min_schedule = scheduler.propose_minimal_schedule()
1613 initial_sram_limit = scheduler_options.optimization_sram_limit
1614 if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
1615 initial_sram_limit = scheduler.min_memory_req
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001616
Johan Alfvén255dad72022-07-16 18:27:05 +02001617 # Build cascades for Min schedule
1618 scheduler.build_cascades_for_min_schedule(min_schedule, max_schedule_template, initial_sram_limit)
Tim Halld8339a72021-05-27 18:49:40 +01001619 sg.schedule = min_schedule
1620 scheduler.update_op_memory_snapshot(min_schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001621
Tim Halld8339a72021-05-27 18:49:40 +01001622 if scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
1623 # Create an optimized schedule
1624 sg.schedule = scheduler.optimize_schedule(
1625 min_schedule, opt_max_schedule, max_schedule_template, scheduler_options
1626 )
1627 scheduler.update_op_memory_snapshot(sg.schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001628
Tim Halld8339a72021-05-27 18:49:40 +01001629 scheduler.apply_schedule(sg.schedule)
1630 scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001631
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001632 if scheduler_options.optimization_strategy == OptimizationStrategy.Performance and scheduler.evicted_fms:
1633 # It might be possible to gain performance by reducing
1634 # weight buffer size and instead fit fms in fast storage
1635 scheduler.optimize_weight_buffering_size(min_schedule, scheduler_options)
1636
Tim Halld8339a72021-05-27 18:49:40 +01001637 if scheduler_options.verbose_schedule:
1638 scheduler.print_schedule(sg.schedule)
Tim Hall79d07d22020-04-27 18:20:16 +01001639
Johan Alfven6e281af2023-02-28 09:03:03 +01001640 # Make a full live range calculation starting from the root sg
1641 _update_memory_snapshot_for_all_npu_graphs(nng, arch, schedulers)
1642
Tim Halld8339a72021-05-27 18:49:40 +01001643 # Evaluate schedule
1644 _update_tensor_allocation(nng, arch, options)