MLBEDSW-7430: Remove non local mem usage from cascade info

- There is a latent bug when calculating the mem usage parallel to the
sub schedule. The error is the calculation done when optimizing the sub
schedules. There the cascade size is withdrawn from the snapshot usage
to decide non local memory usage. The problem is that the cascade mem
usage actually also includes non local memory so the end result will be
zero. This is normally not a problem but it will be when starting to
optimize sub schedule when optimizing for Size.

- The solution is to not include the non local usage in the cascade
info, the scheduler already have this information.

- Corrected usage of persistent initial IFM. This size should not be
included for Dedicated SRAM since only intermediate buffers are in SRAM.

- Added some comment to clarify the code in the cascade builder.

Change-Id: I473b36e0d69550ab6565f4ef028195636b362997
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 0e651b9..95872cf 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -188,9 +188,6 @@
             # The first IFM needs to be stored in full
             cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
 
-            # Add non-local memory usage
-            cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
-
             # Sum of all intermediate cascade buffers (including weight buffers)
             cascade_buffers = weight_buffer
             # Best cascade size - Initially it's the fallback cost of the first Op in the cascade
@@ -248,8 +245,10 @@
                         best_cascade_size = cascade_buffers
 
                 else:
-                    # Calculate the total size of the current cascade
-                    cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+                    # Calculate the total size of the current cascade including non local mem usage
+                    cascade_size = (
+                        cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0)
+                    )
 
                     # Determine if cascading search should stop
                     if (
@@ -257,7 +256,8 @@
                         and best_cascade_size < peak_sram_usage
                         or (cascade_ifm_size + cascade_buffers) > best_cascade_size
                     ):
-                        # Both the existing cascade and current Op fits
+                        # Both the existing cascade and current Op fits or
+                        # not possible to reduce cascade size any further
                         break
 
                     """
@@ -306,7 +306,7 @@
                     hence, better to choose Cascade OP1-OP3 in this case.
                     """
                     if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
-                        best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+                        best_cascade_size = cascade_size
                         ops_in_best_cascade = [op for op in ops_in_cascade]
 
                 producer = current_op
@@ -326,9 +326,15 @@
 
                     prev_op = cascaded_op
 
-                # Create a CascadeInfo for the cascade
+                # Create a CascadeInfo for the cascade, only store the actual size used by
+                # the cascade so non local usage is removed. This is done in order to be
+                # able to calculate the correct non local usage in the scheduler when
+                # optimizing the sub schedules.
                 cascade_map[cascade_end] = CascadeInfo(
-                    cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
+                    cascade_start,
+                    cascade_end,
+                    buffers_in_cascade,
+                    best_cascade_size - self.non_local_mem_usage.get(op, 0),
                 )
                 if not self.spilling:
                     # Update peak memory usage
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 16531c2..83e19bc 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -952,8 +952,7 @@
             if cost[sched_op].cascade:
                 # This Op is part of a cascade - use the cascade's memory usage
                 cascade_info = cascades[cost[sched_op].cascade]
-                # Non-local memory usage is already included in the cascade_info
-                peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
+                op_mem_usage = cascade_info.mem_usage + non_local_mem_usage.get(sched_op, 0)
             else:
                 # This Op is not part of a cascade - calculate the memory usage
                 op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors)
@@ -964,7 +963,7 @@
                     + op_weight_buffer
                     + non_local_mem_usage.get(sched_op, 0)
                 )
-                peak_mem_usage = max(op_mem_usage, peak_mem_usage)
+            peak_mem_usage = max(op_mem_usage, peak_mem_usage)
 
         return peak_mem_usage
 
@@ -1021,9 +1020,11 @@
         time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
         mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
         # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
-        # included in a cascade or not
+        # included in a cascade or not. Not valid in Dedicated SRAM mode (spilling enabled).
         persistent_initial_ifm = (
-            sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0
+            sub_schedule_ops[0].ifm_size_in_bytes()
+            if not self.arch.is_spilling_enabled() and len(sub_schedule_ops[0].ifm.connection.consumers) > 1
+            else 0
         )
         # Calculate non-local-mem-usage per Operator
         non_local_mem_usage = {}