blob: 09c36b9eca87ff493ebcf99d65e231883d4fa7ff [file] [log] [blame]
Tim Halld8339a72021-05-27 18:49:40 +01001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Groups Operators in a schedule together to form Cascades.
19from .numeric_util import round_up
20from .operation import NpuBlockType
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010021from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010022from .shape4d import Shape4D
23
24non_cascadable_blocks = (
25 NpuBlockType.Default,
26 NpuBlockType.VectorProduct,
Tim Halld8339a72021-05-27 18:49:40 +010027 NpuBlockType.ReduceSum,
28)
29
30
31class CascadeInfo:
32 """Contains metadata about a cascade"""
33
34 def __init__(self, start, end, buffers, mem_usage: int):
35 self.start = start
36 self.end = end
37 self.buffers = buffers
38 self.mem_usage = mem_usage
39
40
41class BufferMap:
42 """Caches the buffers seen"""
43
44 def __init__(self):
45 self.buffer_map = {}
46
47 def get_buffer(self, producer, consumer, cost):
48 assert producer or consumer
49 key = (producer, consumer)
50 if key not in self.buffer_map:
51 # No cached buffer between these two SchedulerOperations
52 if consumer is None:
53 # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full
54 buffer_shape = producer.ofm.shape
55 buffer_size = producer.ofm_size_in_bytes()
56 elif producer is None:
57 # First Op in subgraph or cascade - FeatureMap needs to be stored in full
58 buffer_shape = consumer.ifm.shape
59 buffer_size = consumer.ifm_size_in_bytes()
60 elif producer.requires_full_ofm or consumer.requires_full_ifm:
61 # FeatureMap needs to be stored in full
62 buffer_shape = max(producer.ofm.shape, consumer.ifm.shape)
63 buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes())
64 else:
65 # Use a rolling buffer
66 buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input)
67 buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes()
68
69 self.buffer_map[key] = (buffer_shape, buffer_size)
70
71 return self.buffer_map[key]
72
73
74def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D:
75 """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade"""
76 buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height)
77 # Rolling buffers have to conform to NHCWB16 format
78 return consumer_stripe_input.with_height(buffer_height).with_depth(round_up(producer_stripe.depth, 16))
79
80
81class CascadeBuilder:
82 """Class for grouping SchedulerOperations into cascades"""
83
84 def __init__(self, sched_ops, spilling, non_local_mem_usage=None):
85 self.sched_ops = sched_ops
86 self.no_cascade = 0
87 self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {}
88 self.spilling = spilling
89
90 def _is_cascadable(self, sched_op, cost) -> bool:
91 """Checks if 'sched_op' can be cascaded"""
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010092
Tim Halld8339a72021-05-27 18:49:40 +010093 return (
94 sched_op.op_type.npu_block_type not in non_cascadable_blocks
95 and cost.stripe.height < sched_op.ofm.shape.height
Johan Alfvénab677b32022-05-09 13:02:24 +020096 and sched_op.parent_op.read_offsets[0] is None
97 and sched_op.parent_op.read_offsets[1] is None
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +010098 and self.element_wise_cascading_conformity(sched_op)
Tim Halld8339a72021-05-27 18:49:40 +010099 )
100
101 def _estimate_sram_usage(self, sched_op, cost) -> int:
102 """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
103 ifm2_size = sched_op.ifm2_size_in_bytes()
104 if sched_op.requires_full_ifm:
105 ifm_size = sched_op.ifm_size_in_bytes()
106 else:
107 ifm_size = (
108 cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
109 * sched_op.ifm.dtype.size_in_bytes()
110 )
111 if sched_op.requires_full_ofm:
112 ofm_size = sched_op.ofm_size_in_bytes()
113 else:
114 ofm_size = (
115 cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
116 )
117
118 return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
119
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100120 @staticmethod
121 def element_wise_cascading_conformity(sched_op):
122 """Check the inputs of the op to see if it's a candidate for cascading."""
123 # Cascading sub-operators of Softmax results in a crash when handling Sub and RescaleAdd ops
124
125 ifm = sched_op.parent_op.ifm
126 ifm2 = sched_op.parent_op.ifm2
127
128 if sched_op.op_type in [Op.RescaleAdd]:
129 return False
130
131 if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
132 # We cannot rule out cascadability if at least one IFM is constant
133 return Op.Const in (ifm.ops[0], ifm2.ops[0])
134 else:
135 # Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
136 return True
137
Tim Halld8339a72021-05-27 18:49:40 +0100138 def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit):
139 ref_cost = ref_schedule.cost_map
140 fallback_cost = fallback_schedule.cost_map
141 cost = {}
142 cascade_map = {}
143 buffers = BufferMap()
144
145 # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit
146 peak_sram_usage = guiding_mem_limit
147
148 idx = 0
149 while idx < len(self.sched_ops):
150 op = self.sched_ops[idx]
151 if op in cost:
152 # Already processed this Op
153 idx += 1
154 continue
155
156 if not self._is_cascadable(op, ref_cost[op]):
157 # Op is not a candidate for cascading - assign fallback cost
158 cost[op] = fallback_cost[op]
159 if not self.spilling:
160 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
161 idx += 1
162 continue
163
164 # Propose a cascade starting with this Op
165 cascade_start = op.index
166 # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
167 ops_in_cascade = [op]
168 ops_in_best_cascade = [op]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000169 # Get the size of the weight buffer(s)
170 weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100171
172 # The first IFM needs to be stored in full
173 cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
174
175 # Add non-local memory usage
176 cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
177
178 # Sum of all intermediate cascade buffers (including weight buffers)
179 cascade_buffers = weight_buffer
180 # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
181 best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op])
182
183 # Op is the producer of the OFM consumed by the next Op to consider
184 producer = op
185 while True:
186 dependants = producer.get_dependants()
187 if len(dependants) != 1:
188 # producer is either the last Op in the schedule or the start of a branch
189 break
190
191 current_op = dependants[0]
192 if (
193 current_op in cost
194 or current_op not in ref_cost
195 or not self._is_cascadable(current_op, ref_cost[current_op])
196 or producer.ofm.shape != current_op.ifm.shape
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200197 or current_op.requires_full_ifm
198 or producer.requires_full_ofm
Tim Halld8339a72021-05-27 18:49:40 +0100199 ):
200 # Current op has already been processed or cannot be cascaded
201 break
202
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100203 if producer.index + 1 != current_op.index:
204 # Cascading is possible, but requires reordering of operations in the schedule,
205 # this is currently not supported
206 break
207
Tim Halld8339a72021-05-27 18:49:40 +0100208 # Get the size of the FeatureMap buffers between current and neighbouring Ops
209 op_full_ifm = current_op.ifm_size_in_bytes()
210 op_full_ofm = current_op.ofm_size_in_bytes()
211 _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
212
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000213 # Get the size of the weight buffer(s)
214 op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
Tim Halld8339a72021-05-27 18:49:40 +0100215
216 # Calculate the uncascaded memory requirement for current Op
217 uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
218
219 # Add current Op to cascade
220 ops_in_cascade.append(current_op)
221
222 # Increase the accumulated intermediate buffers in the cascade
223 cascade_buffers += op_ifm_buffer + op_weight_buffer
224
225 if self.spilling:
226 # For Dedicated SRAM only the intermediate buffers are in SRAM
227 if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage:
228 # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit
229 break
230 else:
231 # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM
232 ops_in_best_cascade = [op for op in ops_in_cascade]
233 best_cascade_size = cascade_buffers
234
235 else:
236 # Calculate the total size of the current cascade
237 cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
238
239 # Determine if cascading search should stop
240 if (
241 uncascaded_sram_usage < peak_sram_usage
242 and best_cascade_size < peak_sram_usage
243 or (cascade_ifm_size + cascade_buffers) > best_cascade_size
244 ):
245 # Both the existing cascade and current Op fits
246 break
247
248 # Determine if current cascade is the best so far
249 if cascade_size < best_cascade_size:
250 best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
251 ops_in_best_cascade = [op for op in ops_in_cascade]
252
253 producer = current_op
254
255 if len(ops_in_best_cascade) > 1:
256 # A cascade was created - assign cascade and ref_cost to all of the Ops
257 cascade_end = cascade_start + (len(ops_in_best_cascade) - 1)
258 buffers_in_cascade = {}
259 prev_op = None
260 for cascaded_op in ops_in_best_cascade:
Louis Verhaard37ba98c2022-03-16 09:56:45 +0100261 assert cascade_start <= cascaded_op.index <= cascade_end
Tim Halld8339a72021-05-27 18:49:40 +0100262 cost[cascaded_op] = ref_cost[cascaded_op]
263 cost[cascaded_op].cascade = cascade_end
264 if prev_op:
265 rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost)
266 buffers_in_cascade[cascaded_op] = rolling_buffer_shape
267
268 prev_op = cascaded_op
269
270 # Create a CascadeInfo for the cascade
271 cascade_map[cascade_end] = CascadeInfo(
272 cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
273 )
274 if not self.spilling:
275 # Update peak memory usage
276 peak_sram_usage = max(best_cascade_size, peak_sram_usage)
277 else:
278 # Assign fallback cost to the initial Op
279 cost[op] = fallback_cost[op]
280 if not self.spilling:
281 peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage)
282
erik.andersson@arm.com6b2a0b42022-03-22 15:35:30 +0100283 # Update costing and cascade information for the ref_schedule
Tim Halld8339a72021-05-27 18:49:40 +0100284 ref_schedule.cost_map = cost
285 ref_schedule.cascades = cascade_map
286 return ref_schedule