blob: ccf49297ecb697ee215def3259e19cd7f37048ea [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Tim Hallffe8e282021-06-24 18:29:53 +010019from collections import namedtuple
Louis Verhaard226ecaf2021-03-30 10:18:28 +020020from typing import List
21
Tim Halld8339a72021-05-27 18:49:40 +010022import numpy as np
23
Louis Verhaardaee5d752020-09-30 09:01:52 +020024from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010025from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020026from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010027from .tensor import Tensor
Tim Halld8339a72021-05-27 18:49:40 +010028from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010029
30
31class LiveRange:
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020032 def __init__(self, tens, alignment):
Tim Hall79d07d22020-04-27 18:20:16 +010033 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
34 self.start_time = 99999999999
35 self.end_time = -1
36 self.size = 0
37 self.name = ""
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020038 self.alignment = alignment
Tim Halld8339a72021-05-27 18:49:40 +010039 self.mem_area = tens.mem_area if tens else MemArea.Unknown
Tim Hall79d07d22020-04-27 18:20:16 +010040
41 if tens:
42 self.add_tensor(tens)
43
44 def __str__(self):
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +010045 return (
46 f"<live_range.LiveRange: {self.start_time}-{self.end_time}, "
47 f"size={self.size}, '{self.name}' #:{len(self.tensors)}>"
48 )
Tim Hall79d07d22020-04-27 18:20:16 +010049
50 __repr__ = __str__
51
52 def add_tensor(self, tens):
53 if self.size == 0:
54 self.size = tens.storage_size()
55 self.name = tens.name # LiveRange will be named after the first tensor added
56 else:
57 assert (
58 self.size >= tens.storage_size()
59 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
60
61 self.tensors.append(tens)
62
Tim Halld8339a72021-05-27 18:49:40 +010063 def mark_usage(self, op_time, op_length=1):
64 op_time_start = max(op_time, 0)
65 op_time_end = op_time + op_length
Rickard Bolinfd8b5002022-05-16 09:11:06 +000066 if op_time_end < op_time_start:
Tim Hall79d07d22020-04-27 18:20:16 +010067 return
Tim Hall79d07d22020-04-27 18:20:16 +010068
69 self.start_time = min(self.start_time, op_time_start)
70 self.end_time = max(self.end_time, op_time_end)
71
Tim Halld8339a72021-05-27 18:49:40 +010072 def set_buffer_size(self, buffer_size):
73 self.size = buffer_size
74 self.mem_area = MemArea.Sram
75
Tim Hall79d07d22020-04-27 18:20:16 +010076 def overlaps_ranges(self, other):
77 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
78
79 def overlaps_address(self, other):
80 # Returns the first pair of tensors in this LiveRange and 'other' which have
81 # overlapping addresses
82 for tens in self.tensors:
83 for other_tens in other.tensors:
84 if max(tens.address, other_tens.address) < min(
85 tens.address + self.size, other_tens.address + other.size
86 ):
87 return True, tens, other_tens
88
89 return False, None, None
90
91 def __lt__(self, other):
92 if self.start_time != other.start_time:
93 return self.start_time < other.start_time
94 if self.end_time != other.end_time:
95 return self.end_time < other.end_time
96 if self.size != other.size:
97 return self.size < other.size
98 return self.name < other.name
99
100 def set_address(self, address):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200101 # Set address of all tensors in LiveRange
Tim Hall79d07d22020-04-27 18:20:16 +0100102 for tens in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200103 tens.address = address
104
105 return address
Tim Hall79d07d22020-04-27 18:20:16 +0100106
107 def get_alignment(self):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200108 return self.alignment
Tim Hall79d07d22020-04-27 18:20:16 +0100109
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200110 def set_alignment(self, alignment):
111 self.alignment = max(self.alignment, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100112
113
Tim Hall79d07d22020-04-27 18:20:16 +0100114class LiveRangeGraph:
115 def __init__(self):
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200116 self.lrs: List[LiveRange] = [] # List of all created ranges
Tim Hall79d07d22020-04-27 18:20:16 +0100117 self.ranges = {} # tens -> range
Tim Hall79d07d22020-04-27 18:20:16 +0100118 self.processed_subgraphs = set()
119 self.current_time = 0
Tim Halld8339a72021-05-27 18:49:40 +0100120 self.end_time = None
Tim Hall79d07d22020-04-27 18:20:16 +0100121
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200122 def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200123 # Return the live range of the tensor (or any of its clones)
124 for existing_tensor, rng in self.ranges.items():
125 if tens.equivalent(existing_tensor):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200126 rng.set_alignment(alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100127 return rng
128
129 # No live range found for the tensor, create a new one
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200130 rng = LiveRange(tens, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100131 self.ranges[tens] = rng
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200132 self.lrs.append(rng)
Tim Hall79d07d22020-04-27 18:20:16 +0100133 return rng
134
135 def fuse_ranges(self, in_tens, out_tens):
136 live_range = self.get_or_create_range(in_tens)
137 assert out_tens not in self.ranges, out_tens
138 live_range.add_tensor(out_tens)
139 self.ranges[out_tens] = live_range
140 return live_range
141
Tim Halld8339a72021-05-27 18:49:40 +0100142 def update_endtime(self):
Louis Verhaardcc34d5d2021-08-19 15:15:36 +0200143 self.end_time = self.current_time
Tim Halld8339a72021-05-27 18:49:40 +0100144 return self.end_time + 1
145
146 def get_temporal_memory_usage(self, target_mem_area):
Louis Verhaardcc34d5d2021-08-19 15:15:36 +0200147 usage = np.zeros(self.update_endtime(), dtype=np.int32)
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100148 for lr in self.lrs:
149 if lr.mem_area == target_mem_area:
Tim Halld8339a72021-05-27 18:49:40 +0100150 # End time is inclusive
erik.andersson@arm.comde6cb642022-02-02 14:03:15 +0100151 usage[lr.start_time : lr.end_time + 1] += lr.size
Tim Halld8339a72021-05-27 18:49:40 +0100152
153 return usage
154
Tim Hall79d07d22020-04-27 18:20:16 +0100155
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200156def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Patrik Gustavssona151f592020-10-16 13:59:52 +0200157 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
158 return True
Patrik Gustavssona151f592020-10-16 13:59:52 +0200159 return False
160
161
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200162def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
Tim Hallffe8e282021-06-24 18:29:53 +0100163 def _tensor_should_be_ignored(tens):
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100164 if tens.ifm_write_protected:
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200165 return True
166 return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
Tim Hallffe8e282021-06-24 18:29:53 +0100167
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200168 # Tries to merge ifm/ofm live ranges of elementwise op
169 if sched_op.op_type.is_elementwise_op():
170 elem_op = sched_op.parent_op
Tim Hallffe8e282021-06-24 18:29:53 +0100171 if not _tensor_should_be_ignored(elem_op.ofm):
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200172 # Check if overwriting the inputs can be allowed
Tim Hallffe8e282021-06-24 18:29:53 +0100173 OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
174 outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
175 inps = []
176 if elem_op.ifm is not None:
177 inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
178 if elem_op.ifm2 is not None:
179 inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
Patrik Gustavssona151f592020-10-16 13:59:52 +0200180
Tim Hallffe8e282021-06-24 18:29:53 +0100181 # find an input tensor that can be overwritten by the output
182 for inp in inps:
183 if (
184 # check op input and output shapes allow overlapping
185 inp.op_shape == outp.op_shape
186 # check input tensor is valid
187 and inp.tens is not None
188 and inp.tens.shape != []
189 and not _tensor_should_be_ignored(inp.tens)
190 # check input and output tensors are compatible
191 and inp.tens.format == outp.tens.format
192 and inp.tens.dtype == outp.tens.dtype
193 # check input tensor only has one consumer
194 and len(inp.tens.consumer_list) == 1
195 # check output tensor only has one producer
196 and len(outp.tens.ops) == 1
197 ):
198 lr_graph.fuse_ranges(inp.tens, outp.tens)
199 break
Tim Hall79d07d22020-04-27 18:20:16 +0100200
201
202def extract_live_ranges_from_cascaded_passes(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200203 sg,
204 target_mem_area,
205 target_mem_type_set,
206 lr_graph=None,
207 cpu_tensor_alignment=Tensor.AllocationQuantum,
Tim Hall79d07d22020-04-27 18:20:16 +0100208):
Diego Russoea6111a2020-04-14 18:41:58 +0100209 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100210 lr_graph = LiveRangeGraph()
211
212 if sg in lr_graph.processed_subgraphs:
213 # if subgraph has been processed already, return the lr_graph as is
214 return lr_graph
215
Tim Hall79d07d22020-04-27 18:20:16 +0100216 for cps in sg.cascaded_passes:
217 cps.time = lr_graph.current_time
218
219 time_for_pass = cps.time
220
Tim Hall79d07d22020-04-27 18:20:16 +0100221 for tens in cps.inputs:
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200222 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100223 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000224 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100225 rng.mark_usage(time_for_pass)
226
227 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200228
Louis Verhaardaee5d752020-09-30 09:01:52 +0200229 if (
230 cps_primary_op
231 and cps_primary_op.type == Op.CustomNpuOp
232 and MemType.Permanent_CPU not in target_mem_type_set
233 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100234 # If the primary-op is an NpuOp that means this is where an Npu subgraph
235 # is called. Go into said subgraph and extract live ranges before continuing.
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200236 # Use default allocation alignment of 16 for Npu tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100237 npu_sg = cps_primary_op.attrs["subgraph"]
Tim Halld8339a72021-05-27 18:49:40 +0100238 lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
Tim Hall79d07d22020-04-27 18:20:16 +0100239 # Set the new time after handling the Npu subgraph
240 time_for_pass = lr_graph.current_time
241 cps.time = time_for_pass
242
Patrik Gustavssona151f592020-10-16 13:59:52 +0200243 for tens in cps.intermediates + cps.outputs:
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200244 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100245 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000246 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100247 rng.mark_usage(time_for_pass)
248
Tim Hall79d07d22020-04-27 18:20:16 +0100249 lr_graph.current_time += 2
250
251 end_time = 0
252 for rng in lr_graph.ranges.values():
253 # Find the maximum end time of all live-ranges in the graph
254 end_time = max(end_time, rng.end_time)
255
256 for tens in sg.output_tensors:
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200257 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100258 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000259 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100260 rng.mark_usage(end_time)
261
262 # Add subgraph to set of processed subgraphs
263 lr_graph.processed_subgraphs.add(sg)
264 return lr_graph
Tim Halld8339a72021-05-27 18:49:40 +0100265
266
267def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
268 assert lr_graph is not None
269 sg_time = lr_graph.current_time
270 for ps in sg.passes:
271 for tens in ps.inputs + ps.outputs + ps.intermediates:
272 if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200273 tens, target_mem_area, target_mem_type_set
Tim Halld8339a72021-05-27 18:49:40 +0100274 ):
275 continue
Tim Halld8339a72021-05-27 18:49:40 +0100276 rng = lr_graph.get_or_create_range(tens)
277 rng.mark_usage(sg_time)
278
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200279 for _, op_info in sg.schedule.cost_map.items():
Tim Halld784af72021-06-08 21:25:57 +0100280 for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200281 if tensor and not (tensor_should_be_ignored(tensor, target_mem_area, target_mem_type_set)):
Tim Halld784af72021-06-08 21:25:57 +0100282 rng = lr_graph.get_or_create_range(tensor)
283 rng.mark_usage(sg_time)
Tim Halld8339a72021-05-27 18:49:40 +0100284
285 lr_graph.current_time += 1
286 return lr_graph
287
288
289def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
290 time_for_cascade = {}
291 for sched_op in sg.sched_ops:
Fredrik Svedberg0ae28482021-10-27 13:58:03 +0200292 merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set)
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200293
Tim Halld8339a72021-05-27 18:49:40 +0100294 op_info = sg.schedule.cost_map[sched_op]
295 cascade = op_info.cascade
296 cascade_info = sg.schedule.cascades.get(cascade, None)
297
298 time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)
299
300 op_info.time_index = time_to_set
301
302 # Mark usage for all tensors related to this Pass
303 ps = sched_op.parent_ps
304 for tens in ps.inputs + ps.outputs + ps.intermediates:
305 if (
306 target_mem_area == MemArea.Sram
307 and cascade_info
308 and tens == ps.ifm_tensor
309 and sched_op in cascade_info.buffers
310 ):
311 # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
312 # for enabling temporal memory snapshots without modifying the original Tensor
313 rng = lr_graph.get_or_create_range(tens)
314 rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
315 elif (
316 tens.purpose == TensorPurpose.Weights
317 or tens.purpose == TensorPurpose.FSBias
318 or tens.mem_type not in target_mem_type_set
319 or tens.mem_area != target_mem_area
320 ):
321 continue
322
323 else:
324 rng = lr_graph.get_or_create_range(tens)
325
326 rng.mark_usage(time_to_set)
327
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000328 for idx, weight_tens in enumerate(op_info.buffered_weight_tensors):
329 if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
330 rng = lr_graph.get_or_create_range(weight_tens)
331 start_time = time_to_set
332 length = 1
333 if weight_tens.pre_buffer:
334 start_time -= 1
335 length += 1
336 if len(op_info.buffered_weight_tensors) > 1:
337 last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors)
338 # Double buffering: reduce end time of the buffer that is not used last
339 if last_idx != idx:
340 length -= 1
341 rng.mark_usage(start_time, length)
Tim Halld8339a72021-05-27 18:49:40 +0100342
343 if time_to_set == lr_graph.current_time:
344 lr_graph.current_time += 2
345
346 if cascade != 0:
347 time_for_cascade[cascade] = time_to_set
348
349 end_time = lr_graph.update_endtime()
350
351 for tens in sg.output_tensors:
352 if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
353 continue
354 rng = lr_graph.get_or_create_range(tens)
355 rng.mark_usage(end_time)
356
357 return lr_graph