MLBEDSW-6296: Regression caused by bigger weight buffering size

- Due to that bigger weight buffer sizes are being used, there are use cases
when feature maps are evicted from SRAM, causing the total performance to drop.
- A way to improve this is to limit the memory for those weight buffer ops,
to get the feature maps back to SRAM, and see if total performance is improved.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: Ibfaff330677185186af9f6362dfbe04824a329f6
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index e73a26d..3cfde28 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -24,6 +24,7 @@
 from collections import namedtuple
 from enum import auto
 from enum import IntEnum
+from typing import Any
 from typing import Dict
 from typing import List
 from typing import Optional
@@ -216,6 +217,8 @@
         self.requires_full_ifm2 = False
         self.requires_full_ofm = False
 
+        self.evicted_fms_size = 0
+
         self.index = 0
 
     def add_ifm_connection(self, conn: "Connection"):
@@ -374,6 +377,9 @@
         self.max_schedule: Optional[Schedule] = None
         self.scheduler_options = options
 
+        self.scratched_fms: Dict[Tensor, Any] = {}
+        self.evicted_fms: List[live_range.LiveRange] = []
+
     def avoid_nhcwb16_for_ofm(self, tens, ps, arch):
         # Only run this check for opt strategy Size
         if self.scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
@@ -579,6 +585,14 @@
 
         # Attempt weight buffering on anything with a weights tensor
         if sched_op.parent_op.weights:
+            buffer_limit_bytes = cost.slack_buffering_memory
+
+            # If applicable apply size limitation, but keep it within reason (ratio 1.5).
+            # Size limitation is used when use_fast_storage_for_feature_maps have
+            # detected that there are fms that do not fit in fast storage.
+            if sched_op.evicted_fms_size and ((buffer_limit_bytes / sched_op.evicted_fms_size) >= 1.5):
+                buffer_limit_bytes -= sched_op.evicted_fms_size
+
             self.propose_weight_buffering(
                 sched_op.parent_op.weights,
                 sched_op.parent_op.bias,
@@ -586,7 +600,7 @@
                 prev_op,
                 buffered_schedule,
                 ref_schedule,
-                cost.slack_buffering_memory,
+                buffer_limit_bytes,
             )
 
         return cost
@@ -966,6 +980,97 @@
         optimized_sched.cascades = schedule.cascades
         return optimized_sched
 
+    def optimize_weight_buffering_size(
+        self,
+        min_schedule: Schedule,
+        options: SchedulerOptions,
+    ):
+        default_schedule = self.sg.schedule
+        npu_performance.calc_new_performance_for_network(self.nng, self.arch)
+        default_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
+        default_dram_cycles = self.nng.cycles[npu_performance.PassCycles.DramAccess]
+
+        # Restore mem/type for scratched_fms
+        for tens in self.scratched_fms:
+            tens.mem_area = self.scratched_fms[tens][0]
+            tens.mem_type = self.scratched_fms[tens][1]
+
+        self.update_op_memory_snapshot(self.sg.schedule)
+
+        # Collect live ranges from tensors
+        memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
+        lr_graph = live_range.LiveRangeGraph()
+        for mem_area, mem_type_set in memories_list:
+            live_range.extract_live_ranges_from_cascaded_passes(
+                self.nng.get_root_subgraph(),
+                mem_area,
+                mem_type_set,
+                lr_graph,
+                Tensor.AllocationQuantum,
+            )
+
+        # Find the relation between the sched_op and the buffering tensor
+        weight_ops = {}
+        for sched_op in self.sched_ops:
+            cost = self.sg.schedule.cost_map[sched_op]
+            if cost.buffered_weight_tensor:
+                weight_ops[cost.buffered_weight_tensor] = sched_op
+
+        # Filter out weight buffer live ranges
+        weight_lrs = []
+        for lr in lr_graph.lrs:
+            for tens in lr.tensors:
+                if weight_ops.get(tens):
+                    weight_lrs.append(lr)
+                    break
+
+        # See if any evicted fm overlaps with a weight buffering op.
+        # If this is the case add a size limitation to the buffering op
+        for lr in self.evicted_fms:
+            for weight_lr in weight_lrs:
+                if lr.overlaps_ranges(weight_lr):
+                    for tens in weight_lr.tensors:
+                        sched_op = weight_ops.get(tens)
+                        if sched_op:
+                            # Add size reduction to the op
+                            sched_op.evicted_fms_size += lr.size
+                            break
+
+        self.sg.schedule = min_schedule
+        self.update_op_memory_snapshot(self.sg.schedule)
+
+        # Run schedule buffering - with weight buffer size reduction
+        schedule = self.propose_schedule_buffering(self.sg.schedule, options.optimization_sram_limit)
+        schedule.cascades = self.sg.schedule.cascades
+        self.sg.schedule = schedule
+
+        # Apply new buffer schdule and calc new performance
+        self.update_op_memory_snapshot(self.sg.schedule)
+        self.apply_schedule(self.sg.schedule)
+        self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
+
+        npu_performance.calc_new_performance_for_network(self.nng, self.arch)
+        new_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
+        new_dram_cycles = self.nng.cycles[npu_performance.PassCycles.DramAccess]
+
+        improvement_tot = round((default_tot_cycles - new_tot_cycles) / default_tot_cycles, 2)
+        improvement_dram = round((default_dram_cycles - new_dram_cycles) / default_dram_cycles, 2)
+
+        # Compare both total and dram improvement
+        if not (improvement_tot > 0 and improvement_dram > 0):
+            # No improvement, restore the default schedule
+            for sched_op in self.sched_ops:
+                sched_op.evicted_fms_size = 0
+
+            for tens in self.scratched_fms:
+                tens.mem_area = self.scratched_fms[tens][0]
+                tens.mem_type = self.scratched_fms[tens][1]
+
+            self.sg.schedule = default_schedule
+            self.update_op_memory_snapshot(self.sg.schedule)
+            self.apply_schedule(self.sg.schedule)
+            self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
+
     def apply_schedule(self, sched: Schedule):
         """Applies the given schedule as a final solution"""
         for sched_op in self.sched_ops:
@@ -987,11 +1092,11 @@
                 op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
 
     def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
-        scratched_fms = {}
         max_mem_usage = []
         base_mem_usage = []
         fast_storage_type = MemType.Scratch_fast
         fast_storage_mem_area = self.arch.fast_storage_mem_area
+        self.evicted_fms = []
 
         # Force all OFMs to fast-storage
         for sched_op in self.sched_ops:
@@ -999,8 +1104,10 @@
             if cost.cascade == 0 and sched_op.get_dependants():
                 ofm_tens = sched_op.ofm.connection.parent_tens
                 if not any(cons is None for cons in ofm_tens.consumer_list):
-                    if ofm_tens not in scratched_fms:
-                        scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
+                    if ofm_tens not in self.scratched_fms:
+                        # Remember default mem area and mem type, only done once
+                        self.scratched_fms[ofm_tens] = (ofm_tens.mem_area, ofm_tens.mem_type)
+
                     ofm_tens.mem_area = fast_storage_mem_area
                     ofm_tens.mem_type = fast_storage_type
 
@@ -1027,17 +1134,17 @@
         curr_lrs = []
         for lr in lr_graph.lrs:
             for tens in lr.tensors:
-                if scratched_fms.get(tens):
+                if self.scratched_fms.get(tens):
                     curr_lrs.append(lr)
                     base_mem_usage[lr.start_time : lr.end_time + 1] -= lr.size
                     break
-
         competing_lrs = []
         for lr in curr_lrs:
             base_usage = max(base_mem_usage[lr.start_time : lr.end_time + 1])
             # If true, the lr will never fit and may thus be evicted
             if base_usage + lr.size > staging_limit:
-                FastStorageComponentAllocator.evict(lr, max_mem_usage, scratched_fms)
+                self.evicted_fms.append(lr)
+                FastStorageComponentAllocator.evict(lr, max_mem_usage, self.scratched_fms)
                 continue
             # Since max_mem_usage is the memory usage with all FMs still in fast-storage,
             # the memory limit cannot be exceeded if max_mem_usage does not.
@@ -1068,13 +1175,18 @@
                     max_mem_usage,
                     base_mem_usage,
                     staging_limit,
-                    scratched_fms,
+                    self.scratched_fms,
                 )
                 start = i
                 start_time = lr.start_time
                 end_time = lr.end_time
         component_allocator.allocate_component(
-            component_allocator, competing_lrs[start:sz], max_mem_usage, base_mem_usage, staging_limit, scratched_fms
+            component_allocator,
+            competing_lrs[start:sz],
+            max_mem_usage,
+            base_mem_usage,
+            staging_limit,
+            self.scratched_fms,
         )
 
     def move_constant_data(self):
@@ -1327,6 +1439,11 @@
             scheduler.apply_schedule(sg.schedule)
             scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
 
+            if scheduler_options.optimization_strategy == OptimizationStrategy.Performance and scheduler.evicted_fms:
+                # It might be possible to gain performance by reducing
+                # weight buffer size and instead fit fms in fast storage
+                scheduler.optimize_weight_buffering_size(min_schedule, scheduler_options)
+
             if scheduler_options.verbose_schedule:
                 scheduler.print_schedule(sg.schedule)