blob: fe00b6229ad046c966f3d531543596b1b2671363 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Tim Hall79d07d22020-04-27 18:20:16 +010019from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
Diego Russoe8a10452020-04-21 17:39:10 +010020from .nn_graph import PassPlacement
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020021from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010022from .tensor import Tensor
Tim Hall79d07d22020-04-27 18:20:16 +010023
24
25class LiveRange:
26 def __init__(self, tens):
27 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
28 self.start_time = 99999999999
29 self.end_time = -1
30 self.size = 0
31 self.name = ""
32
33 if tens:
34 self.add_tensor(tens)
35
36 def __str__(self):
37 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
38
39 __repr__ = __str__
40
41 def add_tensor(self, tens):
42 if self.size == 0:
43 self.size = tens.storage_size()
44 self.name = tens.name # LiveRange will be named after the first tensor added
45 else:
46 assert (
47 self.size >= tens.storage_size()
48 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
49
50 self.tensors.append(tens)
51
52 def mark_usage(self, op_time):
53 if op_time == -1:
54 return
55 op_time_start = op_time
56 op_time_end = op_time + 1
57
58 self.start_time = min(self.start_time, op_time_start)
59 self.end_time = max(self.end_time, op_time_end)
60
61 def overlaps_ranges(self, other):
62 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
63
64 def overlaps_address(self, other):
65 # Returns the first pair of tensors in this LiveRange and 'other' which have
66 # overlapping addresses
67 for tens in self.tensors:
68 for other_tens in other.tensors:
69 if max(tens.address, other_tens.address) < min(
70 tens.address + self.size, other_tens.address + other.size
71 ):
72 return True, tens, other_tens
73
74 return False, None, None
75
76 def __lt__(self, other):
77 if self.start_time != other.start_time:
78 return self.start_time < other.start_time
79 if self.end_time != other.end_time:
80 return self.end_time < other.end_time
81 if self.size != other.size:
82 return self.size < other.size
83 return self.name < other.name
84
85 def set_address(self, address):
86 # Set address of all unaddressed tensors in LiveRange
87 for tens in self.tensors:
Charles Xu04ce34c2020-06-23 12:42:28 +020088 if tens.address is None:
Charles Xu5b3dcd72020-05-28 07:20:52 +020089 addr = address
90 else:
91 # Limit to single tensor for the lr if the tensor address already assigned
92 assert len(self.tensors) == 1
93 addr = tens.address
94 tens.address = addr
95 # Also need to set the address to the tensor's cpu/npu clones
96 if tens.cpu_tensor is not None:
97 tens.cpu_tensor.address = addr
98 if tens.npu_tensor is not None:
99 tens.npu_tensor.address = addr
100 return addr
Tim Hall79d07d22020-04-27 18:20:16 +0100101
102 def get_alignment(self):
103 # Get max alignment of LiveRange's tensors
104 if self.tensors:
105 alignment = 0
106 for tens in self.tensors:
107 alignment = max(alignment, tens.alignment)
108
109 return alignment
110
111 return Tensor.AllocationQuantum
112
113
114def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area):
115 for ps in sg.passes:
116 if ps.placement == PassPlacement.MemoryOnly:
117 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
118 input_tensor = ps.inputs[0]
119 output_tensor = ps.outputs[0]
120 # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
121 # or output, fuse the live-range with the Cpu tensors' live-range instead.
Diego Russoea6111a2020-04-14 18:41:58 +0100122 input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
123 output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
Tim Hall79d07d22020-04-27 18:20:16 +0100124 if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
125 output_tensor, target_mem_area
126 ):
127 lr_graph.fuse_ranges(input_tensor, output_tensor)
128
129
130class LiveRangeGraph:
131 def __init__(self):
132 self.ranges = {} # tens -> range
133 self.allowed_overlaps = {} # (tens,tens) -> overlap_int
134 self.ignore_tensors = set()
135 self.processed_subgraphs = set()
136 self.current_time = 0
137
138 def get_or_create_range(self, tens):
139 for rng in self.ranges.values():
140 # Return the live range of the tensor (or it's cpu/npu clone)
141 if any(tensor in rng.tensors for tensor in [tens, tens.npu_tensor, tens.cpu_tensor]):
142 return rng
143
144 # No live range found for the tensor, create a new one
145 rng = LiveRange(tens)
146 self.ranges[tens] = rng
147 return rng
148
149 def fuse_ranges(self, in_tens, out_tens):
150 live_range = self.get_or_create_range(in_tens)
151 assert out_tens not in self.ranges, out_tens
152 live_range.add_tensor(out_tens)
153 self.ranges[out_tens] = live_range
154 return live_range
155
156
157def extract_live_ranges_from_passes(
158 sg,
159 target_mem_area,
160 mark_output_tensors_overlapping_with_input_tensors=False,
161 ignore_subgraph_input_output_tensors=False,
162):
163 lr_graph = LiveRangeGraph()
164
165 if ignore_subgraph_input_output_tensors:
166 lr_graph.ignore_tensors.update(sg.input_tensors)
167 lr_graph.ignore_tensors.update(sg.output_tensors)
168
169 def tensor_should_be_ignored(tens, target_mem_area):
170 if tens.mem_area != target_mem_area:
171 return True
172 if tens in lr_graph.ignore_tensors:
173 return True
174 if tens.name.endswith("reshape_shape_npu"):
175 # Reshape tensor, no need to allocate
176 lr_graph.ignore_tensors.add(tens)
177 return True
178 return False
179
180 # Merge only memory operations in the NPU subgraphs
181 if sg.placement == PassPlacement.Npu:
182 merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
183
184 for idx, ps in enumerate(sg.passes):
185 ps.time = 2 * idx
186
187 time_for_pass = ps.time
188
189 for tens in ps.inputs:
190 if tensor_should_be_ignored(tens, target_mem_area):
191 continue
192 rng = lr_graph.get_or_create_range(tens)
193 rng.mark_usage(time_for_pass)
194
195 for tens in ps.intermediates:
196 if tensor_should_be_ignored(tens, target_mem_area):
197 continue
198 rng = lr_graph.get_or_create_range(tens)
199 rng.mark_usage(time_for_pass)
200
201 for tens in ps.outputs:
202 if tensor_should_be_ignored(tens, target_mem_area):
203 continue
204 rng = lr_graph.get_or_create_range(tens)
205 output_time = time_for_pass
206 if not mark_output_tensors_overlapping_with_input_tensors and ps.is_element_wise:
207 output_time += 1
208 rng.mark_usage(output_time)
209
210 end_time = len(sg.passes) * 2
211 for tens in sg.output_tensors:
212 if tensor_should_be_ignored(tens, target_mem_area):
213 continue
214 rng = lr_graph.get_or_create_range(tens)
215 rng.mark_usage(end_time)
216
217 return lr_graph
218
219
220def extract_live_ranges_from_cascaded_passes(
221 sg,
222 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200223 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100224 mark_output_tensors_overlapping_with_input_tensors=False,
225 use_ifm_ofm_overlap=True,
226 ignore_subgraph_input_output_tensors=False,
227 lr_graph=None,
228):
Diego Russoea6111a2020-04-14 18:41:58 +0100229 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100230 lr_graph = LiveRangeGraph()
231
232 if sg in lr_graph.processed_subgraphs:
233 # if subgraph has been processed already, return the lr_graph as is
234 return lr_graph
235
236 if ignore_subgraph_input_output_tensors:
237 lr_graph.ignore_tensors.update(sg.input_tensors)
238 lr_graph.ignore_tensors.update(sg.output_tensors)
239
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200240 def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
241 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100242 return True
243 if tens in lr_graph.ignore_tensors:
244 return True
245 if tens.name.endswith("reshape_shape_npu"):
246 # Reshape tensor, no need to allocate
247 lr_graph.ignore_tensors.add(tens)
248 return True
249 return False
250
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200251 def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set):
252 for ps in sg.passes:
253 if ps.placement == PassPlacement.MemoryOnly:
254 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
255 input_tensor = ps.inputs[0]
256 output_tensor = ps.outputs[0]
257 # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
258 # or output, fuse the live-range with the Cpu tensors' live-range instead.
259 input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
260 output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
261 if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
262 tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
263 ):
264 lr_graph.fuse_ranges(input_tensor, output_tensor)
265
Tim Hall79d07d22020-04-27 18:20:16 +0100266 # Merge only memory operations in the NPU subgraphs
267 if sg.placement == PassPlacement.Npu:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200268 merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100269
270 for cps in sg.cascaded_passes:
271 cps.time = lr_graph.current_time
272
273 time_for_pass = cps.time
274
275 is_element_wise = cps.is_element_wise
276
277 for tens in cps.inputs:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200278 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100279 continue
280 rng = lr_graph.get_or_create_range(tens)
281 rng.mark_usage(time_for_pass)
282
283 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200284
285 if cps_primary_op and cps_primary_op.type == "NpuOp" and MemType.Permanent_CPU not in target_mem_type_set:
Tim Hall79d07d22020-04-27 18:20:16 +0100286 # If the primary-op is an NpuOp that means this is where an Npu subgraph
287 # is called. Go into said subgraph and extract live ranges before continuing.
288 npu_sg = cps_primary_op.attrs["subgraph"]
289 lr_graph = extract_live_ranges_from_cascaded_passes(
290 npu_sg,
291 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200292 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100293 mark_output_tensors_overlapping_with_input_tensors,
294 use_ifm_ofm_overlap,
295 False,
296 lr_graph,
297 )
298 # Set the new time after handling the Npu subgraph
299 time_for_pass = lr_graph.current_time
300 cps.time = time_for_pass
301
302 for tens in cps.intermediates:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200303 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100304 continue
305 rng = lr_graph.get_or_create_range(tens)
306 rng.mark_usage(time_for_pass)
307
308 for tens in cps.outputs:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200309 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100310 continue
311 rng = lr_graph.get_or_create_range(tens)
312 output_time = time_for_pass
313 if not mark_output_tensors_overlapping_with_input_tensors and is_element_wise:
314 output_time += 1
315 rng.mark_usage(output_time)
316
317 if use_ifm_ofm_overlap:
318 # fill allowed overlap for ifm and ofm tensor
319 ifm_tensor = cps.passes[0].ifm_tensor
320 ofm_tensor = cps.passes[-1].ofm_tensor
321 if (
322 ifm_tensor is not None
323 and ofm_tensor is not None
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200324 and not tensor_should_be_ignored(ifm_tensor, target_mem_area, target_mem_type_set)
325 and not tensor_should_be_ignored(ofm_tensor, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100326 ):
327 lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
328 cps
329 )
330
331 lr_graph.current_time += 2
332
333 end_time = 0
334 for rng in lr_graph.ranges.values():
335 # Find the maximum end time of all live-ranges in the graph
336 end_time = max(end_time, rng.end_time)
337
338 for tens in sg.output_tensors:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200339 if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100340 continue
341 rng = lr_graph.get_or_create_range(tens)
342 rng.mark_usage(end_time)
343
344 # Add subgraph to set of processed subgraphs
345 lr_graph.processed_subgraphs.add(sg)
346 return lr_graph