blob: d75a167db4d4cb9f00b5822f3b493d65d6e33bad [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
18# Can work with either a pass packed subgraph or a scheduled subgraph.
Louis Verhaard226ecaf2021-03-30 10:18:28 +020019from typing import List
20
Tim Halld8339a72021-05-27 18:49:40 +010021import numpy as np
22
Diego Russoe8a10452020-04-21 17:39:10 +010023from .nn_graph import PassPlacement
Louis Verhaardaee5d752020-09-30 09:01:52 +020024from .operation import Op
Tim Halld8339a72021-05-27 18:49:40 +010025from .tensor import MemArea
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020026from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010027from .tensor import Tensor
Tim Halld8339a72021-05-27 18:49:40 +010028from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010029
30
31class LiveRange:
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020032 def __init__(self, tens, alignment):
Tim Hall79d07d22020-04-27 18:20:16 +010033 self.tensors = [] # Tensors that are assigned to the same LiveRange will be allocated to the same address
34 self.start_time = 99999999999
35 self.end_time = -1
36 self.size = 0
37 self.name = ""
Jacob Bohlin0628a8c2020-08-28 13:25:14 +020038 self.alignment = alignment
Tim Halld8339a72021-05-27 18:49:40 +010039 self.mem_area = tens.mem_area if tens else MemArea.Unknown
Tim Hall79d07d22020-04-27 18:20:16 +010040
41 if tens:
42 self.add_tensor(tens)
43
44 def __str__(self):
45 return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
46
47 __repr__ = __str__
48
49 def add_tensor(self, tens):
50 if self.size == 0:
51 self.size = tens.storage_size()
52 self.name = tens.name # LiveRange will be named after the first tensor added
53 else:
54 assert (
55 self.size >= tens.storage_size()
56 ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."
57
58 self.tensors.append(tens)
59
Tim Halld8339a72021-05-27 18:49:40 +010060 def mark_usage(self, op_time, op_length=1):
61 op_time_start = max(op_time, 0)
62 op_time_end = op_time + op_length
63 if op_time_end <= op_time_start:
Tim Hall79d07d22020-04-27 18:20:16 +010064 return
Tim Hall79d07d22020-04-27 18:20:16 +010065
66 self.start_time = min(self.start_time, op_time_start)
67 self.end_time = max(self.end_time, op_time_end)
68
Tim Halld8339a72021-05-27 18:49:40 +010069 def set_buffer_size(self, buffer_size):
70 self.size = buffer_size
71 self.mem_area = MemArea.Sram
72
Tim Hall79d07d22020-04-27 18:20:16 +010073 def overlaps_ranges(self, other):
74 return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
75
76 def overlaps_address(self, other):
77 # Returns the first pair of tensors in this LiveRange and 'other' which have
78 # overlapping addresses
79 for tens in self.tensors:
80 for other_tens in other.tensors:
81 if max(tens.address, other_tens.address) < min(
82 tens.address + self.size, other_tens.address + other.size
83 ):
84 return True, tens, other_tens
85
86 return False, None, None
87
88 def __lt__(self, other):
89 if self.start_time != other.start_time:
90 return self.start_time < other.start_time
91 if self.end_time != other.end_time:
92 return self.end_time < other.end_time
93 if self.size != other.size:
94 return self.size < other.size
95 return self.name < other.name
96
97 def set_address(self, address):
Jacob Bohlin1a666972020-09-11 10:04:15 +020098 # Set address of all tensors in LiveRange
Tim Hall79d07d22020-04-27 18:20:16 +010099 for tens in self.tensors:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200100 tens.address = address
101
102 return address
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def get_alignment(self):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200105 return self.alignment
Tim Hall79d07d22020-04-27 18:20:16 +0100106
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200107 def set_alignment(self, alignment):
108 self.alignment = max(self.alignment, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100109
110
Tim Hall79d07d22020-04-27 18:20:16 +0100111class LiveRangeGraph:
112 def __init__(self):
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200113 self.lrs: List[LiveRange] = [] # List of all created ranges
Tim Hall79d07d22020-04-27 18:20:16 +0100114 self.ranges = {} # tens -> range
Tim Hall79d07d22020-04-27 18:20:16 +0100115 self.ignore_tensors = set()
116 self.processed_subgraphs = set()
117 self.current_time = 0
Tim Halld8339a72021-05-27 18:49:40 +0100118 self.end_time = None
Tim Hall79d07d22020-04-27 18:20:16 +0100119
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200120 def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200121 # Return the live range of the tensor (or any of its clones)
122 for existing_tensor, rng in self.ranges.items():
123 if tens.equivalent(existing_tensor):
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200124 rng.set_alignment(alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100125 return rng
126
127 # No live range found for the tensor, create a new one
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200128 rng = LiveRange(tens, alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100129 self.ranges[tens] = rng
Louis Verhaard226ecaf2021-03-30 10:18:28 +0200130 self.lrs.append(rng)
Tim Hall79d07d22020-04-27 18:20:16 +0100131 return rng
132
133 def fuse_ranges(self, in_tens, out_tens):
134 live_range = self.get_or_create_range(in_tens)
135 assert out_tens not in self.ranges, out_tens
136 live_range.add_tensor(out_tens)
137 self.ranges[out_tens] = live_range
138 return live_range
139
Tim Halld8339a72021-05-27 18:49:40 +0100140 def update_endtime(self):
141 self.end_time = 0
142 for rng in self.ranges.values():
143 self.end_time = max(self.end_time, rng.end_time)
144 return self.end_time + 1
145
146 def get_temporal_memory_usage(self, target_mem_area):
147 if not self.end_time:
148 self.update_endtime()
149 usage = np.zeros(self.end_time, dtype=np.int32)
150 for rng in self.ranges.values():
151 if rng.mem_area == target_mem_area:
152 # End time is inclusive
153 usage[rng.start_time : rng.end_time + 1] += rng.size
154
155 return usage
156
Tim Hall79d07d22020-04-27 18:20:16 +0100157
Patrik Gustavssona151f592020-10-16 13:59:52 +0200158def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
159 if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
160 return True
161 if tens in lr_graph.ignore_tensors:
162 return True
163 if tens.name.endswith("reshape_shape_npu"):
164 # Reshape tensor, no need to allocate
165 lr_graph.ignore_tensors.add(tens)
166 return True
167 return False
168
169
170# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops
171def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set):
172 for ps in sg.passes:
173 if ps.placement == PassPlacement.MemoryOnly:
174 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
175 input_tensor = ps.inputs[0]
176 output_tensor = ps.outputs[0]
177 if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not (
178 tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set)
179 ):
180 lr_graph.fuse_ranges(input_tensor, output_tensor)
181 elif ps.is_element_wise:
182 merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set)
183
184
185# Tries to merge ifm/ofm live of elementwise op
186def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set):
187 elem_op = None
188 for op in ps.ops:
189 if op.type.is_elementwise_op():
190 assert elem_op is None
191 elem_op = op
192
193 if elem_op is not None and not tensor_should_be_ignored(
194 lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set
195 ):
196 # Check if overwriting the inputs can be allowed
197 if elem_op.type not in (Op.SHL, Op.SHR):
198 inps = []
199 if (
200 elem_op.ifm is not None
201 and elem_op.ifm.shape != []
202 and elem_op.ifm.mem_area == target_mem_area
203 and elem_op.ifm.mem_type in target_mem_type_set
204 ):
205 inps.append(elem_op.ifm)
206 if (
207 elem_op.ifm2 is not None
208 and elem_op.ifm2.shape != []
209 and elem_op.ifm2.mem_area == target_mem_area
210 and elem_op.ifm.mem_type in target_mem_type_set
211 ):
212 inps.append(elem_op.ifm2)
213
214 if len(inps) > 0:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100215 for i, inp in enumerate(inps):
Patrik Gustavssona151f592020-10-16 13:59:52 +0200216 # check input format, dtype, broadcasting or if there are more input consumers
217 if (
218 inp.format == elem_op.ofm.format
219 and inp.dtype == elem_op.ofm.dtype
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100220 and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0]
Patrik Gustavssona151f592020-10-16 13:59:52 +0200221 and (len(inp.consumer_list) == 1 and len(inp.ops) == 1)
222 ):
223 lr_graph.fuse_ranges(inp, elem_op.ofm)
224 break
225
226
Tim Hall79d07d22020-04-27 18:20:16 +0100227def extract_live_ranges_from_passes(
Michael McGeagh6f725262020-12-03 15:21:36 +0000228 sg, target_mem_area, target_mem_type_set=None, ignore_subgraph_input_output_tensors=False,
Tim Hall79d07d22020-04-27 18:20:16 +0100229):
230 lr_graph = LiveRangeGraph()
231
232 if ignore_subgraph_input_output_tensors:
233 lr_graph.ignore_tensors.update(sg.input_tensors)
234 lr_graph.ignore_tensors.update(sg.output_tensors)
235
Michael McGeagh6f725262020-12-03 15:21:36 +0000236 if target_mem_type_set is None:
237 target_mem_type_set = set((MemType.Scratch, MemType.Scratch_fast))
238
Patrik Gustavssona151f592020-10-16 13:59:52 +0200239 # Try to merge live ranges of operations in the NPU subgraphs
Tim Hall79d07d22020-04-27 18:20:16 +0100240 if sg.placement == PassPlacement.Npu:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100241 merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100242
243 for idx, ps in enumerate(sg.passes):
244 ps.time = 2 * idx
245
246 time_for_pass = ps.time
247
Patrik Gustavssona151f592020-10-16 13:59:52 +0200248 for tens in ps.inputs + ps.intermediates + ps.outputs:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100249 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100250 continue
251 rng = lr_graph.get_or_create_range(tens)
252 rng.mark_usage(time_for_pass)
253
Tim Hall79d07d22020-04-27 18:20:16 +0100254 end_time = len(sg.passes) * 2
255 for tens in sg.output_tensors:
Patrik Gustavssonfad90c22020-11-03 13:07:40 +0100256 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100257 continue
258 rng = lr_graph.get_or_create_range(tens)
259 rng.mark_usage(end_time)
260
261 return lr_graph
262
263
264def extract_live_ranges_from_cascaded_passes(
265 sg,
266 target_mem_area,
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200267 target_mem_type_set,
Tim Hall79d07d22020-04-27 18:20:16 +0100268 ignore_subgraph_input_output_tensors=False,
269 lr_graph=None,
Tim Hallb9b515c2020-11-01 21:27:19 +0000270 cpu_tensor_alignment=Tensor.AllocationQuantum,
Tim Hall79d07d22020-04-27 18:20:16 +0100271):
Diego Russoea6111a2020-04-14 18:41:58 +0100272 if lr_graph is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100273 lr_graph = LiveRangeGraph()
274
275 if sg in lr_graph.processed_subgraphs:
276 # if subgraph has been processed already, return the lr_graph as is
277 return lr_graph
278
279 if ignore_subgraph_input_output_tensors:
280 lr_graph.ignore_tensors.update(sg.input_tensors)
281 lr_graph.ignore_tensors.update(sg.output_tensors)
282
Patrik Gustavssona151f592020-10-16 13:59:52 +0200283 # Try to merge live ranges of operations in the NPU subgraphs
Tim Hall79d07d22020-04-27 18:20:16 +0100284 if sg.placement == PassPlacement.Npu:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200285 merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
Tim Hall79d07d22020-04-27 18:20:16 +0100286
287 for cps in sg.cascaded_passes:
288 cps.time = lr_graph.current_time
289
290 time_for_pass = cps.time
291
Tim Hall79d07d22020-04-27 18:20:16 +0100292 for tens in cps.inputs:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200293 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100294 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000295 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100296 rng.mark_usage(time_for_pass)
297
298 cps_primary_op = cps.passes[0].primary_op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200299
Louis Verhaardaee5d752020-09-30 09:01:52 +0200300 if (
301 cps_primary_op
302 and cps_primary_op.type == Op.CustomNpuOp
303 and MemType.Permanent_CPU not in target_mem_type_set
304 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100305 # If the primary-op is an NpuOp that means this is where an Npu subgraph
306 # is called. Go into said subgraph and extract live ranges before continuing.
Jacob Bohlin0628a8c2020-08-28 13:25:14 +0200307 # Use default allocation alignment of 16 for Npu tensors
Tim Hall79d07d22020-04-27 18:20:16 +0100308 npu_sg = cps_primary_op.attrs["subgraph"]
Tim Halld8339a72021-05-27 18:49:40 +0100309 lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
Tim Hall79d07d22020-04-27 18:20:16 +0100310 # Set the new time after handling the Npu subgraph
311 time_for_pass = lr_graph.current_time
312 cps.time = time_for_pass
313
Patrik Gustavssona151f592020-10-16 13:59:52 +0200314 for tens in cps.intermediates + cps.outputs:
315 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100316 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000317 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100318 rng.mark_usage(time_for_pass)
319
Tim Hall79d07d22020-04-27 18:20:16 +0100320 lr_graph.current_time += 2
321
322 end_time = 0
323 for rng in lr_graph.ranges.values():
324 # Find the maximum end time of all live-ranges in the graph
325 end_time = max(end_time, rng.end_time)
326
327 for tens in sg.output_tensors:
Patrik Gustavssona151f592020-10-16 13:59:52 +0200328 if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
Tim Hall79d07d22020-04-27 18:20:16 +0100329 continue
Tim Hallb9b515c2020-11-01 21:27:19 +0000330 rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
Tim Hall79d07d22020-04-27 18:20:16 +0100331 rng.mark_usage(end_time)
332
333 # Add subgraph to set of processed subgraphs
334 lr_graph.processed_subgraphs.add(sg)
335 return lr_graph
Tim Halld8339a72021-05-27 18:49:40 +0100336
337
338def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
339 assert lr_graph is not None
340 sg_time = lr_graph.current_time
341 for ps in sg.passes:
342 for tens in ps.inputs + ps.outputs + ps.intermediates:
343 if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
344 lr_graph, tens, target_mem_area, target_mem_type_set
345 ):
346 continue
347
348 rng = lr_graph.get_or_create_range(tens)
349 rng.mark_usage(sg_time)
350
351 for sched_op, op_info in sg.schedule.cost_map.items():
352 if op_info.npu_weights_tensor and not (
353 tensor_should_be_ignored(lr_graph, op_info.npu_weights_tensor, target_mem_area, target_mem_type_set)
354 ):
355 rng = lr_graph.get_or_create_range(op_info.npu_weights_tensor)
356 rng.mark_usage(sg_time)
357
358 lr_graph.current_time += 1
359 return lr_graph
360
361
362def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
363 time_for_cascade = {}
364 for sched_op in sg.sched_ops:
365 op_info = sg.schedule.cost_map[sched_op]
366 cascade = op_info.cascade
367 cascade_info = sg.schedule.cascades.get(cascade, None)
368
369 time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)
370
371 op_info.time_index = time_to_set
372
373 # Mark usage for all tensors related to this Pass
374 ps = sched_op.parent_ps
375 for tens in ps.inputs + ps.outputs + ps.intermediates:
376 if (
377 target_mem_area == MemArea.Sram
378 and cascade_info
379 and tens == ps.ifm_tensor
380 and sched_op in cascade_info.buffers
381 ):
382 # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
383 # for enabling temporal memory snapshots without modifying the original Tensor
384 rng = lr_graph.get_or_create_range(tens)
385 rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
386 elif (
387 tens.purpose == TensorPurpose.Weights
388 or tens.purpose == TensorPurpose.FSBias
389 or tens.mem_type not in target_mem_type_set
390 or tens.mem_area != target_mem_area
391 ):
392 continue
393
394 else:
395 rng = lr_graph.get_or_create_range(tens)
396
397 rng.mark_usage(time_to_set)
398
399 weight_tens = op_info.buffered_weight_tensor
400 if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
401 rng = lr_graph.get_or_create_range(weight_tens)
402 if weight_tens.pre_buffer:
403 rng.mark_usage(time_to_set - 1, 2)
404 else:
405 rng.mark_usage(time_to_set)
406
407 if time_to_set == lr_graph.current_time:
408 lr_graph.current_time += 2
409
410 if cascade != 0:
411 time_for_cascade[cascade] = time_to_set
412
413 end_time = lr_graph.update_endtime()
414
415 for tens in sg.output_tensors:
416 if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
417 continue
418 rng = lr_graph.get_or_create_range(tens)
419 rng.mark_usage(end_time)
420
421 return lr_graph