vela: Improve block configuration and weight buffering algorithm

 - Update block config selection to take into account partial
   IFM fetches at edge of non-whole OFM block data.
 - Change to scheduler depth slicing for networks in MLBEDSW-4637
   for improved buffering. This helps general performance by buffering
   larger depth slices.
 - Bug fix for opt_max_schedule always being fitted to SRAM which
   prevented the optimisation step running in some cases.

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I97642c5adec3bb684b1daabf2b81574c27d4eef2
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index dfb8867..de2189b 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -459,10 +459,9 @@
 
         return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query)
 
-    def propose_schedule_buffering(self, ref_schedule: Schedule):
+    def propose_schedule_buffering(self, ref_schedule: Schedule, staging_limit_bytes):
         """Create a buffered schedule"""
         buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED")
-        staging_limit_bytes = self.scheduler_options.optimization_sram_limit
 
         prev_op = None
         for sched_op in self.sched_ops:
@@ -588,24 +587,35 @@
                 prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit)
             else:
                 prebuffer_bytes = min(full_weights_bytes, half_buffer_limit)
-                prebuffer_ratio = prebuffer_bytes / full_weights_bytes
+
+            prebuffer_ratio = prebuffer_bytes / full_weights_bytes
 
             # Have to split the weights if the initial buffering can't store
             # all of the compressed weights
             if prebuffer_bytes < full_weights_bytes:
-                prebuffer_depth = int(ref_cost.stripe.depth * prebuffer_ratio)
+                block_depth = cost.block_config.ofm_block.depth
 
-                # Round prebuffering down to nearest valid split depth
+                # Choose initial prebuffering depth (already buffer clamped)
+                prebuffer_depth = ref_cost.stripe.depth * prebuffer_ratio
                 prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
 
-                while True:
-                    buffering_depth = max(cost.block_config.ofm_block.depth, prebuffer_depth)
+                # Calculate cycles executed during the prebuffer
+                pre_op_cycles = self.estimate_op_performance(sched_op, cost.block_config, prebuffer_depth)
+                buffering_depth = ref_cost.stripe.depth * (pre_op_cycles.op_cycles / full_transfer_cycles)
 
-                    # Clamp buffering to the double buffering limit
-                    buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
-                    if buffering_bytes > half_buffer_limit:
-                        buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
-                        buffering_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
+                # Choose initial buffering depth and clamp to the double buffering limit
+                buffering_depth = round_up(buffering_depth, block_depth)
+                buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
+                if buffering_bytes > half_buffer_limit:
+                    buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
+
+                while True:
+                    # Attempt to buffer whole blocks
+                    if buffering_bytes > block_depth:
+                        buffering_depth = round_down(buffering_depth, block_depth)
+                    else:
+                        buffering_depth = round_down(buffering_depth, ArchitectureFeatures.OFMSplitDepth)
+                    buffering_depth = int(max(buffering_depth, ArchitectureFeatures.OFMSplitDepth))
 
                     # Create list of depth slices
                     depth_slices = [0]
@@ -633,7 +643,10 @@
                     ):
                         break
 
-                    prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
+                    if buffering_depth > prebuffer_depth:
+                        buffering_depth = round_up(buffering_depth // 2, ArchitectureFeatures.OFMSplitDepth)
+                    else:
+                        prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
 
                 # Calculate cycles required to run the last op for use as future slack
                 tail_cycles = self.estimate_op_performance(
@@ -790,7 +803,9 @@
         cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
 
         # Start by adding buffering
-        buffered_sub_schedule = self.propose_schedule_buffering(sub_schedule)
+        buffered_sub_schedule = self.propose_schedule_buffering(
+            sub_schedule, self.scheduler_options.optimization_sram_limit
+        )
         # Copy the cascades over from the unbuffered-schedule
         buffered_sub_schedule.cascades = sub_schedule.cascades
 
@@ -852,7 +867,7 @@
         self.sg.schedule = schedule
         self.update_op_memory_snapshot(schedule)
         # Propose schedule buffering to the optimized schedule
-        optimized_sched = self.propose_schedule_buffering(schedule)
+        optimized_sched = self.propose_schedule_buffering(schedule, self.scheduler_options.optimization_sram_limit)
         # Copy the cascade's metadata from the unbuffered schedule
         optimized_sched.cascades = schedule.cascades
         return optimized_sched
@@ -1047,7 +1062,7 @@
             # Create the optimimised Max schedule
             sg.schedule = max_schedule_template
             scheduler.update_op_memory_snapshot(max_schedule_template)
-            opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template)
+            opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template, 1 << 32)
             sg.schedule = opt_max_schedule
             scheduler.update_op_memory_snapshot(opt_max_schedule)