blob: b884035570a703f6f1b8c548c713d77b8342af69 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Tim Hall79d07d22020-04-27 18:20:16 +010019from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
Diego Russoe8a10452020-04-21 17:39:10 +010020from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020021from .operation import Op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020022from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010023from .tensor import Tensor
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26class LiveRange:
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020027 def __init__(self, tens, alignment):
Tim Hall79d07d22020-04-27 18:20:16 +010028 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
29 self.start_time = 99999999999
30 self.end_time = -1
31 self.size = 0
32 self.name = ""
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020033 self.alignment = alignment
Tim Hall79d07d22020-04-27 18:20:16 +010034
35 if tens:
36 self.add_tensor(tens)
37
38 def __str__(self):
39 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
40
41 __repr__ = __str__
42
43 def add_tensor(self, tens):
44 if self.size == 0:
45 self.size = tens.storage_size()
46 self.name = tens.name # LiveRange will be named after the first tensor added
47 else:
48 assert (
49 self.size >= tens.storage_size()
50 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
51
52 self.tensors.append(tens)
53
54 def mark_usage(self, op_time):
55 if op_time == -1:
56 return
57 op_time_start = op_time
58 op_time_end = op_time + 1
59
60 self.start_time = min(self.start_time, op_time_start)
61 self.end_time = max(self.end_time, op_time_end)
62
63 def overlaps_ranges(self, other):
64 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
65
66 def overlaps_address(self, other):
67 # Returns the first pair of tensors in this LiveRange and 'other' which have
68 # overlapping addresses
69 for tens in self.tensors:
70 for other_tens in other.tensors:
71 if max(tens.address, other_tens.address) < min(
72 tens.address + self.size, other_tens.address + other.size
73 ):
74 return True, tens, other_tens
75
76 return False, None, None
77
78 def __lt__(self, other):
79 if self.start_time != other.start_time:
80 return self.start_time < other.start_time
81 if self.end_time != other.end_time:
82 return self.end_time < other.end_time
83 if self.size != other.size:
84 return self.size < other.size
85 return self.name < other.name
86
87 def set_address(self, address):
Jacob Bohlin1a666972020-09-11 10:04:15 +020088 # Set address of all tensors in LiveRange
Tim Hall79d07d22020-04-27 18:20:16 +010089 for tens in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +020090 tens.address = address
91
92 return address
Tim Hall79d07d22020-04-27 18:20:16 +010093
94 def get_alignment(self):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020095 return self.alignment
Tim Hall79d07d22020-04-27 18:20:16 +010096
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020097 def set_alignment(self, alignment):
98 self.alignment = max(self.alignment, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +010099
100
Tim Hall79d07d22020-04-27 18:20:16 +0100101class LiveRangeGraph:
102 def __init__(self):
103 self.ranges = {} # tens -> range
104 self.allowed_overlaps = {} # (tens,tens) -> overlap_int
105 self.ignore_tensors = set()
106 self.processed_subgraphs = set()
107 self.current_time = 0
108
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200109 def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200110 # Return the live range of the tensor (or any of its clones)
111 for existing_tensor, rng in self.ranges.items():
112 if tens.equivalent(existing_tensor):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200113 rng.set_alignment(alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100114 return rng
115
116 # No live range found for the tensor, create a new one
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200117 rng = LiveRange(tens, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100118 self.ranges[tens] = rng
119 return rng
120
121 def fuse_ranges(self, in_tens, out_tens):
122 live_range = self.get_or_create_range(in_tens)
123 assert out_tens not in self.ranges, out_tens
124 live_range.add_tensor(out_tens)
125 self.ranges[out_tens] = live_range
126 return live_range
127
128
Patrik Gustavssona151f592020-10-16 13:59:52 +0200129def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
130 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
131 return True
132 if tens in lr_graph.ignore_tensors:
133 return True
134 if tens.name.endswith("reshape_shape_npu"):
135 # Reshape tensor, no need to allocate
136 lr_graph.ignore_tensors.add(tens)
137 return True
138 return False
139
140
141# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
142def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
143 for ps in sg.passes:
144 if ps.placement == PassPlacement.MemoryOnly:
145 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
146 input_tensor = ps.inputs[0]
147 output_tensor = ps.outputs[0]
148 if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
149 tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
150 ):
151 lr_graph.fuse_ranges(input_tensor, output_tensor)
152 elif ps.is_element_wise:
153 merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
154
155
156# Tries to merge ifm/ofm live of elementwise op
157def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
158 elem_op = None
159 for op in ps.ops:
160 if op.type.is_elementwise_op():
161 assert elem_op is None
162 elem_op = op
163
164 if elem_op is not None and not tensor_should_be_ignored(
165 lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
166 ):
167 # Check if overwriting the inputs can be allowed
168 if elem_op.type not in (Op.SHL, Op.SHR):
169 inps = []
170 if (
171 elem_op.ifm is not None
172 and elem_op.ifm.shape != []
173 and elem_op.ifm.mem_area == target_mem_area
174 and elem_op.ifm.mem_type in target_mem_type_set
175 ):
176 inps.append(elem_op.ifm)
177 if (
178 elem_op.ifm2 is not None
179 and elem_op.ifm2.shape != []
180 and elem_op.ifm2.mem_area == target_mem_area
181 and elem_op.ifm.mem_type in target_mem_type_set
182 ):
183 inps.append(elem_op.ifm2)
184
185 if len(inps) > 0:
186 for inp in inps:
187 # check input format, dtype, broadcasting or if there are more input consumers
188 if (
189 inp.format == elem_op.ofm.format
190 and inp.dtype == elem_op.ofm.dtype
191 and inp.shape == elem_op.ofm.shape
192 and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
193 ):
194 lr_graph.fuse_ranges(inp, elem_op.ofm)
195 break
196
197
Tim Hall79d07d22020-04-27 18:20:16 +0100198def extract_live_ranges_from_passes(
199 sg,
200 target_mem_area,
Patrik Gustavssona151f592020-10-16 13:59:52 +0200201 target_mem_type=set((MemType.Scratch, MemType.Scratch_fast)),
Tim Hall79d07d22020-04-27 18:20:16 +0100202 ignore_subgraph_input_output_tensors=False,
203):
204 lr_graph = LiveRangeGraph()
205
206 if ignore_subgraph_input_output_tensors:
207 lr_graph.ignore_tensors.update(sg.input_tensors)
208 lr_graph.ignore_tensors.update(sg.output_tensors)
209
Patrik Gustavssona151f592020-10-16 13:59:52 +0200210 # Try to merge live ranges of operations in the NPU subgraphs
Tim Hall79d07d22020-04-27 18:20:16 +0100211 if sg.placement == PassPlacement.Npu:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200212 merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type)
Tim Hall79d07d22020-04-27 18:20:16 +0100213
214 for idx, ps in enumerate(sg.passes):
215 ps.time = 2 * idx
216
217 time_for_pass = ps.time
218
Patrik Gustavssona151f592020-10-16 13:59:52 +0200219 for tens in ps.inputs + ps.intermediates + ps.outputs:
220 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
Tim Hall79d07d22020-04-27 18:20:16 +0100221 continue
222 rng = lr_graph.get_or_create_range(tens)
223 rng.mark_usage(time_for_pass)
224
Tim Hall79d07d22020-04-27 18:20:16 +0100225 end_time = len(sg.passes) * 2
226 for tens in sg.output_tensors:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200227 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
Tim Hall79d07d22020-04-27 18:20:16 +0100228 continue
229 rng = lr_graph.get_or_create_range(tens)
230 rng.mark_usage(end_time)
231
232 return lr_graph
233
234
235def extract_live_ranges_from_cascaded_passes(
236 sg,
237 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200238 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100239 use_ifm_ofm_overlap=True,
240 ignore_subgraph_input_output_tensors=False,
241 lr_graph=None,
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200242 allocation_alignment=Tensor.AllocationQuantum,
Tim Hall79d07d22020-04-27 18:20:16 +0100243):
Diego Russoea6111a2020-04-14 18:41:58 +0100244 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100245 lr_graph = LiveRangeGraph()
246
247 if sg in lr_graph.processed_subgraphs:
248 # if subgraph has been processed already, return the lr_graph as is
249 return lr_graph
250
251 if ignore_subgraph_input_output_tensors:
252 lr_graph.ignore_tensors.update(sg.input_tensors)
253 lr_graph.ignore_tensors.update(sg.output_tensors)
254
Patrik Gustavssona151f592020-10-16 13:59:52 +0200255 # Try to merge live ranges of operations in the NPU subgraphs
Tim Hall79d07d22020-04-27 18:20:16 +0100256 if sg.placement == PassPlacement.Npu:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200257 merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100258
259 for cps in sg.cascaded_passes:
260 cps.time = lr_graph.current_time
261
262 time_for_pass = cps.time
263
Tim Hall79d07d22020-04-27 18:20:16 +0100264 for tens in cps.inputs:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200265 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100266 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200267 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100268 rng.mark_usage(time_for_pass)
269
270 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200271
Louis Verhaardaee5d752020-09-30 09:01:52 +0200272 if (
273 cps_primary_op
274 and cps_primary_op.type == Op.CustomNpuOp
275 and MemType.Permanent_CPU not in target_mem_type_set
276 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100277 # If the primary-op is an NpuOp that means this is where an Npu subgraph
278 # is called. Go into said subgraph and extract live ranges before continuing.
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200279 # Use default allocation alignment of 16 for Npu tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100280 npu_sg = cps_primary_op.attrs["subgraph"]
281 lr_graph = extract_live_ranges_from_cascaded_passes(
Patrik Gustavssona151f592020-10-16 13:59:52 +0200282 npu_sg, target_mem_area, target_mem_type_set, use_ifm_ofm_overlap, False, lr_graph,
Tim Hall79d07d22020-04-27 18:20:16 +0100283 )
284 # Set the new time after handling the Npu subgraph
285 time_for_pass = lr_graph.current_time
286 cps.time = time_for_pass
287
Patrik Gustavssona151f592020-10-16 13:59:52 +0200288 for tens in cps.intermediates + cps.outputs:
289 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100290 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200291 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100292 rng.mark_usage(time_for_pass)
293
Tim Hall79d07d22020-04-27 18:20:16 +0100294 if use_ifm_ofm_overlap:
295 # fill allowed overlap for ifm and ofm tensor
296 ifm_tensor = cps.passes[0].ifm_tensor
297 ofm_tensor = cps.passes[-1].ofm_tensor
298 if (
299 ifm_tensor is not None
300 and ofm_tensor is not None
Patrik Gustavssona151f592020-10-16 13:59:52 +0200301 and not tensor_should_be_ignored(lr_graph, ifm_tensor, target_mem_area, target_mem_type_set)
302 and not tensor_should_be_ignored(lr_graph, ofm_tensor, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100303 ):
304 lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
305 cps
306 )
307
308 lr_graph.current_time += 2
309
310 end_time = 0
311 for rng in lr_graph.ranges.values():
312 # Find the maximum end time of all live-ranges in the graph
313 end_time = max(end_time, rng.end_time)
314
315 for tens in sg.output_tensors:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200316 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100317 continue
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200318 rng = lr_graph.get_or_create_range(tens, allocation_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100319 rng.mark_usage(end_time)
320
321 # Add subgraph to set of processed subgraphs
322 lr_graph.processed_subgraphs.add(sg)
323 return lr_graph