blob: 95872cfef40648fe7fca5efbdecefe9255996c86 [file] [log] [blame]
Johan Alfvén16573072022-12-22 10:12:54 +01001# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Halld8339a72021-05-27 18:49:40 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Groups Operators in a schedule together to form Cascades.
Johan Alfvénfba0a7d2022-10-11 20:41:41 +020019from .live_range import ofm_can_reuse_ifm
Tim Halld8339a72021-05-27 18:49:40 +010020from .numeric_util import round_up
21from .operation import NpuBlockType
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010022from .operation import Op
Rickard Bolin9ae34552022-06-09 13:07:17 +000023from .operation import Padding
Tim Halld8339a72021-05-27 18:49:40 +010024from .shape4d import Shape4D
25
26non_cascadable_blocks = (
27 NpuBlockType.Default,
28 NpuBlockType.VectorProduct,
Tim Halld8339a72021-05-27 18:49:40 +010029 NpuBlockType.ReduceSum,
30)
31
32
33class CascadeInfo:
34 """Contains metadata about a cascade"""
35
36 def __init__(self, start, end, buffers, mem_usage: int):
37 self.start = start
38 self.end = end
39 self.buffers = buffers
40 self.mem_usage = mem_usage
41
42
43class BufferMap:
44 """Caches the buffers seen"""
45
46 def __init__(self):
47 self.buffer_map = {}
48
49 def get_buffer(self, producer, consumer, cost):
50 assert producer or consumer
51 key = (producer, consumer)
52 if key not in self.buffer_map:
53 # No cached buffer between these two SchedulerOperations
54 if consumer is None:
55 # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full
56 buffer_shape = producer.ofm.shape
57 buffer_size = producer.ofm_size_in_bytes()
58 elif producer is None:
59 # First Op in subgraph or cascade - FeatureMap needs to be stored in full
60 buffer_shape = consumer.ifm.shape
61 buffer_size = consumer.ifm_size_in_bytes()
62 elif producer.requires_full_ofm or consumer.requires_full_ifm:
63 # FeatureMap needs to be stored in full
64 buffer_shape = max(producer.ofm.shape, consumer.ifm.shape)
65 buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes())
66 else:
67 # Use a rolling buffer
68 buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input)
69 buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes()
70
71 self.buffer_map[key] = (buffer_shape, buffer_size)
72
73 return self.buffer_map[key]
74
75
76def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
77 """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
78 buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
Johan Alfvénda69e6d2022-12-21 10:27:18 +010079 # Striding on the consumer op can result in IFM widths that are narrower than the OFM width of the producer.
80 # Therefore, the maximum of the two needs to be used.
81 buffer_width = max(producer_stripe.width, consumer_stripe_input.width)
Tim Halld8339a72021-05-27 18:49:40 +010082 # Rolling buffers have to conform to NHCWB16 format
Johan Alfvénda69e6d2022-12-21 10:27:18 +010083 return Shape4D([1, buffer_height, buffer_width, round_up(producer_stripe.depth, 16)])
Tim Halld8339a72021-05-27 18:49:40 +010084
85
86class CascadeBuilder:
87 """Class for grouping SchedulerOperations into cascades"""
88
89 def __init__(self, sched_ops, spilling, non_local_mem_usage=None):
90 self.sched_ops = sched_ops
91 self.no_cascade = 0
92 self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {}
93 self.spilling = spilling
94
95 def _is_cascadable(self, sched_op, cost) -> bool:
96 """Checks if 'sched_op' can be cascaded"""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010097
Tim Halld8339a72021-05-27 18:49:40 +010098 return (
99 sched_op.op_type.npu_block_type not in non_cascadable_blocks
100 and cost.stripe.height < sched_op.ofm.shape.height
Johan Alfvénab677b32022-05-09 13:02:24 +0200101 and sched_op.parent_op.read_offsets[0] is None
102 and sched_op.parent_op.read_offsets[1] is None
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200103 and self.elementwise_cascadable(sched_op)
Johan Alfvéndc7414a2022-08-18 11:12:40 +0200104 and not sched_op.parent_op.type.is_resize_op()
Fredrik Svedberg3e3faa92022-10-11 16:15:47 +0200105 and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias
Rickard Bolin9ae34552022-06-09 13:07:17 +0000106 and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
Tim Halld8339a72021-05-27 18:49:40 +0100107 )
108
109 def _estimate_sram_usage(self, sched_op, cost) -> int:
110 """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
Johan Alfvén92689d52022-12-06 11:16:19 +0100111 if sched_op.parent_op.type.is_binary_elementwise_op():
112 # ifm2 is scalar or constant and will always persist in permanent memory
113 ifm2_size = 0
114 else:
115 ifm2_size = sched_op.ifm2_size_in_bytes()
Tim Halld8339a72021-05-27 18:49:40 +0100116 if sched_op.requires_full_ifm:
117 ifm_size = sched_op.ifm_size_in_bytes()
118 else:
119 ifm_size = (
120 cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
121 * sched_op.ifm.dtype.size_in_bytes()
122 )
Johan Alfvénfba0a7d2022-10-11 20:41:41 +0200123 if ofm_can_reuse_ifm(sched_op):
124 # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
125 ofm_size = 0
126 elif sched_op.requires_full_ofm:
Tim Halld8339a72021-05-27 18:49:40 +0100127 ofm_size = sched_op.ofm_size_in_bytes()
128 else:
129 ofm_size = (
130 cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
131 )
132
133 return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
134
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100135 @staticmethod
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200136 def elementwise_cascadable(sched_op):
137 """Check if the elementwise can be cascaded."""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100138
Johan Alfvén56a71b02022-10-19 11:20:12 +0200139 if sched_op.parent_op.type.is_binary_elementwise_op():
Johan Alfvén56a71b02022-10-19 11:20:12 +0200140 ifm = sched_op.parent_op.ifm
141 ifm2 = sched_op.parent_op.ifm2
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200142 ofm = sched_op.parent_op.ofm
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100143
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200144 # IFM must be non-constant/non-scalar/non-broadcast
145 ifm_cascadable = not (ifm.is_const or ifm.is_scalar or ifm.is_broadcast(ofm))
Johan Alfvén56a71b02022-10-19 11:20:12 +0200146
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200147 # IFM2 must be constant or scalar
148 ifm2_cascadable = ifm2.is_const or ifm2.is_scalar
Johan Alfvén56a71b02022-10-19 11:20:12 +0200149
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200150 return ifm_cascadable and ifm2_cascadable
Johan Alfvén56a71b02022-10-19 11:20:12 +0200151 else:
152 return True
153
Tim Halld8339a72021-05-27 18:49:40 +0100154 def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
155 ref_cost = ref_schedule.cost_map
156 fallback_cost = fallback_schedule.cost_map
157 cost = {}
158 cascade_map = {}
159 buffers = BufferMap()
160
161 # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit
162 peak_sram_usage = guiding_mem_limit
163
164 idx = 0
165 while idx < len(self.sched_ops):
166 op = self.sched_ops[idx]
167 if op in cost:
168 # Already processed this Op
169 idx += 1
170 continue
171
172 if not self._is_cascadable(op, ref_cost[op]):
173 # Op is not a candidate for cascading - assign fallback cost
174 cost[op] = fallback_cost[op]
175 if not self.spilling:
176 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
177 idx += 1
178 continue
179
180 # Propose a cascade starting with this Op
181 cascade_start = op.index
182 # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
183 ops_in_cascade = [op]
184 ops_in_best_cascade = [op]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000185 # Get the size of the weight buffer(s)
186 weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100187
188 # The first IFM needs to be stored in full
189 cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
190
Tim Halld8339a72021-05-27 18:49:40 +0100191 # Sum of all intermediate cascade buffers (including weight buffers)
192 cascade_buffers = weight_buffer
193 # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
194 best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op])
195
196 # Op is the producer of the OFM consumed by the next Op to consider
197 producer = op
198 while True:
199 dependants = producer.get_dependants()
200 if len(dependants) != 1:
201 # producer is either the last Op in the schedule or the start of a branch
202 break
203
204 current_op = dependants[0]
205 if (
206 current_op in cost
207 or current_op not in ref_cost
208 or not self._is_cascadable(current_op, ref_cost[current_op])
209 or producer.ofm.shape != current_op.ifm.shape
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200210 or current_op.requires_full_ifm
211 or producer.requires_full_ofm
Tim Halld8339a72021-05-27 18:49:40 +0100212 ):
213 # Current op has already been processed or cannot be cascaded
214 break
215
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100216 if producer.index + 1 != current_op.index:
217 # Cascading is possible, but requires reordering of operations in the schedule,
218 # this is currently not supported
219 break
220
Tim Halld8339a72021-05-27 18:49:40 +0100221 # Get the size of the FeatureMap buffers between current and neighbouring Ops
Tim Halld8339a72021-05-27 18:49:40 +0100222 op_full_ofm = current_op.ofm_size_in_bytes()
223 _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
224
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000225 # Get the size of the weight buffer(s)
226 op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100227
228 # Calculate the uncascaded memory requirement for current Op
Johan Alfvén16573072022-12-22 10:12:54 +0100229 uncascaded_sram_usage = self._estimate_sram_usage(current_op, fallback_cost[current_op])
Tim Halld8339a72021-05-27 18:49:40 +0100230
231 # Add current Op to cascade
232 ops_in_cascade.append(current_op)
233
234 # Increase the accumulated intermediate buffers in the cascade
235 cascade_buffers += op_ifm_buffer + op_weight_buffer
236
237 if self.spilling:
238 # For Dedicated SRAM only the intermediate buffers are in SRAM
239 if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage:
240 # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit
241 break
242 else:
243 # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM
244 ops_in_best_cascade = [op for op in ops_in_cascade]
245 best_cascade_size = cascade_buffers
246
247 else:
Johan Alfven3340a882023-03-16 11:04:31 +0100248 # Calculate the total size of the current cascade including non local mem usage
249 cascade_size = (
250 cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0)
251 )
Tim Halld8339a72021-05-27 18:49:40 +0100252
253 # Determine if cascading search should stop
254 if (
255 uncascaded_sram_usage < peak_sram_usage
256 and best_cascade_size < peak_sram_usage
257 or (cascade_ifm_size + cascade_buffers) > best_cascade_size
258 ):
Johan Alfven3340a882023-03-16 11:04:31 +0100259 # Both the existing cascade and current Op fits or
260 # not possible to reduce cascade size any further
Tim Halld8339a72021-05-27 18:49:40 +0100261 break
262
Johan Alfvén255dad72022-07-16 18:27:05 +0200263 """
264 One of two conditions will update the best cascade:
265
266 - cascade_size < best_cascade_size or
267 - cascade_size < uncascaded_sram_usage
268
269 The last condition is illustrated below, showing an example where it is
270 better to choose a larger cascade_size (with more OPs) because it will
271 use less total SRAM usage.
272
273 For simplicity, all featuremaps have same size.
274
275 Cascade OP1-OP2, OP3 is standalone
276
277 -> |OP1| -> roll buffer -> |OP2| -> FM -> |OP3| -> FM
278 /
279 |OP0| -> FM
280 \
281 -> ....
282
283
284 best_cascade_size : FM + roll buffer + FM
285 uncascaded_sram_usage: FM + FM + FM
286
287 compared with:
288
289 Cascade OP1-OP3
290
291 -> |OP1| -> roll buffer -> |OP2| -> roll buffer -> |OP3| -> FM
292 /
293 |OP0| -> FM
294 \
295 -> ....
296
297
298 cascade_size : FM + roll buffer + roll buffer + FM
299
300
301 So, for this use case the comparison will be
302
303 (FM + roll buffer + roll buffer + FM) < (FM + roll buffer + FM) or
304 (FM + roll buffer + roll buffer + FM) < (FM + FM + FM)
305
306 hence, better to choose Cascade OP1-OP3 in this case.
307 """
308 if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
Johan Alfven3340a882023-03-16 11:04:31 +0100309 best_cascade_size = cascade_size
Tim Halld8339a72021-05-27 18:49:40 +0100310 ops_in_best_cascade = [op for op in ops_in_cascade]
311
312 producer = current_op
313
314 if len(ops_in_best_cascade) > 1:
315 # A cascade was created - assign cascade and ref_cost to all of the Ops
316 cascade_end = cascade_start + (len(ops_in_best_cascade) - 1)
317 buffers_in_cascade = {}
318 prev_op = None
319 for cascaded_op in ops_in_best_cascade:
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100320 assert cascade_start <= cascaded_op.index <= cascade_end
Tim Halld8339a72021-05-27 18:49:40 +0100321 cost[cascaded_op] = ref_cost[cascaded_op]
322 cost[cascaded_op].cascade = cascade_end
323 if prev_op:
324 rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost)
325 buffers_in_cascade[cascaded_op] = rolling_buffer_shape
326
327 prev_op = cascaded_op
328
Johan Alfven3340a882023-03-16 11:04:31 +0100329 # Create a CascadeInfo for the cascade, only store the actual size used by
330 # the cascade so non local usage is removed. This is done in order to be
331 # able to calculate the correct non local usage in the scheduler when
332 # optimizing the sub schedules.
Tim Halld8339a72021-05-27 18:49:40 +0100333 cascade_map[cascade_end] = CascadeInfo(
Johan Alfven3340a882023-03-16 11:04:31 +0100334 cascade_start,
335 cascade_end,
336 buffers_in_cascade,
337 best_cascade_size - self.non_local_mem_usage.get(op, 0),
Tim Halld8339a72021-05-27 18:49:40 +0100338 )
339 if not self.spilling:
340 # Update peak memory usage
341 peak_sram_usage = max(best_cascade_size, peak_sram_usage)
342 else:
343 # Assign fallback cost to the initial Op
344 cost[op] = fallback_cost[op]
345 if not self.spilling:
346 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
347
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100348 # Update costing and cascade information for the ref_schedule
Tim Halld8339a72021-05-27 18:49:40 +0100349 ref_schedule.cost_map = cost
350 ref_schedule.cascades = cascade_map
351 return ref_schedule