MLBEDSW-2816: Fix assert in scheduler

  - Use non local memory as the base sram usage for a subgraph
  - Make avoid_for_spilling more generic for all mem configs

Change-Id: I99cd30fe6a8ba075d5a70dc2138aa0635afaadb3
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 47f8a47..24453d8 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -254,7 +254,11 @@
         self.pareto_max_candidates = 16
 
         self.ifm_stream_npu_blocks = set(
-            (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
+            (
+                NpuBlockType.ConvolutionMxN,
+                NpuBlockType.ConvolutionDepthWise,
+                NpuBlockType.Pooling,
+            )
         )
 
     num_pareto_metrics = 4
@@ -519,7 +523,7 @@
         if self.verbose_pareto_frontier_schedules:
             print(
                 "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
-                % (self.n_combinations_searched, len(strat_data,))
+                % (self.n_combinations_searched, len(strat_data))
             )
             for idx, (_, strat_set) in enumerate(strat_data):
                 extra = ""
@@ -645,13 +649,13 @@
         res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
         return res
 
-    def avoid_for_spilling(self, pred_candidate):
-        if self.arch.feature_map_storage_mem_area == self.arch.fast_storage_mem_area:
-            return False
-
-        # For SRAM spilling, concat op is avoided as predecessor
+    def avoid_for_cascading(self, pred_candidate):
         for op in pred_candidate.ops:
-            if op.type == "ConcatSliceWrite":
+            if (
+                op.type == "ConcatSliceWrite"
+                and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
+            ):
+                # For SRAM spilling, concat op is avoided as predecessor
                 return True
             if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
                 # The op has consumers in other subgraphs
@@ -685,7 +689,7 @@
                         if pred_candidate.placement == PassPlacement.Npu:
                             if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
                                 # and it is on the Npu
-                                if not self.avoid_for_spilling(pred_candidate):
+                                if not self.avoid_for_cascading(pred_candidate):
                                     # and fusable - it's a candidate
                                     pred_pass_list.append(pred_candidate)
 
@@ -896,10 +900,11 @@
                     )
                     assert ps.shared_buffer is not None
 
+                sram_used = max(self.non_local_mem_usage[ps.time], 0)
                 for op in ps.ops:
                     subgraph = op.attrs.get("subgraph")
                     if subgraph:
-                        subgraph.base_sram_used = cascaded_pass.sram_used
+                        subgraph.base_sram_used = sram_used
 
         # all passes should have a cascaded pass now
         if len(pass_to_cascaded_pass) != len(self.sg.passes):