blob: 2795b668853c7ad1190f45325b979b0fb49bf0c8 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Louis Verhaard226ecaf2021-03-30 10:18:28 +020019from typing import List
20
Tim Halld8339a72021-05-27 18:49:40 +010021import numpy as np
22
Louis Verhaardaee5d752020-09-30 09:01:52 +020023from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010024from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020025from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010026from .tensor import Tensor
Tim Halld8339a72021-05-27 18:49:40 +010027from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010028
29
30class LiveRange:
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020031 def __init__(self, tens, alignment):
Tim Hall79d07d22020-04-27 18:20:16 +010032 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
33 self.start_time = 99999999999
34 self.end_time = -1
35 self.size = 0
36 self.name = ""
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020037 self.alignment = alignment
Tim Halld8339a72021-05-27 18:49:40 +010038 self.mem_area = tens.mem_area if tens else MemArea.Unknown
Tim Hall79d07d22020-04-27 18:20:16 +010039
40 if tens:
41 self.add_tensor(tens)
42
43 def __str__(self):
44 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
45
46 __repr__ = __str__
47
48 def add_tensor(self, tens):
49 if self.size == 0:
50 self.size = tens.storage_size()
51 self.name = tens.name # LiveRange will be named after the first tensor added
52 else:
53 assert (
54 self.size >= tens.storage_size()
55 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
56
57 self.tensors.append(tens)
58
Tim Halld8339a72021-05-27 18:49:40 +010059 def mark_usage(self, op_time, op_length=1):
60 op_time_start = max(op_time, 0)
61 op_time_end = op_time + op_length
62 if op_time_end <= op_time_start:
Tim Hall79d07d22020-04-27 18:20:16 +010063 return
Tim Hall79d07d22020-04-27 18:20:16 +010064
65 self.start_time = min(self.start_time, op_time_start)
66 self.end_time = max(self.end_time, op_time_end)
67
Tim Halld8339a72021-05-27 18:49:40 +010068 def set_buffer_size(self, buffer_size):
69 self.size = buffer_size
70 self.mem_area = MemArea.Sram
71
Tim Hall79d07d22020-04-27 18:20:16 +010072 def overlaps_ranges(self, other):
73 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
74
75 def overlaps_address(self, other):
76 # Returns the first pair of tensors in this LiveRange and 'other' which have
77 # overlapping addresses
78 for tens in self.tensors:
79 for other_tens in other.tensors:
80 if max(tens.address, other_tens.address) < min(
81 tens.address + self.size, other_tens.address + other.size
82 ):
83 return True, tens, other_tens
84
85 return False, None, None
86
87 def __lt__(self, other):
88 if self.start_time != other.start_time:
89 return self.start_time < other.start_time
90 if self.end_time != other.end_time:
91 return self.end_time < other.end_time
92 if self.size != other.size:
93 return self.size < other.size
94 return self.name < other.name
95
96 def set_address(self, address):
Jacob Bohlin1a666972020-09-11 10:04:15 +020097 # Set address of all tensors in LiveRange
Tim Hall79d07d22020-04-27 18:20:16 +010098 for tens in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +020099 tens.address = address
100
101 return address
Tim Hall79d07d22020-04-27 18:20:16 +0100102
103 def get_alignment(self):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200104 return self.alignment
Tim Hall79d07d22020-04-27 18:20:16 +0100105
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200106 def set_alignment(self, alignment):
107 self.alignment = max(self.alignment, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100108
109
Tim Hall79d07d22020-04-27 18:20:16 +0100110class LiveRangeGraph:
111 def __init__(self):
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200112 self.lrs: List[LiveRange] = [] # List of all created ranges
Tim Hall79d07d22020-04-27 18:20:16 +0100113 self.ranges = {} # tens -> range
Tim Hall79d07d22020-04-27 18:20:16 +0100114 self.ignore_tensors = set()
115 self.processed_subgraphs = set()
116 self.current_time = 0
Tim Halld8339a72021-05-27 18:49:40 +0100117 self.end_time = None
Tim Hall79d07d22020-04-27 18:20:16 +0100118
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200119 def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200120 # Return the live range of the tensor (or any of its clones)
121 for existing_tensor, rng in self.ranges.items():
122 if tens.equivalent(existing_tensor):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200123 rng.set_alignment(alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100124 return rng
125
126 # No live range found for the tensor, create a new one
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200127 rng = LiveRange(tens, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100128 self.ranges[tens] = rng
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200129 self.lrs.append(rng)
Tim Hall79d07d22020-04-27 18:20:16 +0100130 return rng
131
132 def fuse_ranges(self, in_tens, out_tens):
133 live_range = self.get_or_create_range(in_tens)
134 assert out_tens not in self.ranges, out_tens
135 live_range.add_tensor(out_tens)
136 self.ranges[out_tens] = live_range
137 return live_range
138
Tim Halld8339a72021-05-27 18:49:40 +0100139 def update_endtime(self):
140 self.end_time = 0
141 for rng in self.ranges.values():
142 self.end_time = max(self.end_time, rng.end_time)
143 return self.end_time + 1
144
145 def get_temporal_memory_usage(self, target_mem_area):
146 if not self.end_time:
147 self.update_endtime()
148 usage = np.zeros(self.end_time, dtype=np.int32)
149 for rng in self.ranges.values():
150 if rng.mem_area == target_mem_area:
151 # End time is inclusive
152 usage[rng.start_time : rng.end_time + 1] += rng.size
153
154 return usage
155
Tim Hall79d07d22020-04-27 18:20:16 +0100156
Patrik Gustavssona151f592020-10-16 13:59:52 +0200157def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
158 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
159 return True
160 if tens in lr_graph.ignore_tensors:
161 return True
162 if tens.name.endswith("reshape_shape_npu"):
163 # Reshape tensor, no need to allocate
164 lr_graph.ignore_tensors.add(tens)
165 return True
166 return False
167
168
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200169def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
170 # Tries to merge ifm/ofm live ranges of elementwise op
171 if sched_op.op_type.is_elementwise_op():
172 elem_op = sched_op.parent_op
173 if not tensor_should_be_ignored(lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set):
174 # Check if overwriting the inputs can be allowed
175 if elem_op.type not in (Op.SHL, Op.SHR):
176 inps = []
177 if (
178 elem_op.ifm is not None
179 and elem_op.ifm.shape != []
180 and elem_op.ifm.mem_area == target_mem_area
181 and elem_op.ifm.mem_type in target_mem_type_set
182 ):
183 inps.append(elem_op.ifm)
184 if (
185 elem_op.ifm2 is not None
186 and elem_op.ifm2.shape != []
187 and elem_op.ifm2.mem_area == target_mem_area
188 and elem_op.ifm.mem_type in target_mem_type_set
189 ):
190 inps.append(elem_op.ifm2)
Patrik Gustavssona151f592020-10-16 13:59:52 +0200191
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200192 if len(inps) > 0:
193 for i, inp in enumerate(inps):
194 # check input format, dtype, broadcasting or if there are more input consumers
195 if (
196 inp.format == elem_op.ofm.format
197 and inp.dtype == elem_op.ofm.dtype
198 and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
199 and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
200 ):
201 lr_graph.fuse_ranges(inp, elem_op.ofm)
202 break
Tim Hall79d07d22020-04-27 18:20:16 +0100203
204
205def extract_live_ranges_from_cascaded_passes(
206 sg,
207 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200208 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100209 ignore_subgraph_input_output_tensors=False,
210 lr_graph=None,
Tim Hallb9b515c2020-11-01 21:27:19 +0000211 cpu_tensor_alignment=Tensor.AllocationQuantum,
Tim Hall79d07d22020-04-27 18:20:16 +0100212):
Diego Russoea6111a2020-04-14 18:41:58 +0100213 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100214 lr_graph = LiveRangeGraph()
215
216 if sg in lr_graph.processed_subgraphs:
217 # if subgraph has been processed already, return the lr_graph as is
218 return lr_graph
219
220 if ignore_subgraph_input_output_tensors:
221 lr_graph.ignore_tensors.update(sg.input_tensors)
222 lr_graph.ignore_tensors.update(sg.output_tensors)
223
Tim Hall79d07d22020-04-27 18:20:16 +0100224 for cps in sg.cascaded_passes:
225 cps.time = lr_graph.current_time
226
227 time_for_pass = cps.time
228
Tim Hall79d07d22020-04-27 18:20:16 +0100229 for tens in cps.inputs:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200230 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100231 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000232 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100233 rng.mark_usage(time_for_pass)
234
235 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200236
Louis Verhaardaee5d752020-09-30 09:01:52 +0200237 if (
238 cps_primary_op
239 and cps_primary_op.type == Op.CustomNpuOp
240 and MemType.Permanent_CPU not in target_mem_type_set
241 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100242 # If the primary-op is an NpuOp that means this is where an Npu subgraph
243 # is called. Go into said subgraph and extract live ranges before continuing.
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200244 # Use default allocation alignment of 16 for Npu tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100245 npu_sg = cps_primary_op.attrs["subgraph"]
Tim Halld8339a72021-05-27 18:49:40 +0100246 lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
Tim Hall79d07d22020-04-27 18:20:16 +0100247 # Set the new time after handling the Npu subgraph
248 time_for_pass = lr_graph.current_time
249 cps.time = time_for_pass
250
Patrik Gustavssona151f592020-10-16 13:59:52 +0200251 for tens in cps.intermediates + cps.outputs:
252 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100253 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000254 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100255 rng.mark_usage(time_for_pass)
256
Tim Hall79d07d22020-04-27 18:20:16 +0100257 lr_graph.current_time += 2
258
259 end_time = 0
260 for rng in lr_graph.ranges.values():
261 # Find the maximum end time of all live-ranges in the graph
262 end_time = max(end_time, rng.end_time)
263
264 for tens in sg.output_tensors:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200265 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100266 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000267 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100268 rng.mark_usage(end_time)
269
270 # Add subgraph to set of processed subgraphs
271 lr_graph.processed_subgraphs.add(sg)
272 return lr_graph
Tim Halld8339a72021-05-27 18:49:40 +0100273
274
275def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
276 assert lr_graph is not None
277 sg_time = lr_graph.current_time
278 for ps in sg.passes:
279 for tens in ps.inputs + ps.outputs + ps.intermediates:
280 if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
281 lr_graph, tens, target_mem_area, target_mem_type_set
282 ):
283 continue
Tim Halld8339a72021-05-27 18:49:40 +0100284 rng = lr_graph.get_or_create_range(tens)
285 rng.mark_usage(sg_time)
286
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200287 for _, op_info in sg.schedule.cost_map.items():
Tim Halld784af72021-06-08 21:25:57 +0100288 for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
289 if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)):
290 rng = lr_graph.get_or_create_range(tensor)
291 rng.mark_usage(sg_time)
Tim Halld8339a72021-05-27 18:49:40 +0100292
293 lr_graph.current_time += 1
294 return lr_graph
295
296
297def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
298 time_for_cascade = {}
299 for sched_op in sg.sched_ops:
Jacob Bohlin98bfecd2021-06-21 17:22:20 +0200300 merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set)
301
Tim Halld8339a72021-05-27 18:49:40 +0100302 op_info = sg.schedule.cost_map[sched_op]
303 cascade = op_info.cascade
304 cascade_info = sg.schedule.cascades.get(cascade, None)
305
306 time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)
307
308 op_info.time_index = time_to_set
309
310 # Mark usage for all tensors related to this Pass
311 ps = sched_op.parent_ps
312 for tens in ps.inputs + ps.outputs + ps.intermediates:
313 if (
314 target_mem_area == MemArea.Sram
315 and cascade_info
316 and tens == ps.ifm_tensor
317 and sched_op in cascade_info.buffers
318 ):
319 # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
320 # for enabling temporal memory snapshots without modifying the original Tensor
321 rng = lr_graph.get_or_create_range(tens)
322 rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
323 elif (
324 tens.purpose == TensorPurpose.Weights
325 or tens.purpose == TensorPurpose.FSBias
326 or tens.mem_type not in target_mem_type_set
327 or tens.mem_area != target_mem_area
328 ):
329 continue
330
331 else:
332 rng = lr_graph.get_or_create_range(tens)
333
334 rng.mark_usage(time_to_set)
335
336 weight_tens = op_info.buffered_weight_tensor
337 if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
338 rng = lr_graph.get_or_create_range(weight_tens)
339 if weight_tens.pre_buffer:
340 rng.mark_usage(time_to_set - 1, 2)
341 else:
342 rng.mark_usage(time_to_set)
343
344 if time_to_set == lr_graph.current_time:
345 lr_graph.current_time += 2
346
347 if cascade != 0:
348 time_for_cascade[cascade] = time_to_set
349
350 end_time = lr_graph.update_endtime()
351
352 for tens in sg.output_tensors:
353 if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
354 continue
355 rng = lr_graph.get_or_create_range(tens)
356 rng.mark_usage(end_time)
357
358 return lr_graph