Johan Alfvén | 1657307 | 2022-12-22 10:12:54 +0100 | [diff] [blame] | 1 | # SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | # |
| 17 | # Description: |
| 18 | # Groups Operators in a schedule together to form Cascades. |
Johan Alfvén | fba0a7d | 2022-10-11 20:41:41 +0200 | [diff] [blame] | 19 | from .live_range import ofm_can_reuse_ifm |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 20 | from .numeric_util import round_up |
| 21 | from .operation import NpuBlockType |
erik.andersson@arm.com | 6b2a0b4 | 2022-03-22 15:35:30 +0100 | [diff] [blame] | 22 | from .operation import Op |
Rickard Bolin | 9ae3455 | 2022-06-09 13:07:17 +0000 | [diff] [blame] | 23 | from .operation import Padding |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 24 | from .shape4d import Shape4D |
| 25 | |
| 26 | non_cascadable_blocks = ( |
| 27 | NpuBlockType.Default, |
| 28 | NpuBlockType.VectorProduct, |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 29 | NpuBlockType.ReduceSum, |
| 30 | ) |
| 31 | |
| 32 | |
| 33 | class CascadeInfo: |
| 34 | """Contains metadata about a cascade""" |
| 35 | |
| 36 | def __init__(self, start, end, buffers, mem_usage: int): |
| 37 | self.start = start |
| 38 | self.end = end |
| 39 | self.buffers = buffers |
| 40 | self.mem_usage = mem_usage |
| 41 | |
| 42 | |
| 43 | class BufferMap: |
| 44 | """Caches the buffers seen""" |
| 45 | |
| 46 | def __init__(self): |
| 47 | self.buffer_map = {} |
| 48 | |
| 49 | def get_buffer(self, producer, consumer, cost): |
| 50 | assert producer or consumer |
| 51 | key = (producer, consumer) |
| 52 | if key not in self.buffer_map: |
| 53 | # No cached buffer between these two SchedulerOperations |
| 54 | if consumer is None: |
| 55 | # There are either no consumers or multiple consumers - FeatureMap needs to be stored in full |
| 56 | buffer_shape = producer.ofm.shape |
| 57 | buffer_size = producer.ofm_size_in_bytes() |
| 58 | elif producer is None: |
| 59 | # First Op in subgraph or cascade - FeatureMap needs to be stored in full |
| 60 | buffer_shape = consumer.ifm.shape |
| 61 | buffer_size = consumer.ifm_size_in_bytes() |
| 62 | elif producer.requires_full_ofm or consumer.requires_full_ifm: |
| 63 | # FeatureMap needs to be stored in full |
| 64 | buffer_shape = max(producer.ofm.shape, consumer.ifm.shape) |
| 65 | buffer_size = max(producer.ofm_size_in_bytes(), consumer.ifm_size_in_bytes()) |
| 66 | else: |
| 67 | # Use a rolling buffer |
| 68 | buffer_shape = rolling_buffer_shape(cost[producer].stripe, cost[consumer].stripe_input) |
| 69 | buffer_size = buffer_shape.elements() * producer.ofm.dtype.size_in_bytes() |
| 70 | |
| 71 | self.buffer_map[key] = (buffer_shape, buffer_size) |
| 72 | |
| 73 | return self.buffer_map[key] |
| 74 | |
| 75 | |
| 76 | def rolling_buffer_shape(producer_stripe: Shape4D, consumer_stripe_input: Shape4D) -> Shape4D: |
| 77 | """Calculates the storage shape of the rolling buffer between two SchedulerOperations in a Cascade""" |
| 78 | buffer_height = round_up(producer_stripe.height + consumer_stripe_input.height, consumer_stripe_input.height) |
Johan Alfvén | da69e6d | 2022-12-21 10:27:18 +0100 | [diff] [blame] | 79 | # Striding on the consumer op can result in IFM widths that are narrower than the OFM width of the producer. |
| 80 | # Therefore, the maximum of the two needs to be used. |
| 81 | buffer_width = max(producer_stripe.width, consumer_stripe_input.width) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 82 | # Rolling buffers have to conform to NHCWB16 format |
Johan Alfvén | da69e6d | 2022-12-21 10:27:18 +0100 | [diff] [blame] | 83 | return Shape4D([1, buffer_height, buffer_width, round_up(producer_stripe.depth, 16)]) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 84 | |
| 85 | |
| 86 | class CascadeBuilder: |
| 87 | """Class for grouping SchedulerOperations into cascades""" |
| 88 | |
| 89 | def __init__(self, sched_ops, spilling, non_local_mem_usage=None): |
| 90 | self.sched_ops = sched_ops |
| 91 | self.no_cascade = 0 |
| 92 | self.non_local_mem_usage = non_local_mem_usage if non_local_mem_usage else {} |
| 93 | self.spilling = spilling |
| 94 | |
| 95 | def _is_cascadable(self, sched_op, cost) -> bool: |
| 96 | """Checks if 'sched_op' can be cascaded""" |
erik.andersson@arm.com | 6b2a0b4 | 2022-03-22 15:35:30 +0100 | [diff] [blame] | 97 | |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 98 | return ( |
| 99 | sched_op.op_type.npu_block_type not in non_cascadable_blocks |
| 100 | and cost.stripe.height < sched_op.ofm.shape.height |
Johan Alfvén | ab677b3 | 2022-05-09 13:02:24 +0200 | [diff] [blame] | 101 | and sched_op.parent_op.read_offsets[0] is None |
| 102 | and sched_op.parent_op.read_offsets[1] is None |
Johan Alfvén | 0f2e59f | 2022-10-21 11:21:38 +0200 | [diff] [blame] | 103 | and self.elementwise_cascadable(sched_op) |
Johan Alfvén | dc7414a | 2022-08-18 11:12:40 +0200 | [diff] [blame] | 104 | and not sched_op.parent_op.type.is_resize_op() |
Fredrik Svedberg | 3e3faa9 | 2022-10-11 16:15:47 +0200 | [diff] [blame] | 105 | and not sched_op.parent_op.type == Op.Conv2DBackpropInputSwitchedBias |
Rickard Bolin | 9ae3455 | 2022-06-09 13:07:17 +0000 | [diff] [blame] | 106 | and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 107 | ) |
| 108 | |
| 109 | def _estimate_sram_usage(self, sched_op, cost) -> int: |
| 110 | """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM""" |
Johan Alfvén | 92689d5 | 2022-12-06 11:16:19 +0100 | [diff] [blame] | 111 | if sched_op.parent_op.type.is_binary_elementwise_op(): |
| 112 | # ifm2 is scalar or constant and will always persist in permanent memory |
| 113 | ifm2_size = 0 |
| 114 | else: |
| 115 | ifm2_size = sched_op.ifm2_size_in_bytes() |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 116 | if sched_op.requires_full_ifm: |
| 117 | ifm_size = sched_op.ifm_size_in_bytes() |
| 118 | else: |
| 119 | ifm_size = ( |
| 120 | cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements() |
| 121 | * sched_op.ifm.dtype.size_in_bytes() |
| 122 | ) |
Johan Alfvén | fba0a7d | 2022-10-11 20:41:41 +0200 | [diff] [blame] | 123 | if ofm_can_reuse_ifm(sched_op): |
| 124 | # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0 |
| 125 | ofm_size = 0 |
| 126 | elif sched_op.requires_full_ofm: |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 127 | ofm_size = sched_op.ofm_size_in_bytes() |
| 128 | else: |
| 129 | ofm_size = ( |
| 130 | cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes() |
| 131 | ) |
| 132 | |
| 133 | return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0) |
| 134 | |
erik.andersson@arm.com | 6b2a0b4 | 2022-03-22 15:35:30 +0100 | [diff] [blame] | 135 | @staticmethod |
Johan Alfvén | 0f2e59f | 2022-10-21 11:21:38 +0200 | [diff] [blame] | 136 | def elementwise_cascadable(sched_op): |
| 137 | """Check if the elementwise can be cascaded.""" |
erik.andersson@arm.com | 6b2a0b4 | 2022-03-22 15:35:30 +0100 | [diff] [blame] | 138 | |
Johan Alfvén | 56a71b0 | 2022-10-19 11:20:12 +0200 | [diff] [blame] | 139 | if sched_op.parent_op.type.is_binary_elementwise_op(): |
Johan Alfvén | 56a71b0 | 2022-10-19 11:20:12 +0200 | [diff] [blame] | 140 | ifm = sched_op.parent_op.ifm |
| 141 | ifm2 = sched_op.parent_op.ifm2 |
Johan Alfvén | 0f2e59f | 2022-10-21 11:21:38 +0200 | [diff] [blame] | 142 | ofm = sched_op.parent_op.ofm |
erik.andersson@arm.com | 6b2a0b4 | 2022-03-22 15:35:30 +0100 | [diff] [blame] | 143 | |
Johan Alfvén | 0f2e59f | 2022-10-21 11:21:38 +0200 | [diff] [blame] | 144 | # IFM must be non-constant/non-scalar/non-broadcast |
| 145 | ifm_cascadable = not (ifm.is_const or ifm.is_scalar or ifm.is_broadcast(ofm)) |
Johan Alfvén | 56a71b0 | 2022-10-19 11:20:12 +0200 | [diff] [blame] | 146 | |
Johan Alfvén | 0f2e59f | 2022-10-21 11:21:38 +0200 | [diff] [blame] | 147 | # IFM2 must be constant or scalar |
| 148 | ifm2_cascadable = ifm2.is_const or ifm2.is_scalar |
Johan Alfvén | 56a71b0 | 2022-10-19 11:20:12 +0200 | [diff] [blame] | 149 | |
Johan Alfvén | 0f2e59f | 2022-10-21 11:21:38 +0200 | [diff] [blame] | 150 | return ifm_cascadable and ifm2_cascadable |
Johan Alfvén | 56a71b0 | 2022-10-19 11:20:12 +0200 | [diff] [blame] | 151 | else: |
| 152 | return True |
| 153 | |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 154 | def build_cascades(self, ref_schedule, fallback_schedule, guiding_mem_limit): |
| 155 | ref_cost = ref_schedule.cost_map |
| 156 | fallback_cost = fallback_schedule.cost_map |
| 157 | cost = {} |
| 158 | cascade_map = {} |
| 159 | buffers = BufferMap() |
| 160 | |
| 161 | # Peak memory usage so far - updated continously, unless dedicated SRAM where this is a hard limit |
| 162 | peak_sram_usage = guiding_mem_limit |
| 163 | |
| 164 | idx = 0 |
| 165 | while idx < len(self.sched_ops): |
| 166 | op = self.sched_ops[idx] |
| 167 | if op in cost: |
| 168 | # Already processed this Op |
| 169 | idx += 1 |
| 170 | continue |
| 171 | |
| 172 | if not self._is_cascadable(op, ref_cost[op]): |
| 173 | # Op is not a candidate for cascading - assign fallback cost |
| 174 | cost[op] = fallback_cost[op] |
| 175 | if not self.spilling: |
| 176 | peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage) |
| 177 | idx += 1 |
| 178 | continue |
| 179 | |
| 180 | # Propose a cascade starting with this Op |
| 181 | cascade_start = op.index |
| 182 | # Keep track of which Ops are in the proposed cascade as well as the best cascade so far |
| 183 | ops_in_cascade = [op] |
| 184 | ops_in_best_cascade = [op] |
Rickard Bolin | fd8b500 | 2022-05-16 09:11:06 +0000 | [diff] [blame] | 185 | # Get the size of the weight buffer(s) |
| 186 | weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 187 | |
| 188 | # The first IFM needs to be stored in full |
| 189 | cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0 |
| 190 | |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 191 | # Sum of all intermediate cascade buffers (including weight buffers) |
| 192 | cascade_buffers = weight_buffer |
| 193 | # Best cascade size - Initially it's the fallback cost of the first Op in the cascade |
| 194 | best_cascade_size = self._estimate_sram_usage(op, fallback_cost[op]) |
| 195 | |
| 196 | # Op is the producer of the OFM consumed by the next Op to consider |
| 197 | producer = op |
| 198 | while True: |
| 199 | dependants = producer.get_dependants() |
| 200 | if len(dependants) != 1: |
| 201 | # producer is either the last Op in the schedule or the start of a branch |
| 202 | break |
| 203 | |
| 204 | current_op = dependants[0] |
| 205 | if ( |
| 206 | current_op in cost |
| 207 | or current_op not in ref_cost |
| 208 | or not self._is_cascadable(current_op, ref_cost[current_op]) |
| 209 | or producer.ofm.shape != current_op.ifm.shape |
Louis Verhaard | 04bd3e9 | 2021-08-19 16:36:32 +0200 | [diff] [blame] | 210 | or current_op.requires_full_ifm |
| 211 | or producer.requires_full_ofm |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 212 | ): |
| 213 | # Current op has already been processed or cannot be cascaded |
| 214 | break |
| 215 | |
Louis Verhaard | 37ba98c | 2022-03-16 09:56:45 +0100 | [diff] [blame] | 216 | if producer.index + 1 != current_op.index: |
| 217 | # Cascading is possible, but requires reordering of operations in the schedule, |
| 218 | # this is currently not supported |
| 219 | break |
| 220 | |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 221 | # Get the size of the FeatureMap buffers between current and neighbouring Ops |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 222 | op_full_ofm = current_op.ofm_size_in_bytes() |
| 223 | _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost) |
| 224 | |
Rickard Bolin | fd8b500 | 2022-05-16 09:11:06 +0000 | [diff] [blame] | 225 | # Get the size of the weight buffer(s) |
| 226 | op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 227 | |
| 228 | # Calculate the uncascaded memory requirement for current Op |
Johan Alfvén | 1657307 | 2022-12-22 10:12:54 +0100 | [diff] [blame] | 229 | uncascaded_sram_usage = self._estimate_sram_usage(current_op, fallback_cost[current_op]) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 230 | |
| 231 | # Add current Op to cascade |
| 232 | ops_in_cascade.append(current_op) |
| 233 | |
| 234 | # Increase the accumulated intermediate buffers in the cascade |
| 235 | cascade_buffers += op_ifm_buffer + op_weight_buffer |
| 236 | |
| 237 | if self.spilling: |
| 238 | # For Dedicated SRAM only the intermediate buffers are in SRAM |
| 239 | if uncascaded_sram_usage < peak_sram_usage or cascade_buffers > peak_sram_usage: |
| 240 | # Cascade until an Op fits in its entirety or the accumulated buffers no longer fit |
| 241 | break |
| 242 | else: |
| 243 | # Any addition to the cascade that fits is the new best cascade for Dedicated SRAM |
| 244 | ops_in_best_cascade = [op for op in ops_in_cascade] |
| 245 | best_cascade_size = cascade_buffers |
| 246 | |
| 247 | else: |
Johan Alfven | 3340a88 | 2023-03-16 11:04:31 +0100 | [diff] [blame^] | 248 | # Calculate the total size of the current cascade including non local mem usage |
| 249 | cascade_size = ( |
| 250 | cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0) |
| 251 | ) |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 252 | |
| 253 | # Determine if cascading search should stop |
| 254 | if ( |
| 255 | uncascaded_sram_usage < peak_sram_usage |
| 256 | and best_cascade_size < peak_sram_usage |
| 257 | or (cascade_ifm_size + cascade_buffers) > best_cascade_size |
| 258 | ): |
Johan Alfven | 3340a88 | 2023-03-16 11:04:31 +0100 | [diff] [blame^] | 259 | # Both the existing cascade and current Op fits or |
| 260 | # not possible to reduce cascade size any further |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 261 | break |
| 262 | |
Johan Alfvén | 255dad7 | 2022-07-16 18:27:05 +0200 | [diff] [blame] | 263 | """ |
| 264 | One of two conditions will update the best cascade: |
| 265 | |
| 266 | - cascade_size < best_cascade_size or |
| 267 | - cascade_size < uncascaded_sram_usage |
| 268 | |
| 269 | The last condition is illustrated below, showing an example where it is |
| 270 | better to choose a larger cascade_size (with more OPs) because it will |
| 271 | use less total SRAM usage. |
| 272 | |
| 273 | For simplicity, all featuremaps have same size. |
| 274 | |
| 275 | Cascade OP1-OP2, OP3 is standalone |
| 276 | |
| 277 | -> |OP1| -> roll buffer -> |OP2| -> FM -> |OP3| -> FM |
| 278 | / |
| 279 | |OP0| -> FM |
| 280 | \ |
| 281 | -> .... |
| 282 | |
| 283 | |
| 284 | best_cascade_size : FM + roll buffer + FM |
| 285 | uncascaded_sram_usage: FM + FM + FM |
| 286 | |
| 287 | compared with: |
| 288 | |
| 289 | Cascade OP1-OP3 |
| 290 | |
| 291 | -> |OP1| -> roll buffer -> |OP2| -> roll buffer -> |OP3| -> FM |
| 292 | / |
| 293 | |OP0| -> FM |
| 294 | \ |
| 295 | -> .... |
| 296 | |
| 297 | |
| 298 | cascade_size : FM + roll buffer + roll buffer + FM |
| 299 | |
| 300 | |
| 301 | So, for this use case the comparison will be |
| 302 | |
| 303 | (FM + roll buffer + roll buffer + FM) < (FM + roll buffer + FM) or |
| 304 | (FM + roll buffer + roll buffer + FM) < (FM + FM + FM) |
| 305 | |
| 306 | hence, better to choose Cascade OP1-OP3 in this case. |
| 307 | """ |
| 308 | if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage: |
Johan Alfven | 3340a88 | 2023-03-16 11:04:31 +0100 | [diff] [blame^] | 309 | best_cascade_size = cascade_size |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 310 | ops_in_best_cascade = [op for op in ops_in_cascade] |
| 311 | |
| 312 | producer = current_op |
| 313 | |
| 314 | if len(ops_in_best_cascade) > 1: |
| 315 | # A cascade was created - assign cascade and ref_cost to all of the Ops |
| 316 | cascade_end = cascade_start + (len(ops_in_best_cascade) - 1) |
| 317 | buffers_in_cascade = {} |
| 318 | prev_op = None |
| 319 | for cascaded_op in ops_in_best_cascade: |
Louis Verhaard | 37ba98c | 2022-03-16 09:56:45 +0100 | [diff] [blame] | 320 | assert cascade_start <= cascaded_op.index <= cascade_end |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 321 | cost[cascaded_op] = ref_cost[cascaded_op] |
| 322 | cost[cascaded_op].cascade = cascade_end |
| 323 | if prev_op: |
| 324 | rolling_buffer_shape, _ = buffers.get_buffer(prev_op, cascaded_op, ref_cost) |
| 325 | buffers_in_cascade[cascaded_op] = rolling_buffer_shape |
| 326 | |
| 327 | prev_op = cascaded_op |
| 328 | |
Johan Alfven | 3340a88 | 2023-03-16 11:04:31 +0100 | [diff] [blame^] | 329 | # Create a CascadeInfo for the cascade, only store the actual size used by |
| 330 | # the cascade so non local usage is removed. This is done in order to be |
| 331 | # able to calculate the correct non local usage in the scheduler when |
| 332 | # optimizing the sub schedules. |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 333 | cascade_map[cascade_end] = CascadeInfo( |
Johan Alfven | 3340a88 | 2023-03-16 11:04:31 +0100 | [diff] [blame^] | 334 | cascade_start, |
| 335 | cascade_end, |
| 336 | buffers_in_cascade, |
| 337 | best_cascade_size - self.non_local_mem_usage.get(op, 0), |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 338 | ) |
| 339 | if not self.spilling: |
| 340 | # Update peak memory usage |
| 341 | peak_sram_usage = max(best_cascade_size, peak_sram_usage) |
| 342 | else: |
| 343 | # Assign fallback cost to the initial Op |
| 344 | cost[op] = fallback_cost[op] |
| 345 | if not self.spilling: |
| 346 | peak_sram_usage = max(self._estimate_sram_usage(op, fallback_cost[op]), peak_sram_usage) |
| 347 | |
erik.andersson@arm.com | 6b2a0b4 | 2022-03-22 15:35:30 +0100 | [diff] [blame] | 348 | # Update costing and cascade information for the ref_schedule |
Tim Hall | d8339a7 | 2021-05-27 18:49:40 +0100 | [diff] [blame] | 349 | ref_schedule.cost_map = cost |
| 350 | ref_schedule.cascades = cascade_map |
| 351 | return ref_schedule |