MLBEDSW-4034: New Scheduler Size or Performance Optimisation

 - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index de001e5..d75a167 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -18,10 +18,14 @@
 # Can work with either a pass packed subgraph or a scheduled subgraph.
 from typing import List
 
+import numpy as np
+
 from .nn_graph import PassPlacement
 from .operation import Op
+from .tensor import MemArea
 from .tensor import MemType
 from .tensor import Tensor
+from .tensor import TensorPurpose
 
 
 class LiveRange:
@@ -32,6 +36,7 @@
         self.size = 0
         self.name = ""
         self.alignment = alignment
+        self.mem_area = tens.mem_area if tens else MemArea.Unknown
 
         if tens:
             self.add_tensor(tens)
@@ -52,15 +57,19 @@
 
         self.tensors.append(tens)
 
-    def mark_usage(self, op_time):
-        if op_time == -1:
+    def mark_usage(self, op_time, op_length=1):
+        op_time_start = max(op_time, 0)
+        op_time_end = op_time + op_length
+        if op_time_end <= op_time_start:
             return
-        op_time_start = op_time
-        op_time_end = op_time + 1
 
         self.start_time = min(self.start_time, op_time_start)
         self.end_time = max(self.end_time, op_time_end)
 
+    def set_buffer_size(self, buffer_size):
+        self.size = buffer_size
+        self.mem_area = MemArea.Sram
+
     def overlaps_ranges(self, other):
         return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
 
@@ -106,6 +115,7 @@
         self.ignore_tensors = set()
         self.processed_subgraphs = set()
         self.current_time = 0
+        self.end_time = None
 
     def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
         # Return the live range of the tensor (or any of its clones)
@@ -127,6 +137,23 @@
         self.ranges[out_tens] = live_range
         return live_range
 
+    def update_endtime(self):
+        self.end_time = 0
+        for rng in self.ranges.values():
+            self.end_time = max(self.end_time, rng.end_time)
+        return self.end_time + 1
+
+    def get_temporal_memory_usage(self, target_mem_area):
+        if not self.end_time:
+            self.update_endtime()
+        usage = np.zeros(self.end_time, dtype=np.int32)
+        for rng in self.ranges.values():
+            if rng.mem_area == target_mem_area:
+                # End time is inclusive
+                usage[rng.start_time : rng.end_time + 1] += rng.size
+
+        return usage
+
 
 def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
     if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
@@ -279,9 +306,7 @@
             # is called. Go into said subgraph and extract live ranges before continuing.
             # Use default allocation alignment of 16 for Npu tensors
             npu_sg = cps_primary_op.attrs["subgraph"]
-            lr_graph = extract_live_ranges_from_cascaded_passes(
-                npu_sg, target_mem_area, target_mem_type_set, False, lr_graph,
-            )
+            lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
             # Set the new time after handling the Npu subgraph
             time_for_pass = lr_graph.current_time
             cps.time = time_for_pass
@@ -308,3 +333,89 @@
     # Add subgraph to set of processed subgraphs
     lr_graph.processed_subgraphs.add(sg)
     return lr_graph
+
+
+def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
+    assert lr_graph is not None
+    sg_time = lr_graph.current_time
+    for ps in sg.passes:
+        for tens in ps.inputs + ps.outputs + ps.intermediates:
+            if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
+                lr_graph, tens, target_mem_area, target_mem_type_set
+            ):
+                continue
+
+            rng = lr_graph.get_or_create_range(tens)
+            rng.mark_usage(sg_time)
+
+    for sched_op, op_info in sg.schedule.cost_map.items():
+        if op_info.npu_weights_tensor and not (
+            tensor_should_be_ignored(lr_graph, op_info.npu_weights_tensor, target_mem_area, target_mem_type_set)
+        ):
+            rng = lr_graph.get_or_create_range(op_info.npu_weights_tensor)
+            rng.mark_usage(sg_time)
+
+    lr_graph.current_time += 1
+    return lr_graph
+
+
+def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
+    time_for_cascade = {}
+    for sched_op in sg.sched_ops:
+        op_info = sg.schedule.cost_map[sched_op]
+        cascade = op_info.cascade
+        cascade_info = sg.schedule.cascades.get(cascade, None)
+
+        time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)
+
+        op_info.time_index = time_to_set
+
+        # Mark usage for all tensors related to this Pass
+        ps = sched_op.parent_ps
+        for tens in ps.inputs + ps.outputs + ps.intermediates:
+            if (
+                target_mem_area == MemArea.Sram
+                and cascade_info
+                and tens == ps.ifm_tensor
+                and sched_op in cascade_info.buffers
+            ):
+                # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
+                # for enabling temporal memory snapshots without modifying the original Tensor
+                rng = lr_graph.get_or_create_range(tens)
+                rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
+            elif (
+                tens.purpose == TensorPurpose.Weights
+                or tens.purpose == TensorPurpose.FSBias
+                or tens.mem_type not in target_mem_type_set
+                or tens.mem_area != target_mem_area
+            ):
+                continue
+
+            else:
+                rng = lr_graph.get_or_create_range(tens)
+
+            rng.mark_usage(time_to_set)
+
+        weight_tens = op_info.buffered_weight_tensor
+        if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+            rng = lr_graph.get_or_create_range(weight_tens)
+            if weight_tens.pre_buffer:
+                rng.mark_usage(time_to_set - 1, 2)
+            else:
+                rng.mark_usage(time_to_set)
+
+        if time_to_set == lr_graph.current_time:
+            lr_graph.current_time += 2
+
+        if cascade != 0:
+            time_for_cascade[cascade] = time_to_set
+
+    end_time = lr_graph.update_endtime()
+
+    for tens in sg.output_tensors:
+        if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
+            continue
+        rng = lr_graph.get_or_create_range(tens)
+        rng.mark_usage(end_time)
+
+    return lr_graph