blob: 0b79402a033e7c5f37ad7e387471bb729ceeeaff [file] [log] [blame]
Tim Halld8339a72021-05-27 18:49:40 +01001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Halld8339a72021-05-27 18:49:40 +010016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
Tim Halld8339a72021-05-27 18:49:40 +010018# The scheduler creates and searches for an optimal plan for the network, selecting block configurations and
19# subdivisions for the Operators
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020# For Class name forward references for the type annotations. (see PEP 563).
21from __future__ import annotations
22
Diego Russoea6111a2020-04-14 18:41:58 +010023import copy
Johan Alfvén5e0ae552022-02-09 21:20:10 +010024from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010025from enum import auto
26from enum import IntEnum
Johan Alfvén6f4cb032022-05-05 08:42:46 +020027from typing import Any
Tim Halld8339a72021-05-27 18:49:40 +010028from typing import Dict
29from typing import List
30from typing import Optional
31from typing import Tuple
Jonas Ohlsson845e2322022-03-01 12:39:55 +010032from typing import TYPE_CHECKING
33
34# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
35if TYPE_CHECKING:
36 from .npu_performance import CycleCost
Diego Russoea6111a2020-04-14 18:41:58 +010037
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +010038import numpy as np
39
Diego Russoea6111a2020-04-14 18:41:58 +010040from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010041from . import npu_performance
Tim Halld8339a72021-05-27 18:49:40 +010042from . import tensor_allocation
43from . import weight_compressor
44from .architecture_allocator import ArchitectureBlockConfig
45from .architecture_allocator import find_block_config
46from .architecture_allocator import get_ifm_area_required
Tim Halld8339a72021-05-27 18:49:40 +010047from .architecture_features import ArchitectureFeatures
48from .architecture_features import Block
49from .cascade_builder import CascadeBuilder
50from .cascade_builder import CascadeInfo
Fredrik Svedberg880e7352020-08-25 11:31:47 +020051from .data_type import DataType
Diego Russoe8a10452020-04-21 17:39:10 +010052from .nn_graph import CascadedPass
Tim Halld8339a72021-05-27 18:49:40 +010053from .nn_graph import Graph
54from .nn_graph import Pass
Diego Russoe8a10452020-04-21 17:39:10 +010055from .nn_graph import PassPlacement
Diego Russoe8a10452020-04-21 17:39:10 +010056from .nn_graph import SchedulingStrategy
Tim Halld8339a72021-05-27 18:49:40 +010057from .nn_graph import Subgraph
58from .numeric_util import round_down
59from .numeric_util import round_up
Diego Russoe8a10452020-04-21 17:39:10 +010060from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020061from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010062from .shape4d import Shape4D
Diego Russoe8a10452020-04-21 17:39:10 +010063from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020064from .tensor import MemType
Tim Halld8339a72021-05-27 18:49:40 +010065from .tensor import Tensor
Diego Russoe8a10452020-04-21 17:39:10 +010066from .tensor import TensorFormat
67from .tensor import TensorPurpose
68from .tensor import TensorSubPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010069from .weight_compressor import NpuWeightTensor
Jacob Bohlin1a666972020-09-11 10:04:15 +020070
Tim Hall79d07d22020-04-27 18:20:16 +010071
Tim Halld8339a72021-05-27 18:49:40 +010072def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D:
73 if tensor_format == TensorFormat.NHCWB16:
74 return shape.with_depth(round_up(shape.depth, 16))
75
76 return shape
77
78
79class OptimizationStrategy(IntEnum):
80 """Enum defining the different optimization strategies for the Scheduler"""
81
82 Size = auto()
83 Performance = auto()
Tim Hall79d07d22020-04-27 18:20:16 +010084
85 def __str__(self):
86 return self.name
87
88
Tim Halld8339a72021-05-27 18:49:40 +010089class SchedulerOpInfo:
90 """Contains metadata about a SchedulerOperation that is unique to one Schedule"""
91
Tim Hall79d07d22020-04-27 18:20:16 +010092 def __init__(
93 self,
Tim Halld8339a72021-05-27 18:49:40 +010094 block_config: ArchitectureBlockConfig,
95 weights_size: int,
96 stripe_input: Shape4D,
97 stripe_input2: Optional[Shape4D],
98 stripe: Shape4D,
Tim Hall79d07d22020-04-27 18:20:16 +010099 ):
Tim Halld8339a72021-05-27 18:49:40 +0100100 self.block_config = block_config
101 self.weights_size = weights_size
102 self.stripe_input = stripe_input
103 self.stripe_input2 = stripe_input2
104 self.stripe = stripe
105 self.cascade = 0 # Assigned by CascadeBuilder. 0 means not part of a cascade
106 self.time_index = None # Set by update_op_memory_snapshot
107 self.ofm_depth_slices: List[int] = [0, stripe.depth]
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100108 self.npu_weights_tensor: Optional[NpuWeightTensor] = None
109 self.npu_scales_tensor: Optional[NpuWeightTensor] = None
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000110 self.buffered_weight_tensors: List[Tensor] = []
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100111 self.cycles: Optional[CycleCost] = None
Tim Halld8339a72021-05-27 18:49:40 +0100112 self.slack_buffering_cycles = 0
113 self.slack_buffering_memory = 0
114 self.full_weight_transfer_cycles = 0
115
116 def copy(self):
Jonas Ohlssond8575072022-03-30 10:30:25 +0200117 res = SchedulerOpInfo(
118 self.block_config,
119 self.weights_size,
120 self.stripe_input,
121 self.stripe_input2,
122 self.stripe,
123 )
Tim Halld8339a72021-05-27 18:49:40 +0100124 res.cascade = self.cascade
125 return res
126
127 def __str__(self):
128 res = f"\t\tBlock Config = {self.block_config}\n"
129 res += f"\t\tOFM Block = {self.block_config.ofm_block}\n"
130 res += f"\t\tIFM Stripe = {self.stripe_input}\n"
131 res += f"\t\tIFM2 Stripe = {self.stripe_input2}\n"
132 res += f"\t\tOFM Stripe = {self.stripe}\n"
133 res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000134 for idx, tens in enumerate(self.buffered_weight_tensors):
135 res += f"\t\tWeight buffer{idx + 1} = {tens.storage_size()} bytes\n"
Tim Halld8339a72021-05-27 18:49:40 +0100136 res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
137 res += f"\t\tAssigned Cascade = {self.cascade}"
138 return res
139
140
141class SchedulerOptions:
142 """Contains options for the Scheduler"""
143
144 def __init__(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200145 self,
146 optimization_strategy,
147 sram_target,
148 verbose_schedule,
Tim Halld8339a72021-05-27 18:49:40 +0100149 ):
150 self.optimization_strategy = optimization_strategy
151 self.optimization_sram_limit = sram_target
Tim Hall79d07d22020-04-27 18:20:16 +0100152 self.verbose_schedule = verbose_schedule
Tim Hall79d07d22020-04-27 18:20:16 +0100153
Tim Halld8339a72021-05-27 18:49:40 +0100154 def __str__(self) -> str:
155 return f"{type(self).__name__}: {str(self.__dict__)}"
Tim Hall79d07d22020-04-27 18:20:16 +0100156
157 __repr__ = __str__
158
159
Tim Halld8339a72021-05-27 18:49:40 +0100160class SchedulerTensor:
161 def __init__(self, shape, dt, mem_area, _format):
162 self.dtype = dt
163 self.mem_area = mem_area
164 self.shape = shape
165 self.format = _format
166 self.connection = None
Tim Hall79d07d22020-04-27 18:20:16 +0100167
Tim Halld8339a72021-05-27 18:49:40 +0100168
169class SchedulerOperation:
170 """Scheduler internal representation of 'Operation'
171 This class can be seen as a node within the Scheduler Graph representation
172 """
173
174 def __init__(self, ps: Pass, arch: ArchitectureFeatures, nng: Graph):
175 self.arch = arch
176 self.parent_ps = ps
177 self.parent_op = ps.primary_op
178 self.name = ps.primary_op.name
179 self.op_type = ps.primary_op.type
180 self.activation = ps.primary_op.activation
181 self.kernel = ps.primary_op.kernel
Tim Hall3c5cfe92022-03-16 16:31:57 +0000182 self.resampling_mode = ps.primary_op.ifm_resampling_mode
Tim Halld8339a72021-05-27 18:49:40 +0100183 self.uses_scalar = ps.primary_op.ifm2 is not None and (
184 ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == []
Tim Hall79d07d22020-04-27 18:20:16 +0100185 )
Tim Halld8339a72021-05-27 18:49:40 +0100186 self.ifm_ublock = arch.ifm_ublock
Tim Hall79d07d22020-04-27 18:20:16 +0100187
Jonas Ohlssond8575072022-03-30 10:30:25 +0200188 self.ifm = SchedulerTensor(
189 ps.ifm_shapes[0],
190 ps.ifm_tensor.dtype,
191 ps.ifm_tensor.mem_area,
192 ps.ifm_tensor.format,
193 )
Tim Hall79d07d22020-04-27 18:20:16 +0100194
Tim Halld8339a72021-05-27 18:49:40 +0100195 self.ifm2 = None
196 if ps.ifm2_tensor:
197 self.ifm2 = SchedulerTensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200198 ps.ifm_shapes[1],
199 ps.ifm2_tensor.dtype,
200 ps.ifm2_tensor.mem_area,
201 ps.ifm2_tensor.format,
Tim Halld8339a72021-05-27 18:49:40 +0100202 )
Tim Hall79d07d22020-04-27 18:20:16 +0100203
Jonas Ohlssond8575072022-03-30 10:30:25 +0200204 self.ofm = SchedulerTensor(
205 ps.ofm_shapes[0],
206 ps.ofm_tensor.dtype,
207 ps.ofm_tensor.mem_area,
208 ps.ofm_tensor.format,
209 )
Tim Hall79d07d22020-04-27 18:20:16 +0100210
Tim Halld8339a72021-05-27 18:49:40 +0100211 # Input volume width and height required to produce the smallest possible stripe
212 self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input()
Tim Hall79d07d22020-04-27 18:20:16 +0100213
Tim Halld8339a72021-05-27 18:49:40 +0100214 # Flags that marks whether this SchedulerOperation requires full IFM/OFM
215 self.requires_full_ifm = False
216 self.requires_full_ifm2 = False
217 self.requires_full_ofm = False
Tim Hall79d07d22020-04-27 18:20:16 +0100218
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200219 self.evicted_fms_size = 0
220
Tim Halld8339a72021-05-27 18:49:40 +0100221 self.index = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100222
Tim Halld8339a72021-05-27 18:49:40 +0100223 def add_ifm_connection(self, conn: "Connection"):
224 """Add input connection to another SchedulerOperation or Subgraph Input"""
225 conn.consumers.append(self)
226 self.ifm.connection = conn
Tim Hall79d07d22020-04-27 18:20:16 +0100227
Tim Halld8339a72021-05-27 18:49:40 +0100228 def add_ifm2_connection(self, conn: "Connection"):
229 """Add input connection to another SchedulerOperation or Subgraph Input"""
230 if self.ifm2:
231 conn.consumers.append(self)
232 self.ifm2.connection = conn
Tim Hall79d07d22020-04-27 18:20:16 +0100233 else:
Tim Halld8339a72021-05-27 18:49:40 +0100234 assert False, f"Trying to set an IFM2 Connection to {self} which has no IFM2"
Tim Hall79d07d22020-04-27 18:20:16 +0100235
Tim Halld8339a72021-05-27 18:49:40 +0100236 def add_ofm_connection(self, conn: "Connection"):
237 """Add output connection to another SchedulerOperation or Subgraph Output"""
238 conn.producers.append(self)
239 self.ofm.connection = conn
240
241 def get_dependants(self):
242 """Returns a list of the Ops that depend on this Operation's OFM"""
243 return self.ofm.connection.consumers
244
245 def ifm_size_in_bytes(self) -> int:
246 """Returns size of the IFM in bytes"""
247 ifm_storage_shape = shape_for_format(self.ifm.shape, self.ifm.format)
248 return round_up(ifm_storage_shape.elements() * self.ifm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
249
250 def ifm2_size_in_bytes(self) -> int:
251 """Returns size of the IFM2 in bytes"""
252 if self.ifm2:
253 ifm2_storage_shape = shape_for_format(self.ifm2.shape, self.ifm2.format)
254 return round_up(ifm2_storage_shape.elements() * self.ifm2.dtype.size_in_bytes(), Tensor.AllocationQuantum)
255
256 return 0
257
258 def ofm_size_in_bytes(self) -> int:
259 """Returns size of the OFM in bytes"""
260 ofm_storage_shape = shape_for_format(self.ofm.shape, self.ofm.format)
261 return round_up(ofm_storage_shape.elements() * self.ofm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
262
263 def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo:
264 """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce"""
265 ifm_shape = self.ifm.shape
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100266 ifm2_shape = self.ifm2.shape if self.ifm2 is not None else None
Tim Halld8339a72021-05-27 18:49:40 +0100267 ofm_shape = stripe
268
269 if ofm_shape != self.ofm.shape:
270 # Striped Op - Need to calculate stripe input volume
271 stripe_input_w, stripe_input_h = self._get_stripe_input_requirement(stripe)
272 # Ensure stripe input volume is within the full IFM volume
273 stripe_input_h = min(stripe_input_h, self.ifm.shape.height)
274 stripe_input_w = min(stripe_input_w, self.ifm.shape.width)
275 ifm_shape = ifm_shape.with_hw(stripe_input_h, stripe_input_w)
276
277 if self.ifm2:
278 stripe_input2_h = min(stripe_input_h, self.ifm2.shape.height)
279 stripe_input2_w = min(stripe_input_w, self.ifm2.shape.width)
280 ifm2_shape = ifm2_shape.with_hw(stripe_input2_h, stripe_input2_w)
281
282 block_config = self._get_block_config(ifm_shape, ifm2_shape, self.uses_scalar, ofm_shape)
283
284 scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape)
285 if self.parent_op.weights:
286 # Default full-depth weight encoding with no buffering
Tim Halld784af72021-06-08 21:25:57 +0100287 (
288 scheduler_op_info.npu_weights_tensor,
289 scheduler_op_info.npu_scales_tensor,
290 ) = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100291 self.arch,
292 self.parent_op,
293 self.parent_op.weights,
294 self.parent_op.bias,
295 self.kernel,
296 block_config,
297 [0, self.ofm.shape.depth],
298 )
299
300 self.parent_ps.block_config = block_config.old_style_representation()
301 return scheduler_op_info
302
303 def _get_stripe_input_requirement(self, stripe_shape: Shape4D) -> Tuple[int, int]:
304 """Returns the amount of IFM required to produce the stripe with shape:'stripe_shape'"""
305 ofm_shape_to_produce = Block.from_shape(stripe_shape.as_list())
306
Fredrik Svedberg3ff7a4a2021-09-29 10:08:04 +0200307 return get_ifm_area_required(ofm_shape_to_produce, self.kernel, self.resampling_mode)
Tim Halld8339a72021-05-27 18:49:40 +0100308
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100309 def _calculate_min_stripe_input(self) -> Tuple[int, int]:
Tim Halld8339a72021-05-27 18:49:40 +0100310 # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1)
311 min_stripe = self.ofm.shape.with_hw(1, 1)
312 return self._get_stripe_input_requirement(min_stripe)
313
314 def _get_block_config(
315 self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100316 ) -> Optional[ArchitectureBlockConfig]:
Tim Halld8339a72021-05-27 18:49:40 +0100317 # Returns a block config and SHRAM layout
318 lut_banks = 2 if self.parent_op.activation_lut else 0
319 return find_block_config(
320 self.arch,
321 self.op_type.npu_block_type,
322 ofm_shape,
323 ifm_shape,
324 ifm2_shape,
325 uses_scalar,
326 self.ifm.dtype.size_in_bits(),
327 self.kernel,
328 lut_banks,
329 self.parent_op.has_scaling(),
330 self.resampling_mode,
331 )
332
333
334class Connection:
335 """Scheduler internal representation of a Tensor that connects two SchedulerOperations
336 This class can be seen as an edge within the Scheduler Graph representation
337 """
338
339 def __init__(self, tensor: Tensor):
340 self.parent_tens = tensor
341
342 # SchedulerOperation relationships
343 self.producers: List[SchedulerOperation] = []
344 self.consumers: List[SchedulerOperation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100345
346 def __str__(self):
Tim Halld8339a72021-05-27 18:49:40 +0100347 return f"<Connection {self.parent_tens.name}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100348
349 __repr__ = __str__
350
351
Tim Halld8339a72021-05-27 18:49:40 +0100352class Schedule:
353 """Class that contains a solution of how to schedule an NPU subgraph and its cost"""
Tim Hall79d07d22020-04-27 18:20:16 +0100354
Tim Halld8339a72021-05-27 18:49:40 +0100355 def __init__(self, sg: Subgraph, label: str):
356 self.sg = sg
357 self.label = label
358 self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {}
359 self.cascades: Dict[int, CascadeInfo] = {}
360 self.fast_storage_peak_usage = 0
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100361 self.memory_snapshot: Optional[List[int]] = None
Tim Halld8339a72021-05-27 18:49:40 +0100362
363 @property
364 def name(self):
365 return f"{self.sg.name}_{self.label}"
Tim Hall79d07d22020-04-27 18:20:16 +0100366
367
Tim Halld8339a72021-05-27 18:49:40 +0100368class Scheduler:
369 """Main class of the Vela Scheduling"""
Tim Hall79d07d22020-04-27 18:20:16 +0100370
Tim Halld8339a72021-05-27 18:49:40 +0100371 def __init__(self, nng: Graph, sg: Subgraph, arch: ArchitectureFeatures, options: SchedulerOptions):
Tim Hall79d07d22020-04-27 18:20:16 +0100372 self.nng = nng
373 self.sg = sg
374 self.arch = arch
Ayaan Masoodb801dda2022-02-22 11:28:55 +0000375 self.sched_ops: List[SchedulerOperation] = []
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100376 self.max_schedule: Optional[Schedule] = None
Tim Halld8339a72021-05-27 18:49:40 +0100377 self.scheduler_options = options
Tim Hall79d07d22020-04-27 18:20:16 +0100378
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200379 self.scratched_fms: Dict[Tensor, Any] = {}
380 self.evicted_fms: List[live_range.LiveRange] = []
381
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100382 def avoid_nhcwb16_for_ofm(self, tens, ps, arch):
383 # Only run this check for opt strategy Size
384 if self.scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
385 return False
386
387 op = ps.primary_op
388 if not op.type.is_elementwise_op():
389 return False
390
391 depth = op.ofm_shapes[0][-1]
392 if (depth % 16) == 0:
393 return False
394
395 # Check if overwriting the inputs can be allowed
396 OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
397 outp = OpShapeTens(op.ofm_shapes[0], op.ofm)
398 inps = []
399 if op.ifm is not None:
400 inps.append(OpShapeTens(op.ifm_shapes[0], op.ifm))
401 if op.ifm2 is not None:
402 inps.append(OpShapeTens(op.ifm_shapes[1], op.ifm2))
403
404 # Find an input tensor that can be overwritten by the output
405 for inp in inps:
406 if (
407 # check op input and output shapes allow overlapping
408 inp.op_shape == outp.op_shape
409 # check input tensor is valid
410 and inp.tens is not None
411 and inp.tens.shape != []
412 # check input and output tensors are compatible
413 and inp.tens.format == outp.tens.format
414 and inp.tens.dtype == outp.tens.dtype
415 ):
416 if inp.tens.format == TensorFormat.NHWC:
417 return True
418
419 return False
420
Tim Halld8339a72021-05-27 18:49:40 +0100421 def create_scheduler_representation(self, arch: ArchitectureFeatures):
422 """Creates a Scheduler Graph representation"""
423 # Temporary dict for creating connections between the Operations
424 connections: Dict[Tensor, Connection] = {}
425 # Memory required for the largest FeatureMap that has to be full
426 min_memory_req = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100427 for ps in self.sg.passes:
Tim Halld8339a72021-05-27 18:49:40 +0100428 if ps.primary_op:
429 # Set tensor format to NHCWB16 for output FeatureMaps, if possible
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200430 for output in ps.outputs:
Jacob Bohlina5e8c1c2021-06-14 13:33:39 +0200431 if output in self.sg.output_tensors or output.purpose != TensorPurpose.FeatureMap:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200432 continue
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100433
434 if output.needs_linear_format:
435 continue
436
437 if self.avoid_nhcwb16_for_ofm(output, ps, arch):
438 output.needs_linear_format = True
439 continue
440
441 output.set_format(TensorFormat.NHCWB16, arch)
Tim Halld8339a72021-05-27 18:49:40 +0100442
443 # Create SchedulerOperations
444 op = SchedulerOperation(ps, arch, self.nng)
445 op.index = len(self.sched_ops)
446
447 # Make connections
448 if ps.ifm_tensor not in connections:
449 connections[ps.ifm_tensor] = Connection(ps.ifm_tensor)
450 if ps.ifm2_tensor and ps.ifm2_tensor not in connections:
451 connections[ps.ifm2_tensor] = Connection(ps.ifm2_tensor)
452 if ps.ofm_tensor not in connections:
453 connections[ps.ofm_tensor] = Connection(ps.ofm_tensor)
454
455 op.add_ifm_connection(connections[ps.ifm_tensor])
456 if ps.ifm2_tensor:
457 op.add_ifm2_connection(connections[ps.ifm2_tensor])
458 op.add_ofm_connection(connections[ps.ofm_tensor])
459
460 # Set requirements on the ifm/ofm buffers
461 self.sched_ops.append(op)
462 if ps.ifm_tensor in self.sg.input_tensors:
463 # This Op consumes a subgraph input
464 op.requires_full_ifm = True
465 if ps.ifm2_tensor and ps.ifm2_tensor in self.sg.input_tensors:
466 # This Op consumes a subgraph input
467 op.requires_full_ifm2 = True
468 if ps.ofm_tensor in self.sg.output_tensors:
469 # This Op produces a subgraph output
470 op.requires_full_ofm = True
471 if ps.ifm_tensor.needs_linear_format:
472 op.requires_full_ifm = True
473 if ps.ifm2_tensor and ps.ifm2_tensor.needs_linear_format:
474 op.requires_full_ifm2 = True
475 if ps.ofm_tensor.needs_linear_format or ps.primary_op.memory_function == Op.ConcatSliceWrite:
476 op.requires_full_ofm = True
477 if len(ps.primary_op.outputs) > 1 or len(ps.primary_op.outputs[0].consumer_list) > 1:
478 # Op has multiple outputs or consumers - requires full OFM
479 op.requires_full_ofm = True
480
481 # Check memory requirements if this Op requires any full FeatureMaps
482 op_memory_req = 0
483 if op.requires_full_ifm:
484 op_memory_req += op.ifm_size_in_bytes()
485 if op.requires_full_ifm2:
486 op_memory_req += op.ifm2_size_in_bytes()
487 if op.requires_full_ofm:
488 op_memory_req += op.ofm_size_in_bytes()
489
490 min_memory_req = max(op_memory_req, min_memory_req)
491
492 # Theoretical minimum required memory - used to guide the cascade building
493 self.min_memory_req = min_memory_req
494
495 def create_initial_schedule(self) -> Schedule:
496 """Creates an initial schedule with no cascading or buffering of any kind"""
497 schedule = Schedule(self.sg, "MAX")
Tim Halld8339a72021-05-27 18:49:40 +0100498 for op in self.sched_ops:
499 cost = op.create_scheduler_info(self.nng, op.ofm.shape)
500 cost.cycles = self.estimate_op_performance(op, cost.block_config, op.ofm.shape.depth)
501 schedule.cost_map[op] = cost
502
503 return schedule
504
505 def update_op_memory_snapshot(self, schedule: Schedule):
506 memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
507
508 # Collect live ranges from tensors
509 lr_graph = live_range.LiveRangeGraph()
510 for mem_area, mem_type_set in memories_list:
511 live_range.extract_live_ranges_from_cascaded_passes(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200512 self.nng.get_root_subgraph(),
513 mem_area,
514 mem_type_set,
515 lr_graph,
516 Tensor.AllocationQuantum,
Tim Halld8339a72021-05-27 18:49:40 +0100517 )
518
519 # Populate time-array with memory used by live ranges
520 temporal_usage = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
521 schedule.memory_snapshot = temporal_usage
522
523 # Set the peak memory usage
524 schedule.fast_storage_peak_usage = max(temporal_usage, default=0)
525
526 def estimate_op_performance(self, op: SchedulerOperation, block_config, ofm_depth):
527 query = npu_performance.PerformanceQuery(op.op_type.npu_block_type)
528 query.ifm_shape = op.ifm.shape
529 query.ifm_memory_area = op.ifm.mem_area
530 query.ifm_bits = op.ifm.dtype.size_in_bits()
531 query.ifm_format = op.ifm.format
532 query.ifm2_shape = op.ifm2 and op.ifm2.shape
533 query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
534 query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
535 query.ifm2_format = op.ifm2 and op.ifm2.format
536 query.ofm_shape = op.ofm.shape.with_depth(ofm_depth)
537 query.ofm_memory_area = op.ofm.mem_area
538 query.ofm_bits = op.ofm.dtype.size_in_bits()
539 query.ofm_format = op.ofm.format
540 if op.parent_op.bias:
541 query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
542 query.const_memory_area = self.arch.fast_storage_mem_area
543
544 query.kernel = op.kernel
545 query.config = block_config
546
547 return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query)
548
Tim Hall789e6f32021-06-17 17:02:31 +0100549 def propose_schedule_buffering(self, ref_schedule: Schedule, staging_limit_bytes):
Tim Halld8339a72021-05-27 18:49:40 +0100550 """Create a buffered schedule"""
551 buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED")
Tim Halld8339a72021-05-27 18:49:40 +0100552
553 prev_op = None
554 for sched_op in self.sched_ops:
555 if sched_op not in ref_schedule.cost_map:
556 # sched_op is not part of this sub-schedule - skip
557 continue
558
559 self.propose_operator_buffering(sched_op, prev_op, buffered_schedule, ref_schedule, staging_limit_bytes)
560 prev_op = sched_op
561
562 return buffered_schedule
563
564 def propose_operator_buffering(
565 self,
566 sched_op: SchedulerOperation,
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100567 prev_op: Optional[SchedulerOperation],
Tim Halld8339a72021-05-27 18:49:40 +0100568 buffered_schedule: Schedule,
569 ref_schedule: Schedule,
570 staging_limit_bytes,
571 ):
572 # Mild recursion might mean this Op has already been seen
573 if sched_op in buffered_schedule.cost_map:
574 return
575
576 # Take the reference schedule as default costings for this schedule
577 ref_cost = ref_schedule.cost_map[sched_op]
578 cost = copy.copy(ref_cost)
579 cost.slack_buffering_cycles = ref_cost.cycles.op_cycles
580 memory_snapshot = ref_schedule.memory_snapshot
581 ref_memory_usage = memory_snapshot[ref_cost.time_index] if ref_cost.time_index < len(memory_snapshot) else 0
582 cost.slack_buffering_memory = staging_limit_bytes - ref_memory_usage
583 buffered_schedule.cost_map[sched_op] = cost
584
585 # Attempt weight buffering on anything with a weights tensor
586 if sched_op.parent_op.weights:
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200587 buffer_limit_bytes = cost.slack_buffering_memory
588
589 # If applicable apply size limitation, but keep it within reason (ratio 1.5).
590 # Size limitation is used when use_fast_storage_for_feature_maps have
591 # detected that there are fms that do not fit in fast storage.
592 if sched_op.evicted_fms_size and ((buffer_limit_bytes / sched_op.evicted_fms_size) >= 1.5):
593 buffer_limit_bytes -= sched_op.evicted_fms_size
594
Tim Halld8339a72021-05-27 18:49:40 +0100595 self.propose_weight_buffering(
596 sched_op.parent_op.weights,
597 sched_op.parent_op.bias,
598 sched_op,
599 prev_op,
600 buffered_schedule,
601 ref_schedule,
Johan Alfvén6f4cb032022-05-05 08:42:46 +0200602 buffer_limit_bytes,
Tim Halld8339a72021-05-27 18:49:40 +0100603 )
604
605 return cost
606
607 def weights_needs_dma(self, weight_tensor):
608 if weight_tensor and weight_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
609 # Weights are in permanent storage
610 # Only when permanent storage differs from feature map storage, there is a point moving the data
611 if (
612 weight_tensor.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
613 and self.arch.permanent_storage_mem_area != self.arch.fast_storage_mem_area
614 ):
615 return True
616 return False
617
618 def propose_weight_buffering(
619 self,
620 weight_tensor,
621 scale_tensor,
622 sched_op: SchedulerOperation,
623 prev_op: SchedulerOperation,
624 buffered_schedule: Schedule,
625 ref_schedule: Schedule,
626 buffer_limit_bytes,
627 ):
628 cost = buffered_schedule.cost_map[sched_op]
629 prev_cost = buffered_schedule.cost_map.get(prev_op)
630 ref_cost = ref_schedule.cost_map[sched_op]
631 assert cost and ref_cost
632
633 needs_dma = self.weights_needs_dma(weight_tensor)
634
635 ofm_full_depth_slices = [0, ref_cost.stripe.depth]
636
637 # Encode weights for the full depth
Tim Halld784af72021-06-08 21:25:57 +0100638 full_weights, full_scales = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100639 self.arch,
640 sched_op.parent_op,
641 weight_tensor,
642 scale_tensor,
643 sched_op.kernel,
644 cost.block_config,
645 ofm_full_depth_slices,
646 )
647 full_weights_bytes = len(full_weights.buffer)
648 cost.ofm_depth_slices = ofm_full_depth_slices
649
650 # No buffering required - take all the weights from permanent storage
651 if sched_op.op_type == Op.FullyConnected or not needs_dma:
652 cost.npu_weights_tensor = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100653 cost.npu_scales_tensor = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100654 return
655
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100656 encoded_weights: Optional[NpuWeightTensor] = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100657 encoded_scales = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100658
659 # How many NPU cycles are available under the previously executing
660 # operator and SRAM unused for performing buffered DMA transfers
661 slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0
662 slack_memory = prev_cost.slack_buffering_memory if prev_cost else 0
663
664 # Force full depth for cascaded Ops
665 if ref_cost.cascade != 0:
666 weight_tensor_purpose = TensorSubPurpose.Standard
667 weight_buffer_size = full_weights_bytes
668 # Update the memory snapshot to reflect the added size of the weights
669 ref_schedule.memory_snapshot[ref_cost.time_index] += weight_buffer_size
670 else:
671 # Estimate the buffering cycle time for the full set of weights
672 full_transfer_cycles = npu_performance.measure_mem2mem_cycles(
673 self.arch, weight_tensor.mem_area, self.arch.fast_storage_mem_area, full_weights_bytes
674 )
675 cost.full_weight_transfer_cycles = full_transfer_cycles
676
677 # Calculate the amount of prebuffering necessary (or what is possible with limited
678 # double buffer buffer size)
679 half_buffer_limit = buffer_limit_bytes // 2
680 if full_transfer_cycles > slack_cycles:
681 prebuffer_ratio = slack_cycles / full_transfer_cycles
682 prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit)
683 else:
684 prebuffer_bytes = min(full_weights_bytes, half_buffer_limit)
Tim Hall789e6f32021-06-17 17:02:31 +0100685
686 prebuffer_ratio = prebuffer_bytes / full_weights_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100687
688 # Have to split the weights if the initial buffering can't store
689 # all of the compressed weights
690 if prebuffer_bytes < full_weights_bytes:
Tim Hall789e6f32021-06-17 17:02:31 +0100691 block_depth = cost.block_config.ofm_block.depth
Tim Halld8339a72021-05-27 18:49:40 +0100692
Tim Hall789e6f32021-06-17 17:02:31 +0100693 # Choose initial prebuffering depth (already buffer clamped)
694 prebuffer_depth = ref_cost.stripe.depth * prebuffer_ratio
Tim Halld8339a72021-05-27 18:49:40 +0100695 prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
696
Tim Hall789e6f32021-06-17 17:02:31 +0100697 # Calculate cycles executed during the prebuffer
698 pre_op_cycles = self.estimate_op_performance(sched_op, cost.block_config, prebuffer_depth)
699 buffering_depth = ref_cost.stripe.depth * (pre_op_cycles.op_cycles / full_transfer_cycles)
Tim Halld8339a72021-05-27 18:49:40 +0100700
Tim Hall789e6f32021-06-17 17:02:31 +0100701 # Choose initial buffering depth and clamp to the double buffering limit
702 buffering_depth = round_up(buffering_depth, block_depth)
703 buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
704 if buffering_bytes > half_buffer_limit:
705 buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
706
707 while True:
708 # Attempt to buffer whole blocks
Johan Alfvéncce7f2d2022-04-08 10:47:09 +0200709 if buffering_depth > block_depth:
Tim Hall789e6f32021-06-17 17:02:31 +0100710 buffering_depth = round_down(buffering_depth, block_depth)
711 else:
712 buffering_depth = round_down(buffering_depth, ArchitectureFeatures.OFMSplitDepth)
713 buffering_depth = int(max(buffering_depth, ArchitectureFeatures.OFMSplitDepth))
Tim Halld8339a72021-05-27 18:49:40 +0100714
715 # Create list of depth slices
716 depth_slices = [0]
717 if prebuffer_depth < ref_cost.stripe.depth:
718 depth_slices += list(range(prebuffer_depth, ref_cost.stripe.depth, buffering_depth))
719 depth_slices.append(ref_cost.stripe.depth)
720
721 # Encode weights based depth slices
722 cost.ofm_depth_slices = depth_slices
Tim Halld784af72021-06-08 21:25:57 +0100723 encoded_weights, encoded_scales = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100724 self.arch,
725 sched_op.parent_op,
726 weight_tensor,
727 scale_tensor,
728 sched_op.kernel,
729 cost.block_config,
730 cost.ofm_depth_slices,
731 )
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100732 assert encoded_weights is not None
Tim Halld8339a72021-05-27 18:49:40 +0100733 # Chosen buffering might not fit at all, iterate until it does
734 # or until the minimum usable slice size is reached
735 if (
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000736 encoded_weights.double_buffer_size() <= buffer_limit_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100737 or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
738 ):
739 break
740
Tim Hall789e6f32021-06-17 17:02:31 +0100741 if buffering_depth > prebuffer_depth:
742 buffering_depth = round_up(buffering_depth // 2, ArchitectureFeatures.OFMSplitDepth)
743 else:
744 prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
Tim Halld8339a72021-05-27 18:49:40 +0100745
746 # Calculate cycles required to run the last op for use as future slack
747 tail_cycles = self.estimate_op_performance(
748 sched_op, cost.block_config, depth_slices[-1] - depth_slices[-2]
749 )
750 cost.slack_buffering_cycles = tail_cycles.op_cycles
751
752 # Determine whether the weights need to be double buffered
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000753 weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes())
Tim Halld8339a72021-05-27 18:49:40 +0100754
755 # Only buffer weights if there's still space left for the buffer
756 if weight_buffer_size <= buffer_limit_bytes:
757 assert weight_buffer_size % 16 == 0
758 # Determine whether to double buffer or single buffer
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000759 double_buffer_size = encoded_weights.double_buffer_size()
760 if (double_buffer_size <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
Tim Halld8339a72021-05-27 18:49:40 +0100761 weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
762 else:
763 weight_tensor_purpose = TensorSubPurpose.Standard
764
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000765 cost.buffered_weight_tensors = [
766 self.buffer_tensor(
767 encoded_weights,
768 weight_tensor_purpose,
769 encoded_weights.double_buffer_sizes[0],
770 weight_tensor.name + "_buffer",
771 )
772 ]
773 if weight_tensor_purpose == TensorSubPurpose.DoubleBuffer:
774 buf2 = self.buffer_tensor(
775 encoded_weights,
776 weight_tensor_purpose,
777 encoded_weights.double_buffer_sizes[1],
778 weight_tensor.name + "_buffer2",
779 )
780 cost.buffered_weight_tensors.append(buf2)
781
782 last_used_buffer_idx = len(cost.ofm_depth_slices) % len(cost.buffered_weight_tensors)
783 weight_buffer_size = encoded_weights.double_buffer_sizes[last_used_buffer_idx]
784
Tim Halld8339a72021-05-27 18:49:40 +0100785 if ref_cost.cascade == 0:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000786 # Determine if the lifetime can be extended and pre-buffer the first weight buffer
787 # under the previous operation
788 cost.buffered_weight_tensors[0].pre_buffer = encoded_weights.double_buffer_size() < slack_memory
Tim Halld8339a72021-05-27 18:49:40 +0100789
790 cost.slack_buffering_memory -= weight_buffer_size
791 else:
792 # Don't slice or buffer - use the whole depth from persistent storage
793 cost.ofm_depth_slices = ofm_full_depth_slices
794 encoded_weights = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100795 encoded_scales = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100796
797 cost.npu_weights_tensor = encoded_weights
Tim Halld784af72021-06-08 21:25:57 +0100798 cost.npu_scales_tensor = encoded_scales
Tim Halld8339a72021-05-27 18:49:40 +0100799
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200800 def buffer_tensor(self, src_tensor: Tensor, sub_purpose: TensorSubPurpose, buffer_size: int, name: str) -> Tensor:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000801 buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name)
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200802 buffered_weight_tensor.src_tensor = src_tensor
803 buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
804 buffered_weight_tensor.mem_type = MemType.Scratch_fast
805 buffered_weight_tensor.purpose = TensorPurpose.Weights
806 buffered_weight_tensor.sub_purpose = sub_purpose
807 return buffered_weight_tensor
808
Tim Halld8339a72021-05-27 18:49:40 +0100809 def propose_minimal_schedule(self) -> Schedule:
810 """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the
811 next operators stride"""
812 min_schedule = Schedule(self.sg, "MIN")
813 cost_map = min_schedule.cost_map
814
815 # Keep track of the previous Op - which consumes the current Op's OFM
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100816 prev_op: Optional[SchedulerOperation] = None
Tim Halld8339a72021-05-27 18:49:40 +0100817 for sched_op in reversed(self.sched_ops):
818 min_stripe_height = prev_op.kernel.stride.y if prev_op else 1
819 min_stripe = sched_op.ofm.shape.with_height(min_stripe_height)
820
821 cost = sched_op.create_scheduler_info(self.nng, min_stripe)
822 cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
823 cost_map[sched_op] = cost
824
825 prev_op = sched_op
826
827 return min_schedule
828
829 def propose_schedule_striping(self, final_stripe: Shape4D, label: str, ref_schedule: Schedule) -> Schedule:
830 """Proposes new striping for a schedule. The stripe is derived from the ifm requirements of the next Op down"""
831 ref_cost = ref_schedule.cost_map
832
833 striped_schedule = Schedule(self.sg, label)
834 stripe = final_stripe
835 for sched_op in reversed(self.sched_ops):
836 if sched_op not in ref_cost:
837 # sched_op is not part of the sub-schedule - skip
838 continue
839
840 # Create a cost entry with the new stripe
841 cost = sched_op.create_scheduler_info(self.nng, stripe)
842
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000843 weight_tensor = cost.npu_weights_tensor
844 for idx, buffered_tens in enumerate(ref_cost[sched_op].buffered_weight_tensors):
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200845 # If the weights are buffered in the reference schedule they should be in the new proposal
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000846 cost.buffered_weight_tensors.append(
847 self.buffer_tensor(
848 weight_tensor,
849 buffered_tens.sub_purpose,
850 weight_tensor.double_buffer_sizes[idx],
851 buffered_tens.name,
852 )
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200853 )
Tim Halld8339a72021-05-27 18:49:40 +0100854
855 # Estimate performance
856 cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
857 striped_schedule.cost_map[sched_op] = cost
858
859 # Calculate the preceeding Op's stripe
860 stripe = sched_op.ifm.shape.with_height(stripe.height * sched_op.kernel.stride.y)
861
862 return striped_schedule
863
864 def estimate_schedule_memory_usage(self, schedule: Schedule, non_local_mem_usage: dict):
865 """Estimates the memory usage of a schedule"""
866 cost = schedule.cost_map
867 cascades = schedule.cascades
868 peak_mem_usage = 0
869 for sched_op in self.sched_ops:
870 if sched_op not in cost:
871 # sched_op is not part of the sub-schedule - skip
872 continue
873
874 if cost[sched_op].cascade:
875 # This Op is part of a cascade - use the cascade's memory usage
876 cascade_info = cascades[cost[sched_op].cascade]
877 # Non-local memory usage is already included in the cascade_info
878 peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
879 else:
880 # This Op is not part of a cascade - calculate the memory usage
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000881 op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100882
883 op_mem_usage = (
884 sched_op.ifm_size_in_bytes()
885 + sched_op.ofm_size_in_bytes()
886 + op_weight_buffer
887 + non_local_mem_usage.get(sched_op, 0)
888 )
889 peak_mem_usage = max(op_mem_usage, peak_mem_usage)
890
891 return peak_mem_usage
892
893 def optimize_sub_schedule(
894 self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int
895 ) -> Schedule:
896 """Extracts the Ops covered by the given cascade and creates a sub-schedule. The sub-schedule is optimized by
897 proposing weight buffering and then continously proposing new stripe sizes"""
898 ref_cost = ref_schedule.cost_map
899 # Extract the ops that are part of this sub-schedule
900 start = cascade_info.start
901 end = cascade_info.end
902 sub_schedule_ops = self.sched_ops[start : end + 1]
903 # Create a sub-schedule that contains only the costs for the Ops that are part of the sub-schedule
904 sub_schedule = Schedule(self.sg, f"SUB_{start}_{end}")
905 for sched_op in sub_schedule_ops:
906 sub_schedule.cost_map[sched_op] = ref_cost[sched_op]
907
908 sub_schedule.cascades[end] = cascade_info
909 # Use the memory snapshot from the reference schedule
910 sub_schedule.memory_snapshot = ref_schedule.memory_snapshot
911
912 # Calculate memory usage that is live during the sub-schedule but not part of it
913 time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
914 mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
915 # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
916 # included in a cascade or not
917 persistent_initial_ifm = (
918 sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0
919 )
920 # Calculate non-local-mem-usage per Operator
921 non_local_mem_usage = {}
922 for idx, sched_op in enumerate(sub_schedule_ops):
923 non_local_mem_usage[sched_op] = mem_usage_parallel_to_sub_schedule
924 if idx != 0:
925 non_local_mem_usage[sched_op] += persistent_initial_ifm
926
927 cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
928
929 # Start by adding buffering
Tim Hall789e6f32021-06-17 17:02:31 +0100930 buffered_sub_schedule = self.propose_schedule_buffering(
931 sub_schedule, self.scheduler_options.optimization_sram_limit
932 )
Tim Halld8339a72021-05-27 18:49:40 +0100933 # Copy the cascades over from the unbuffered-schedule
934 buffered_sub_schedule.cascades = sub_schedule.cascades
935
936 # Generate the possible stripings for the final Op in the sub-schedule
937 final_ofm_shape = sub_schedule_ops[-1].ofm.shape
938 possible_stripes = [
939 final_ofm_shape.with_height(stripe_h) for stripe_h in range(1, final_ofm_shape.height // 2 + 1)
940 ]
941
942 # Propose different striping - the possible stripes are proposed similarly to a binary search
Jacob Bohlinfad72042021-08-24 21:51:41 +0200943 best_schedule = None
Tim Halld8339a72021-05-27 18:49:40 +0100944 iteration = 0
945 while len(possible_stripes) > 1:
946 proposed_stripe = possible_stripes[len(possible_stripes) // 2]
947 proposed_schedule = self.propose_schedule_striping(
948 proposed_stripe, f"OPTIMIZED_{iteration}", buffered_sub_schedule
949 )
950
951 cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit)
952
953 # Check if proposal fits
954 proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
955 if (proposed_schedule_mem_usage) <= memory_limit:
956 # Remove all possible stripes smaller than this
957 possible_stripes = possible_stripes[len(possible_stripes) // 2 :]
958 best_schedule = proposed_schedule
959 if not proposed_schedule.cascades:
960 # No cascading required - early exit
961 break
962 else:
963 # Proposal doesn't fit within the limit - remove all possible stripes larger than this
964 possible_stripes = possible_stripes[: len(possible_stripes) // 2]
965
966 iteration += 1
967
968 return best_schedule
969
970 def optimize_schedule(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200971 self,
972 schedule: Schedule,
973 max_sched: Schedule,
974 max_template: Schedule,
975 options: SchedulerOptions,
Tim Halld8339a72021-05-27 18:49:40 +0100976 ) -> Schedule:
977 """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule"""
978 sram_limit = options.optimization_sram_limit
979 if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled():
980 # Maximum performance schedule fits within the SRAM target
981 return max_sched
982
Jacob Bohlinfad72042021-08-24 21:51:41 +0200983 # Iterate over a copy of the cascades since they may change during the loop
984 for cascade_info in list(schedule.cascades.values()):
Tim Halld8339a72021-05-27 18:49:40 +0100985 # Optimize the sub-schedule in this cascade
986 opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
Jacob Bohlinfad72042021-08-24 21:51:41 +0200987 if opt_sub_schedule:
988 # Remove the existing cascade
989 del schedule.cascades[cascade_info.end]
990 # Update the sub-schedule Op and cascade costs to the full schedule
991 schedule.cost_map.update(opt_sub_schedule.cost_map)
992 schedule.cascades.update(opt_sub_schedule.cascades)
Tim Halld8339a72021-05-27 18:49:40 +0100993
994 # Update memory snapshot
995 self.sg.schedule = schedule
996 self.update_op_memory_snapshot(schedule)
997 # Propose schedule buffering to the optimized schedule
Tim Hall789e6f32021-06-17 17:02:31 +0100998 optimized_sched = self.propose_schedule_buffering(schedule, self.scheduler_options.optimization_sram_limit)
Tim Halld8339a72021-05-27 18:49:40 +0100999 # Copy the cascade's metadata from the unbuffered schedule
1000 optimized_sched.cascades = schedule.cascades
1001 return optimized_sched
1002
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001003 def optimize_weight_buffering_size(
1004 self,
1005 min_schedule: Schedule,
1006 options: SchedulerOptions,
1007 ):
1008 default_schedule = self.sg.schedule
Tim Hallc1be0872022-03-03 17:50:52 +00001009 npu_performance.calc_new_performance_for_network(self.nng, self.arch, None, False)
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001010 default_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
1011 default_dram_cycles = self.nng.cycles[npu_performance.PassCycles.DramAccess]
1012
1013 # Restore mem/type for scratched_fms
1014 for tens in self.scratched_fms:
1015 tens.mem_area = self.scratched_fms[tens][0]
1016 tens.mem_type = self.scratched_fms[tens][1]
1017
1018 self.update_op_memory_snapshot(self.sg.schedule)
1019
1020 # Collect live ranges from tensors
1021 memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
1022 lr_graph = live_range.LiveRangeGraph()
1023 for mem_area, mem_type_set in memories_list:
1024 live_range.extract_live_ranges_from_cascaded_passes(
1025 self.nng.get_root_subgraph(),
1026 mem_area,
1027 mem_type_set,
1028 lr_graph,
1029 Tensor.AllocationQuantum,
1030 )
1031
1032 # Find the relation between the sched_op and the buffering tensor
1033 weight_ops = {}
1034 for sched_op in self.sched_ops:
1035 cost = self.sg.schedule.cost_map[sched_op]
Rickard Bolinfd8b5002022-05-16 09:11:06 +00001036 for tens in cost.buffered_weight_tensors:
1037 weight_ops[tens] = sched_op
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001038
1039 # Filter out weight buffer live ranges
1040 weight_lrs = []
1041 for lr in lr_graph.lrs:
1042 for tens in lr.tensors:
1043 if weight_ops.get(tens):
1044 weight_lrs.append(lr)
1045 break
1046
1047 # See if any evicted fm overlaps with a weight buffering op.
1048 # If this is the case add a size limitation to the buffering op
1049 for lr in self.evicted_fms:
1050 for weight_lr in weight_lrs:
1051 if lr.overlaps_ranges(weight_lr):
1052 for tens in weight_lr.tensors:
1053 sched_op = weight_ops.get(tens)
1054 if sched_op:
1055 # Add size reduction to the op
1056 sched_op.evicted_fms_size += lr.size
1057 break
1058
1059 self.sg.schedule = min_schedule
1060 self.update_op_memory_snapshot(self.sg.schedule)
1061
1062 # Run schedule buffering - with weight buffer size reduction
1063 schedule = self.propose_schedule_buffering(self.sg.schedule, options.optimization_sram_limit)
1064 schedule.cascades = self.sg.schedule.cascades
1065 self.sg.schedule = schedule
1066
1067 # Apply new buffer schdule and calc new performance
1068 self.update_op_memory_snapshot(self.sg.schedule)
1069 self.apply_schedule(self.sg.schedule)
1070 self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
1071
Tim Hallc1be0872022-03-03 17:50:52 +00001072 npu_performance.calc_new_performance_for_network(self.nng, self.arch, None, False)
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001073 new_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
1074 new_dram_cycles = self.nng.cycles[npu_performance.PassCycles.DramAccess]
1075
Tim Hall8bc7a652022-05-19 15:29:23 +01001076 improvement_tot = (
1077 round((default_tot_cycles - new_tot_cycles) / default_tot_cycles, 2) if default_tot_cycles != 0 else 0
1078 )
1079 improvement_dram = (
1080 round((default_dram_cycles - new_dram_cycles) / default_dram_cycles, 2) if default_dram_cycles != 0 else 0
1081 )
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001082
1083 # Compare both total and dram improvement
Johan Alfvén3dae1b62022-05-17 10:26:48 +02001084 if not (improvement_tot >= 0 and improvement_dram > 0):
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001085 # No improvement, restore the default schedule
1086 for sched_op in self.sched_ops:
1087 sched_op.evicted_fms_size = 0
1088
1089 for tens in self.scratched_fms:
1090 tens.mem_area = self.scratched_fms[tens][0]
1091 tens.mem_type = self.scratched_fms[tens][1]
1092
1093 self.sg.schedule = default_schedule
1094 self.update_op_memory_snapshot(self.sg.schedule)
1095 self.apply_schedule(self.sg.schedule)
1096 self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
1097
Tim Halld8339a72021-05-27 18:49:40 +01001098 def apply_schedule(self, sched: Schedule):
1099 """Applies the given schedule as a final solution"""
1100 for sched_op in self.sched_ops:
1101 op_info = sched.cost_map[sched_op]
1102 cascade_info = sched.cascades.get(op_info.cascade, None)
1103 if cascade_info and sched_op in cascade_info.buffers:
1104 buffer_tens = sched_op.ifm.connection.parent_tens
1105 # Apply memory area and type
1106 buffer_tens.mem_area = self.arch.fast_storage_mem_area
1107 buffer_tens.mem_type = MemType.Scratch_fast
1108 # Apply Rolling buffer
1109 buffer_tens.set_format(TensorFormat.NHCWB16, self.arch)
1110 buffer_tens.set_new_sub_purpose(TensorSubPurpose.RollingBufferY, cascade_info.buffers[sched_op].height)
1111
1112 sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
1113
1114 # Ensure that the src_tensor reference is set correctly
Rickard Bolinfd8b5002022-05-16 09:11:06 +00001115 for tens in op_info.buffered_weight_tensors:
1116 tens.src_tensor = op_info.npu_weights_tensor
Tim Halld8339a72021-05-27 18:49:40 +01001117
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001118 def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001119 max_mem_usage = []
1120 base_mem_usage = []
1121 fast_storage_type = MemType.Scratch_fast
1122 fast_storage_mem_area = self.arch.fast_storage_mem_area
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001123 self.evicted_fms = []
Tim Halld8339a72021-05-27 18:49:40 +01001124
1125 # Force all OFMs to fast-storage
1126 for sched_op in self.sched_ops:
1127 cost = schedule.cost_map[sched_op]
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001128 if cost.cascade == 0 and sched_op.get_dependants():
1129 ofm_tens = sched_op.ofm.connection.parent_tens
1130 if not any(cons is None for cons in ofm_tens.consumer_list):
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001131 if ofm_tens not in self.scratched_fms:
1132 # Remember default mem area and mem type, only done once
1133 self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
1134
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001135 ofm_tens.mem_area = fast_storage_mem_area
1136 ofm_tens.mem_type = fast_storage_type
Tim Halld8339a72021-05-27 18:49:40 +01001137
1138 # Collect live ranges from tensors
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001139 memories_list = [(fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
Tim Halld8339a72021-05-27 18:49:40 +01001140 lr_graph = live_range.LiveRangeGraph()
1141 for mem_area, mem_type_set in memories_list:
1142 live_range.extract_live_ranges_from_cascaded_passes(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001143 self.nng.get_root_subgraph(),
1144 mem_area,
1145 mem_type_set,
1146 lr_graph,
1147 Tensor.AllocationQuantum,
Tim Halld8339a72021-05-27 18:49:40 +01001148 )
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001149 max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
Tim Halld8339a72021-05-27 18:49:40 +01001150
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001151 # If true, everything fits and we can proceed
1152 if max(max_mem_usage) <= staging_limit:
1153 return
1154
1155 # Build up the base memory usage by removing the
1156 # mem_usage of the lrs we previously moved to fast-storage
1157 base_mem_usage = np.array(max_mem_usage)
1158 curr_lrs = []
Tim Halld8339a72021-05-27 18:49:40 +01001159 for lr in lr_graph.lrs:
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001160 for tens in lr.tensors:
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001161 if self.scratched_fms.get(tens):
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001162 curr_lrs.append(lr)
1163 base_mem_usage[lr.start_time : lr.end_time + 1] -= lr.size
1164 break
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001165 competing_lrs = []
1166 for lr in curr_lrs:
1167 base_usage = max(base_mem_usage[lr.start_time : lr.end_time + 1])
1168 # If true, the lr will never fit and may thus be evicted
1169 if base_usage + lr.size > staging_limit:
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001170 self.evicted_fms.append(lr)
1171 FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001172 continue
1173 # Since max_mem_usage is the memory usage with all FMs still in fast-storage,
1174 # the memory limit cannot be exceeded if max_mem_usage does not.
1175 # Thus, the affected lrs can remain in fast-storage if the following is true
1176 if max(max_mem_usage[lr.start_time : lr.end_time + 1]) <= staging_limit:
1177 FastStorageComponentAllocator.keep(lr, base_mem_usage, staging_limit)
1178 else:
1179 competing_lrs.append(lr)
1180 sz = len(competing_lrs)
1181 # All lrs and their tensors have been handled if sz is zero, we may thus return
1182 if sz == 0:
1183 return
1184
1185 competing_lrs = sorted(competing_lrs, key=lambda lr: (lr.start_time, lr.end_time + 1, lr.size))
1186 start = 0
1187 start_time = competing_lrs[0].start_time
1188 end_time = competing_lrs[0].end_time
1189 component_allocator = FastStorageComponentAllocator(base_mem_usage, max_mem_usage, staging_limit)
1190 # Build up components and then allocate each separately
1191 for i, lr in enumerate(competing_lrs):
1192 if lr.start_time <= end_time and i - start < component_allocator.max_exhaustive_size:
1193 start_time = min(start_time, lr.start_time)
1194 end_time = max(end_time, lr.end_time)
1195 else:
1196 component_allocator.allocate_component(
1197 component_allocator,
1198 competing_lrs[start:i],
1199 max_mem_usage,
1200 base_mem_usage,
1201 staging_limit,
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001202 self.scratched_fms,
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001203 )
1204 start = i
1205 start_time = lr.start_time
1206 end_time = lr.end_time
1207 component_allocator.allocate_component(
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001208 component_allocator,
1209 competing_lrs[start:sz],
1210 max_mem_usage,
1211 base_mem_usage,
1212 staging_limit,
1213 self.scratched_fms,
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001214 )
Tim Halld8339a72021-05-27 18:49:40 +01001215
1216 def move_constant_data(self):
1217 """Determine if data, can be moved from permanent storage to another memory area. A move
1218 will generate a DMA command in the high-level command stream"""
1219 for sched_op in self.sched_ops:
1220 parent_op = sched_op.parent_op
1221 is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in parent_op.inputs)
1222 max_ifm_shram_avail = (
1223 (self.arch.available_shram_banks(is_lut_used) - self.arch.shram_reserved_output_banks)
1224 * self.arch.shram_bank_size
1225 // 2
1226 )
1227
1228 for idx, tens in enumerate(parent_op.inputs):
1229 if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
1230 # Tensor is in permanent storage
1231 # Only when permanent storage differs from feature map storage, there is a point moving the data
1232 if (
1233 tens.mem_area in self.arch.permanent_storage_mem_area
1234 and self.arch.permanent_storage_mem_area != self.arch.feature_map_storage_mem_area
1235 ) or tens.purpose == TensorPurpose.LUT:
1236 if tens.purpose == TensorPurpose.LUT or (
Patrik Gustavsson94292fe2021-09-02 08:22:58 +02001237 # For elementwise broadcast
Tim Halld8339a72021-05-27 18:49:40 +01001238 tens.purpose == TensorPurpose.FeatureMap
1239 and sched_op.op_type.is_binary_elementwise_op()
1240 and tens.shape != []
1241 and sched_op.ifm.shape != sched_op.ofm.shape
Patrik Gustavsson94292fe2021-09-02 08:22:58 +02001242 and parent_op.write_shape is None
Tim Halld8339a72021-05-27 18:49:40 +01001243 and tens.storage_size() > max_ifm_shram_avail
1244 ):
1245 only_vector_product_consumers = all(
1246 oper and oper.type.npu_block_type == NpuBlockType.VectorProduct
1247 for oper in tens.consumers()
1248 )
1249
1250 if (not only_vector_product_consumers) or tens.purpose == TensorPurpose.LUT:
1251 new_tens = tens.clone_into_fast_storage(self.arch)
1252 if tens.purpose == TensorPurpose.LUT:
1253 new_tens.mem_area = MemArea.Shram
1254
1255 new_tens.consumer_list.append(parent_op)
1256 parent_op.inputs[idx] = new_tens
Dwight Lidman352607c2021-09-29 17:00:09 +02001257 # If the index is out of range, IFM and IFM2 are the same tensor
1258 # and pass inputs don't have duplicates
1259 if idx < len(sched_op.parent_ps.inputs):
1260 sched_op.parent_ps.inputs[idx] = new_tens
Tim Halld8339a72021-05-27 18:49:40 +01001261
1262 def print_schedule(self, schedule: Schedule):
1263 print(f"Schedule: '{schedule.name}'")
1264 for sched_op in self.sched_ops:
1265 if sched_op not in schedule.cost_map:
1266 # Sub-schedule printing
1267 continue
1268
1269 op_info = schedule.cost_map[sched_op]
1270 print(f"\t{sched_op.index}: Operation {sched_op.name} - OFM {sched_op.ofm.shape}")
1271 print(f"\t\tType: {sched_op.op_type}")
1272 print(f"\t\tKernel: {sched_op.kernel}")
1273 print(f"{op_info}")
1274 mem_usage = (
1275 schedule.memory_snapshot[op_info.time_index]
1276 if op_info.time_index < len(schedule.memory_snapshot)
1277 else 0
1278 )
1279 print(f"\t\tSRAM Used: {mem_usage} bytes")
1280
Jonas Ohlsson25e700c2022-03-04 14:58:56 +01001281 print("\tCascades:")
Tim Halld8339a72021-05-27 18:49:40 +01001282 for i, cascade in enumerate(schedule.cascades.values()):
1283 print(f"\t\t{i}: {cascade.start} -> {cascade.end}, size: {cascade.mem_usage}")
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001284
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001285
Tim Halld8339a72021-05-27 18:49:40 +01001286def _update_tensor_allocation(nng: Graph, arch: ArchitectureFeatures, options):
1287 """
1288 Creates live ranges and runs tensor allocator for the current schedule
1289 (i.e. sg.schedule for all subgraphs), returns the maximum memory usage
1290 and updates SchedulerOpInfo.mem_usage for all operations in the schedule.
1291 """
1292 root_sg = nng.get_root_subgraph()
1293
1294 alloc_list = []
1295 if arch.is_spilling_enabled():
1296 mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
1297 mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
1298 # Order is important
1299 alloc_list.append(mem_alloc_scratch_fast)
1300 alloc_list.append(mem_alloc_scratch)
1301 else:
1302 mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
1303 alloc_list.append(mem_alloc_scratch)
1304
1305 for mem_area, mem_type_set in alloc_list:
1306 tensor_allocation.allocate_tensors(
1307 nng,
1308 root_sg,
1309 arch,
1310 mem_area,
1311 mem_type_set,
1312 tensor_allocator=options.tensor_allocator,
1313 verbose_allocation=options.verbose_allocation,
1314 cpu_tensor_alignment=options.cpu_tensor_alignment,
1315 )
1316
1317
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001318class FastStorageComponentAllocator:
1319 def __init__(self, base_mem_usage, max_mem_usage, staging_limit):
1320 self.base_mem_usage = base_mem_usage
1321 self.max_mem_usage = list(max_mem_usage)
1322 self.staging_limit = staging_limit
1323 self.lrs = []
1324 self.evicted = []
1325 self.curr_evicted = []
1326 self.remaining_total_size = []
1327 self.best_allocated_size = 0
1328 self.max_exhaustive_size = 20
1329
1330 def allocate_exhaustive(self, ix, alloc_size):
1331 if ix >= len(self.lrs):
1332 if alloc_size > self.best_allocated_size:
1333 self.best_allocated_size = alloc_size
Louis Verhaard5c8f1e52022-02-23 14:13:07 +01001334 self.evicted = self.curr_evicted.copy()
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001335 return
1336
1337 lr = self.lrs[ix]
1338 for t in range(lr.start_time, lr.end_time):
1339 assert self.base_mem_usage[t] <= self.max_mem_usage[t]
1340 base_usage = max(self.base_mem_usage[lr.start_time : lr.end_time + 1])
1341 can_fit = base_usage + lr.size <= self.staging_limit
1342 always_fits = can_fit
1343
1344 if can_fit:
1345 max_usage = max(self.max_mem_usage[lr.start_time : lr.end_time + 1])
1346 always_fits = max_usage <= self.staging_limit
1347
1348 if can_fit or always_fits:
1349 self.curr_evicted[ix] = False
1350 self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, True)
1351 self.allocate_exhaustive(ix + 1, alloc_size + lr.size)
1352 self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, False)
1353
1354 if not always_fits:
1355 self.curr_evicted[ix] = True
1356 self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, False)
1357 self.allocate_exhaustive(ix + 1, alloc_size)
1358 self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, True)
1359
1360 @staticmethod
1361 def update_mem_usage(mem_usage, lr, increase):
1362 for t in range(lr.start_time, lr.end_time + 1):
1363 mem_usage[t] += lr.size if increase else -lr.size
1364 assert mem_usage[t] >= 0
1365 return mem_usage
1366
1367 @staticmethod
1368 def evict(lr, max_mem_usage, scratched_fms):
1369 for t in range(lr.start_time, lr.end_time + 1):
1370 max_mem_usage[t] -= lr.size
1371 for tens in lr.tensors:
1372 if tens in scratched_fms:
1373 tens.mem_area = scratched_fms[tens][0]
1374 tens.mem_type = scratched_fms[tens][1]
1375
1376 @staticmethod
1377 def keep(lr, base_mem_usage, staging_limit):
1378 for t in range(lr.start_time, lr.end_time + 1):
1379 base_mem_usage[t] += lr.size
1380 assert base_mem_usage[t] <= staging_limit
1381
1382 def allocate_component(self, allocator, lrs, max_mem, min_mem, staging_limit, scratched_fms):
1383 sz = len(lrs)
1384 allocator.lrs = lrs
1385 allocator.evicted = [0] * len(lrs)
1386 allocator.curr_evicted = [0] * sz
1387 allocator.best_allocated_size = -1
1388 # Recursively evaluate all permutations of allocations of the lrs found in the component
1389 allocator.allocate_exhaustive(0, 0)
1390
1391 # Optimal allocation has been found, move lrs accordingly
1392 for i, e in enumerate(allocator.evicted):
1393 if e:
1394 self.evict(lrs[i], max_mem, scratched_fms)
1395 else:
1396 self.keep(lrs[i], min_mem, staging_limit)
1397
1398
Tim Halld8339a72021-05-27 18:49:40 +01001399def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):
1400 """Entry point for the Scheduler"""
1401 # Initialize CPU subgraphs
1402 schedulers = dict()
1403 # Initialize schedulers with max schedule. Only schedule NPU subgraphs
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001404 for sg in nng.subgraphs:
Tim Halld8339a72021-05-27 18:49:40 +01001405 if sg.placement != PassPlacement.Npu:
1406 # Create cascaded passes for CPU Ops
1407 cascaded_passes = []
1408 for idx, ps in enumerate(sg.passes):
1409 cps = CascadedPass(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001410 ps.name,
1411 SchedulingStrategy.WeightStream,
1412 ps.inputs,
1413 [],
1414 ps.outputs,
1415 [ps],
1416 ps.placement,
1417 False,
Tim Halld8339a72021-05-27 18:49:40 +01001418 )
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001419
Tim Halld8339a72021-05-27 18:49:40 +01001420 cps.time = idx
1421 ps.cascade = cps
1422 cascaded_passes.append(cps)
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001423
Tim Halld8339a72021-05-27 18:49:40 +01001424 sg.cascaded_passes = cascaded_passes
1425 else:
1426 # Npu subgraph - create schedule
1427 scheduler = Scheduler(nng, sg, arch, scheduler_options)
1428 schedulers[sg] = scheduler
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001429
Tim Halld8339a72021-05-27 18:49:40 +01001430 scheduler.create_scheduler_representation(arch)
1431 sg.sched_ops = scheduler.sched_ops
1432 scheduler.move_constant_data()
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001433
Tim Halld8339a72021-05-27 18:49:40 +01001434 # Create the Max schedule template
1435 max_schedule_template = scheduler.create_initial_schedule()
1436 scheduler.max_schedule = max_schedule_template
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001437
Tim Halld8339a72021-05-27 18:49:40 +01001438 # Create the optimimised Max schedule
1439 sg.schedule = max_schedule_template
1440 scheduler.update_op_memory_snapshot(max_schedule_template)
Tim Hall789e6f32021-06-17 17:02:31 +01001441 opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template, 1 << 32)
Tim Halld8339a72021-05-27 18:49:40 +01001442 sg.schedule = opt_max_schedule
1443 scheduler.update_op_memory_snapshot(opt_max_schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001444
Tim Halld8339a72021-05-27 18:49:40 +01001445 # Create Min schedule
1446 min_schedule = scheduler.propose_minimal_schedule()
1447 initial_sram_limit = scheduler_options.optimization_sram_limit
1448 if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
1449 initial_sram_limit = scheduler.min_memory_req
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001450
Tim Halld8339a72021-05-27 18:49:40 +01001451 cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled())
1452 cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit)
1453 sg.schedule = min_schedule
1454 scheduler.update_op_memory_snapshot(min_schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001455
Tim Halld8339a72021-05-27 18:49:40 +01001456 if scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
1457 # Create an optimized schedule
1458 sg.schedule = scheduler.optimize_schedule(
1459 min_schedule, opt_max_schedule, max_schedule_template, scheduler_options
1460 )
1461 scheduler.update_op_memory_snapshot(sg.schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001462
Tim Halld8339a72021-05-27 18:49:40 +01001463 scheduler.apply_schedule(sg.schedule)
1464 scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001465
Johan Alfvén6f4cb032022-05-05 08:42:46 +02001466 if scheduler_options.optimization_strategy == OptimizationStrategy.Performance and scheduler.evicted_fms:
1467 # It might be possible to gain performance by reducing
1468 # weight buffer size and instead fit fms in fast storage
1469 scheduler.optimize_weight_buffering_size(min_schedule, scheduler_options)
1470
Tim Halld8339a72021-05-27 18:49:40 +01001471 if scheduler_options.verbose_schedule:
1472 scheduler.print_schedule(sg.schedule)
Tim Hall79d07d22020-04-27 18:20:16 +01001473
Tim Halld8339a72021-05-27 18:49:40 +01001474 # Evaluate schedule
1475 _update_tensor_allocation(nng, arch, options)