MLBEDSW-7062: Clean up and and add comments to scheduler

Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I026facce572ddce4249e05529f2bb1d285552ab9
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index dd66ec8..4e0599e 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -831,6 +831,8 @@
                 )
                 cost.buffered_weight_tensors.append(buf2)
 
+            # Note! OFM depth slices define slices as [0, s1, ... sn]. For example, [0, 70, 140] describes two slices
+            # (0-70 and 70-140) but has a length of 3, which would result in idx = 3 % 2 = 1 if two buffers were used.
             last_used_buffer_idx = len(cost.ofm_depth_slices) % len(cost.buffered_weight_tensors)
             weight_buffer_size = encoded_weights.double_buffer_sizes[last_used_buffer_idx]
 
@@ -986,7 +988,7 @@
             non_local_mem_usage[sched_op] = min_schedule.memory_snapshot[time_index] - op_mem_usage
             assert non_local_mem_usage[sched_op] >= 0
 
-        # Crate cascades for Min schedule
+        # Create cascades for Min schedule
         cascade_builder = CascadeBuilder(self.sched_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
         cascade_builder.build_cascades(min_schedule, max_template, memory_limit)
 
@@ -1053,17 +1055,14 @@
             )
 
             cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit)
-
-            # Check if proposal fits
-            proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
-
             nbr_of_cascades = len(proposed_schedule.cascades)
-
             if iteration == 0:
                 # First iteration - used as limit to prevent splitting up the cascades
                 # Long cascades are better in order to reduce IFM/IFM dram bandwidth
                 max_nbr_of_cascades = nbr_of_cascades
 
+            # Check if proposal fits
+            proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
             if (proposed_schedule_mem_usage) <= memory_limit and nbr_of_cascades <= max_nbr_of_cascades:
                 best_schedule = proposed_schedule
 
@@ -1224,6 +1223,8 @@
                 tens.src_tensor = op_info.npu_weights_tensor
 
     def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
+        """Finds the set of feature maps that fits within the staging limit which combined has the largest amount of
+        access cycles and moves those feature map into fast storage"""
         max_mem_usage = []
         base_mem_usage = []
         fast_storage_type = MemType.Scratch_fast
@@ -1256,12 +1257,11 @@
             )
         max_mem_usage = lr_graph.get_temporal_memory_usage(fast_storage_mem_area)
 
-        # If true, everything fits and we can proceed
+        # If max_mem_usage does not exceed staging limit at any point all lrs fit and can stay in fast storage
         if max(max_mem_usage) <= staging_limit:
             return
 
-        # Build up the base memory usage by removing the
-        # mem_usage of the lrs we previously moved to fast-storage
+        # Build up the base memory usage by removing the mem_usage of the lrs we previously moved to fast-storage
         base_mem_usage = np.array(max_mem_usage)
         curr_lrs = []
         for lr in lr_graph.lrs:
@@ -1272,26 +1272,30 @@
                     break
         competing_lrs = []
         competing_tens_access = {}
-        for lr in curr_lrs:
+
+        # Evict live ranges that will never fit
+        for lr in curr_lrs.copy():
             base_usage = max(base_mem_usage[lr.start_time : lr.end_time + 1])
-            # If true, the lr will never fit and may thus be evicted
             if base_usage + lr.size > staging_limit:
+                # Lr will never fit and may thus be evicted
                 self.evicted_fms.append(lr)
                 FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
-                continue
+                curr_lrs.remove(lr)
+
+        # Keep live ranges that will always fit in fast storage and let the remaining ones compete
+        for lr in curr_lrs:
             # Since max_mem_usage is the memory usage with all FMs still in fast-storage,
             # the memory limit cannot be exceeded if max_mem_usage does not.
             # Thus, the affected lrs can remain in fast-storage if the following is true
             if max(max_mem_usage[lr.start_time : lr.end_time + 1]) <= staging_limit:
-                FastStorageComponentAllocator.keep(lr, base_mem_usage, staging_limit)
+                FastStorageComponentAllocator.keep(lr, base_mem_usage)
             else:
                 competing_lrs.append(lr)
                 for tens in lr.tensors:
                     competing_tens_access[tens] = 0
 
-        competing_lrs_sz = len(competing_lrs)
         # All lrs and their tensors have been handled if competing_lrs_sz is zero, we may thus return
-        if competing_lrs_sz == 0:
+        if len(competing_lrs) == 0:
             return
 
         # Estimate element access for all tensors that are competing for a place in fast-storage.
@@ -1314,6 +1318,7 @@
                 access = self.estimate_element_access(sched_op, cost.block_config, sched_op.ofm.shape.depth)
                 competing_tens_access[tens] += access.ofm_write
 
+        # Sort live ranges "from left to right" on the time axis to simplify checking overlapping ranges
         competing_lrs = sorted(competing_lrs, key=lambda lr: (lr.start_time, lr.end_time + 1, lr.size))
 
         # Remove lrs that have a live range that is too long compared to others.
@@ -1327,33 +1332,32 @@
         # Too long is currently decided to be (based on experience, analyzing many networks):
         # Compare lr at postion i with lr at position i + MAX_EXHAUSTIVE_ITEMS.
         # If end time differs by at least MAX_EXHAUSTIVE_LIFE_RANGE then do not include lr at position i.
-        if competing_lrs_sz > FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS:
-            # create a copy of the original list to iterate over because the original version is modified in-loop
+        if len(competing_lrs) > FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS:
+            # Create a copy of the original list to iterate over because the original version is modified in-loop
             competing_lrs_copy = competing_lrs.copy()
             for i, lr in enumerate(competing_lrs_copy):
                 lr_time = lr.end_time - lr.start_time
-                if lr_time < FastStorageComponentAllocator.MAX_EXHAUSTIVE_LIFE_RANGE:
-                    # Skip small ranges
-                    continue
+                # Only check live ranges longer than MAX_EXHAUSTIVE_LIFE_RANGE
+                if lr_time >= FastStorageComponentAllocator.MAX_EXHAUSTIVE_LIFE_RANGE:
+                    # Compare current lr with lr at position lr + MAX_EXHAUSTIVE_ITEMS
+                    cmp_pos = min(i + FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS, len(competing_lrs) - 1)
 
-                # Compare current lr with lr at position lr + MAX_EXHAUSTIVE_ITEMS
-                cmp_pos = min(i + FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS, competing_lrs_sz - 1)
+                    # Compare end times + plus a margin by MAX_EXHAUSTIVE_LIFE_RANGE
+                    if (
+                        lr.end_time
+                        > competing_lrs_copy[cmp_pos].end_time + FastStorageComponentAllocator.MAX_EXHAUSTIVE_LIFE_RANGE
+                    ):
+                        # Current lr live time stands out, remove it. No use adding it to the
+                        # evicted_fms list since the lr should not be included in the fast storage allocation
+                        FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
+                        competing_lrs.remove(lr)
 
-                # Compare end times + plus a margin by MAX_EXHAUSTIVE_LIFE_RANGE
-                if (
-                    lr.end_time
-                    > competing_lrs_copy[cmp_pos].end_time + FastStorageComponentAllocator.MAX_EXHAUSTIVE_LIFE_RANGE
-                ):
-                    # Current lr live time stands out, remove it. No use adding it to the
-                    # evicted_fms list since the lr should not be included in the fast storage allocation
-                    FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
-                    competing_lrs.remove(lr)
-
+        # Split competing live ranges into components by finding disconnected groups of live ranges or components of
+        # max size MAX_EXHAUSTIVE_ITEMS
         start = 0
         end_time = competing_lrs[0].end_time
-        competing_lrs_sz = len(competing_lrs)
         component_allocator = FastStorageComponentAllocator(base_mem_usage, max_mem_usage, staging_limit)
-        # Build up components and then allocate each separately
+        component_ranges = []
         for i, lr in enumerate(competing_lrs):
             nbr_items = i - start
             if lr.start_time <= end_time and (nbr_items < FastStorageComponentAllocator.MAX_EXHAUSTIVE_ITEMS):
@@ -1361,32 +1365,26 @@
             else:
                 # Number items reached max items or current lr's start time
                 # does not overlap with previous lr's end time
-                component_allocator.allocate_component(
-                    component_allocator,
-                    competing_lrs[start:i],
-                    max_mem_usage,
-                    base_mem_usage,
-                    staging_limit,
-                    self.scratched_fms,
-                    competing_tens_access,
-                    self.evicted_fms,
-                )
+                component_ranges.append((start, i))
                 start = i
                 end_time = lr.end_time
-        component_allocator.allocate_component(
-            component_allocator,
-            competing_lrs[start:competing_lrs_sz],
-            max_mem_usage,
-            base_mem_usage,
-            staging_limit,
-            self.scratched_fms,
-            competing_tens_access,
-            self.evicted_fms,
-        )
+        component_ranges.append((start, len(competing_lrs)))
+
+        # Allocate each component separately
+        for start, end in component_ranges:
+            component_allocator.allocate_component(
+                competing_lrs[start:end],
+                max_mem_usage,
+                base_mem_usage,
+                self.scratched_fms,
+                competing_tens_access,
+                self.evicted_fms,
+            )
+        assert max(max_mem_usage) <= staging_limit, "Allocation exceeds staging limit"
 
     def move_constant_data(self):
-        """Determine if  data, can be moved from permanent storage to another memory area. A move
-        will generate a DMA command in the high-level command stream"""
+        """Determine if data can be moved from permanent storage to another memory area. A move will generate a DMA
+        command in the high-level command stream"""
         for sched_op in self.sched_ops:
             parent_op = sched_op.parent_op
             is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in parent_op.inputs)
@@ -1493,7 +1491,7 @@
 
     def __init__(self, base_mem_usage, max_mem_usage, staging_limit):
         self.base_mem_usage = base_mem_usage
-        self.max_mem_usage = list(max_mem_usage)
+        self.max_mem_usage = max_mem_usage
         self.staging_limit = staging_limit
         self.lrs = []
         self.evicted = []
@@ -1504,32 +1502,29 @@
 
     def allocate_exhaustive(self, ix, score):
         # Favour tensors with highest element access (score)
-        if ix >= len(self.lrs):
+        if ix >= self.num_lrs:
             if score > self.best_score:
                 self.best_score = score
                 self.evicted = self.curr_evicted.copy()
             return
 
         lr = self.lrs[ix]
-        for t in range(lr.start_time, lr.end_time):
-            assert self.base_mem_usage[t] <= self.max_mem_usage[t]
-        base_usage = max(self.base_mem_usage[lr.start_time : lr.end_time + 1])
-        can_fit = base_usage + lr.size <= self.staging_limit
-        always_fits = can_fit
 
+        # If adding the tensor size to the base mem usage doesn't exceed the staging limit anywhere on the lr time
+        # range, it can fit and the case where the tensor is included needs to be checked
+        can_fit = max(self.base_mem_usage[lr.start_time : lr.end_time + 1]) + lr.size <= self.staging_limit
         if can_fit:
-            max_usage = max(self.max_mem_usage[lr.start_time : lr.end_time + 1])
-            always_fits = max_usage <= self.staging_limit
-
-        if can_fit or always_fits:
+            # Tensor can fit, add tensor element access to the score and check case where tensor is included
             self.curr_evicted[ix] = False
             self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, True)
-            tens = lr.tensors[0]
-            # Tensor is being included - add tensor element access to the score
-            self.allocate_exhaustive(ix + 1, score + self.competing_tens_access[tens])
+            self.allocate_exhaustive(ix + 1, score + self.competing_tens_access[lr.tensors[0]])
             self.base_mem_usage = self.update_mem_usage(self.base_mem_usage, lr, False)
 
+        # If the max mem usage doesn't exceed the staging limit anywhere on the lr time range, it always fits and the
+        # case where the tensor is not included can be skipped
+        always_fits = max(self.max_mem_usage[lr.start_time : lr.end_time + 1]) <= self.staging_limit
         if not always_fits:
+            # Tensor doesn't always fit, check case when tensor is not included
             self.curr_evicted[ix] = True
             self.max_mem_usage = self.update_mem_usage(self.max_mem_usage, lr, False)
             self.allocate_exhaustive(ix + 1, score)
@@ -1537,9 +1532,9 @@
 
     @staticmethod
     def update_mem_usage(mem_usage, lr, increase):
+        size = lr.size if increase else -lr.size
         for t in range(lr.start_time, lr.end_time + 1):
-            mem_usage[t] += lr.size if increase else -lr.size
-            assert mem_usage[t] >= 0
+            mem_usage[t] += size
         return mem_usage
 
     @staticmethod
@@ -1552,38 +1547,34 @@
                 tens.mem_type = scratched_fms[tens][1]
 
     @staticmethod
-    def keep(lr, base_mem_usage, staging_limit):
+    def keep(lr, base_mem_usage):
         for t in range(lr.start_time, lr.end_time + 1):
             base_mem_usage[t] += lr.size
-            assert base_mem_usage[t] <= staging_limit
 
-    def allocate_component(
-        self, allocator, lrs, max_mem, min_mem, staging_limit, scratched_fms, competing_tens_access, evicted_fms
-    ):
-        sz = len(lrs)
-        allocator.lrs = lrs
-        allocator.evicted = [0] * len(lrs)
-        allocator.curr_evicted = [0] * sz
-        allocator.best_score = -1
-        allocator.competing_tens_access = competing_tens_access
+    def allocate_component(self, lrs, max_mem, min_mem, scratched_fms, competing_tens_access, evicted_fms):
+        self.lrs = lrs
+        self.num_lrs = len(lrs)
+        self.evicted = [0] * self.num_lrs
+        self.curr_evicted = [0] * self.num_lrs
+        self.best_score = -1
+        self.competing_tens_access = competing_tens_access
         # Recursively evaluate all permutations of allocations of the lrs found in the component.
         # For every permutation that fits within the staging_limit there is a score calculated.
         # The permutation with the highest score will then be chosen. The score is calculated
         # as the sum of the actual element access (ifm read and ofm write) for all the
         # including tensors. So it is not necessary the tensor with the biggest size that ends up
         # being included in the result.
-        allocator.allocate_exhaustive(0, 0)
-
+        self.allocate_exhaustive(0, 0)
         # Optimal allocation has been found, move lrs accordingly
-        for i, e in enumerate(allocator.evicted):
-            if e:
-                self.evict(lrs[i], max_mem, scratched_fms)
-                if lrs[i] not in evicted_fms:
-                    evicted_fms.append(lrs[i])
+        for i, lr in enumerate(self.lrs):
+            if self.evicted[i]:
+                self.evict(lr, max_mem, scratched_fms)
+                if lr not in evicted_fms:
+                    evicted_fms.append(lr)
             else:
-                self.keep(lrs[i], min_mem, staging_limit)
-                if lrs[i] in evicted_fms:
-                    evicted_fms.remove(lrs[i])
+                self.keep(lr, min_mem)
+                if lr in evicted_fms:
+                    evicted_fms.remove(lr)
 
 
 def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):