blob: de001e568d33dce2513d3a245391f49915a01c92 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Louis Verhaard226ecaf2021-03-30 10:18:28 +020019from typing import List
20
Diego Russoe8a10452020-04-21 17:39:10 +010021from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020022from .operation import Op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020023from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010024from .tensor import Tensor
Tim Hall79d07d22020-04-27 18:20:16 +010025
26
27class LiveRange:
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020028 def __init__(self, tens, alignment):
Tim Hall79d07d22020-04-27 18:20:16 +010029 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
30 self.start_time = 99999999999
31 self.end_time = -1
32 self.size = 0
33 self.name = ""
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020034 self.alignment = alignment
Tim Hall79d07d22020-04-27 18:20:16 +010035
36 if tens:
37 self.add_tensor(tens)
38
39 def __str__(self):
40 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
41
42 __repr__ = __str__
43
44 def add_tensor(self, tens):
45 if self.size == 0:
46 self.size = tens.storage_size()
47 self.name = tens.name # LiveRange will be named after the first tensor added
48 else:
49 assert (
50 self.size >= tens.storage_size()
51 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
52
53 self.tensors.append(tens)
54
55 def mark_usage(self, op_time):
56 if op_time == -1:
57 return
58 op_time_start = op_time
59 op_time_end = op_time + 1
60
61 self.start_time = min(self.start_time, op_time_start)
62 self.end_time = max(self.end_time, op_time_end)
63
64 def overlaps_ranges(self, other):
65 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
66
67 def overlaps_address(self, other):
68 # Returns the first pair of tensors in this LiveRange and 'other' which have
69 # overlapping addresses
70 for tens in self.tensors:
71 for other_tens in other.tensors:
72 if max(tens.address, other_tens.address) < min(
73 tens.address + self.size, other_tens.address + other.size
74 ):
75 return True, tens, other_tens
76
77 return False, None, None
78
79 def __lt__(self, other):
80 if self.start_time != other.start_time:
81 return self.start_time < other.start_time
82 if self.end_time != other.end_time:
83 return self.end_time < other.end_time
84 if self.size != other.size:
85 return self.size < other.size
86 return self.name < other.name
87
88 def set_address(self, address):
Jacob Bohlin1a666972020-09-11 10:04:15 +020089 # Set address of all tensors in LiveRange
Tim Hall79d07d22020-04-27 18:20:16 +010090 for tens in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +020091 tens.address = address
92
93 return address
Tim Hall79d07d22020-04-27 18:20:16 +010094
95 def get_alignment(self):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020096 return self.alignment
Tim Hall79d07d22020-04-27 18:20:16 +010097
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020098 def set_alignment(self, alignment):
99 self.alignment = max(self.alignment, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100100
101
Tim Hall79d07d22020-04-27 18:20:16 +0100102class LiveRangeGraph:
103 def __init__(self):
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200104 self.lrs: List[LiveRange] = [] # List of all created ranges
Tim Hall79d07d22020-04-27 18:20:16 +0100105 self.ranges = {} # tens -> range
Tim Hall79d07d22020-04-27 18:20:16 +0100106 self.ignore_tensors = set()
107 self.processed_subgraphs = set()
108 self.current_time = 0
109
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200110 def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200111 # Return the live range of the tensor (or any of its clones)
112 for existing_tensor, rng in self.ranges.items():
113 if tens.equivalent(existing_tensor):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200114 rng.set_alignment(alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100115 return rng
116
117 # No live range found for the tensor, create a new one
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200118 rng = LiveRange(tens, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100119 self.ranges[tens] = rng
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200120 self.lrs.append(rng)
Tim Hall79d07d22020-04-27 18:20:16 +0100121 return rng
122
123 def fuse_ranges(self, in_tens, out_tens):
124 live_range = self.get_or_create_range(in_tens)
125 assert out_tens not in self.ranges, out_tens
126 live_range.add_tensor(out_tens)
127 self.ranges[out_tens] = live_range
128 return live_range
129
130
Patrik Gustavssona151f592020-10-16 13:59:52 +0200131def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
132 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
133 return True
134 if tens in lr_graph.ignore_tensors:
135 return True
136 if tens.name.endswith("reshape_shape_npu"):
137 # Reshape tensor, no need to allocate
138 lr_graph.ignore_tensors.add(tens)
139 return True
140 return False
141
142
143# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
144def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
145 for ps in sg.passes:
146 if ps.placement == PassPlacement.MemoryOnly:
147 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
148 input_tensor = ps.inputs[0]
149 output_tensor = ps.outputs[0]
150 if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
151 tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
152 ):
153 lr_graph.fuse_ranges(input_tensor, output_tensor)
154 elif ps.is_element_wise:
155 merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
156
157
158# Tries to merge ifm/ofm live of elementwise op
159def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
160 elem_op = None
161 for op in ps.ops:
162 if op.type.is_elementwise_op():
163 assert elem_op is None
164 elem_op = op
165
166 if elem_op is not None and not tensor_should_be_ignored(
167 lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
168 ):
169 # Check if overwriting the inputs can be allowed
170 if elem_op.type not in (Op.SHL, Op.SHR):
171 inps = []
172 if (
173 elem_op.ifm is not None
174 and elem_op.ifm.shape != []
175 and elem_op.ifm.mem_area == target_mem_area
176 and elem_op.ifm.mem_type in target_mem_type_set
177 ):
178 inps.append(elem_op.ifm)
179 if (
180 elem_op.ifm2 is not None
181 and elem_op.ifm2.shape != []
182 and elem_op.ifm2.mem_area == target_mem_area
183 and elem_op.ifm.mem_type in target_mem_type_set
184 ):
185 inps.append(elem_op.ifm2)
186
187 if len(inps) > 0:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100188 for i, inp in enumerate(inps):
Patrik Gustavssona151f592020-10-16 13:59:52 +0200189 # check input format, dtype, broadcasting or if there are more input consumers
190 if (
191 inp.format == elem_op.ofm.format
192 and inp.dtype == elem_op.ofm.dtype
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100193 and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
Patrik Gustavssona151f592020-10-16 13:59:52 +0200194 and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
195 ):
196 lr_graph.fuse_ranges(inp, elem_op.ofm)
197 break
198
199
Tim Hall79d07d22020-04-27 18:20:16 +0100200def extract_live_ranges_from_passes(
Michael McGeagh6f725262020-12-03 15:21:36 +0000201 sg, target_mem_area, target_mem_type_set=None, ignore_subgraph_input_output_tensors=False,
Tim Hall79d07d22020-04-27 18:20:16 +0100202):
203 lr_graph = LiveRangeGraph()
204
205 if ignore_subgraph_input_output_tensors:
206 lr_graph.ignore_tensors.update(sg.input_tensors)
207 lr_graph.ignore_tensors.update(sg.output_tensors)
208
Michael McGeagh6f725262020-12-03 15:21:36 +0000209 if target_mem_type_set is None:
210 target_mem_type_set = set((MemType.Scratch, MemType.Scratch_fast))
211
Patrik Gustavssona151f592020-10-16 13:59:52 +0200212 # Try to merge live ranges of operations in the NPU subgraphs
Tim Hall79d07d22020-04-27 18:20:16 +0100213 if sg.placement == PassPlacement.Npu:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100214 merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100215
216 for idx, ps in enumerate(sg.passes):
217 ps.time = 2 * idx
218
219 time_for_pass = ps.time
220
Patrik Gustavssona151f592020-10-16 13:59:52 +0200221 for tens in ps.inputs + ps.intermediates + ps.outputs:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100222 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100223 continue
224 rng = lr_graph.get_or_create_range(tens)
225 rng.mark_usage(time_for_pass)
226
Tim Hall79d07d22020-04-27 18:20:16 +0100227 end_time = len(sg.passes) * 2
228 for tens in sg.output_tensors:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100229 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100230 continue
231 rng = lr_graph.get_or_create_range(tens)
232 rng.mark_usage(end_time)
233
234 return lr_graph
235
236
237def extract_live_ranges_from_cascaded_passes(
238 sg,
239 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200240 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100241 ignore_subgraph_input_output_tensors=False,
242 lr_graph=None,
Tim Hallb9b515c2020-11-01 21:27:19 +0000243 cpu_tensor_alignment=Tensor.AllocationQuantum,
Tim Hall79d07d22020-04-27 18:20:16 +0100244):
Diego Russoea6111a2020-04-14 18:41:58 +0100245 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100246 lr_graph = LiveRangeGraph()
247
248 if sg in lr_graph.processed_subgraphs:
249 # if subgraph has been processed already, return the lr_graph as is
250 return lr_graph
251
252 if ignore_subgraph_input_output_tensors:
253 lr_graph.ignore_tensors.update(sg.input_tensors)
254 lr_graph.ignore_tensors.update(sg.output_tensors)
255
Patrik Gustavssona151f592020-10-16 13:59:52 +0200256 # Try to merge live ranges of operations in the NPU subgraphs
Tim Hall79d07d22020-04-27 18:20:16 +0100257 if sg.placement == PassPlacement.Npu:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200258 merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100259
260 for cps in sg.cascaded_passes:
261 cps.time = lr_graph.current_time
262
263 time_for_pass = cps.time
264
Tim Hall79d07d22020-04-27 18:20:16 +0100265 for tens in cps.inputs:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200266 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100267 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000268 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100269 rng.mark_usage(time_for_pass)
270
271 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200272
Louis Verhaardaee5d752020-09-30 09:01:52 +0200273 if (
274 cps_primary_op
275 and cps_primary_op.type == Op.CustomNpuOp
276 and MemType.Permanent_CPU not in target_mem_type_set
277 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100278 # If the primary-op is an NpuOp that means this is where an Npu subgraph
279 # is called. Go into said subgraph and extract live ranges before continuing.
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200280 # Use default allocation alignment of 16 for Npu tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100281 npu_sg = cps_primary_op.attrs["subgraph"]
282 lr_graph = extract_live_ranges_from_cascaded_passes(
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100283 npu_sg, target_mem_area, target_mem_type_set, False, lr_graph,
Tim Hall79d07d22020-04-27 18:20:16 +0100284 )
285 # Set the new time after handling the Npu subgraph
286 time_for_pass = lr_graph.current_time
287 cps.time = time_for_pass
288
Patrik Gustavssona151f592020-10-16 13:59:52 +0200289 for tens in cps.intermediates + cps.outputs:
290 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100291 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000292 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100293 rng.mark_usage(time_for_pass)
294
Tim Hall79d07d22020-04-27 18:20:16 +0100295 lr_graph.current_time += 2
296
297 end_time = 0
298 for rng in lr_graph.ranges.values():
299 # Find the maximum end time of all live-ranges in the graph
300 end_time = max(end_time, rng.end_time)
301
302 for tens in sg.output_tensors:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200303 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100304 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000305 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100306 rng.mark_usage(end_time)
307
308 # Add subgraph to set of processed subgraphs
309 lr_graph.processed_subgraphs.add(sg)
310 return lr_graph