blob: 23026c7950400c3ce96f583e0db9d6c62c2f195d [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Tim Hall79d07d22020-04-27 18:20:16 +010019from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
Diego Russoe8a10452020-04-21 17:39:10 +010020from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020021from .operation import Op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020022from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010023from .tensor import Tensor
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26class LiveRange:
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020027 def __init__(self, tens, alignment):
Tim Hall79d07d22020-04-27 18:20:16 +010028 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
29 self.start_time = 99999999999
30 self.end_time = -1
31 self.size = 0
32 self.name = ""
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020033 self.alignment = alignment
Tim Hall79d07d22020-04-27 18:20:16 +010034
35 if tens:
36 self.add_tensor(tens)
37
38 def __str__(self):
39 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
40
41 __repr__ = __str__
42
43 def add_tensor(self, tens):
44 if self.size == 0:
45 self.size = tens.storage_size()
46 self.name = tens.name # LiveRange will be named after the first tensor added
47 else:
48 assert (
49 self.size >= tens.storage_size()
50 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
51
52 self.tensors.append(tens)
53
54 def mark_usage(self, op_time):
55 if op_time == -1:
56 return
57 op_time_start = op_time
58 op_time_end = op_time + 1
59
60 self.start_time = min(self.start_time, op_time_start)
61 self.end_time = max(self.end_time, op_time_end)
62
63 def overlaps_ranges(self, other):
64 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
65
66 def overlaps_address(self, other):
67 # Returns the first pair of tensors in this LiveRange and 'other' which have
68 # overlapping addresses
69 for tens in self.tensors:
70 for other_tens in other.tensors:
71 if max(tens.address, other_tens.address) < min(
72 tens.address + self.size, other_tens.address + other.size
73 ):
74 return True, tens, other_tens
75
76 return False, None, None
77
78 def __lt__(self, other):
79 if self.start_time != other.start_time:
80 return self.start_time < other.start_time
81 if self.end_time != other.end_time:
82 return self.end_time < other.end_time
83 if self.size != other.size:
84 return self.size < other.size
85 return self.name < other.name
86
87 def set_address(self, address):
Jacob Bohlin1a666972020-09-11 10:04:15 +020088 # Set address of all tensors in LiveRange
Tim Hall79d07d22020-04-27 18:20:16 +010089 for tens in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +020090 tens.address = address
91
92 return address
Tim Hall79d07d22020-04-27 18:20:16 +010093
94 def get_alignment(self):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020095 return self.alignment
Tim Hall79d07d22020-04-27 18:20:16 +010096
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020097 def set_alignment(self, alignment):
98 self.alignment = max(self.alignment, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +010099
100
101def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area):
102 for ps in sg.passes:
103 if ps.placement == PassPlacement.MemoryOnly:
104 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
105 input_tensor = ps.inputs[0]
106 output_tensor = ps.outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100107 if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
108 output_tensor, target_mem_area
109 ):
110 lr_graph.fuse_ranges(input_tensor, output_tensor)
111
112
113class LiveRangeGraph:
114 def __init__(self):
115 self.ranges = {} # tens -> range
116 self.allowed_overlaps = {} # (tens,tens) -> overlap_int
117 self.ignore_tensors = set()
118 self.processed_subgraphs = set()
119 self.current_time = 0
120
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200121 def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200122 # Return the live range of the tensor (or any of its clones)
123 for existing_tensor, rng in self.ranges.items():
124 if tens.equivalent(existing_tensor):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200125 rng.set_alignment(alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100126 return rng
127
128 # No live range found for the tensor, create a new one
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200129 rng = LiveRange(tens, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100130 self.ranges[tens] = rng
131 return rng
132
133 def fuse_ranges(self, in_tens, out_tens):
134 live_range = self.get_or_create_range(in_tens)
135 assert out_tens not in self.ranges, out_tens
136 live_range.add_tensor(out_tens)
137 self.ranges[out_tens] = live_range
138 return live_range
139
140
141def extract_live_ranges_from_passes(
142 sg,
143 target_mem_area,
144 mark_output_tensors_overlapping_with_input_tensors=False,
145 ignore_subgraph_input_output_tensors=False,
146):
147 lr_graph = LiveRangeGraph()
148
149 if ignore_subgraph_input_output_tensors:
150 lr_graph.ignore_tensors.update(sg.input_tensors)
151 lr_graph.ignore_tensors.update(sg.output_tensors)
152
153 def tensor_should_be_ignored(tens, target_mem_area):
154 if tens.mem_area != target_mem_area:
155 return True
156 if tens in lr_graph.ignore_tensors:
157 return True
158 if tens.name.endswith("reshape_shape_npu"):
159 # Reshape tensor, no need to allocate
160 lr_graph.ignore_tensors.add(tens)
161 return True
162 return False
163
164 # Merge only memory operations in the NPU subgraphs
165 if sg.placement == PassPlacement.Npu:
166 merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
167
168 for idx, ps in enumerate(sg.passes):
169 ps.time = 2 * idx
170
171 time_for_pass = ps.time
172
173 for tens in ps.inputs:
174 if tensor_should_be_ignored(tens, target_mem_area):
175 continue
176 rng = lr_graph.get_or_create_range(tens)
177 rng.mark_usage(time_for_pass)
178
179 for tens in ps.intermediates:
180 if tensor_should_be_ignored(tens, target_mem_area):
181 continue
182 rng = lr_graph.get_or_create_range(tens)
183 rng.mark_usage(time_for_pass)
184
185 for tens in ps.outputs:
186 if tensor_should_be_ignored(tens, target_mem_area):
187 continue
188 rng = lr_graph.get_or_create_range(tens)
189 output_time = time_for_pass
190 if not mark_output_tensors_overlapping_with_input_tensors and ps.is_element_wise:
191 output_time += 1
192 rng.mark_usage(output_time)
193
194 end_time = len(sg.passes) * 2
195 for tens in sg.output_tensors:
196 if tensor_should_be_ignored(tens, target_mem_area):
197 continue
198 rng = lr_graph.get_or_create_range(tens)
199 rng.mark_usage(end_time)
200
201 return lr_graph
202
203
204def extract_live_ranges_from_cascaded_passes(
205 sg,
206 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200207 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100208 mark_output_tensors_overlapping_with_input_tensors=False,
209 use_ifm_ofm_overlap=True,
210 ignore_subgraph_input_output_tensors=False,
211 lr_graph=None,
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200212 allocation_alignment=Tensor.AllocationQuantum,
Tim Hall79d07d22020-04-27 18:20:16 +0100213):
Diego Russoea6111a2020-04-14 18:41:58 +0100214 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100215 lr_graph = LiveRangeGraph()
216
217 if sg in lr_graph.processed_subgraphs:
218 # if subgraph has been processed already, return the lr_graph as is
219 return lr_graph
220
221 if ignore_subgraph_input_output_tensors:
222 lr_graph.ignore_tensors.update(sg.input_tensors)
223 lr_graph.ignore_tensors.update(sg.output_tensors)
224
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200225 def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
226 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100227 return True
228 if tens in lr_graph.ignore_tensors:
229 return True
230 if tens.name.endswith("reshape_shape_npu"):
231 # Reshape tensor, no need to allocate
232 lr_graph.ignore_tensors.add(tens)
233 return True
234 return False
235
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200236 def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set):
237 for ps in sg.passes:
238 if ps.placement == PassPlacement.MemoryOnly:
239 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
240 input_tensor = ps.inputs[0]
241 output_tensor = ps.outputs[0]
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200242 if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
243 tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
244 ):
245 lr_graph.fuse_ranges(input_tensor, output_tensor)
246
Tim Hall79d07d22020-04-27 18:20:16 +0100247 # Merge only memory operations in the NPU subgraphs
248 if sg.placement == PassPlacement.Npu:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200249 merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100250
251 for cps in sg.cascaded_passes:
252 cps.time = lr_graph.current_time
253
254 time_for_pass = cps.time
255
256 is_element_wise = cps.is_element_wise
257
258 for tens in cps.inputs:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200259 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100260 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200261 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100262 rng.mark_usage(time_for_pass)
263
264 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200265
Louis Verhaardaee5d752020-09-30 09:01:52 +0200266 if (
267 cps_primary_op
268 and cps_primary_op.type == Op.CustomNpuOp
269 and MemType.Permanent_CPU not in target_mem_type_set
270 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100271 # If the primary-op is an NpuOp that means this is where an Npu subgraph
272 # is called. Go into said subgraph and extract live ranges before continuing.
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200273 # Use default allocation alignment of 16 for Npu tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100274 npu_sg = cps_primary_op.attrs["subgraph"]
275 lr_graph = extract_live_ranges_from_cascaded_passes(
276 npu_sg,
277 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200278 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100279 mark_output_tensors_overlapping_with_input_tensors,
280 use_ifm_ofm_overlap,
281 False,
282 lr_graph,
283 )
284 # Set the new time after handling the Npu subgraph
285 time_for_pass = lr_graph.current_time
286 cps.time = time_for_pass
287
288 for tens in cps.intermediates:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200289 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100290 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200291 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100292 rng.mark_usage(time_for_pass)
293
294 for tens in cps.outputs:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200295 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100296 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200297 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100298 output_time = time_for_pass
299 if not mark_output_tensors_overlapping_with_input_tensors and is_element_wise:
300 output_time += 1
301 rng.mark_usage(output_time)
302
303 if use_ifm_ofm_overlap:
304 # fill allowed overlap for ifm and ofm tensor
305 ifm_tensor = cps.passes[0].ifm_tensor
306 ofm_tensor = cps.passes[-1].ofm_tensor
307 if (
308 ifm_tensor is not None
309 and ofm_tensor is not None
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200310 and not tensor_should_be_ignored(ifm_tensor, target_mem_area, target_mem_type_set)
311 and not tensor_should_be_ignored(ofm_tensor, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100312 ):
313 lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
314 cps
315 )
316
317 lr_graph.current_time += 2
318
319 end_time = 0
320 for rng in lr_graph.ranges.values():
321 # Find the maximum end time of all live-ranges in the graph
322 end_time = max(end_time, rng.end_time)
323
324 for tens in sg.output_tensors:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200325 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100326 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200327 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100328 rng.mark_usage(end_time)
329
330 # Add subgraph to set of processed subgraphs
331 lr_graph.processed_subgraphs.add(sg)
332 return lr_graph