MLBEDSW-7147: Enable weight buffering when opt for Size

- When optimizing for Size the scheduler does not try to add weight
buffering to the schedule since this would add extra SRAM usage to
the peak usage. However, for all other ops that uses less SRAM than
the peak there is memory available that could be used for weight
buffering and hence improve the performance.

- Removed limitation to only run optimize schedule when optimizing
for Performance. Regardless of optimizing for Performance or Size the
scheduler flow is the same except that the limit for max SRAM usage is
different.

Change-Id: I6880b35655e37b4916a9c15150f0b8e5126a1cd8
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 2174a6e..4befad4 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -416,6 +416,9 @@
         self.max_schedule: Optional[Schedule] = None
         self.scheduler_options = options
 
+        # sram limit can be changed when scheduling for Size
+        self.sram_limit = options.optimization_sram_limit
+
         self.scratched_fms: Dict[Tensor, Any] = {}
         self.evicted_fms: List[live_range.LiveRange] = []
 
@@ -1045,9 +1048,7 @@
         cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
 
         # Start by adding buffering
-        buffered_sub_schedule = self.propose_schedule_buffering(
-            sub_schedule, self.scheduler_options.optimization_sram_limit
-        )
+        buffered_sub_schedule = self.propose_schedule_buffering(sub_schedule, self.sram_limit)
         # Copy the cascades over from the unbuffered-schedule
         buffered_sub_schedule.cascades = sub_schedule.cascades
 
@@ -1095,12 +1096,10 @@
         schedule: Schedule,
         max_sched: Schedule,
         max_template: Schedule,
-        options: SchedulerOptions,
     ) -> Schedule:
         """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule"""
-        verbose_progress = options.verbose_progress
-        sram_limit = options.optimization_sram_limit
-        if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled():
+        verbose_progress = self.scheduler_options.verbose_progress
+        if max_sched.fast_storage_peak_usage < self.sram_limit and not self.arch.is_spilling_enabled():
             # Maximum performance schedule fits within the SRAM target
             return max_sched
 
@@ -1109,7 +1108,7 @@
         for index, cascade_info in enumerate(cascades):
             progress_print(verbose_progress, "Processing cascade", index, cascades)
             # Optimize the sub-schedule in this cascade
-            opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
+            opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, self.sram_limit)
             if opt_sub_schedule:
                 # Remove the existing cascade
                 del schedule.cascades[cascade_info.end]
@@ -1121,7 +1120,7 @@
         self.sg.schedule = schedule
         self.update_op_memory_snapshot(schedule)
         # Propose schedule buffering to the optimized schedule
-        optimized_sched = self.propose_schedule_buffering(schedule, self.scheduler_options.optimization_sram_limit)
+        optimized_sched = self.propose_schedule_buffering(schedule, self.sram_limit)
         # Copy the cascade's metadata from the unbuffered schedule
         optimized_sched.cascades = schedule.cascades
         return optimized_sched
@@ -1129,9 +1128,8 @@
     def optimize_weight_buffering_size(
         self,
         min_schedule: Schedule,
-        options: SchedulerOptions,
     ):
-        verbose_progress = options.verbose_progress
+        verbose_progress = self.scheduler_options.verbose_progress
         default_schedule = self.sg.schedule
         npu_performance.calc_new_performance_for_network(self.nng, self.arch, None, False)
         default_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
@@ -1181,14 +1179,14 @@
         self.update_op_memory_snapshot(self.sg.schedule)
 
         # Run schedule buffering - with weight buffer size reduction
-        schedule = self.propose_schedule_buffering(self.sg.schedule, options.optimization_sram_limit)
+        schedule = self.propose_schedule_buffering(self.sg.schedule, self.sram_limit)
         schedule.cascades = self.sg.schedule.cascades
         self.sg.schedule = schedule
 
         # Apply new buffer schdule and calc new performance
         self.update_op_memory_snapshot(self.sg.schedule)
         self.apply_schedule(self.sg.schedule)
-        self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
+        self.use_fast_storage_for_feature_maps(self.sg.schedule, self.sram_limit)
 
         npu_performance.calc_new_performance_for_network(self.nng, self.arch, None, False)
         new_tot_cycles = self.nng.cycles[npu_performance.PassCycles.Total]
@@ -1214,7 +1212,7 @@
             self.sg.schedule = default_schedule
             self.update_op_memory_snapshot(self.sg.schedule)
             self.apply_schedule(self.sg.schedule)
-            self.use_fast_storage_for_feature_maps(self.sg.schedule, options.optimization_sram_limit)
+            self.use_fast_storage_for_feature_maps(self.sg.schedule, self.sram_limit)
 
     def apply_schedule(self, sched: Schedule):
         """Applies the given schedule as a final solution"""
@@ -1628,7 +1626,7 @@
             progress_print(verbose_progress, "Creating minimal schedule")
             # Create Min schedule
             min_schedule = scheduler.propose_minimal_schedule()
-            initial_sram_limit = scheduler_options.optimization_sram_limit
+            initial_sram_limit = scheduler.sram_limit
             if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
                 initial_sram_limit = scheduler.min_memory_req
 
@@ -1638,22 +1636,25 @@
             sg.schedule = min_schedule
             scheduler.update_op_memory_snapshot(min_schedule)
 
-            if scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
+            if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
                 progress_print(verbose_progress, "Creating schedule optimized for performance")
-                # Create an optimized schedule
-                sg.schedule = scheduler.optimize_schedule(
-                    min_schedule, opt_max_schedule, max_schedule_template, scheduler_options
-                )
-                scheduler.update_op_memory_snapshot(sg.schedule)
+                # Update sram limit to peak usage from the minimum scheduler when optimizing for Size.
+                # Then optimize schedule can be called for both OptimizationStrategy Performance and Size
+                # as long the max sram usage is <= scheduler.sram_limit
+                scheduler.sram_limit = min_schedule.fast_storage_peak_usage
+
+            # Create an optimized schedule
+            sg.schedule = scheduler.optimize_schedule(min_schedule, opt_max_schedule, max_schedule_template)
+            scheduler.update_op_memory_snapshot(sg.schedule)
 
             scheduler.apply_schedule(sg.schedule)
-            scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
+            scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler.sram_limit)
 
             if scheduler_options.optimization_strategy == OptimizationStrategy.Performance and scheduler.evicted_fms:
                 progress_print(verbose_progress, "Optimizing weight buffering size")
                 # It might be possible to gain performance by reducing
                 # weight buffer size and instead fit fms in fast storage
-                scheduler.optimize_weight_buffering_size(min_schedule, scheduler_options)
+                scheduler.optimize_weight_buffering_size(min_schedule)
 
             if scheduler_options.verbose_schedule:
                 scheduler.print_schedule(sg.schedule)