MLBEDSW-5839: Port of improved spilling behaviour

Ported the improved spilling behaviour from Regor
into Vela. This replaces use_fast_storage_for_feature_maps
with allocate_feature_maps and introduces the class called
FastStorageComponentAllocator.

Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com>
Change-Id: I34785840c905a79750a62863773015b00fb43387
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 1aaaadd..fc94e9d 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -42,7 +42,10 @@
             self.add_tensor(tens)
 
     def __str__(self):
-        return "<live_range.LiveRange: '%s' start_time=%s, end_time=%s>" % (self.name, self.start_time, self.end_time)
+        return (
+            f"<live_range.LiveRange: {self.start_time}-{self.end_time}, "
+            f"size={self.size}, '{self.name}' #:{len(self.tensors)}>"
+        )
 
     __repr__ = __str__
 
@@ -142,10 +145,10 @@
 
     def get_temporal_memory_usage(self, target_mem_area):
         usage = np.zeros(self.update_endtime(), dtype=np.int32)
-        for rng in self.ranges.values():
-            if rng.mem_area == target_mem_area:
+        for lr in self.lrs:
+            if lr.mem_area == target_mem_area:
                 # End time is inclusive
-                usage[rng.start_time : rng.end_time + 1] += rng.size
+                usage[lr.start_time : lr.end_time + 1] += lr.size
 
         return usage
 
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 782e8d9..d160777 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -25,6 +25,8 @@
 from typing import Optional
 from typing import Tuple
 
+import numpy as np
+
 from . import live_range
 from . import npu_performance
 from . import tensor_allocation
@@ -899,43 +901,92 @@
             if op_info.buffered_weight_tensor:
                 op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
 
-    def use_fast_storage_for_feature_maps(self, schedule: Schedule, memory_limit: int):
-        if self.arch.fast_storage_mem_area == self.arch.feature_map_storage_mem_area:
-            return
+    def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
+        scratched_fms = {}
+        max_mem_usage = []
+        base_mem_usage = []
+        fast_storage_type = MemType.Scratch_fast
+        fast_storage_mem_area = self.arch.fast_storage_mem_area
 
         # Force all OFMs to fast-storage
         for sched_op in self.sched_ops:
             cost = schedule.cost_map[sched_op]
-            if cost.cascade == 0:
-                if sched_op.get_dependants():
-                    ofm_tens = sched_op.ofm.connection.parent_tens
-                    if not any(cons is None for cons in ofm_tens.consumer_list):
-                        ofm_tens.mem_area = self.arch.fast_storage_mem_area
-                        ofm_tens.mem_type = MemType.Scratch_fast
+            if cost.cascade == 0 and sched_op.get_dependants():
+                ofm_tens = sched_op.ofm.connection.parent_tens
+                if not any(cons is None for cons in ofm_tens.consumer_list):
+                    if ofm_tens not in scratched_fms:
+                        scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
+                    ofm_tens.mem_area = fast_storage_mem_area
+                    ofm_tens.mem_type = fast_storage_type
 
         # Collect live ranges from tensors
-        memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
+        memories_list = [(fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
         lr_graph = live_range.LiveRangeGraph()
         for mem_area, mem_type_set in memories_list:
             live_range.extract_live_ranges_from_cascaded_passes(
                 self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
             )
+        max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
 
-        # Iterate over live ranges and evict tensors that doesn't fit
-        fast_storage_snapshot = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
+        # If true, everything fits and we can proceed
+        if max(max_mem_usage) <= staging_limit:
+            return
+
+        # Build up the base memory usage by removing the
+        # mem_usage of the lrs we previously moved to fast-storage
+        base_mem_usage = np.array(max_mem_usage)
+        curr_lrs = []
         for lr in lr_graph.lrs:
-            if (
-                lr.mem_area == self.arch.fast_storage_mem_area
-                and max(fast_storage_snapshot[lr.start_time : lr.end_time + 1]) > memory_limit
-            ):
-                # Evict tensor to DRAM
-                for tens in lr.tensors:
-                    if tens.purpose == TensorPurpose.FeatureMap and tens.sub_purpose == TensorSubPurpose.Standard:
-                        # Can only evict unbuffered FeatureMaps
-                        tens.mem_area = self.arch.feature_map_storage_mem_area
-                        tens.mem_type = MemType.Scratch
-                        # Adjust the snapshot
-                        fast_storage_snapshot[lr.start_time : lr.end_time + 1] -= lr.size
+            for tens in lr.tensors:
+                if scratched_fms.get(tens):
+                    curr_lrs.append(lr)
+                    base_mem_usage[lr.start_time : lr.end_time + 1] -= lr.size
+                    break
+
+        competing_lrs = []
+        for lr in curr_lrs:
+            base_usage = max(base_mem_usage[lr.start_time : lr.end_time + 1])
+            # If true, the lr will never fit and may thus be evicted
+            if base_usage + lr.size > staging_limit:
+                FastStorageComponentAllocator.evict(lr, max_mem_usage, scratched_fms)
+                continue
+            # Since max_mem_usage is the memory usage with all FMs still in fast-storage,
+            # the memory limit cannot be exceeded if max_mem_usage does not.
+            # Thus, the affected lrs can remain in fast-storage if the following is true
+            if max(max_mem_usage[lr.start_time : lr.end_time + 1]) <= staging_limit:
+                FastStorageComponentAllocator.keep(lr, base_mem_usage, staging_limit)
+            else:
+                competing_lrs.append(lr)
+        sz = len(competing_lrs)
+        # All lrs and their tensors have been handled if sz is zero, we may thus return
+        if sz == 0:
+            return
+
+        competing_lrs = sorted(competing_lrs, key=lambda lr: (lr.start_time, lr.end_time + 1, lr.size))
+        start = 0
+        start_time = competing_lrs[0].start_time
+        end_time = competing_lrs[0].end_time
+        component_allocator = FastStorageComponentAllocator(base_mem_usage, max_mem_usage, staging_limit)
+        # Build up components and then allocate each separately
+        for i, lr in enumerate(competing_lrs):
+            if lr.start_time <= end_time and i - start < component_allocator.max_exhaustive_size:
+                start_time = min(start_time, lr.start_time)
+                end_time = max(end_time, lr.end_time)
+            else:
+                component_allocator.allocate_component(
+                    component_allocator,
+                    competing_lrs[start:i],
+                    max_mem_usage,
+                    base_mem_usage,
+                    staging_limit,
+                    scratched_fms,
+                )
+                start = i
+                start_time = lr.start_time
+                end_time = lr.end_time
+        component_allocator.allocate_component(
+            component_allocator, competing_lrs[start:sz], max_mem_usage, base_mem_usage, staging_limit, scratched_fms
+        )
 
     def move_constant_data(self):
         """Determine if  data, can be moved from permanent storage to another memory area. A move
@@ -1039,6 +1090,87 @@
         )
 
 
+class FastStorageComponentAllocator:
+    def __init__(self, base_mem_usage, max_mem_usage, staging_limit):
+        self.base_mem_usage = base_mem_usage
+        self.max_mem_usage = list(max_mem_usage)
+        self.staging_limit = staging_limit
+        self.lrs = []
+        self.evicted = []
+        self.curr_evicted = []
+        self.remaining_total_size = []
+        self.best_allocated_size = 0
+        self.max_exhaustive_size = 20
+
+    def allocate_exhaustive(self, ix, alloc_size):
+        if ix >= len(self.lrs):
+            if alloc_size > self.best_allocated_size:
+                self.best_allocated_size = alloc_size
+                self.evicted = self.curr_evicted
+            return
+
+        lr = self.lrs[ix]
+        for t in range(lr.start_time, lr.end_time):
+            assert self.base_mem_usage[t] <= self.max_mem_usage[t]
+        base_usage = max(self.base_mem_usage[lr.start_time : lr.end_time + 1])
+        can_fit = base_usage + lr.size <= self.staging_limit
+        always_fits = can_fit
+
+        if can_fit:
+            max_usage = max(self.max_mem_usage[lr.start_time : lr.end_time + 1])
+            always_fits = max_usage <= self.staging_limit
+
+        if can_fit or always_fits:
+            self.curr_evicted[ix] = False
+            self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, True)
+            self.allocate_exhaustive(ix + 1, alloc_size + lr.size)
+            self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, False)
+
+        if not always_fits:
+            self.curr_evicted[ix] = True
+            self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, False)
+            self.allocate_exhaustive(ix + 1, alloc_size)
+            self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, True)
+
+    @staticmethod
+    def update_mem_usage(mem_usage, lr, increase):
+        for t in range(lr.start_time, lr.end_time + 1):
+            mem_usage[t] += lr.size if increase else -lr.size
+            assert mem_usage[t] >= 0
+        return mem_usage
+
+    @staticmethod
+    def evict(lr, max_mem_usage, scratched_fms):
+        for t in range(lr.start_time, lr.end_time + 1):
+            max_mem_usage[t] -= lr.size
+        for tens in lr.tensors:
+            if tens in scratched_fms:
+                tens.mem_area = scratched_fms[tens][0]
+                tens.mem_type = scratched_fms[tens][1]
+
+    @staticmethod
+    def keep(lr, base_mem_usage, staging_limit):
+        for t in range(lr.start_time, lr.end_time + 1):
+            base_mem_usage[t] += lr.size
+            assert base_mem_usage[t] <= staging_limit
+
+    def allocate_component(self, allocator, lrs, max_mem, min_mem, staging_limit, scratched_fms):
+        sz = len(lrs)
+        allocator.lrs = lrs
+        allocator.evicted = [0] * len(lrs)
+        allocator.curr_evicted = [0] * sz
+        allocator.best_allocated_size = -1
+        # Recursively evaluate all permutations of allocations of the lrs found in the component
+        allocator.allocate_exhaustive(0, 0)
+
+        # Optimal allocation has been found, move lrs accordingly
+        for i, e in enumerate(allocator.evicted):
+            if e:
+                self.evict(lrs[i], max_mem, scratched_fms)
+            else:
+                self.keep(lrs[i], min_mem, staging_limit)
+
+
 def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):
     """Entry point for the Scheduler"""
     # Initialize CPU subgraphs