blob: 3ef0d7f200494c4017294fbf8e502c1137d8a0c5 [file] [log] [blame]
Tim Halld8339a72021-05-27 18:49:40 +01001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Halld8339a72021-05-27 18:49:40 +010016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
Tim Halld8339a72021-05-27 18:49:40 +010018# The scheduler creates and searches for an optimal plan for the network, selecting block configurations and
19# subdivisions for the Operators
Diego Russoea6111a2020-04-14 18:41:58 +010020import copy
Johan Alfvén5e0ae552022-02-09 21:20:10 +010021from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010022from enum import auto
23from enum import IntEnum
24from typing import Dict
25from typing import List
26from typing import Optional
27from typing import Tuple
Diego Russoea6111a2020-04-14 18:41:58 +010028
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +010029import numpy as np
30
Diego Russoea6111a2020-04-14 18:41:58 +010031from . import live_range
Tim Hall79d07d22020-04-27 18:20:16 +010032from . import npu_performance
Tim Halld8339a72021-05-27 18:49:40 +010033from . import tensor_allocation
34from . import weight_compressor
35from .architecture_allocator import ArchitectureBlockConfig
36from .architecture_allocator import find_block_config
37from .architecture_allocator import get_ifm_area_required
Tim Halld8339a72021-05-27 18:49:40 +010038from .architecture_features import ArchitectureFeatures
39from .architecture_features import Block
40from .cascade_builder import CascadeBuilder
41from .cascade_builder import CascadeInfo
Fredrik Svedberg880e7352020-08-25 11:31:47 +020042from .data_type import DataType
Diego Russoe8a10452020-04-21 17:39:10 +010043from .nn_graph import CascadedPass
Tim Halld8339a72021-05-27 18:49:40 +010044from .nn_graph import Graph
45from .nn_graph import Pass
Diego Russoe8a10452020-04-21 17:39:10 +010046from .nn_graph import PassPlacement
Diego Russoe8a10452020-04-21 17:39:10 +010047from .nn_graph import SchedulingStrategy
Tim Halld8339a72021-05-27 18:49:40 +010048from .nn_graph import Subgraph
49from .numeric_util import round_down
50from .numeric_util import round_up
Diego Russoe8a10452020-04-21 17:39:10 +010051from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020052from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010053from .shape4d import Shape4D
Diego Russoe8a10452020-04-21 17:39:10 +010054from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020055from .tensor import MemType
Tim Halld8339a72021-05-27 18:49:40 +010056from .tensor import Tensor
Diego Russoe8a10452020-04-21 17:39:10 +010057from .tensor import TensorFormat
58from .tensor import TensorPurpose
59from .tensor import TensorSubPurpose
Jacob Bohlin1a666972020-09-11 10:04:15 +020060
Tim Hall79d07d22020-04-27 18:20:16 +010061
Tim Halld8339a72021-05-27 18:49:40 +010062def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D:
63 if tensor_format == TensorFormat.NHCWB16:
64 return shape.with_depth(round_up(shape.depth, 16))
65
66 return shape
67
68
69class OptimizationStrategy(IntEnum):
70 """Enum defining the different optimization strategies for the Scheduler"""
71
72 Size = auto()
73 Performance = auto()
Tim Hall79d07d22020-04-27 18:20:16 +010074
75 def __str__(self):
76 return self.name
77
78
Tim Halld8339a72021-05-27 18:49:40 +010079class SchedulerOpInfo:
80 """Contains metadata about a SchedulerOperation that is unique to one Schedule"""
81
Tim Hall79d07d22020-04-27 18:20:16 +010082 def __init__(
83 self,
Tim Halld8339a72021-05-27 18:49:40 +010084 block_config: ArchitectureBlockConfig,
85 weights_size: int,
86 stripe_input: Shape4D,
87 stripe_input2: Optional[Shape4D],
88 stripe: Shape4D,
Tim Hall79d07d22020-04-27 18:20:16 +010089 ):
Tim Halld8339a72021-05-27 18:49:40 +010090 self.block_config = block_config
91 self.weights_size = weights_size
92 self.stripe_input = stripe_input
93 self.stripe_input2 = stripe_input2
94 self.stripe = stripe
95 self.cascade = 0 # Assigned by CascadeBuilder. 0 means not part of a cascade
96 self.time_index = None # Set by update_op_memory_snapshot
97 self.ofm_depth_slices: List[int] = [0, stripe.depth]
98 self.npu_weights_tensor = None
Tim Halld784af72021-06-08 21:25:57 +010099 self.npu_scales_tensor = None
Tim Halld8339a72021-05-27 18:49:40 +0100100 self.buffered_weight_tensor = None
101 self.cycles = None
102 self.slack_buffering_cycles = 0
103 self.slack_buffering_memory = 0
104 self.full_weight_transfer_cycles = 0
105
106 def copy(self):
107 res = SchedulerOpInfo(self.block_config, self.weights_size, self.stripe_input, self.stripe_input2, self.stripe,)
108 res.cascade = self.cascade
109 return res
110
111 def __str__(self):
112 res = f"\t\tBlock Config = {self.block_config}\n"
113 res += f"\t\tOFM Block = {self.block_config.ofm_block}\n"
114 res += f"\t\tIFM Stripe = {self.stripe_input}\n"
115 res += f"\t\tIFM2 Stripe = {self.stripe_input2}\n"
116 res += f"\t\tOFM Stripe = {self.stripe}\n"
117 res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
118 res += (
119 f"\t\tWeight buffer = {self.buffered_weight_tensor and self.buffered_weight_tensor.storage_size()} bytes\n"
120 )
121 res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
122 res += f"\t\tAssigned Cascade = {self.cascade}"
123 return res
124
125
126class SchedulerOptions:
127 """Contains options for the Scheduler"""
128
129 def __init__(
130 self, optimization_strategy, sram_target, verbose_schedule,
131 ):
132 self.optimization_strategy = optimization_strategy
133 self.optimization_sram_limit = sram_target
Tim Hall79d07d22020-04-27 18:20:16 +0100134 self.verbose_schedule = verbose_schedule
Tim Hall79d07d22020-04-27 18:20:16 +0100135
Tim Halld8339a72021-05-27 18:49:40 +0100136 def __str__(self) -> str:
137 return f"{type(self).__name__}: {str(self.__dict__)}"
Tim Hall79d07d22020-04-27 18:20:16 +0100138
139 __repr__ = __str__
140
141
Tim Halld8339a72021-05-27 18:49:40 +0100142class SchedulerTensor:
143 def __init__(self, shape, dt, mem_area, _format):
144 self.dtype = dt
145 self.mem_area = mem_area
146 self.shape = shape
147 self.format = _format
148 self.connection = None
Tim Hall79d07d22020-04-27 18:20:16 +0100149
Tim Halld8339a72021-05-27 18:49:40 +0100150
151class SchedulerOperation:
152 """Scheduler internal representation of 'Operation'
153 This class can be seen as a node within the Scheduler Graph representation
154 """
155
156 def __init__(self, ps: Pass, arch: ArchitectureFeatures, nng: Graph):
157 self.arch = arch
158 self.parent_ps = ps
159 self.parent_op = ps.primary_op
160 self.name = ps.primary_op.name
161 self.op_type = ps.primary_op.type
162 self.activation = ps.primary_op.activation
163 self.kernel = ps.primary_op.kernel
164 self.resampling_mode = ps.primary_op.ifm.resampling_mode
165 self.uses_scalar = ps.primary_op.ifm2 is not None and (
166 ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == []
Tim Hall79d07d22020-04-27 18:20:16 +0100167 )
Tim Halld8339a72021-05-27 18:49:40 +0100168 self.ifm_ublock = arch.ifm_ublock
Tim Hall79d07d22020-04-27 18:20:16 +0100169
Tim Halld8339a72021-05-27 18:49:40 +0100170 self.ifm = SchedulerTensor(ps.ifm_shapes[0], ps.ifm_tensor.dtype, ps.ifm_tensor.mem_area, ps.ifm_tensor.format,)
Tim Hall79d07d22020-04-27 18:20:16 +0100171
Tim Halld8339a72021-05-27 18:49:40 +0100172 self.ifm2 = None
173 if ps.ifm2_tensor:
174 self.ifm2 = SchedulerTensor(
175 ps.ifm_shapes[1], ps.ifm2_tensor.dtype, ps.ifm2_tensor.mem_area, ps.ifm2_tensor.format,
176 )
Tim Hall79d07d22020-04-27 18:20:16 +0100177
Tim Halld8339a72021-05-27 18:49:40 +0100178 self.ofm = SchedulerTensor(ps.ofm_shapes[0], ps.ofm_tensor.dtype, ps.ofm_tensor.mem_area, ps.ofm_tensor.format,)
Tim Hall79d07d22020-04-27 18:20:16 +0100179
Tim Halld8339a72021-05-27 18:49:40 +0100180 # Input volume width and height required to produce the smallest possible stripe
181 self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input()
Tim Hall79d07d22020-04-27 18:20:16 +0100182
Tim Halld8339a72021-05-27 18:49:40 +0100183 # Flags that marks whether this SchedulerOperation requires full IFM/OFM
184 self.requires_full_ifm = False
185 self.requires_full_ifm2 = False
186 self.requires_full_ofm = False
Tim Hall79d07d22020-04-27 18:20:16 +0100187
Tim Halld8339a72021-05-27 18:49:40 +0100188 self.index = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100189
Tim Halld8339a72021-05-27 18:49:40 +0100190 def add_ifm_connection(self, conn: "Connection"):
191 """Add input connection to another SchedulerOperation or Subgraph Input"""
192 conn.consumers.append(self)
193 self.ifm.connection = conn
Tim Hall79d07d22020-04-27 18:20:16 +0100194
Tim Halld8339a72021-05-27 18:49:40 +0100195 def add_ifm2_connection(self, conn: "Connection"):
196 """Add input connection to another SchedulerOperation or Subgraph Input"""
197 if self.ifm2:
198 conn.consumers.append(self)
199 self.ifm2.connection = conn
Tim Hall79d07d22020-04-27 18:20:16 +0100200 else:
Tim Halld8339a72021-05-27 18:49:40 +0100201 assert False, f"Trying to set an IFM2 Connection to {self} which has no IFM2"
Tim Hall79d07d22020-04-27 18:20:16 +0100202
Tim Halld8339a72021-05-27 18:49:40 +0100203 def add_ofm_connection(self, conn: "Connection"):
204 """Add output connection to another SchedulerOperation or Subgraph Output"""
205 conn.producers.append(self)
206 self.ofm.connection = conn
207
208 def get_dependants(self):
209 """Returns a list of the Ops that depend on this Operation's OFM"""
210 return self.ofm.connection.consumers
211
212 def ifm_size_in_bytes(self) -> int:
213 """Returns size of the IFM in bytes"""
214 ifm_storage_shape = shape_for_format(self.ifm.shape, self.ifm.format)
215 return round_up(ifm_storage_shape.elements() * self.ifm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
216
217 def ifm2_size_in_bytes(self) -> int:
218 """Returns size of the IFM2 in bytes"""
219 if self.ifm2:
220 ifm2_storage_shape = shape_for_format(self.ifm2.shape, self.ifm2.format)
221 return round_up(ifm2_storage_shape.elements() * self.ifm2.dtype.size_in_bytes(), Tensor.AllocationQuantum)
222
223 return 0
224
225 def ofm_size_in_bytes(self) -> int:
226 """Returns size of the OFM in bytes"""
227 ofm_storage_shape = shape_for_format(self.ofm.shape, self.ofm.format)
228 return round_up(ofm_storage_shape.elements() * self.ofm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
229
230 def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo:
231 """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce"""
232 ifm_shape = self.ifm.shape
233 ifm2_shape = self.ifm2 and self.ifm2.shape
234 ofm_shape = stripe
235
236 if ofm_shape != self.ofm.shape:
237 # Striped Op - Need to calculate stripe input volume
238 stripe_input_w, stripe_input_h = self._get_stripe_input_requirement(stripe)
239 # Ensure stripe input volume is within the full IFM volume
240 stripe_input_h = min(stripe_input_h, self.ifm.shape.height)
241 stripe_input_w = min(stripe_input_w, self.ifm.shape.width)
242 ifm_shape = ifm_shape.with_hw(stripe_input_h, stripe_input_w)
243
244 if self.ifm2:
245 stripe_input2_h = min(stripe_input_h, self.ifm2.shape.height)
246 stripe_input2_w = min(stripe_input_w, self.ifm2.shape.width)
247 ifm2_shape = ifm2_shape.with_hw(stripe_input2_h, stripe_input2_w)
248
249 block_config = self._get_block_config(ifm_shape, ifm2_shape, self.uses_scalar, ofm_shape)
250
251 scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape)
252 if self.parent_op.weights:
253 # Default full-depth weight encoding with no buffering
Tim Halld784af72021-06-08 21:25:57 +0100254 (
255 scheduler_op_info.npu_weights_tensor,
256 scheduler_op_info.npu_scales_tensor,
257 ) = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100258 self.arch,
259 self.parent_op,
260 self.parent_op.weights,
261 self.parent_op.bias,
262 self.kernel,
263 block_config,
264 [0, self.ofm.shape.depth],
265 )
266
267 self.parent_ps.block_config = block_config.old_style_representation()
268 return scheduler_op_info
269
270 def _get_stripe_input_requirement(self, stripe_shape: Shape4D) -> Tuple[int, int]:
271 """Returns the amount of IFM required to produce the stripe with shape:'stripe_shape'"""
272 ofm_shape_to_produce = Block.from_shape(stripe_shape.as_list())
273
Fredrik Svedberg3ff7a4a2021-09-29 10:08:04 +0200274 return get_ifm_area_required(ofm_shape_to_produce, self.kernel, self.resampling_mode)
Tim Halld8339a72021-05-27 18:49:40 +0100275
276 def _calculate_min_stripe_input(self) -> Shape4D:
277 # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1)
278 min_stripe = self.ofm.shape.with_hw(1, 1)
279 return self._get_stripe_input_requirement(min_stripe)
280
281 def _get_block_config(
282 self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D
283 ) -> ArchitectureBlockConfig:
284 # Returns a block config and SHRAM layout
285 lut_banks = 2 if self.parent_op.activation_lut else 0
286 return find_block_config(
287 self.arch,
288 self.op_type.npu_block_type,
289 ofm_shape,
290 ifm_shape,
291 ifm2_shape,
292 uses_scalar,
293 self.ifm.dtype.size_in_bits(),
294 self.kernel,
295 lut_banks,
296 self.parent_op.has_scaling(),
297 self.resampling_mode,
298 )
299
300
301class Connection:
302 """Scheduler internal representation of a Tensor that connects two SchedulerOperations
303 This class can be seen as an edge within the Scheduler Graph representation
304 """
305
306 def __init__(self, tensor: Tensor):
307 self.parent_tens = tensor
308
309 # SchedulerOperation relationships
310 self.producers: List[SchedulerOperation] = []
311 self.consumers: List[SchedulerOperation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100312
313 def __str__(self):
Tim Halld8339a72021-05-27 18:49:40 +0100314 return f"<Connection {self.parent_tens.name}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100315
316 __repr__ = __str__
317
318
Tim Halld8339a72021-05-27 18:49:40 +0100319class Schedule:
320 """Class that contains a solution of how to schedule an NPU subgraph and its cost"""
Tim Hall79d07d22020-04-27 18:20:16 +0100321
Tim Halld8339a72021-05-27 18:49:40 +0100322 def __init__(self, sg: Subgraph, label: str):
323 self.sg = sg
324 self.label = label
325 self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {}
326 self.cascades: Dict[int, CascadeInfo] = {}
327 self.fast_storage_peak_usage = 0
328 self.memory_snapshot = None
329
330 @property
331 def name(self):
332 return f"{self.sg.name}_{self.label}"
Tim Hall79d07d22020-04-27 18:20:16 +0100333
334
Tim Halld8339a72021-05-27 18:49:40 +0100335class Scheduler:
336 """Main class of the Vela Scheduling"""
Tim Hall79d07d22020-04-27 18:20:16 +0100337
Tim Halld8339a72021-05-27 18:49:40 +0100338 def __init__(self, nng: Graph, sg: Subgraph, arch: ArchitectureFeatures, options: SchedulerOptions):
Tim Hall79d07d22020-04-27 18:20:16 +0100339 self.nng = nng
340 self.sg = sg
341 self.arch = arch
Ayaan Masoodb801dda2022-02-22 11:28:55 +0000342 self.sched_ops: List[SchedulerOperation] = []
Tim Halld8339a72021-05-27 18:49:40 +0100343 self.max_schedule = None
344 self.scheduler_options = options
Tim Hall79d07d22020-04-27 18:20:16 +0100345
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100346 def avoid_nhcwb16_for_ofm(self, tens, ps, arch):
347 # Only run this check for opt strategy Size
348 if self.scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
349 return False
350
351 op = ps.primary_op
352 if not op.type.is_elementwise_op():
353 return False
354
355 depth = op.ofm_shapes[0][-1]
356 if (depth % 16) == 0:
357 return False
358
359 # Check if overwriting the inputs can be allowed
360 OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
361 outp = OpShapeTens(op.ofm_shapes[0], op.ofm)
362 inps = []
363 if op.ifm is not None:
364 inps.append(OpShapeTens(op.ifm_shapes[0], op.ifm))
365 if op.ifm2 is not None:
366 inps.append(OpShapeTens(op.ifm_shapes[1], op.ifm2))
367
368 # Find an input tensor that can be overwritten by the output
369 for inp in inps:
370 if (
371 # check op input and output shapes allow overlapping
372 inp.op_shape == outp.op_shape
373 # check input tensor is valid
374 and inp.tens is not None
375 and inp.tens.shape != []
376 # check input and output tensors are compatible
377 and inp.tens.format == outp.tens.format
378 and inp.tens.dtype == outp.tens.dtype
379 ):
380 if inp.tens.format == TensorFormat.NHWC:
381 return True
382
383 return False
384
Tim Halld8339a72021-05-27 18:49:40 +0100385 def create_scheduler_representation(self, arch: ArchitectureFeatures):
386 """Creates a Scheduler Graph representation"""
387 # Temporary dict for creating connections between the Operations
388 connections: Dict[Tensor, Connection] = {}
389 # Memory required for the largest FeatureMap that has to be full
390 min_memory_req = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100391 for ps in self.sg.passes:
Tim Halld8339a72021-05-27 18:49:40 +0100392 if ps.primary_op:
393 # Set tensor format to NHCWB16 for output FeatureMaps, if possible
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200394 for output in ps.outputs:
Jacob Bohlina5e8c1c2021-06-14 13:33:39 +0200395 if output in self.sg.output_tensors or output.purpose != TensorPurpose.FeatureMap:
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +0200396 continue
Johan Alfvén5e0ae552022-02-09 21:20:10 +0100397
398 if output.needs_linear_format:
399 continue
400
401 if self.avoid_nhcwb16_for_ofm(output, ps, arch):
402 output.needs_linear_format = True
403 continue
404
405 output.set_format(TensorFormat.NHCWB16, arch)
Tim Halld8339a72021-05-27 18:49:40 +0100406
407 # Create SchedulerOperations
408 op = SchedulerOperation(ps, arch, self.nng)
409 op.index = len(self.sched_ops)
410
411 # Make connections
412 if ps.ifm_tensor not in connections:
413 connections[ps.ifm_tensor] = Connection(ps.ifm_tensor)
414 if ps.ifm2_tensor and ps.ifm2_tensor not in connections:
415 connections[ps.ifm2_tensor] = Connection(ps.ifm2_tensor)
416 if ps.ofm_tensor not in connections:
417 connections[ps.ofm_tensor] = Connection(ps.ofm_tensor)
418
419 op.add_ifm_connection(connections[ps.ifm_tensor])
420 if ps.ifm2_tensor:
421 op.add_ifm2_connection(connections[ps.ifm2_tensor])
422 op.add_ofm_connection(connections[ps.ofm_tensor])
423
424 # Set requirements on the ifm/ofm buffers
425 self.sched_ops.append(op)
426 if ps.ifm_tensor in self.sg.input_tensors:
427 # This Op consumes a subgraph input
428 op.requires_full_ifm = True
429 if ps.ifm2_tensor and ps.ifm2_tensor in self.sg.input_tensors:
430 # This Op consumes a subgraph input
431 op.requires_full_ifm2 = True
432 if ps.ofm_tensor in self.sg.output_tensors:
433 # This Op produces a subgraph output
434 op.requires_full_ofm = True
435 if ps.ifm_tensor.needs_linear_format:
436 op.requires_full_ifm = True
437 if ps.ifm2_tensor and ps.ifm2_tensor.needs_linear_format:
438 op.requires_full_ifm2 = True
439 if ps.ofm_tensor.needs_linear_format or ps.primary_op.memory_function == Op.ConcatSliceWrite:
440 op.requires_full_ofm = True
441 if len(ps.primary_op.outputs) > 1 or len(ps.primary_op.outputs[0].consumer_list) > 1:
442 # Op has multiple outputs or consumers - requires full OFM
443 op.requires_full_ofm = True
444
445 # Check memory requirements if this Op requires any full FeatureMaps
446 op_memory_req = 0
447 if op.requires_full_ifm:
448 op_memory_req += op.ifm_size_in_bytes()
449 if op.requires_full_ifm2:
450 op_memory_req += op.ifm2_size_in_bytes()
451 if op.requires_full_ofm:
452 op_memory_req += op.ofm_size_in_bytes()
453
454 min_memory_req = max(op_memory_req, min_memory_req)
455
456 # Theoretical minimum required memory - used to guide the cascade building
457 self.min_memory_req = min_memory_req
458
459 def create_initial_schedule(self) -> Schedule:
460 """Creates an initial schedule with no cascading or buffering of any kind"""
461 schedule = Schedule(self.sg, "MAX")
Tim Halld8339a72021-05-27 18:49:40 +0100462 for op in self.sched_ops:
463 cost = op.create_scheduler_info(self.nng, op.ofm.shape)
464 cost.cycles = self.estimate_op_performance(op, cost.block_config, op.ofm.shape.depth)
465 schedule.cost_map[op] = cost
466
467 return schedule
468
469 def update_op_memory_snapshot(self, schedule: Schedule):
470 memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
471
472 # Collect live ranges from tensors
473 lr_graph = live_range.LiveRangeGraph()
474 for mem_area, mem_type_set in memories_list:
475 live_range.extract_live_ranges_from_cascaded_passes(
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200476 self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
Tim Halld8339a72021-05-27 18:49:40 +0100477 )
478
479 # Populate time-array with memory used by live ranges
480 temporal_usage = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
481 schedule.memory_snapshot = temporal_usage
482
483 # Set the peak memory usage
484 schedule.fast_storage_peak_usage = max(temporal_usage, default=0)
485
486 def estimate_op_performance(self, op: SchedulerOperation, block_config, ofm_depth):
487 query = npu_performance.PerformanceQuery(op.op_type.npu_block_type)
488 query.ifm_shape = op.ifm.shape
489 query.ifm_memory_area = op.ifm.mem_area
490 query.ifm_bits = op.ifm.dtype.size_in_bits()
491 query.ifm_format = op.ifm.format
492 query.ifm2_shape = op.ifm2 and op.ifm2.shape
493 query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
494 query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
495 query.ifm2_format = op.ifm2 and op.ifm2.format
496 query.ofm_shape = op.ofm.shape.with_depth(ofm_depth)
497 query.ofm_memory_area = op.ofm.mem_area
498 query.ofm_bits = op.ofm.dtype.size_in_bits()
499 query.ofm_format = op.ofm.format
500 if op.parent_op.bias:
501 query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
502 query.const_memory_area = self.arch.fast_storage_mem_area
503
504 query.kernel = op.kernel
505 query.config = block_config
506
507 return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query)
508
Tim Hall789e6f32021-06-17 17:02:31 +0100509 def propose_schedule_buffering(self, ref_schedule: Schedule, staging_limit_bytes):
Tim Halld8339a72021-05-27 18:49:40 +0100510 """Create a buffered schedule"""
511 buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED")
Tim Halld8339a72021-05-27 18:49:40 +0100512
513 prev_op = None
514 for sched_op in self.sched_ops:
515 if sched_op not in ref_schedule.cost_map:
516 # sched_op is not part of this sub-schedule - skip
517 continue
518
519 self.propose_operator_buffering(sched_op, prev_op, buffered_schedule, ref_schedule, staging_limit_bytes)
520 prev_op = sched_op
521
522 return buffered_schedule
523
524 def propose_operator_buffering(
525 self,
526 sched_op: SchedulerOperation,
527 prev_op: SchedulerOperation,
528 buffered_schedule: Schedule,
529 ref_schedule: Schedule,
530 staging_limit_bytes,
531 ):
532 # Mild recursion might mean this Op has already been seen
533 if sched_op in buffered_schedule.cost_map:
534 return
535
536 # Take the reference schedule as default costings for this schedule
537 ref_cost = ref_schedule.cost_map[sched_op]
538 cost = copy.copy(ref_cost)
539 cost.slack_buffering_cycles = ref_cost.cycles.op_cycles
540 memory_snapshot = ref_schedule.memory_snapshot
541 ref_memory_usage = memory_snapshot[ref_cost.time_index] if ref_cost.time_index < len(memory_snapshot) else 0
542 cost.slack_buffering_memory = staging_limit_bytes - ref_memory_usage
543 buffered_schedule.cost_map[sched_op] = cost
544
545 # Attempt weight buffering on anything with a weights tensor
546 if sched_op.parent_op.weights:
547 self.propose_weight_buffering(
548 sched_op.parent_op.weights,
549 sched_op.parent_op.bias,
550 sched_op,
551 prev_op,
552 buffered_schedule,
553 ref_schedule,
554 cost.slack_buffering_memory,
555 )
556
557 return cost
558
559 def weights_needs_dma(self, weight_tensor):
560 if weight_tensor and weight_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
561 # Weights are in permanent storage
562 # Only when permanent storage differs from feature map storage, there is a point moving the data
563 if (
564 weight_tensor.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
565 and self.arch.permanent_storage_mem_area != self.arch.fast_storage_mem_area
566 ):
567 return True
568 return False
569
570 def propose_weight_buffering(
571 self,
572 weight_tensor,
573 scale_tensor,
574 sched_op: SchedulerOperation,
575 prev_op: SchedulerOperation,
576 buffered_schedule: Schedule,
577 ref_schedule: Schedule,
578 buffer_limit_bytes,
579 ):
580 cost = buffered_schedule.cost_map[sched_op]
581 prev_cost = buffered_schedule.cost_map.get(prev_op)
582 ref_cost = ref_schedule.cost_map[sched_op]
583 assert cost and ref_cost
584
585 needs_dma = self.weights_needs_dma(weight_tensor)
586
587 ofm_full_depth_slices = [0, ref_cost.stripe.depth]
588
589 # Encode weights for the full depth
Tim Halld784af72021-06-08 21:25:57 +0100590 full_weights, full_scales = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100591 self.arch,
592 sched_op.parent_op,
593 weight_tensor,
594 scale_tensor,
595 sched_op.kernel,
596 cost.block_config,
597 ofm_full_depth_slices,
598 )
599 full_weights_bytes = len(full_weights.buffer)
600 cost.ofm_depth_slices = ofm_full_depth_slices
601
602 # No buffering required - take all the weights from permanent storage
603 if sched_op.op_type == Op.FullyConnected or not needs_dma:
604 cost.npu_weights_tensor = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100605 cost.npu_scales_tensor = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100606 return
607
608 encoded_weights = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100609 encoded_scales = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100610
611 # How many NPU cycles are available under the previously executing
612 # operator and SRAM unused for performing buffered DMA transfers
613 slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0
614 slack_memory = prev_cost.slack_buffering_memory if prev_cost else 0
615
616 # Force full depth for cascaded Ops
617 if ref_cost.cascade != 0:
618 weight_tensor_purpose = TensorSubPurpose.Standard
619 weight_buffer_size = full_weights_bytes
620 # Update the memory snapshot to reflect the added size of the weights
621 ref_schedule.memory_snapshot[ref_cost.time_index] += weight_buffer_size
622 else:
623 # Estimate the buffering cycle time for the full set of weights
624 full_transfer_cycles = npu_performance.measure_mem2mem_cycles(
625 self.arch, weight_tensor.mem_area, self.arch.fast_storage_mem_area, full_weights_bytes
626 )
627 cost.full_weight_transfer_cycles = full_transfer_cycles
628
629 # Calculate the amount of prebuffering necessary (or what is possible with limited
630 # double buffer buffer size)
631 half_buffer_limit = buffer_limit_bytes // 2
632 if full_transfer_cycles > slack_cycles:
633 prebuffer_ratio = slack_cycles / full_transfer_cycles
634 prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit)
635 else:
636 prebuffer_bytes = min(full_weights_bytes, half_buffer_limit)
Tim Hall789e6f32021-06-17 17:02:31 +0100637
638 prebuffer_ratio = prebuffer_bytes / full_weights_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100639
640 # Have to split the weights if the initial buffering can't store
641 # all of the compressed weights
642 if prebuffer_bytes < full_weights_bytes:
Tim Hall789e6f32021-06-17 17:02:31 +0100643 block_depth = cost.block_config.ofm_block.depth
Tim Halld8339a72021-05-27 18:49:40 +0100644
Tim Hall789e6f32021-06-17 17:02:31 +0100645 # Choose initial prebuffering depth (already buffer clamped)
646 prebuffer_depth = ref_cost.stripe.depth * prebuffer_ratio
Tim Halld8339a72021-05-27 18:49:40 +0100647 prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
648
Tim Hall789e6f32021-06-17 17:02:31 +0100649 # Calculate cycles executed during the prebuffer
650 pre_op_cycles = self.estimate_op_performance(sched_op, cost.block_config, prebuffer_depth)
651 buffering_depth = ref_cost.stripe.depth * (pre_op_cycles.op_cycles / full_transfer_cycles)
Tim Halld8339a72021-05-27 18:49:40 +0100652
Tim Hall789e6f32021-06-17 17:02:31 +0100653 # Choose initial buffering depth and clamp to the double buffering limit
654 buffering_depth = round_up(buffering_depth, block_depth)
655 buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
656 if buffering_bytes > half_buffer_limit:
657 buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
658
659 while True:
660 # Attempt to buffer whole blocks
661 if buffering_bytes > block_depth:
662 buffering_depth = round_down(buffering_depth, block_depth)
663 else:
664 buffering_depth = round_down(buffering_depth, ArchitectureFeatures.OFMSplitDepth)
665 buffering_depth = int(max(buffering_depth, ArchitectureFeatures.OFMSplitDepth))
Tim Halld8339a72021-05-27 18:49:40 +0100666
667 # Create list of depth slices
668 depth_slices = [0]
669 if prebuffer_depth < ref_cost.stripe.depth:
670 depth_slices += list(range(prebuffer_depth, ref_cost.stripe.depth, buffering_depth))
671 depth_slices.append(ref_cost.stripe.depth)
672
673 # Encode weights based depth slices
674 cost.ofm_depth_slices = depth_slices
Tim Halld784af72021-06-08 21:25:57 +0100675 encoded_weights, encoded_scales = weight_compressor.encode_weight_and_scale_tensor(
Tim Halld8339a72021-05-27 18:49:40 +0100676 self.arch,
677 sched_op.parent_op,
678 weight_tensor,
679 scale_tensor,
680 sched_op.kernel,
681 cost.block_config,
682 cost.ofm_depth_slices,
683 )
684
685 # Chosen buffering might not fit at all, iterate until it does
686 # or until the minimum usable slice size is reached
687 if (
688 encoded_weights.max_range_bytes <= half_buffer_limit
689 or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
690 ):
691 break
692
Tim Hall789e6f32021-06-17 17:02:31 +0100693 if buffering_depth > prebuffer_depth:
694 buffering_depth = round_up(buffering_depth // 2, ArchitectureFeatures.OFMSplitDepth)
695 else:
696 prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
Tim Halld8339a72021-05-27 18:49:40 +0100697
698 # Calculate cycles required to run the last op for use as future slack
699 tail_cycles = self.estimate_op_performance(
700 sched_op, cost.block_config, depth_slices[-1] - depth_slices[-2]
701 )
702 cost.slack_buffering_cycles = tail_cycles.op_cycles
703
704 # Determine whether the weights need to be double buffered
705 weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes)
706
707 # Only buffer weights if there's still space left for the buffer
708 if weight_buffer_size <= buffer_limit_bytes:
709 assert weight_buffer_size % 16 == 0
710 # Determine whether to double buffer or single buffer
711 if (weight_buffer_size * 2 <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
712 weight_buffer_size = weight_buffer_size * 2
713 weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
714 else:
715 weight_tensor_purpose = TensorSubPurpose.Standard
716
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200717 cost.buffered_weight_tensor = self.buffer_tensor(
718 encoded_weights, weight_tensor_purpose, weight_buffer_size, weight_tensor.name
Tim Halld8339a72021-05-27 18:49:40 +0100719 )
Tim Halld8339a72021-05-27 18:49:40 +0100720 if ref_cost.cascade == 0:
721 # Determine if the lifetime can be extended and pre-buffer weights under the previous operation
722 cost.buffered_weight_tensor.pre_buffer = weight_buffer_size < slack_memory
723
724 cost.slack_buffering_memory -= weight_buffer_size
725 else:
726 # Don't slice or buffer - use the whole depth from persistent storage
727 cost.ofm_depth_slices = ofm_full_depth_slices
728 encoded_weights = full_weights
Tim Halld784af72021-06-08 21:25:57 +0100729 encoded_scales = full_scales
Tim Halld8339a72021-05-27 18:49:40 +0100730
731 cost.npu_weights_tensor = encoded_weights
Tim Halld784af72021-06-08 21:25:57 +0100732 cost.npu_scales_tensor = encoded_scales
Tim Halld8339a72021-05-27 18:49:40 +0100733
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200734 def buffer_tensor(self, src_tensor: Tensor, sub_purpose: TensorSubPurpose, buffer_size: int, name: str) -> Tensor:
735 buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name + "_buffer")
736 buffered_weight_tensor.src_tensor = src_tensor
737 buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
738 buffered_weight_tensor.mem_type = MemType.Scratch_fast
739 buffered_weight_tensor.purpose = TensorPurpose.Weights
740 buffered_weight_tensor.sub_purpose = sub_purpose
741 return buffered_weight_tensor
742
Tim Halld8339a72021-05-27 18:49:40 +0100743 def propose_minimal_schedule(self) -> Schedule:
744 """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the
745 next operators stride"""
746 min_schedule = Schedule(self.sg, "MIN")
747 cost_map = min_schedule.cost_map
748
749 # Keep track of the previous Op - which consumes the current Op's OFM
750 prev_op = None
751 for sched_op in reversed(self.sched_ops):
752 min_stripe_height = prev_op.kernel.stride.y if prev_op else 1
753 min_stripe = sched_op.ofm.shape.with_height(min_stripe_height)
754
755 cost = sched_op.create_scheduler_info(self.nng, min_stripe)
756 cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
757 cost_map[sched_op] = cost
758
759 prev_op = sched_op
760
761 return min_schedule
762
763 def propose_schedule_striping(self, final_stripe: Shape4D, label: str, ref_schedule: Schedule) -> Schedule:
764 """Proposes new striping for a schedule. The stripe is derived from the ifm requirements of the next Op down"""
765 ref_cost = ref_schedule.cost_map
766
767 striped_schedule = Schedule(self.sg, label)
768 stripe = final_stripe
769 for sched_op in reversed(self.sched_ops):
770 if sched_op not in ref_cost:
771 # sched_op is not part of the sub-schedule - skip
772 continue
773
774 # Create a cost entry with the new stripe
775 cost = sched_op.create_scheduler_info(self.nng, stripe)
776
Jacob Bohlineee9e5d2021-08-17 17:44:45 +0200777 if ref_cost[sched_op].buffered_weight_tensor:
778 # If the weights are buffered in the reference schedule they should be in the new proposal
779 weight_tensor = cost.npu_weights_tensor
780 cost.buffered_weight_tensor = self.buffer_tensor(
781 weight_tensor, TensorSubPurpose.Standard, len(weight_tensor.buffer), weight_tensor.name
782 )
Tim Halld8339a72021-05-27 18:49:40 +0100783
784 # Estimate performance
785 cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
786 striped_schedule.cost_map[sched_op] = cost
787
788 # Calculate the preceeding Op's stripe
789 stripe = sched_op.ifm.shape.with_height(stripe.height * sched_op.kernel.stride.y)
790
791 return striped_schedule
792
793 def estimate_schedule_memory_usage(self, schedule: Schedule, non_local_mem_usage: dict):
794 """Estimates the memory usage of a schedule"""
795 cost = schedule.cost_map
796 cascades = schedule.cascades
797 peak_mem_usage = 0
798 for sched_op in self.sched_ops:
799 if sched_op not in cost:
800 # sched_op is not part of the sub-schedule - skip
801 continue
802
803 if cost[sched_op].cascade:
804 # This Op is part of a cascade - use the cascade's memory usage
805 cascade_info = cascades[cost[sched_op].cascade]
806 # Non-local memory usage is already included in the cascade_info
807 peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
808 else:
809 # This Op is not part of a cascade - calculate the memory usage
810 op_weight_buffer = 0
811 if cost[sched_op].buffered_weight_tensor:
812 op_weight_buffer = cost[sched_op].buffered_weight_tensor.storage_size()
813
814 op_mem_usage = (
815 sched_op.ifm_size_in_bytes()
816 + sched_op.ofm_size_in_bytes()
817 + op_weight_buffer
818 + non_local_mem_usage.get(sched_op, 0)
819 )
820 peak_mem_usage = max(op_mem_usage, peak_mem_usage)
821
822 return peak_mem_usage
823
824 def optimize_sub_schedule(
825 self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int
826 ) -> Schedule:
827 """Extracts the Ops covered by the given cascade and creates a sub-schedule. The sub-schedule is optimized by
828 proposing weight buffering and then continously proposing new stripe sizes"""
829 ref_cost = ref_schedule.cost_map
830 # Extract the ops that are part of this sub-schedule
831 start = cascade_info.start
832 end = cascade_info.end
833 sub_schedule_ops = self.sched_ops[start : end + 1]
834 # Create a sub-schedule that contains only the costs for the Ops that are part of the sub-schedule
835 sub_schedule = Schedule(self.sg, f"SUB_{start}_{end}")
836 for sched_op in sub_schedule_ops:
837 sub_schedule.cost_map[sched_op] = ref_cost[sched_op]
838
839 sub_schedule.cascades[end] = cascade_info
840 # Use the memory snapshot from the reference schedule
841 sub_schedule.memory_snapshot = ref_schedule.memory_snapshot
842
843 # Calculate memory usage that is live during the sub-schedule but not part of it
844 time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
845 mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
846 # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
847 # included in a cascade or not
848 persistent_initial_ifm = (
849 sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0
850 )
851 # Calculate non-local-mem-usage per Operator
852 non_local_mem_usage = {}
853 for idx, sched_op in enumerate(sub_schedule_ops):
854 non_local_mem_usage[sched_op] = mem_usage_parallel_to_sub_schedule
855 if idx != 0:
856 non_local_mem_usage[sched_op] += persistent_initial_ifm
857
858 cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
859
860 # Start by adding buffering
Tim Hall789e6f32021-06-17 17:02:31 +0100861 buffered_sub_schedule = self.propose_schedule_buffering(
862 sub_schedule, self.scheduler_options.optimization_sram_limit
863 )
Tim Halld8339a72021-05-27 18:49:40 +0100864 # Copy the cascades over from the unbuffered-schedule
865 buffered_sub_schedule.cascades = sub_schedule.cascades
866
867 # Generate the possible stripings for the final Op in the sub-schedule
868 final_ofm_shape = sub_schedule_ops[-1].ofm.shape
869 possible_stripes = [
870 final_ofm_shape.with_height(stripe_h) for stripe_h in range(1, final_ofm_shape.height // 2 + 1)
871 ]
872
873 # Propose different striping - the possible stripes are proposed similarly to a binary search
Jacob Bohlinfad72042021-08-24 21:51:41 +0200874 best_schedule = None
Tim Halld8339a72021-05-27 18:49:40 +0100875 iteration = 0
876 while len(possible_stripes) > 1:
877 proposed_stripe = possible_stripes[len(possible_stripes) // 2]
878 proposed_schedule = self.propose_schedule_striping(
879 proposed_stripe, f"OPTIMIZED_{iteration}", buffered_sub_schedule
880 )
881
882 cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit)
883
884 # Check if proposal fits
885 proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
886 if (proposed_schedule_mem_usage) <= memory_limit:
887 # Remove all possible stripes smaller than this
888 possible_stripes = possible_stripes[len(possible_stripes) // 2 :]
889 best_schedule = proposed_schedule
890 if not proposed_schedule.cascades:
891 # No cascading required - early exit
892 break
893 else:
894 # Proposal doesn't fit within the limit - remove all possible stripes larger than this
895 possible_stripes = possible_stripes[: len(possible_stripes) // 2]
896
897 iteration += 1
898
899 return best_schedule
900
901 def optimize_schedule(
902 self, schedule: Schedule, max_sched: Schedule, max_template: Schedule, options: SchedulerOptions,
903 ) -> Schedule:
904 """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule"""
905 sram_limit = options.optimization_sram_limit
906 if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled():
907 # Maximum performance schedule fits within the SRAM target
908 return max_sched
909
Jacob Bohlinfad72042021-08-24 21:51:41 +0200910 # Iterate over a copy of the cascades since they may change during the loop
911 for cascade_info in list(schedule.cascades.values()):
Tim Halld8339a72021-05-27 18:49:40 +0100912 # Optimize the sub-schedule in this cascade
913 opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
Jacob Bohlinfad72042021-08-24 21:51:41 +0200914 if opt_sub_schedule:
915 # Remove the existing cascade
916 del schedule.cascades[cascade_info.end]
917 # Update the sub-schedule Op and cascade costs to the full schedule
918 schedule.cost_map.update(opt_sub_schedule.cost_map)
919 schedule.cascades.update(opt_sub_schedule.cascades)
Tim Halld8339a72021-05-27 18:49:40 +0100920
921 # Update memory snapshot
922 self.sg.schedule = schedule
923 self.update_op_memory_snapshot(schedule)
924 # Propose schedule buffering to the optimized schedule
Tim Hall789e6f32021-06-17 17:02:31 +0100925 optimized_sched = self.propose_schedule_buffering(schedule, self.scheduler_options.optimization_sram_limit)
Tim Halld8339a72021-05-27 18:49:40 +0100926 # Copy the cascade's metadata from the unbuffered schedule
927 optimized_sched.cascades = schedule.cascades
928 return optimized_sched
929
930 def apply_schedule(self, sched: Schedule):
931 """Applies the given schedule as a final solution"""
932 for sched_op in self.sched_ops:
933 op_info = sched.cost_map[sched_op]
934 cascade_info = sched.cascades.get(op_info.cascade, None)
935 if cascade_info and sched_op in cascade_info.buffers:
936 buffer_tens = sched_op.ifm.connection.parent_tens
937 # Apply memory area and type
938 buffer_tens.mem_area = self.arch.fast_storage_mem_area
939 buffer_tens.mem_type = MemType.Scratch_fast
940 # Apply Rolling buffer
941 buffer_tens.set_format(TensorFormat.NHCWB16, self.arch)
942 buffer_tens.set_new_sub_purpose(TensorSubPurpose.RollingBufferY, cascade_info.buffers[sched_op].height)
943
944 sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
945
946 # Ensure that the src_tensor reference is set correctly
947 if op_info.buffered_weight_tensor:
948 op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
949
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100950 def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
951 scratched_fms = {}
952 max_mem_usage = []
953 base_mem_usage = []
954 fast_storage_type = MemType.Scratch_fast
955 fast_storage_mem_area = self.arch.fast_storage_mem_area
Tim Halld8339a72021-05-27 18:49:40 +0100956
957 # Force all OFMs to fast-storage
958 for sched_op in self.sched_ops:
959 cost = schedule.cost_map[sched_op]
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100960 if cost.cascade == 0 and sched_op.get_dependants():
961 ofm_tens = sched_op.ofm.connection.parent_tens
962 if not any(cons is None for cons in ofm_tens.consumer_list):
963 if ofm_tens not in scratched_fms:
964 scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
965 ofm_tens.mem_area = fast_storage_mem_area
966 ofm_tens.mem_type = fast_storage_type
Tim Halld8339a72021-05-27 18:49:40 +0100967
968 # Collect live ranges from tensors
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100969 memories_list = [(fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
Tim Halld8339a72021-05-27 18:49:40 +0100970 lr_graph = live_range.LiveRangeGraph()
971 for mem_area, mem_type_set in memories_list:
972 live_range.extract_live_ranges_from_cascaded_passes(
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200973 self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
Tim Halld8339a72021-05-27 18:49:40 +0100974 )
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100975 max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
Tim Halld8339a72021-05-27 18:49:40 +0100976
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100977 # If true, everything fits and we can proceed
978 if max(max_mem_usage) <= staging_limit:
979 return
980
981 # Build up the base memory usage by removing the
982 # mem_usage of the lrs we previously moved to fast-storage
983 base_mem_usage = np.array(max_mem_usage)
984 curr_lrs = []
Tim Halld8339a72021-05-27 18:49:40 +0100985 for lr in lr_graph.lrs:
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100986 for tens in lr.tensors:
987 if scratched_fms.get(tens):
988 curr_lrs.append(lr)
989 base_mem_usage[lr.start_time : lr.end_time + 1] -= lr.size
990 break
991
992 competing_lrs = []
993 for lr in curr_lrs:
994 base_usage = max(base_mem_usage[lr.start_time : lr.end_time + 1])
995 # If true, the lr will never fit and may thus be evicted
996 if base_usage + lr.size > staging_limit:
997 FastStorageComponentAllocator.evict(lr, max_mem_usage, scratched_fms)
998 continue
999 # Since max_mem_usage is the memory usage with all FMs still in fast-storage,
1000 # the memory limit cannot be exceeded if max_mem_usage does not.
1001 # Thus, the affected lrs can remain in fast-storage if the following is true
1002 if max(max_mem_usage[lr.start_time : lr.end_time + 1]) <= staging_limit:
1003 FastStorageComponentAllocator.keep(lr, base_mem_usage, staging_limit)
1004 else:
1005 competing_lrs.append(lr)
1006 sz = len(competing_lrs)
1007 # All lrs and their tensors have been handled if sz is zero, we may thus return
1008 if sz == 0:
1009 return
1010
1011 competing_lrs = sorted(competing_lrs, key=lambda lr: (lr.start_time, lr.end_time + 1, lr.size))
1012 start = 0
1013 start_time = competing_lrs[0].start_time
1014 end_time = competing_lrs[0].end_time
1015 component_allocator = FastStorageComponentAllocator(base_mem_usage, max_mem_usage, staging_limit)
1016 # Build up components and then allocate each separately
1017 for i, lr in enumerate(competing_lrs):
1018 if lr.start_time <= end_time and i - start < component_allocator.max_exhaustive_size:
1019 start_time = min(start_time, lr.start_time)
1020 end_time = max(end_time, lr.end_time)
1021 else:
1022 component_allocator.allocate_component(
1023 component_allocator,
1024 competing_lrs[start:i],
1025 max_mem_usage,
1026 base_mem_usage,
1027 staging_limit,
1028 scratched_fms,
1029 )
1030 start = i
1031 start_time = lr.start_time
1032 end_time = lr.end_time
1033 component_allocator.allocate_component(
1034 component_allocator, competing_lrs[start:sz], max_mem_usage, base_mem_usage, staging_limit, scratched_fms
1035 )
Tim Halld8339a72021-05-27 18:49:40 +01001036
1037 def move_constant_data(self):
1038 """Determine if data, can be moved from permanent storage to another memory area. A move
1039 will generate a DMA command in the high-level command stream"""
1040 for sched_op in self.sched_ops:
1041 parent_op = sched_op.parent_op
1042 is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in parent_op.inputs)
1043 max_ifm_shram_avail = (
1044 (self.arch.available_shram_banks(is_lut_used) - self.arch.shram_reserved_output_banks)
1045 * self.arch.shram_bank_size
1046 // 2
1047 )
1048
1049 for idx, tens in enumerate(parent_op.inputs):
1050 if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
1051 # Tensor is in permanent storage
1052 # Only when permanent storage differs from feature map storage, there is a point moving the data
1053 if (
1054 tens.mem_area in self.arch.permanent_storage_mem_area
1055 and self.arch.permanent_storage_mem_area != self.arch.feature_map_storage_mem_area
1056 ) or tens.purpose == TensorPurpose.LUT:
1057 if tens.purpose == TensorPurpose.LUT or (
Patrik Gustavsson94292fe2021-09-02 08:22:58 +02001058 # For elementwise broadcast
Tim Halld8339a72021-05-27 18:49:40 +01001059 tens.purpose == TensorPurpose.FeatureMap
1060 and sched_op.op_type.is_binary_elementwise_op()
1061 and tens.shape != []
1062 and sched_op.ifm.shape != sched_op.ofm.shape
Patrik Gustavsson94292fe2021-09-02 08:22:58 +02001063 and parent_op.write_shape is None
Tim Halld8339a72021-05-27 18:49:40 +01001064 and tens.storage_size() > max_ifm_shram_avail
1065 ):
1066 only_vector_product_consumers = all(
1067 oper and oper.type.npu_block_type == NpuBlockType.VectorProduct
1068 for oper in tens.consumers()
1069 )
1070
1071 if (not only_vector_product_consumers) or tens.purpose == TensorPurpose.LUT:
1072 new_tens = tens.clone_into_fast_storage(self.arch)
1073 if tens.purpose == TensorPurpose.LUT:
1074 new_tens.mem_area = MemArea.Shram
1075
1076 new_tens.consumer_list.append(parent_op)
1077 parent_op.inputs[idx] = new_tens
Dwight Lidman352607c2021-09-29 17:00:09 +02001078 # If the index is out of range, IFM and IFM2 are the same tensor
1079 # and pass inputs don't have duplicates
1080 if idx < len(sched_op.parent_ps.inputs):
1081 sched_op.parent_ps.inputs[idx] = new_tens
Tim Halld8339a72021-05-27 18:49:40 +01001082
1083 def print_schedule(self, schedule: Schedule):
1084 print(f"Schedule: '{schedule.name}'")
1085 for sched_op in self.sched_ops:
1086 if sched_op not in schedule.cost_map:
1087 # Sub-schedule printing
1088 continue
1089
1090 op_info = schedule.cost_map[sched_op]
1091 print(f"\t{sched_op.index}: Operation {sched_op.name} - OFM {sched_op.ofm.shape}")
1092 print(f"\t\tType: {sched_op.op_type}")
1093 print(f"\t\tKernel: {sched_op.kernel}")
1094 print(f"{op_info}")
1095 mem_usage = (
1096 schedule.memory_snapshot[op_info.time_index]
1097 if op_info.time_index < len(schedule.memory_snapshot)
1098 else 0
1099 )
1100 print(f"\t\tSRAM Used: {mem_usage} bytes")
1101
Jonas Ohlsson25e700c2022-03-04 14:58:56 +01001102 print("\tCascades:")
Tim Halld8339a72021-05-27 18:49:40 +01001103 for i, cascade in enumerate(schedule.cascades.values()):
1104 print(f"\t\t{i}: {cascade.start} -> {cascade.end}, size: {cascade.mem_usage}")
Patrik Gustavssonfeeb06d2020-04-22 12:53:47 +02001105
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001106
Tim Halld8339a72021-05-27 18:49:40 +01001107def _update_tensor_allocation(nng: Graph, arch: ArchitectureFeatures, options):
1108 """
1109 Creates live ranges and runs tensor allocator for the current schedule
1110 (i.e. sg.schedule for all subgraphs), returns the maximum memory usage
1111 and updates SchedulerOpInfo.mem_usage for all operations in the schedule.
1112 """
1113 root_sg = nng.get_root_subgraph()
1114
1115 alloc_list = []
1116 if arch.is_spilling_enabled():
1117 mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
1118 mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
1119 # Order is important
1120 alloc_list.append(mem_alloc_scratch_fast)
1121 alloc_list.append(mem_alloc_scratch)
1122 else:
1123 mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
1124 alloc_list.append(mem_alloc_scratch)
1125
1126 for mem_area, mem_type_set in alloc_list:
1127 tensor_allocation.allocate_tensors(
1128 nng,
1129 root_sg,
1130 arch,
1131 mem_area,
1132 mem_type_set,
1133 tensor_allocator=options.tensor_allocator,
1134 verbose_allocation=options.verbose_allocation,
1135 cpu_tensor_alignment=options.cpu_tensor_alignment,
1136 )
1137
1138
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +01001139class FastStorageComponentAllocator:
1140 def __init__(self, base_mem_usage, max_mem_usage, staging_limit):
1141 self.base_mem_usage = base_mem_usage
1142 self.max_mem_usage = list(max_mem_usage)
1143 self.staging_limit = staging_limit
1144 self.lrs = []
1145 self.evicted = []
1146 self.curr_evicted = []
1147 self.remaining_total_size = []
1148 self.best_allocated_size = 0
1149 self.max_exhaustive_size = 20
1150
1151 def allocate_exhaustive(self, ix, alloc_size):
1152 if ix >= len(self.lrs):
1153 if alloc_size > self.best_allocated_size:
1154 self.best_allocated_size = alloc_size
1155 self.evicted = self.curr_evicted
1156 return
1157
1158 lr = self.lrs[ix]
1159 for t in range(lr.start_time, lr.end_time):
1160 assert self.base_mem_usage[t] <= self.max_mem_usage[t]
1161 base_usage = max(self.base_mem_usage[lr.start_time : lr.end_time + 1])
1162 can_fit = base_usage + lr.size <= self.staging_limit
1163 always_fits = can_fit
1164
1165 if can_fit:
1166 max_usage = max(self.max_mem_usage[lr.start_time : lr.end_time + 1])
1167 always_fits = max_usage <= self.staging_limit
1168
1169 if can_fit or always_fits:
1170 self.curr_evicted[ix] = False
1171 self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, True)
1172 self.allocate_exhaustive(ix + 1, alloc_size + lr.size)
1173 self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, False)
1174
1175 if not always_fits:
1176 self.curr_evicted[ix] = True
1177 self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, False)
1178 self.allocate_exhaustive(ix + 1, alloc_size)
1179 self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, True)
1180
1181 @staticmethod
1182 def update_mem_usage(mem_usage, lr, increase):
1183 for t in range(lr.start_time, lr.end_time + 1):
1184 mem_usage[t] += lr.size if increase else -lr.size
1185 assert mem_usage[t] >= 0
1186 return mem_usage
1187
1188 @staticmethod
1189 def evict(lr, max_mem_usage, scratched_fms):
1190 for t in range(lr.start_time, lr.end_time + 1):
1191 max_mem_usage[t] -= lr.size
1192 for tens in lr.tensors:
1193 if tens in scratched_fms:
1194 tens.mem_area = scratched_fms[tens][0]
1195 tens.mem_type = scratched_fms[tens][1]
1196
1197 @staticmethod
1198 def keep(lr, base_mem_usage, staging_limit):
1199 for t in range(lr.start_time, lr.end_time + 1):
1200 base_mem_usage[t] += lr.size
1201 assert base_mem_usage[t] <= staging_limit
1202
1203 def allocate_component(self, allocator, lrs, max_mem, min_mem, staging_limit, scratched_fms):
1204 sz = len(lrs)
1205 allocator.lrs = lrs
1206 allocator.evicted = [0] * len(lrs)
1207 allocator.curr_evicted = [0] * sz
1208 allocator.best_allocated_size = -1
1209 # Recursively evaluate all permutations of allocations of the lrs found in the component
1210 allocator.allocate_exhaustive(0, 0)
1211
1212 # Optimal allocation has been found, move lrs accordingly
1213 for i, e in enumerate(allocator.evicted):
1214 if e:
1215 self.evict(lrs[i], max_mem, scratched_fms)
1216 else:
1217 self.keep(lrs[i], min_mem, staging_limit)
1218
1219
Tim Halld8339a72021-05-27 18:49:40 +01001220def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):
1221 """Entry point for the Scheduler"""
1222 # Initialize CPU subgraphs
1223 schedulers = dict()
1224 # Initialize schedulers with max schedule. Only schedule NPU subgraphs
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001225 for sg in nng.subgraphs:
Tim Halld8339a72021-05-27 18:49:40 +01001226 if sg.placement != PassPlacement.Npu:
1227 # Create cascaded passes for CPU Ops
1228 cascaded_passes = []
1229 for idx, ps in enumerate(sg.passes):
1230 cps = CascadedPass(
1231 ps.name, SchedulingStrategy.WeightStream, ps.inputs, [], ps.outputs, [ps], ps.placement, False,
1232 )
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001233
Tim Halld8339a72021-05-27 18:49:40 +01001234 cps.time = idx
1235 ps.cascade = cps
1236 cascaded_passes.append(cps)
Andreas Nevalainen27d36f02020-11-19 11:27:50 +01001237
Tim Halld8339a72021-05-27 18:49:40 +01001238 sg.cascaded_passes = cascaded_passes
1239 else:
1240 # Npu subgraph - create schedule
1241 scheduler = Scheduler(nng, sg, arch, scheduler_options)
1242 schedulers[sg] = scheduler
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001243
Tim Halld8339a72021-05-27 18:49:40 +01001244 scheduler.create_scheduler_representation(arch)
1245 sg.sched_ops = scheduler.sched_ops
1246 scheduler.move_constant_data()
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001247
Tim Halld8339a72021-05-27 18:49:40 +01001248 # Create the Max schedule template
1249 max_schedule_template = scheduler.create_initial_schedule()
1250 scheduler.max_schedule = max_schedule_template
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001251
Tim Halld8339a72021-05-27 18:49:40 +01001252 # Create the optimimised Max schedule
1253 sg.schedule = max_schedule_template
1254 scheduler.update_op_memory_snapshot(max_schedule_template)
Tim Hall789e6f32021-06-17 17:02:31 +01001255 opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template, 1 << 32)
Tim Halld8339a72021-05-27 18:49:40 +01001256 sg.schedule = opt_max_schedule
1257 scheduler.update_op_memory_snapshot(opt_max_schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001258
Tim Halld8339a72021-05-27 18:49:40 +01001259 # Create Min schedule
1260 min_schedule = scheduler.propose_minimal_schedule()
1261 initial_sram_limit = scheduler_options.optimization_sram_limit
1262 if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
1263 initial_sram_limit = scheduler.min_memory_req
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001264
Tim Halld8339a72021-05-27 18:49:40 +01001265 cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled())
1266 cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit)
1267 sg.schedule = min_schedule
1268 scheduler.update_op_memory_snapshot(min_schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001269
Tim Halld8339a72021-05-27 18:49:40 +01001270 if scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
1271 # Create an optimized schedule
1272 sg.schedule = scheduler.optimize_schedule(
1273 min_schedule, opt_max_schedule, max_schedule_template, scheduler_options
1274 )
1275 scheduler.update_op_memory_snapshot(sg.schedule)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001276
Tim Halld8339a72021-05-27 18:49:40 +01001277 scheduler.apply_schedule(sg.schedule)
1278 scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
Andreas Nevalainen897cc142020-10-28 15:42:08 +01001279
Tim Halld8339a72021-05-27 18:49:40 +01001280 if scheduler_options.verbose_schedule:
1281 scheduler.print_schedule(sg.schedule)
Tim Hall79d07d22020-04-27 18:20:16 +01001282
Tim Halld8339a72021-05-27 18:49:40 +01001283 # Evaluate schedule
1284 _update_tensor_allocation(nng, arch, options)