blob: 24f1f64cb557e70936c899cd0fb0db42930ece49 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
20# Can work with either a pass packed subgraph or a scheduled subgraph.
21
22from .tensor import Tensor, MemArea
23from .nn_graph import TensorPurpose, PassPlacement
24from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
25
26
27class LiveRange:
28 def __init__(self, tens):
29 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
30 self.start_time = 99999999999
31 self.end_time = -1
32 self.size = 0
33 self.name = ""
34
35 if tens:
36 self.add_tensor(tens)
37
38 def __str__(self):
39 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
40
41 __repr__ = __str__
42
43 def add_tensor(self, tens):
44 if self.size == 0:
45 self.size = tens.storage_size()
46 self.name = tens.name # LiveRange will be named after the first tensor added
47 else:
48 assert (
49 self.size >= tens.storage_size()
50 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
51
52 self.tensors.append(tens)
53
54 def mark_usage(self, op_time):
55 if op_time == -1:
56 return
57 op_time_start = op_time
58 op_time_end = op_time + 1
59
60 self.start_time = min(self.start_time, op_time_start)
61 self.end_time = max(self.end_time, op_time_end)
62
63 def overlaps_ranges(self, other):
64 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
65
66 def overlaps_address(self, other):
67 # Returns the first pair of tensors in this LiveRange and 'other' which have
68 # overlapping addresses
69 for tens in self.tensors:
70 for other_tens in other.tensors:
71 if max(tens.address, other_tens.address) < min(
72 tens.address + self.size, other_tens.address + other.size
73 ):
74 return True, tens, other_tens
75
76 return False, None, None
77
78 def __lt__(self, other):
79 if self.start_time != other.start_time:
80 return self.start_time < other.start_time
81 if self.end_time != other.end_time:
82 return self.end_time < other.end_time
83 if self.size != other.size:
84 return self.size < other.size
85 return self.name < other.name
86
87 def set_address(self, address):
88 # Set address of all unaddressed tensors in LiveRange
89 for tens in self.tensors:
90 if tens.address == 0:
91 tens.address = address
92 # Also need to set the address to the tensor's cpu/npu clones
93 if tens.cpu_tensor != None:
94 tens.cpu_tensor.address = address
95 if tens.npu_tensor != None:
96 tens.npu_tensor.address = address
97
98 def get_alignment(self):
99 # Get max alignment of LiveRange's tensors
100 if self.tensors:
101 alignment = 0
102 for tens in self.tensors:
103 alignment = max(alignment, tens.alignment)
104
105 return alignment
106
107 return Tensor.AllocationQuantum
108
109
110def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area):
111 for ps in sg.passes:
112 if ps.placement == PassPlacement.MemoryOnly:
113 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
114 input_tensor = ps.inputs[0]
115 output_tensor = ps.outputs[0]
116 # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
117 # or output, fuse the live-range with the Cpu tensors' live-range instead.
118 input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor != None else input_tensor
119 output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor != None else output_tensor
120 if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
121 output_tensor, target_mem_area
122 ):
123 lr_graph.fuse_ranges(input_tensor, output_tensor)
124
125
126class LiveRangeGraph:
127 def __init__(self):
128 self.ranges = {} # tens -> range
129 self.allowed_overlaps = {} # (tens,tens) -> overlap_int
130 self.ignore_tensors = set()
131 self.processed_subgraphs = set()
132 self.current_time = 0
133
134 def get_or_create_range(self, tens):
135 for rng in self.ranges.values():
136 # Return the live range of the tensor (or it's cpu/npu clone)
137 if any(tensor in rng.tensors for tensor in [tens, tens.npu_tensor, tens.cpu_tensor]):
138 return rng
139
140 # No live range found for the tensor, create a new one
141 rng = LiveRange(tens)
142 self.ranges[tens] = rng
143 return rng
144
145 def fuse_ranges(self, in_tens, out_tens):
146 live_range = self.get_or_create_range(in_tens)
147 assert out_tens not in self.ranges, out_tens
148 live_range.add_tensor(out_tens)
149 self.ranges[out_tens] = live_range
150 return live_range
151
152
153def extract_live_ranges_from_passes(
154 sg,
155 target_mem_area,
156 mark_output_tensors_overlapping_with_input_tensors=False,
157 ignore_subgraph_input_output_tensors=False,
158):
159 lr_graph = LiveRangeGraph()
160
161 if ignore_subgraph_input_output_tensors:
162 lr_graph.ignore_tensors.update(sg.input_tensors)
163 lr_graph.ignore_tensors.update(sg.output_tensors)
164
165 def tensor_should_be_ignored(tens, target_mem_area):
166 if tens.mem_area != target_mem_area:
167 return True
168 if tens in lr_graph.ignore_tensors:
169 return True
170 if tens.name.endswith("reshape_shape_npu"):
171 # Reshape tensor, no need to allocate
172 lr_graph.ignore_tensors.add(tens)
173 return True
174 return False
175
176 # Merge only memory operations in the NPU subgraphs
177 if sg.placement == PassPlacement.Npu:
178 merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
179
180 for idx, ps in enumerate(sg.passes):
181 ps.time = 2 * idx
182
183 time_for_pass = ps.time
184
185 for tens in ps.inputs:
186 if tensor_should_be_ignored(tens, target_mem_area):
187 continue
188 rng = lr_graph.get_or_create_range(tens)
189 rng.mark_usage(time_for_pass)
190
191 for tens in ps.intermediates:
192 if tensor_should_be_ignored(tens, target_mem_area):
193 continue
194 rng = lr_graph.get_or_create_range(tens)
195 rng.mark_usage(time_for_pass)
196
197 for tens in ps.outputs:
198 if tensor_should_be_ignored(tens, target_mem_area):
199 continue
200 rng = lr_graph.get_or_create_range(tens)
201 output_time = time_for_pass
202 if not mark_output_tensors_overlapping_with_input_tensors and ps.is_element_wise:
203 output_time += 1
204 rng.mark_usage(output_time)
205
206 end_time = len(sg.passes) * 2
207 for tens in sg.output_tensors:
208 if tensor_should_be_ignored(tens, target_mem_area):
209 continue
210 rng = lr_graph.get_or_create_range(tens)
211 rng.mark_usage(end_time)
212
213 return lr_graph
214
215
216def extract_live_ranges_from_cascaded_passes(
217 sg,
218 target_mem_area,
219 mark_output_tensors_overlapping_with_input_tensors=False,
220 use_ifm_ofm_overlap=True,
221 ignore_subgraph_input_output_tensors=False,
222 lr_graph=None,
223):
224 if lr_graph == None:
225 lr_graph = LiveRangeGraph()
226
227 if sg in lr_graph.processed_subgraphs:
228 # if subgraph has been processed already, return the lr_graph as is
229 return lr_graph
230
231 if ignore_subgraph_input_output_tensors:
232 lr_graph.ignore_tensors.update(sg.input_tensors)
233 lr_graph.ignore_tensors.update(sg.output_tensors)
234
235 def tensor_should_be_ignored(tens, target_mem_area):
236 if tens.mem_area != target_mem_area:
237 return True
238 if tens in lr_graph.ignore_tensors:
239 return True
240 if tens.name.endswith("reshape_shape_npu"):
241 # Reshape tensor, no need to allocate
242 lr_graph.ignore_tensors.add(tens)
243 return True
244 return False
245
246 # Merge only memory operations in the NPU subgraphs
247 if sg.placement == PassPlacement.Npu:
248 merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
249
250 for cps in sg.cascaded_passes:
251 cps.time = lr_graph.current_time
252
253 time_for_pass = cps.time
254
255 is_element_wise = cps.is_element_wise
256
257 for tens in cps.inputs:
258 if tensor_should_be_ignored(tens, target_mem_area):
259 continue
260 rng = lr_graph.get_or_create_range(tens)
261 rng.mark_usage(time_for_pass)
262
263 cps_primary_op = cps.passes[0].primary_op
264 if cps_primary_op and cps_primary_op.type == "NpuOp" and target_mem_area in set((MemArea.Sram, MemArea.Dram)):
265 # If the primary-op is an NpuOp that means this is where an Npu subgraph
266 # is called. Go into said subgraph and extract live ranges before continuing.
267 npu_sg = cps_primary_op.attrs["subgraph"]
268 lr_graph = extract_live_ranges_from_cascaded_passes(
269 npu_sg,
270 target_mem_area,
271 mark_output_tensors_overlapping_with_input_tensors,
272 use_ifm_ofm_overlap,
273 False,
274 lr_graph,
275 )
276 # Set the new time after handling the Npu subgraph
277 time_for_pass = lr_graph.current_time
278 cps.time = time_for_pass
279
280 for tens in cps.intermediates:
281 if tensor_should_be_ignored(tens, target_mem_area):
282 continue
283 rng = lr_graph.get_or_create_range(tens)
284 rng.mark_usage(time_for_pass)
285
286 for tens in cps.outputs:
287 if tensor_should_be_ignored(tens, target_mem_area):
288 continue
289 rng = lr_graph.get_or_create_range(tens)
290 output_time = time_for_pass
291 if not mark_output_tensors_overlapping_with_input_tensors and is_element_wise:
292 output_time += 1
293 rng.mark_usage(output_time)
294
295 if use_ifm_ofm_overlap:
296 # fill allowed overlap for ifm and ofm tensor
297 ifm_tensor = cps.passes[0].ifm_tensor
298 ofm_tensor = cps.passes[-1].ofm_tensor
299 if (
300 ifm_tensor is not None
301 and ofm_tensor is not None
302 and not tensor_should_be_ignored(ifm_tensor, target_mem_area)
303 and not tensor_should_be_ignored(ofm_tensor, target_mem_area)
304 ):
305 lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
306 cps
307 )
308
309 lr_graph.current_time += 2
310
311 end_time = 0
312 for rng in lr_graph.ranges.values():
313 # Find the maximum end time of all live-ranges in the graph
314 end_time = max(end_time, rng.end_time)
315
316 for tens in sg.output_tensors:
317 if tensor_should_be_ignored(tens, target_mem_area):
318 continue
319 rng = lr_graph.get_or_create_range(tens)
320 rng.mark_usage(end_time)
321
322 # Add subgraph to set of processed subgraphs
323 lr_graph.processed_subgraphs.add(sg)
324 return lr_graph