blob: b4a4f8767d2da8aadfac507131ba53edcc160abf [file] [log] [blame]
Tim Halld8339a72021-05-27 18:49:40 +01001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Groups Operators in a schedule together to form Cascades.
Johan Alfvén255dad72022-07-16 18:27:05 +020019from collections import namedtuple
20
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +020021from .high_level_command_to_npu_op import ifm_ifm2_correct_order
Tim Halld8339a72021-05-27 18:49:40 +010022from .numeric_util import round_up
23from .operation import NpuBlockType
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010024from .operation import Op
Rickard Bolin9ae34552022-06-09 13:07:17 +000025from .operation import Padding
Tim Halld8339a72021-05-27 18:49:40 +010026from .shape4d import Shape4D
27
28non_cascadable_blocks = (
29 NpuBlockType.Default,
30 NpuBlockType.VectorProduct,
Tim Halld8339a72021-05-27 18:49:40 +010031 NpuBlockType.ReduceSum,
32)
33
34
35class CascadeInfo:
36 """Contains metadata about a cascade"""
37
38 def __init__(self, start, end, buffers, mem_usage: int):
39 self.start = start
40 self.end = end
41 self.buffers = buffers
42 self.mem_usage = mem_usage
43
44
45class BufferMap:
46 """Caches the buffers seen"""
47
48 def __init__(self):
49 self.buffer_map = {}
50
51 def get_buffer(self, producer, consumer, cost):
52 assert producer or consumer
53 key = (producer, consumer)
54 if key not in self.buffer_map:
55 # No cached buffer between these two SchedulerOperations
56 if consumer is None:
57 # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full
58 buffer_shape = producer.ofm.shape
59 buffer_size = producer.ofm_size_in_bytes()
60 elif producer is None:
61 # First Op in subgraph or cascade - FeatureMap needs to be stored in full
62 buffer_shape = consumer.ifm.shape
63 buffer_size = consumer.ifm_size_in_bytes()
64 elif producer.requires_full_ofm or consumer.requires_full_ifm:
65 # FeatureMap needs to be stored in full
66 buffer_shape = max(producer.ofm.shape, consumer.ifm.shape)
67 buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes())
68 else:
69 # Use a rolling buffer
70 buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input)
71 buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes()
72
73 self.buffer_map[key] = (buffer_shape, buffer_size)
74
75 return self.buffer_map[key]
76
77
78def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
79 """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
80 buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
81 # Rolling buffers have to conform to NHCWB16 format
82 return consumer_stripe_input.with_height(buffer_height).with_depth(round_up(producer_stripe.depth, 16))
83
84
85class CascadeBuilder:
86 """Class for grouping SchedulerOperations into cascades"""
87
88 def __init__(self, sched_ops, spilling, non_local_mem_usage=None):
89 self.sched_ops = sched_ops
90 self.no_cascade = 0
91 self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {}
92 self.spilling = spilling
93
94 def _is_cascadable(self, sched_op, cost) -> bool:
95 """Checks if 'sched_op' can be cascaded"""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010096
Tim Halld8339a72021-05-27 18:49:40 +010097 return (
98 sched_op.op_type.npu_block_type not in non_cascadable_blocks
99 and cost.stripe.height < sched_op.ofm.shape.height
Johan Alfvénab677b32022-05-09 13:02:24 +0200100 and sched_op.parent_op.read_offsets[0] is None
101 and sched_op.parent_op.read_offsets[1] is None
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100102 and self.element_wise_cascading_conformity(sched_op)
Johan Alfvéndc7414a2022-08-18 11:12:40 +0200103 and not sched_op.parent_op.type.is_resize_op()
Fredrik Svedberg3e3faa92022-10-11 16:15:47 +0200104 and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias
Rickard Bolin9ae34552022-06-09 13:07:17 +0000105 and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
Tim Halld8339a72021-05-27 18:49:40 +0100106 )
107
Johan Alfvén255dad72022-07-16 18:27:05 +0200108 def _is_mergeable(self, sched_op) -> bool:
109 # Code based on merge_elementwise_op_ranges from live_range.py
110
111 if not sched_op.op_type.is_elementwise_op():
112 return False
113
114 elem_op = sched_op.parent_op
115
116 # Check if overwriting the inputs can be allowed
117 OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
118 outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
119
120 # check output tensor only has one producer
121 if len(outp.tens.ops) != 1:
122 return False
123
124 inps = []
125 if elem_op.ifm is not None:
126 inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
127 if elem_op.ifm2 is not None:
128 inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
129
130 # find an input tensor that can be overwritten by the output
131 for inp in inps:
132 if (
133 # check op input and output shapes allow overlapping
134 inp.op_shape == outp.op_shape
135 # check input tensor is valid
136 and inp.tens is not None
137 and inp.tens.shape != []
138 # check input and output tensors are compatible
139 and inp.tens.format == outp.tens.format
140 and inp.tens.dtype == outp.tens.dtype
141 # check input tensor only has one consumer
142 and len(inp.tens.consumer_list) == 1
143 ):
144 return True
145
146 return False
147
Tim Halld8339a72021-05-27 18:49:40 +0100148 def _estimate_sram_usage(self, sched_op, cost) -> int:
149 """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
150 ifm2_size = sched_op.ifm2_size_in_bytes()
151 if sched_op.requires_full_ifm:
152 ifm_size = sched_op.ifm_size_in_bytes()
153 else:
154 ifm_size = (
155 cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
156 * sched_op.ifm.dtype.size_in_bytes()
157 )
158 if sched_op.requires_full_ofm:
159 ofm_size = sched_op.ofm_size_in_bytes()
160 else:
161 ofm_size = (
162 cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
163 )
164
Johan Alfvén255dad72022-07-16 18:27:05 +0200165 if self._is_mergeable(sched_op):
166 # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
167 ofm_size = 0
168
Tim Halld8339a72021-05-27 18:49:40 +0100169 return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
170
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100171 @staticmethod
172 def element_wise_cascading_conformity(sched_op):
173 """Check the inputs of the op to see if it's a candidate for cascading."""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100174
175 ifm = sched_op.parent_op.ifm
176 ifm2 = sched_op.parent_op.ifm2
177
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200178 # Cascading elementwise operations with reverse operand order is not handled
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100179 if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
180 # We cannot rule out cascadability if at least one IFM is constant
Fredrik Svedberg3e3faa92022-10-11 16:15:47 +0200181 ifm2_const = ifm2.ops != [] and ifm2.ops[0].type == Op.Const
182 return ifm_ifm2_correct_order(ifm.shape, ifm2.shape) and ifm2_const
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100183 else:
184 # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
185 return True
186
Tim Halld8339a72021-05-27 18:49:40 +0100187 def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
188 ref_cost = ref_schedule.cost_map
189 fallback_cost = fallback_schedule.cost_map
190 cost = {}
191 cascade_map = {}
192 buffers = BufferMap()
193
194 # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit
195 peak_sram_usage = guiding_mem_limit
196
197 idx = 0
198 while idx < len(self.sched_ops):
199 op = self.sched_ops[idx]
200 if op in cost:
201 # Already processed this Op
202 idx += 1
203 continue
204
205 if not self._is_cascadable(op, ref_cost[op]):
206 # Op is not a candidate for cascading - assign fallback cost
207 cost[op] = fallback_cost[op]
208 if not self.spilling:
209 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
210 idx += 1
211 continue
212
213 # Propose a cascade starting with this Op
214 cascade_start = op.index
215 # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
216 ops_in_cascade = [op]
217 ops_in_best_cascade = [op]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000218 # Get the size of the weight buffer(s)
219 weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100220
221 # The first IFM needs to be stored in full
222 cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
223
224 # Add non-local memory usage
225 cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
226
227 # Sum of all intermediate cascade buffers (including weight buffers)
228 cascade_buffers = weight_buffer
229 # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
230 best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op])
231
232 # Op is the producer of the OFM consumed by the next Op to consider
233 producer = op
234 while True:
235 dependants = producer.get_dependants()
236 if len(dependants) != 1:
237 # producer is either the last Op in the schedule or the start of a branch
238 break
239
240 current_op = dependants[0]
241 if (
242 current_op in cost
243 or current_op not in ref_cost
244 or not self._is_cascadable(current_op, ref_cost[current_op])
245 or producer.ofm.shape != current_op.ifm.shape
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200246 or current_op.requires_full_ifm
247 or producer.requires_full_ofm
Tim Halld8339a72021-05-27 18:49:40 +0100248 ):
249 # Current op has already been processed or cannot be cascaded
250 break
251
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100252 if producer.index + 1 != current_op.index:
253 # Cascading is possible, but requires reordering of operations in the schedule,
254 # this is currently not supported
255 break
256
Tim Halld8339a72021-05-27 18:49:40 +0100257 # Get the size of the FeatureMap buffers between current and neighbouring Ops
258 op_full_ifm = current_op.ifm_size_in_bytes()
259 op_full_ofm = current_op.ofm_size_in_bytes()
260 _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
261
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000262 # Get the size of the weight buffer(s)
263 op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100264
265 # Calculate the uncascaded memory requirement for current Op
266 uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
267
268 # Add current Op to cascade
269 ops_in_cascade.append(current_op)
270
271 # Increase the accumulated intermediate buffers in the cascade
272 cascade_buffers += op_ifm_buffer + op_weight_buffer
273
274 if self.spilling:
275 # For Dedicated SRAM only the intermediate buffers are in SRAM
276 if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage:
277 # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit
278 break
279 else:
280 # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM
281 ops_in_best_cascade = [op for op in ops_in_cascade]
282 best_cascade_size = cascade_buffers
283
284 else:
285 # Calculate the total size of the current cascade
286 cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
287
288 # Determine if cascading search should stop
289 if (
290 uncascaded_sram_usage < peak_sram_usage
291 and best_cascade_size < peak_sram_usage
292 or (cascade_ifm_size + cascade_buffers) > best_cascade_size
293 ):
294 # Both the existing cascade and current Op fits
295 break
296
Johan Alfvén255dad72022-07-16 18:27:05 +0200297 """
298 One of two conditions will update the best cascade:
299
300 - cascade_size < best_cascade_size or
301 - cascade_size < uncascaded_sram_usage
302
303 The last condition is illustrated below, showing an example where it is
304 better to choose a larger cascade_size (with more OPs) because it will
305 use less total SRAM usage.
306
307 For simplicity, all featuremaps have same size.
308
309 Cascade OP1-OP2, OP3 is standalone
310
311 -> |OP1| -> roll buffer -> |OP2| -> FM -> |OP3| -> FM
312 /
313 |OP0| -> FM
314 \
315 -> ....
316
317
318 best_cascade_size : FM + roll buffer + FM
319 uncascaded_sram_usage: FM + FM + FM
320
321 compared with:
322
323 Cascade OP1-OP3
324
325 -> |OP1| -> roll buffer -> |OP2| -> roll buffer -> |OP3| -> FM
326 /
327 |OP0| -> FM
328 \
329 -> ....
330
331
332 cascade_size : FM + roll buffer + roll buffer + FM
333
334
335 So, for this use case the comparison will be
336
337 (FM + roll buffer + roll buffer + FM) < (FM + roll buffer + FM) or
338 (FM + roll buffer + roll buffer + FM) < (FM + FM + FM)
339
340 hence, better to choose Cascade OP1-OP3 in this case.
341 """
342 if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
Tim Halld8339a72021-05-27 18:49:40 +0100343 best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
344 ops_in_best_cascade = [op for op in ops_in_cascade]
345
346 producer = current_op
347
348 if len(ops_in_best_cascade) > 1:
349 # A cascade was created - assign cascade and ref_cost to all of the Ops
350 cascade_end = cascade_start + (len(ops_in_best_cascade) - 1)
351 buffers_in_cascade = {}
352 prev_op = None
353 for cascaded_op in ops_in_best_cascade:
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100354 assert cascade_start <= cascaded_op.index <= cascade_end
Tim Halld8339a72021-05-27 18:49:40 +0100355 cost[cascaded_op] = ref_cost[cascaded_op]
356 cost[cascaded_op].cascade = cascade_end
357 if prev_op:
358 rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost)
359 buffers_in_cascade[cascaded_op] = rolling_buffer_shape
360
361 prev_op = cascaded_op
362
363 # Create a CascadeInfo for the cascade
364 cascade_map[cascade_end] = CascadeInfo(
365 cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
366 )
367 if not self.spilling:
368 # Update peak memory usage
369 peak_sram_usage = max(best_cascade_size, peak_sram_usage)
370 else:
371 # Assign fallback cost to the initial Op
372 cost[op] = fallback_cost[op]
373 if not self.spilling:
374 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
375
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100376 # Update costing and cascade information for the ref_schedule
Tim Halld8339a72021-05-27 18:49:40 +0100377 ref_schedule.cost_map = cost
378 ref_schedule.cascades = cascade_map
379 return ref_schedule