blob: b042ba73010d6a176222b16c5709f67a7c73e0d1 [file] [log] [blame]
Tim Halld8339a72021-05-27 18:49:40 +01001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Groups Operators in a schedule together to form Cascades.
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +020019from .high_level_command_to_npu_op import ifm_ifm2_correct_order
Johan Alfvénfba0a7d2022-10-11 20:41:41 +020020from .live_range import ofm_can_reuse_ifm
Tim Halld8339a72021-05-27 18:49:40 +010021from .numeric_util import round_up
22from .operation import NpuBlockType
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010023from .operation import Op
Rickard Bolin9ae34552022-06-09 13:07:17 +000024from .operation import Padding
Tim Halld8339a72021-05-27 18:49:40 +010025from .shape4d import Shape4D
26
27non_cascadable_blocks = (
28 NpuBlockType.Default,
29 NpuBlockType.VectorProduct,
Tim Halld8339a72021-05-27 18:49:40 +010030 NpuBlockType.ReduceSum,
31)
32
33
34class CascadeInfo:
35 """Contains metadata about a cascade"""
36
37 def __init__(self, start, end, buffers, mem_usage: int):
38 self.start = start
39 self.end = end
40 self.buffers = buffers
41 self.mem_usage = mem_usage
42
43
44class BufferMap:
45 """Caches the buffers seen"""
46
47 def __init__(self):
48 self.buffer_map = {}
49
50 def get_buffer(self, producer, consumer, cost):
51 assert producer or consumer
52 key = (producer, consumer)
53 if key not in self.buffer_map:
54 # No cached buffer between these two SchedulerOperations
55 if consumer is None:
56 # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full
57 buffer_shape = producer.ofm.shape
58 buffer_size = producer.ofm_size_in_bytes()
59 elif producer is None:
60 # First Op in subgraph or cascade - FeatureMap needs to be stored in full
61 buffer_shape = consumer.ifm.shape
62 buffer_size = consumer.ifm_size_in_bytes()
63 elif producer.requires_full_ofm or consumer.requires_full_ifm:
64 # FeatureMap needs to be stored in full
65 buffer_shape = max(producer.ofm.shape, consumer.ifm.shape)
66 buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes())
67 else:
68 # Use a rolling buffer
69 buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input)
70 buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes()
71
72 self.buffer_map[key] = (buffer_shape, buffer_size)
73
74 return self.buffer_map[key]
75
76
77def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
78 """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
79 buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
80 # Rolling buffers have to conform to NHCWB16 format
81 return consumer_stripe_input.with_height(buffer_height).with_depth(round_up(producer_stripe.depth, 16))
82
83
84class CascadeBuilder:
85 """Class for grouping SchedulerOperations into cascades"""
86
87 def __init__(self, sched_ops, spilling, non_local_mem_usage=None):
88 self.sched_ops = sched_ops
89 self.no_cascade = 0
90 self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {}
91 self.spilling = spilling
92
93 def _is_cascadable(self, sched_op, cost) -> bool:
94 """Checks if 'sched_op' can be cascaded"""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010095
Tim Halld8339a72021-05-27 18:49:40 +010096 return (
97 sched_op.op_type.npu_block_type not in non_cascadable_blocks
98 and cost.stripe.height < sched_op.ofm.shape.height
Johan Alfvénab677b32022-05-09 13:02:24 +020099 and sched_op.parent_op.read_offsets[0] is None
100 and sched_op.parent_op.read_offsets[1] is None
Johan Alfvén56a71b02022-10-19 11:20:12 +0200101 and self.elementwise_cascading_correct_order(sched_op)
Johan Alfvéndc7414a2022-08-18 11:12:40 +0200102 and not sched_op.parent_op.type.is_resize_op()
Fredrik Svedberg3e3faa92022-10-11 16:15:47 +0200103 and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias
Rickard Bolin9ae34552022-06-09 13:07:17 +0000104 and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
Tim Halld8339a72021-05-27 18:49:40 +0100105 )
106
107 def _estimate_sram_usage(self, sched_op, cost) -> int:
108 """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
109 ifm2_size = sched_op.ifm2_size_in_bytes()
110 if sched_op.requires_full_ifm:
111 ifm_size = sched_op.ifm_size_in_bytes()
112 else:
113 ifm_size = (
114 cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
115 * sched_op.ifm.dtype.size_in_bytes()
116 )
Johan Alfvénfba0a7d2022-10-11 20:41:41 +0200117 if ofm_can_reuse_ifm(sched_op):
118 # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
119 ofm_size = 0
120 elif sched_op.requires_full_ofm:
Tim Halld8339a72021-05-27 18:49:40 +0100121 ofm_size = sched_op.ofm_size_in_bytes()
122 else:
123 ofm_size = (
124 cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
125 )
126
127 return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
128
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100129 @staticmethod
Johan Alfvén56a71b02022-10-19 11:20:12 +0200130 def elementwise_cascading_conformity(sched_op):
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100131 """Check the inputs of the op to see if it's a candidate for cascading."""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100132
Johan Alfvén56a71b02022-10-19 11:20:12 +0200133 if sched_op.parent_op.type.is_binary_elementwise_op():
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100134 # We cannot rule out cascadability if at least one IFM is constant
Johan Alfvén56a71b02022-10-19 11:20:12 +0200135 ifm = sched_op.parent_op.ifm
136 ifm2 = sched_op.parent_op.ifm2
Fredrik Svedbergb81e1bb2022-10-11 21:50:51 +0200137 ifm_const = ifm.ops != [] and ifm.ops[0].type == Op.Const
Fredrik Svedberg3e3faa92022-10-11 16:15:47 +0200138 ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
Johan Alfvén56a71b02022-10-19 11:20:12 +0200139 return ifm_const or ifm2_const
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100140 else:
141 # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
142 return True
143
Johan Alfvén56a71b02022-10-19 11:20:12 +0200144 @staticmethod
145 def elementwise_cascading_correct_order(sched_op):
146 """Check the inputs of the op to see ifm and ifm2 has correct order."""
147
148 if sched_op.parent_op.type.is_binary_elementwise_op():
149 ifm2 = sched_op.parent_op.ifm2
150 ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
151
152 # ifm_ifm2_correct_order needs full shape
153 correct_order = ifm_ifm2_correct_order(sched_op.ifm.shape, sched_op.ifm2.shape)
154 return ifm2_const and correct_order
155 else:
156 return True
157
Tim Halld8339a72021-05-27 18:49:40 +0100158 def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
159 ref_cost = ref_schedule.cost_map
160 fallback_cost = fallback_schedule.cost_map
161 cost = {}
162 cascade_map = {}
163 buffers = BufferMap()
164
165 # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit
166 peak_sram_usage = guiding_mem_limit
167
168 idx = 0
169 while idx < len(self.sched_ops):
170 op = self.sched_ops[idx]
171 if op in cost:
172 # Already processed this Op
173 idx += 1
174 continue
175
176 if not self._is_cascadable(op, ref_cost[op]):
177 # Op is not a candidate for cascading - assign fallback cost
178 cost[op] = fallback_cost[op]
179 if not self.spilling:
180 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
181 idx += 1
182 continue
183
184 # Propose a cascade starting with this Op
185 cascade_start = op.index
186 # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
187 ops_in_cascade = [op]
188 ops_in_best_cascade = [op]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000189 # Get the size of the weight buffer(s)
190 weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100191
192 # The first IFM needs to be stored in full
193 cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
194
195 # Add non-local memory usage
196 cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
197
198 # Sum of all intermediate cascade buffers (including weight buffers)
199 cascade_buffers = weight_buffer
200 # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
201 best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op])
202
203 # Op is the producer of the OFM consumed by the next Op to consider
204 producer = op
205 while True:
206 dependants = producer.get_dependants()
207 if len(dependants) != 1:
208 # producer is either the last Op in the schedule or the start of a branch
209 break
210
211 current_op = dependants[0]
212 if (
213 current_op in cost
214 or current_op not in ref_cost
215 or not self._is_cascadable(current_op, ref_cost[current_op])
216 or producer.ofm.shape != current_op.ifm.shape
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200217 or current_op.requires_full_ifm
218 or producer.requires_full_ofm
Tim Halld8339a72021-05-27 18:49:40 +0100219 ):
220 # Current op has already been processed or cannot be cascaded
221 break
222
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100223 if producer.index + 1 != current_op.index:
224 # Cascading is possible, but requires reordering of operations in the schedule,
225 # this is currently not supported
226 break
227
Tim Halld8339a72021-05-27 18:49:40 +0100228 # Get the size of the FeatureMap buffers between current and neighbouring Ops
229 op_full_ifm = current_op.ifm_size_in_bytes()
230 op_full_ofm = current_op.ofm_size_in_bytes()
231 _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
232
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000233 # Get the size of the weight buffer(s)
234 op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100235
236 # Calculate the uncascaded memory requirement for current Op
237 uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
238
239 # Add current Op to cascade
240 ops_in_cascade.append(current_op)
241
242 # Increase the accumulated intermediate buffers in the cascade
243 cascade_buffers += op_ifm_buffer + op_weight_buffer
244
245 if self.spilling:
246 # For Dedicated SRAM only the intermediate buffers are in SRAM
247 if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage:
248 # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit
249 break
250 else:
251 # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM
252 ops_in_best_cascade = [op for op in ops_in_cascade]
253 best_cascade_size = cascade_buffers
254
255 else:
256 # Calculate the total size of the current cascade
257 cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
258
259 # Determine if cascading search should stop
260 if (
261 uncascaded_sram_usage < peak_sram_usage
262 and best_cascade_size < peak_sram_usage
263 or (cascade_ifm_size + cascade_buffers) > best_cascade_size
264 ):
265 # Both the existing cascade and current Op fits
266 break
267
Johan Alfvén255dad72022-07-16 18:27:05 +0200268 """
269 One of two conditions will update the best cascade:
270
271 - cascade_size < best_cascade_size or
272 - cascade_size < uncascaded_sram_usage
273
274 The last condition is illustrated below, showing an example where it is
275 better to choose a larger cascade_size (with more OPs) because it will
276 use less total SRAM usage.
277
278 For simplicity, all featuremaps have same size.
279
280 Cascade OP1-OP2, OP3 is standalone
281
282 -> |OP1| -> roll buffer -> |OP2| -> FM -> |OP3| -> FM
283 /
284 |OP0| -> FM
285 \
286 -> ....
287
288
289 best_cascade_size : FM + roll buffer + FM
290 uncascaded_sram_usage: FM + FM + FM
291
292 compared with:
293
294 Cascade OP1-OP3
295
296 -> |OP1| -> roll buffer -> |OP2| -> roll buffer -> |OP3| -> FM
297 /
298 |OP0| -> FM
299 \
300 -> ....
301
302
303 cascade_size : FM + roll buffer + roll buffer + FM
304
305
306 So, for this use case the comparison will be
307
308 (FM + roll buffer + roll buffer + FM) < (FM + roll buffer + FM) or
309 (FM + roll buffer + roll buffer + FM) < (FM + FM + FM)
310
311 hence, better to choose Cascade OP1-OP3 in this case.
312 """
313 if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
Tim Halld8339a72021-05-27 18:49:40 +0100314 best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
315 ops_in_best_cascade = [op for op in ops_in_cascade]
316
317 producer = current_op
318
319 if len(ops_in_best_cascade) > 1:
320 # A cascade was created - assign cascade and ref_cost to all of the Ops
321 cascade_end = cascade_start + (len(ops_in_best_cascade) - 1)
322 buffers_in_cascade = {}
323 prev_op = None
324 for cascaded_op in ops_in_best_cascade:
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100325 assert cascade_start <= cascaded_op.index <= cascade_end
Tim Halld8339a72021-05-27 18:49:40 +0100326 cost[cascaded_op] = ref_cost[cascaded_op]
327 cost[cascaded_op].cascade = cascade_end
328 if prev_op:
329 rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost)
330 buffers_in_cascade[cascaded_op] = rolling_buffer_shape
331
332 prev_op = cascaded_op
333
334 # Create a CascadeInfo for the cascade
335 cascade_map[cascade_end] = CascadeInfo(
336 cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
337 )
338 if not self.spilling:
339 # Update peak memory usage
340 peak_sram_usage = max(best_cascade_size, peak_sram_usage)
341 else:
342 # Assign fallback cost to the initial Op
343 cost[op] = fallback_cost[op]
344 if not self.spilling:
345 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
346
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100347 # Update costing and cascade information for the ref_schedule
Tim Halld8339a72021-05-27 18:49:40 +0100348 ref_schedule.cost_map = cost
349 ref_schedule.cascades = cascade_map
350 return ref_schedule