MLBEDSW-4034: New Scheduler Size or Performance Optimisation

 - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 65d3313..00a4dfc 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -13,1156 +13,1059 @@
 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+#
 # Description:
-# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
+# The scheduler creates and searches for an optimal plan for the network, selecting block configurations and
+# subdivisions for the Operators
 import copy
-import enum
-from functools import lru_cache
-
-import numpy as np
+from enum import auto
+from enum import IntEnum
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
 
 from . import live_range
 from . import npu_performance
-from . import stats_writer
+from . import tensor_allocation
+from . import weight_compressor
+from .architecture_allocator import ArchitectureBlockConfig
+from .architecture_allocator import find_block_config
+from .architecture_allocator import get_ifm_area_required
+from .architecture_allocator import to_upscale
+from .architecture_features import ArchitectureFeatures
+from .architecture_features import Block
+from .cascade_builder import CascadeBuilder
+from .cascade_builder import CascadeInfo
 from .data_type import DataType
-from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
 from .nn_graph import CascadedPass
+from .nn_graph import Graph
+from .nn_graph import Pass
 from .nn_graph import PassPlacement
-from .nn_graph import SchedulerRewrite
 from .nn_graph import SchedulingStrategy
-from .npu_performance import make_bandwidth_array
-from .npu_performance import make_cycles_array
-from .npu_performance import make_metrics_arrays
-from .npu_performance import PassCycles
+from .nn_graph import Subgraph
+from .numeric_util import round_down
+from .numeric_util import round_up
 from .operation import NpuBlockType
 from .operation import Op
-from .operation import Operation
-from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
-from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
+from .shape4d import Shape4D
 from .tensor import MemArea
 from .tensor import MemType
+from .tensor import Tensor
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 from .tensor import TensorSubPurpose
 
 
-class ParetoMetric(enum.Enum):
-    BwCycMem = 1
-    BwCycMemBlkH = 2
+def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D:
+    if tensor_format == TensorFormat.NHCWB16:
+        return shape.with_depth(round_up(shape.depth, 16))
+
+    return shape
+
+
+class OptimizationStrategy(IntEnum):
+    """Enum defining the different optimization strategies for the Scheduler"""
+
+    Size = auto()
+    Performance = auto()
 
     def __str__(self):
         return self.name
 
 
-class SchedulerOptions:
+class SchedulerOpInfo:
+    """Contains metadata about a SchedulerOperation that is unique to one Schedule"""
+
     def __init__(
         self,
-        use_cascading=True,
-        verbose_schedule=False,
-        verbose_pareto_frontier_schedules=False,
-        use_ifm_streaming=True,
-        pareto_metric=ParetoMetric.BwCycMem,
-        use_nhcwb16_between_cascaded_passes=True,
-        cache_bias_scale_tensor=True,
+        block_config: ArchitectureBlockConfig,
+        weights_size: int,
+        stripe_input: Shape4D,
+        stripe_input2: Optional[Shape4D],
+        stripe: Shape4D,
     ):
-        self.use_cascading = use_cascading
+        self.block_config = block_config
+        self.weights_size = weights_size
+        self.stripe_input = stripe_input
+        self.stripe_input2 = stripe_input2
+        self.stripe = stripe
+        self.cascade = 0  # Assigned by CascadeBuilder. 0 means not part of a cascade
+        self.time_index = None  # Set by update_op_memory_snapshot
+        self.ofm_depth_slices: List[int] = [0, stripe.depth]
+        self.npu_weights_tensor = None
+        self.buffered_weight_tensor = None
+        self.cycles = None
+        self.slack_buffering_cycles = 0
+        self.slack_buffering_memory = 0
+        self.full_weight_transfer_cycles = 0
+
+    def copy(self):
+        res = SchedulerOpInfo(self.block_config, self.weights_size, self.stripe_input, self.stripe_input2, self.stripe,)
+        res.cascade = self.cascade
+        return res
+
+    def __str__(self):
+        res = f"\t\tBlock Config = {self.block_config}\n"
+        res += f"\t\tOFM Block = {self.block_config.ofm_block}\n"
+        res += f"\t\tIFM Stripe   = {self.stripe_input}\n"
+        res += f"\t\tIFM2 Stripe  = {self.stripe_input2}\n"
+        res += f"\t\tOFM Stripe   = {self.stripe}\n"
+        res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
+        res += (
+            f"\t\tWeight buffer = {self.buffered_weight_tensor and self.buffered_weight_tensor.storage_size()} bytes\n"
+        )
+        res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
+        res += f"\t\tAssigned Cascade = {self.cascade}"
+        return res
+
+
+class SchedulerOptions:
+    """Contains options for the Scheduler"""
+
+    def __init__(
+        self, optimization_strategy, sram_target, verbose_schedule,
+    ):
+        self.optimization_strategy = optimization_strategy
+        self.optimization_sram_limit = sram_target
         self.verbose_schedule = verbose_schedule
-        self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
-        self.use_ifm_streaming = use_ifm_streaming
-        self.pareto_metric = pareto_metric
-        self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes
-        self.cache_bias_scale_tensor = cache_bias_scale_tensor
 
-    def __str__(self):
-        return type(self).__name__ + ": " + str(self.__dict__)
+    def __str__(self) -> str:
+        return f"{type(self).__name__}: {str(self.__dict__)}"
 
     __repr__ = __str__
 
 
-class Strategy:
-    __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
+class SchedulerTensor:
+    def __init__(self, shape, dt, mem_area, _format):
+        self.dtype = dt
+        self.mem_area = mem_area
+        self.shape = shape
+        self.format = _format
+        self.connection = None
 
-    def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
-        self.strat = strat
-        self.param = param
-        self.passes = passes
-        self.block_configs = block_configs
-        self.rewrite_list = (
-            rewrite_list  # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
+
+class SchedulerOperation:
+    """Scheduler internal representation of 'Operation'
+    This class can be seen as a node within the Scheduler Graph representation
+    """
+
+    def __init__(self, ps: Pass, arch: ArchitectureFeatures, nng: Graph):
+        self.arch = arch
+        self.parent_ps = ps
+        self.parent_op = ps.primary_op
+        self.name = ps.primary_op.name
+        self.op_type = ps.primary_op.type
+        self.activation = ps.primary_op.activation
+        self.kernel = ps.primary_op.kernel
+        self.resampling_mode = ps.primary_op.ifm.resampling_mode
+        self.uses_scalar = ps.primary_op.ifm2 is not None and (
+            ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == []
         )
-        self.bws = bws
-        self.macs = macs
-        self.cycles = cycles
-        self.sram_used = sram_used
+        self.ifm_ublock = arch.ifm_ublock
 
-    def __eq__(self, other):
-        if self.strat != other.strat:
-            return False
-        if self.param != other.param:
-            return False
-        if self.block_configs != other.block_configs:
-            return False
-        if self.passes != other.passes:
-            return False
-        if (self.bws != other.bws).any():
-            return False
-        if self.macs != other.macs:
-            return False
-        if (self.cycles != other.cycles).any():
-            return False
-        if self.sram_used != other.sram_used:
-            return False
-        return True
+        self.ifm = SchedulerTensor(ps.ifm_shapes[0], ps.ifm_tensor.dtype, ps.ifm_tensor.mem_area, ps.ifm_tensor.format,)
 
-    def empty(self):
-        return not self.passes
+        self.ifm2 = None
+        if ps.ifm2_tensor:
+            self.ifm2 = SchedulerTensor(
+                ps.ifm_shapes[1], ps.ifm2_tensor.dtype, ps.ifm2_tensor.mem_area, ps.ifm2_tensor.format,
+            )
 
-    def key(self):
-        return self.passes[-1]
+        self.ofm = SchedulerTensor(ps.ofm_shapes[0], ps.ofm_tensor.dtype, ps.ofm_tensor.mem_area, ps.ofm_tensor.format,)
 
-    def clone(self):
-        return Strategy(
-            self.strat,
-            self.param,
-            self.passes,
-            self.block_configs,
-            self.rewrite_list,
-            self.bws,
-            self.macs,
-            self.cycles,
-            self.sram_used,
-        )
+        # Input volume width and height required to produce the smallest possible stripe
+        self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input()
 
-    def __str__(self):
-        return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
-            self.strat,
-            self.passes,
-            self.rewrite_list,
-            self.bws,
-            self.macs,
-            self.cycles,
-            self.sram_used,
-        )
+        # Flags that marks whether this SchedulerOperation requires full IFM/OFM
+        self.requires_full_ifm = False
+        self.requires_full_ifm2 = False
+        self.requires_full_ofm = False
 
-    __repr__ = __str__
+        self.index = 0
 
+    def add_ifm_connection(self, conn: "Connection"):
+        """Add input connection to another SchedulerOperation or Subgraph Input"""
+        conn.consumers.append(self)
+        self.ifm.connection = conn
 
-class StrategySet:
-    __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
-
-    def __init__(self, strats=None):
-        if strats is None:
-            strats = dict()
-        self.strats = strats  # final pass in packed pass -> Strategy
-        self.bws, self.macs, self.cycles = make_metrics_arrays()
-        self.max_sram_used = 0
-        self.total_sram_used = 0
-
-    def update_statistics(self):
-        self.bws = make_bandwidth_array()
-        self.max_sram_used = 0
-        for ps, strat in self.strats.items():
-            self.bws += strat.bws
-            self.macs += strat.macs
-            self.cycles += strat.cycles
-            self.max_sram_used = max(self.max_sram_used, strat.sram_used)
-            self.total_sram_used += strat.sram_used
-
-    def clone_add_strategy(self, new_strat):
-        key = new_strat.key()
-        if key in self.strats:
-            assert new_strat == self.strats[key]
-            return self
+    def add_ifm2_connection(self, conn: "Connection"):
+        """Add input connection to another SchedulerOperation or Subgraph Input"""
+        if self.ifm2:
+            conn.consumers.append(self)
+            self.ifm2.connection = conn
         else:
-            new_strats = dict(self.strats)
-            new_strats[key] = new_strat
-            new_set = StrategySet(new_strats)
-            new_set.bws = self.bws + new_strat.bws
-            new_set.macs = self.macs + new_strat.macs
-            new_set.cycles = self.cycles + new_strat.cycles
-            new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
-            new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
-            return new_set
+            assert False, f"Trying to set an IFM2 Connection to {self} which has no IFM2"
 
-    def __eq__(self, other):
-        if (self.bws != other.bws).any():
-            return False
-        if self.macs != other.macs:
-            return False
-        if (self.cycles != other.cycles).any():
-            return False
-        if self.max_sram_used != other.max_sram_used:
-            return False
-        if self.total_sram_used != other.total_sram_used:
-            return False
-        if self.strats != other.strats:
-            return False
-        return True
+    def add_ofm_connection(self, conn: "Connection"):
+        """Add output connection to another SchedulerOperation or Subgraph Output"""
+        conn.producers.append(self)
+        self.ofm.connection = conn
+
+    def get_dependants(self):
+        """Returns a list of the Ops that depend on this Operation's OFM"""
+        return self.ofm.connection.consumers
+
+    def ifm_size_in_bytes(self) -> int:
+        """Returns size of the IFM in bytes"""
+        ifm_storage_shape = shape_for_format(self.ifm.shape, self.ifm.format)
+        return round_up(ifm_storage_shape.elements() * self.ifm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
+
+    def ifm2_size_in_bytes(self) -> int:
+        """Returns size of the IFM2 in bytes"""
+        if self.ifm2:
+            ifm2_storage_shape = shape_for_format(self.ifm2.shape, self.ifm2.format)
+            return round_up(ifm2_storage_shape.elements() * self.ifm2.dtype.size_in_bytes(), Tensor.AllocationQuantum)
+
+        return 0
+
+    def ofm_size_in_bytes(self) -> int:
+        """Returns size of the OFM in bytes"""
+        ofm_storage_shape = shape_for_format(self.ofm.shape, self.ofm.format)
+        return round_up(ofm_storage_shape.elements() * self.ofm.dtype.size_in_bytes(), Tensor.AllocationQuantum)
+
+    def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo:
+        """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce"""
+        ifm_shape = self.ifm.shape
+        ifm2_shape = self.ifm2 and self.ifm2.shape
+        ofm_shape = stripe
+
+        if ofm_shape != self.ofm.shape:
+            # Striped Op - Need to calculate stripe input volume
+            stripe_input_w, stripe_input_h = self._get_stripe_input_requirement(stripe)
+            # Ensure stripe input volume is within the full IFM volume
+            stripe_input_h = min(stripe_input_h, self.ifm.shape.height)
+            stripe_input_w = min(stripe_input_w, self.ifm.shape.width)
+            ifm_shape = ifm_shape.with_hw(stripe_input_h, stripe_input_w)
+
+            if self.ifm2:
+                stripe_input2_h = min(stripe_input_h, self.ifm2.shape.height)
+                stripe_input2_w = min(stripe_input_w, self.ifm2.shape.width)
+                ifm2_shape = ifm2_shape.with_hw(stripe_input2_h, stripe_input2_w)
+
+        block_config = self._get_block_config(ifm_shape, ifm2_shape, self.uses_scalar, ofm_shape)
+
+        scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape)
+        if self.parent_op.weights:
+            # Default full-depth weight encoding with no buffering
+            scheduler_op_info.npu_weights_tensor = weight_compressor.encode_weight_and_scale_tensor(
+                self.arch,
+                self.parent_op,
+                self.parent_op.weights,
+                self.parent_op.bias,
+                self.kernel,
+                block_config,
+                [0, self.ofm.shape.depth],
+            )
+
+        self.parent_ps.block_config = block_config.old_style_representation()
+        return scheduler_op_info
+
+    def _get_stripe_input_requirement(self, stripe_shape: Shape4D) -> Tuple[int, int]:
+        """Returns the amount of IFM required to produce the stripe with shape:'stripe_shape'"""
+        ofm_shape_to_produce = Block.from_shape(stripe_shape.as_list())
+
+        return get_ifm_area_required(ofm_shape_to_produce, self.kernel, to_upscale(self.resampling_mode))
+
+    def _calculate_min_stripe_input(self) -> Shape4D:
+        # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1)
+        min_stripe = self.ofm.shape.with_hw(1, 1)
+        return self._get_stripe_input_requirement(min_stripe)
+
+    def _get_block_config(
+        self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D
+    ) -> ArchitectureBlockConfig:
+        # Returns a block config and SHRAM layout
+        lut_banks = 2 if self.parent_op.activation_lut else 0
+        return find_block_config(
+            self.arch,
+            self.op_type.npu_block_type,
+            ofm_shape,
+            ifm_shape,
+            ifm2_shape,
+            uses_scalar,
+            self.ifm.dtype.size_in_bits(),
+            self.kernel,
+            lut_banks,
+            self.parent_op.has_scaling(),
+            self.resampling_mode,
+        )
+
+
+class Connection:
+    """Scheduler internal representation of a Tensor that connects two SchedulerOperations
+    This class can be seen as an edge within the Scheduler Graph representation
+    """
+
+    def __init__(self, tensor: Tensor):
+        self.parent_tens = tensor
+
+        # SchedulerOperation relationships
+        self.producers: List[SchedulerOperation] = []
+        self.consumers: List[SchedulerOperation] = []
 
     def __str__(self):
-        return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
-            self.max_sram_used,
-            list(ps.name for ps in self.strats),
-        )
+        return f"<Connection {self.parent_tens.name}>"
 
     __repr__ = __str__
 
 
-empty_strategy = Strategy(
-    SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), 0, make_cycles_array(), 0
-)
-INFINITY = 1e30
+class Schedule:
+    """Class that contains a solution of how to schedule an NPU subgraph and its cost"""
 
-ABORT_SEARCH = []
+    def __init__(self, sg: Subgraph, label: str):
+        self.sg = sg
+        self.label = label
+        self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {}
+        self.cascades: Dict[int, CascadeInfo] = {}
+        self.fast_storage_peak_usage = 0
+        self.memory_snapshot = None
+
+    @property
+    def name(self):
+        return f"{self.sg.name}_{self.label}"
 
 
-def flatten_list_of_lists(lstlst):
-    lst = []
-    for v in lstlst:
-        lst.extend(v)
-    return lst
+class Scheduler:
+    """Main class of the Vela Scheduling"""
 
-
-class DynamicProgrammingScheduler:
-    def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
+    def __init__(self, nng: Graph, sg: Subgraph, arch: ArchitectureFeatures, options: SchedulerOptions):
         self.nng = nng
         self.sg = sg
         self.arch = arch
-        self.sram_limit = sram_limit
-        self.options = copy.copy(options)
-        self.use_cascading = options.use_cascading
+        self.sched_ops: List(SchedulerOperation) = []
+        self.max_schedule = None
+        self.scheduler_options = options
 
-        if self.arch.feature_map_storage_mem_area != MemArea.Sram:
-            self.use_ifm_ofm_overlap = False  # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
-        else:
-            self.use_ifm_ofm_overlap = True
-
-        self.verbose_schedule = options.verbose_schedule
-        self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
-        self.mem_area = MemArea.Sram
-
-        self.bandwidth_weights = arch.bandwidth_weights
-        self.cycles_weight = arch.cycles_weight
-        self.max_sram_used_weight = arch.max_sram_used_weight
-
-        self.n_combinations_searched = 0
-
-        self.pareto_max_candidates = 16
-
-        self.ifm_stream_npu_blocks = set(
-            (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
-        )
-
-    num_pareto_metrics = 4
-    view_values = ",".join(["d"] * num_pareto_metrics)
-    order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
-
-    def pareto_metric(self, candidate):
-        strat, strat_set = candidate
-        total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
-        bws = strat.bws + strat_set.bws
-        last_block_height = 0
-        if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
-            last_block_height = strat.block_configs[-1][0]
-
-        return (
-            np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
-            strat_set.max_sram_used,
-            strat.sram_used,
-            last_block_height,
-        )
-
-    def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
-
-        candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
-
-        if len(candidates) <= 1:
-            return candidates
-        assert remove_equally_good_candidates
-        pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
-        ids = np.arange(len(candidates), dtype=np.int32)
-        for idx, cand in enumerate(candidates):
-            pareto_vals[idx] = self.pareto_metric(cand)
-
-        sort_order = np.argsort(
-            pareto_vals.view(DynamicProgrammingScheduler.view_values),
-            order=DynamicProgrammingScheduler.order_values,
-            axis=0,
-            kind="stable",
-        ).flatten()
-        pareto_vals = pareto_vals[sort_order]
-        ids = ids[sort_order]
-
-        pareto_frontier = []
-        while len(ids) > 0:
-            pareto_frontier.append(candidates[ids[0]])
-            not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
-            ids = ids[not_dominated_by_first]
-            pareto_vals = pareto_vals[not_dominated_by_first]
-
-        if len(pareto_frontier) > self.pareto_max_candidates:
-            pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
-            pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
-
-        return pareto_frontier
-
-    def candidate_metric(self, candidate):
-        strat, strat_set = candidate
-        max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
-        bws = strat.bws + strat_set.bws
-        total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
-
-        return (
-            max_sram_used * self.max_sram_used_weight
-            + np.tensordot(bws, self.bandwidth_weights, axes=3)
-            + total_cycles * self.cycles_weight
-        )
-
-    def sort_by_candidate_metric(self, candidate_list):
-        sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
-        return sorted_list
-
-    def best_candidate(self, candidate_list):
-        if len(candidate_list) == 0:
-            return ABORT_SEARCH
-        if len(candidate_list) == 1:
-            return candidate_list[0]
-        sorted_list = self.sort_by_candidate_metric(candidate_list)
-        return sorted_list[0]
-
-    def graduate_strat(self, strat_type, sram_used, old_strat_data):
-        res = []
-        for old_strat, old_strat_set in old_strat_data:
-            if old_strat.sram_used + sram_used > self.sram_limit:
-                continue  # This strategy is bad, drop it
-            if old_strat_set.max_sram_used > self.sram_limit:
-                continue  # This strategy is bad, drop it
-            assert old_strat.strat == SchedulingStrategy.Unknown
-
-            new_strat = old_strat.clone()
-            new_strat.strat = strat_type
-            new_strat.sram_used = old_strat.sram_used + sram_used
-
-            if self.use_ifm_ofm_overlap:
-                overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
-                    new_strat.strat, new_strat.passes, new_strat.block_configs
-                )
-                new_strat.sram_used -= overlap
-
-            new_strat_set = old_strat_set.clone_add_strategy(new_strat)
-            res.append((empty_strategy, new_strat_set))
-        return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
-
-    def append_sram(self, sram_used, old_strat_data):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            assert old_strat.sram_used == 0
-            new_strat = old_strat.clone()
-            new_strat.sram_used = old_strat.sram_used + sram_used
-
-            res.append((new_strat, strat_set))
-        return res
-
-    def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            new_strat = old_strat.clone()
-            bws, macs, cycles = metrics[:3]
-
-            new_strat.sram_used = old_strat.sram_used + sram_used
-            new_strat.block_configs = old_strat.block_configs + [block_config]
-            new_strat.bws = old_strat.bws + bws
-            new_strat.macs = old_strat.macs + macs
-            new_strat.cycles = old_strat.cycles + cycles
-            new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
-                self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
-            )
-
-            res.append((new_strat, strat_set))
-        return res
-
-    def append_sram_pass_block_config_performance_metrics_rewrite_list(
-        self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
-    ):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            new_strat = old_strat.clone()
-            bws, macs, cycles = metrics[:3]
-            new_strat.sram_used = old_strat.sram_used + sram_used
-            new_strat.block_configs = old_strat.block_configs + [block_config]
-            new_strat.bws = old_strat.bws + bws
-            new_strat.macs = old_strat.macs + macs
-            new_strat.cycles = old_strat.cycles + cycles
-            new_strat.passes = old_strat.passes + [new_pass]
-            new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
-                self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
-            )
-            new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
-            res.append((new_strat, strat_set))
-        return res
-
-    def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
-        res = []
-        for old_strat, strat_set in old_strat_data:
-            assert old_strat.strat == SchedulingStrategy.Unknown
-            new_strat = old_strat.clone()
-            new_strat.sram_used = old_strat.sram_used + sram_used
-            new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
-            res.append((new_strat, strat_set))
-        return res
-
-    def pass_to_strat(self, strat_data):
-        res = {}
-        for strat in strat_data[1].strats.values():
-            for ps in strat.passes:
-                res[ps] = strat
-        return res
-
-    def compatible_strats(self, a, b):
-        intersection = a.keys() & b.keys()
-        for k in intersection:
-            if a[k] != b[k]:
-                return False
-        return True
-
-    def collate_strats_for_passes(self, all_passes):
-        if len(all_passes) == 0:
-            return [(empty_strategy, StrategySet(dict()))]
-        if len(all_passes) == 1:
-            return all_passes[0]  # save some space in the common case
-        all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
-        prev_combos = [dict()]
-        for j, strand in enumerate(all_strands):
-            new_combos = []
-            for i, alt in enumerate(strand):
-                for prev in prev_combos:
-                    if self.compatible_strats(prev, alt):
-                        cmb = dict(prev)
-                        cmb.update(all_passes[j][i][1].strats)
-                        new_combos.append(cmb)
-            prev_combos = new_combos
-
-        res = []
-        for d in prev_combos:
-            s = StrategySet(d)
-            s.update_statistics()
-            res.append((empty_strategy, s))
-        return res
-
-    def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
-        # get the rest of the predecessors
-        other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
-        other_predecessor_data = self.search_pass_list(other_predecessors)
-
-        # pred strat data has an incomplete strategy, which we need
-        # to continue on, whereas the other ones have completed strategies.
-        # we need to merge these, but keep the incomplete strategy too.
-
-        res = []
-        for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
-            all_strats = [
-                [(empty_strategy, pred_pass_strat_set)],  # pred strat data but with a dummy empty strategy
-                other_predecessor_data,  # this one is fine to use as-is
-            ]
-            collated_strat_data = self.collate_strats_for_passes(all_strats)
-            strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
-            res.extend(strat_data)
-        return res
-
-    def calc_non_local_mem_usage(self):
-        ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
-        range_set = live_range.extract_live_ranges_from_passes(
-            self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
-        )
-        range_dict = range_set.ranges
-
-        # find which ranges overlap passes but aren't input/outputs of the passes.
-        # these won't be counted by the dynamic programming search and must be counted in manually.
-        end_pos = max(ps.time for ps in self.sg.passes) + 2
-        mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
-        non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
-
-        for tens, rng in range_dict.items():
-            storage_size = tens.storage_size()
-            assert tens.mem_area == self.mem_area
-            mem_usage[rng.start_time : rng.end_time] += storage_size
-
+    def create_scheduler_representation(self, arch: ArchitectureFeatures):
+        """Creates a Scheduler Graph representation"""
+        # Temporary dict for creating connections between the Operations
+        connections: Dict[Tensor, Connection] = {}
+        # Memory required for the largest FeatureMap that has to be full
+        min_memory_req = 0
         for ps in self.sg.passes:
-            local_mem_usage = 0
-            for tens in ps.inputs + ps.outputs + ps.intermediates:
-                if tens.mem_area != self.mem_area:
-                    continue
-
-                local_mem_usage += tens.storage_size()
-
-            non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
-
-        self.non_local_mem_usage = non_local_mem_usage
-
-    def search(self):
-        self.calc_non_local_mem_usage()
-        starting_passes = [ps for ps in self.sg.passes if not ps.successors]
-        strat_data = self.search_pass_list(starting_passes)
-
-        _, best_set = self.best_candidate(strat_data)
-
-        if self.verbose_pareto_frontier_schedules:
-            print(
-                "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
-                % (self.n_combinations_searched, len(strat_data))
-            )
-            for idx, (_, strat_set) in enumerate(strat_data):
-                extra = ""
-                if strat_set == best_set:
-                    extra = "(Best candidate)"
-                print("Candidate", idx, extra)
-                memory_used = {MemArea.Sram: strat_set.max_sram_used}
-                stats_writer.print_performance_metrics_for_strat(
-                    self.arch,
-                    "",
-                    strat_set.cycles,
-                    strat_set.macs,
-                    strat_set.bws,
-                    self.nng.batch_size,
-                    memory_used,
-                    len(self.sg.passes),
-                    len(strat_set.strats),
-                )
-
-        return best_set
-
-    def search_pass_list(self, pass_list):
-        all_strats = []
-        for ps in pass_list:
-            strat = self.search_output(ps)
-            all_strats.append(strat)
-        strat_data = self.collate_strats_for_passes(all_strats)
-        for strd in strat_data:
-            for ps in pass_list:
-                assert ps in strd[1].strats  # should have strategies for everything we asked to search
-        return strat_data
-
-    def search_predecessors(self, ps):
-
-        # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
-        # we have strats for all passes
-
-        pass_list = ps.dag_predecessors
-        strat_data = self.search_pass_list(pass_list)
-
-        return strat_data
-
-    @lru_cache(maxsize=None)
-    def search_output(self, ps):
-
-        assert ps in self.sg.passes
-        candidate_list = []
-
-        candidate_list.extend(self.search_weight_streaming_output(ps))
-
-        if self.options.use_ifm_streaming:
-            candidate_list.extend(self.search_ifm_streaming_output(ps))
-
-        best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
-
-        if not best:
-            print(
-                "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
-                % (ps.name,)
-            )
-            return self.search_predecessors(ps)
-
-        return best
-
-    def search_ifm_streaming_output(self, ps):
-        if ps.placement != PassPlacement.Npu:
-            return ABORT_SEARCH
-        if ps.npu_block_type not in self.ifm_stream_npu_blocks:
-            return ABORT_SEARCH
-        strat_data = self.search_ifm_streaming_body(ps, False)
-
-        sram_used = self.non_local_mem_usage[ps.time]
-        for tens in ps.outputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
-
-    @lru_cache(maxsize=None)
-    def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
-        if ps.placement != PassPlacement.Npu:
-            return ABORT_SEARCH
-        if ps.npu_block_type not in self.ifm_stream_npu_blocks:
-            return ABORT_SEARCH
-        ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
-        res = []
-
-        base_sram_used = 0
-        for tens in ps.intermediates:
-            if tens.mem_area == self.mem_area:
-                if tens.purpose == TensorPurpose.Weights:
-                    base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling)
-                else:
-                    base_sram_used += tens.storage_size()
-
-        all_block_configs = self.get_block_configs(ps)
-        for block_config in all_block_configs:
-            all_strats = []
-
-            if self.use_cascading:
-                all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
-
-            all_strats.extend(ifm_input_search_resuls)
-
-            rewrite_list = []
-            sram_used = base_sram_used
-
-            metrics = npu_performance.performance_metrics_for_pass(
-                self.arch,
-                ps,
-                block_config,
-                rewrite_list=rewrite_list,
-                force_outputs_to_fast_storage=force_outputs_to_fast_storage,
-            )
-
-            res.extend(
-                self.append_sram_pass_block_config_performance_metrics_rewrite_list(
-                    sram_used, ps, block_config, metrics, rewrite_list, all_strats
-                )
-            )
-
-        self.n_combinations_searched += len(res)
-        res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
-        return res
-
-    def avoid_for_cascading(self, pred_candidate):
-        for op in pred_candidate.ops:
-            if (
-                op.memory_function == Op.ConcatSliceWrite
-                and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
-            ):
-                # For SRAM spilling, concat op is avoided as predecessor
-                return True
-            if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
-                # The op has consumers in other subgraphs
-                return True
-        return False
-
-    def search_ifm_streaming_partial(self, ps, block_config):
-        if ps.placement != PassPlacement.Npu:
-            return ABORT_SEARCH
-
-        if len(ps.inputs) < 1:
-            return ABORT_SEARCH
-
-        ifm_tensor = ps.ifm_tensor
-
-        if ifm_tensor is None:
-            return ABORT_SEARCH
-        if ifm_tensor.purpose != TensorPurpose.FeatureMap:
-            return ABORT_SEARCH
-        if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
-            return ABORT_SEARCH
-
-        pred_pass_list = []
-        for pred_candidate in ps.dag_predecessors:
-            if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
-                # we found a predecessor that produces this IFM tensor
-                if not ifm_tensor.needs_linear_format:
-                    # and NHCWB16 can be used
-                    if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
-                        # and it only has one successor, namely us
-                        if pred_candidate.placement == PassPlacement.Npu:
-                            if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
-                                # and it is on the Npu
-                                if not self.avoid_for_cascading(pred_candidate):
-                                    # and fusable - it's a candidate
-                                    pred_pass_list.append(pred_candidate)
-
-        if not pred_pass_list:
-            return ABORT_SEARCH
-
-        all_candidates = []
-        for pred_pass in pred_pass_list:
-            # recurse into the next pass
-            ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.arch.is_spilling_enabled())
-
-            strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
-            for strat_opt in strat_data:
-
-                pred_pass_block_config = strat_opt[0].block_configs[-1]
-                rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
-                    self.arch, pred_pass, pred_pass_block_config, ps, block_config
-                )
-                if rolling_buffer_dims is None:
-                    continue  # this does not pack properly, skip it.
-
-                sram_used = 0
-                for tens in ps.inputs:
-                    if tens != ifm_tensor:
-                        if tens.mem_area == self.mem_area:
-                            sram_used += tens.storage_size()
-
-                rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
-
-                rewrite_list = [
-                    (
-                        SchedulerRewrite.ChangeTensorSubPurpose,
-                        ifm_tensor,
-                        TensorSubPurpose.RollingBufferY,
-                        rolling_buffer_y,
-                        None,
-                        ps,
-                    )
-                ]
-                sram_used += ifm_tensor.storage_size_for_sub_purpose(
-                    self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
-                )
-
-                all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
-
-        self.n_combinations_searched += len(all_candidates)
-        return all_candidates
-
-    def get_block_configs(self, ps):
-        if ps.placement != PassPlacement.Npu:
-            return [(1, 1, 1, 1)]  # default
-
-        block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
-
-        # Take a limited number of the largest blocks
-        if self.arch.block_config_limit > 0:
-            # Sort by block area, followed by depth
-            block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
-            bound = min(len(block_configs), self.arch.block_config_limit)
-            # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
-            tmp = block_configs[:bound]
-            tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
-            block_configs = tmp
-
-        return block_configs
-
-    def search_ifm_streaming_input(self, ps):
-        sram_used = 0
-        for tens in ps.inputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.append_sram(sram_used, self.search_predecessors(ps))
-
-    def search_weight_streaming_output(self, ps):
-        strat_data = self.search_weight_streaming_body(ps)
-
-        sram_used = self.non_local_mem_usage[ps.time]
-        for tens in ps.outputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
-
-    @lru_cache(maxsize=None)
-    def search_weight_streaming_body(self, ps):
-
-        strat_data = self.search_weight_streaming_input(ps)
-
-        res = []
-
-        all_block_configs = self.get_block_configs(ps)
-
-        for block_config in all_block_configs:
-
-            sram_used = 0
-            rewrite_list = []
-
-            for tens in ps.intermediates:
-                if tens.mem_area == self.mem_area:
-                    if tens.purpose == TensorPurpose.Weights:
-                        sram_used += tens.storage_size_for_sub_purpose(
-                            self.arch, TensorSubPurpose.DoubleBuffer, block_config[3]
-                        )
-                        rewrite_list.append(
-                            (
-                                SchedulerRewrite.ChangeTensorSubPurpose,
-                                tens,
-                                TensorSubPurpose.DoubleBuffer,
-                                block_config[3],
-                                None,
-                                ps,
-                            )
-                        )
-                    else:
-                        sram_used += tens.storage_size()
-
-            metrics = npu_performance.performance_metrics_for_pass(
-                self.arch, ps, block_config, rewrite_list=rewrite_list
-            )
-
-            res.extend(
-                self.append_sram_pass_block_config_performance_metrics_rewrite_list(
-                    sram_used, ps, block_config, metrics, rewrite_list, strat_data
-                )
-            )
-
-        self.n_combinations_searched += len(res)
-        res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
-        return res
-
-    def search_weight_streaming_input(self, ps):
-        sram_used = 0
-        for tens in ps.inputs:
-            if tens.mem_area == self.mem_area:
-                sram_used += tens.storage_size()
-
-        return self.append_sram(sram_used, self.search_predecessors(ps))
-
-    def apply_result(self, strat_set, arch):
-        pass_to_cascaded_pass = dict()
-        for _, strat in strat_set.strats.items():
-            # rewrite the tensors that need this first. e.g. make rolling buffers
-            inputs = []
-            intermediates = []
-            outputs = []
-
-            for ps in strat.passes:
-                inputs += ps.inputs
-                intermediates += ps.intermediates
-                outputs += ps.outputs
-
-            for tens in set(inputs) & set(outputs):
-                # tensors that are in both sets are intermediates
-
-                # find pass with input/output tensor, and check if they are both placed on NPU
-                input_placement = None
-                output_placement = None
-                for ps in strat.passes:
-                    if tens in ps.inputs:
-                        input_placement = ps.placement
-                    if tens in ps.outputs:
-                        output_placement = ps.placement
-                if input_placement == output_placement == PassPlacement.Npu:
-                    tens.set_format(TensorFormat.NHCWB16, arch)
-
-                intermediates.append(tens)
-                inputs.remove(tens)
-                outputs.remove(tens)
-
-            for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
-                if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
-                    tens.mem_area = self.arch.fast_storage_mem_area
-                    tens.mem_type = MemType.Scratch_fast
-                    tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
-                else:
-                    assert 0, "unknown rewrite_op " + str(rewrite_op)
-
-            is_element_wise = True
-            for ps in strat.passes:
-                assert ps.placement == strat.passes[0].placement
-                if not ps.is_element_wise:
-                    is_element_wise = False
-                    break
-
-            cascaded_pass = CascadedPass(
-                strat.passes[0].name,
-                strat.strat,
-                inputs,
-                intermediates,
-                outputs,
-                strat.passes,
-                strat.passes[0].placement,
-                is_element_wise,
-            )
-            assert strat.sram_used >= 0
-            cascaded_pass.sram_used = strat.sram_used
-
-            for idx, ps in enumerate(strat.passes):
-                assert ps not in pass_to_cascaded_pass
-                pass_to_cascaded_pass[ps] = cascaded_pass
-                ps.cascade = cascaded_pass
-                ps.block_config = strat.block_configs[idx]
-
-                if ps.placement == PassPlacement.Npu:
-                    ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
-                        self.arch, ps, ps.block_config
-                    )
-                    assert ps.shared_buffer is not None
-
-                sram_used = max(self.non_local_mem_usage[ps.time], 0)
-                for op in ps.ops:
-                    subgraph = op.attrs.get("subgraph")
-                    if subgraph:
-                        subgraph.base_sram_used = sram_used
-
-        # all passes should have a cascaded pass now
-        if len(pass_to_cascaded_pass) != len(self.sg.passes):
-            print(
-                "mismatch: we have %d passes, but only %d have cascaded passes associated"
-                % (len(self.sg.passes), len(pass_to_cascaded_pass))
-            )
-            for ps in self.sg.passes:
-                if ps not in pass_to_cascaded_pass:
-                    print("%3d pass missing cascaded pass %s" % (ps.time, ps))
-
-            assert len(pass_to_cascaded_pass) == len(self.sg.passes)
-
-        cascaded_passes = []
-        if self.sg.placement == PassPlacement.Cpu:
-            # Retain the pass order for CPU subgraph
-            cascaded_passes = [ps.cascade for ps in self.sg.passes]
-        else:
-            # we have all the passes, but we need to put them in order and build predecessor/successor links.
-            visit_pass_set = set()
-
-            def visit_pass(ps):
-                if ps in visit_pass_set:
-                    return
-                visit_pass_set.add(ps)
-
-                cps = ps.cascade
-                dont_traverse = set(cps.passes)
-
-                for ps in cps.passes:
-                    for pred in ps.predecessors:
-                        if pred in dont_traverse:
-                            continue
-                        visit_pass(pred)
-
-                cascaded_passes.append(cps)
-
-            starting_passes = [ps for ps in self.sg.passes if not ps.successors]
-            for ps in starting_passes:
-                visit_pass(ps)
-
-        # reorder so startup init cascaded passes come first
-        def is_startup_cascaded_pass(cps):
-            if not cps.passes:
-                return False
-            return cps.placement == PassPlacement.StartupInit
-
-        cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
-            cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
-        ]
-
-        self.sg.cascaded_passes = cascaded_passes
-        self.sg.build_cascaded_pass_links()
-
-        # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
-        # (NHCWB16 within cascaded passes has been handled earlier in this function)
-        if self.sg.placement == PassPlacement.Npu:
-            # Dictionary tensor -> list of ops, containing feature maps that can be attempted
-            # to be moved to fast storage
-            fast_storage_tensor_rewrites = {}
-            last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
-            # Memory only passes have no primary_op, so use the last op in ops
-            if last_op_in_subgraph is None:
-                last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].ops[-1]
-            for ps in self.sg.cascaded_passes:
-                if ps.placement != PassPlacement.Npu:
-                    continue
+            if ps.primary_op:
+                # Set tensor format to NHCWB16 for output FeatureMaps, if possible
                 for output in ps.outputs:
                     if output.purpose != TensorPurpose.FeatureMap:
                         continue
-
-                    use_NHCWB16 = not output.needs_linear_format
-                    use_fast_storage = True
-                    rewrites = []
-                    for op in output.consumer_list:
-                        if op is None:
-                            use_NHCWB16 = False
-                            use_fast_storage = False
-                            continue
-                        if op.type == Op.ReduceSum and output.dtype == DataType.int32:
-                            use_NHCWB16 = False
-                        elif op.type == Op.Reshape:
-                            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
-                            # consumers do not also need to perform a reshape or if the OFM is going to
-                            # be processed by CPU operations. No-op reshape consumers with empty lists
-                            # (those that have no consumers, or null-consumers used as list terminators)
-                            # must use normal NHWC output.
-                            def incompatible_consumers(oper):
-                                if oper and oper.type == Op.Reshape:
-                                    for consumer in oper.outputs[0].consumer_list:
-                                        yield from incompatible_consumers(consumer)
-                                yield not oper or not oper.run_on_npu or oper is last_op_in_subgraph
-
-                            if not any(incompatible_consumers(op)):
-
-                                def get_rewrites(oper):
-                                    if oper and oper.type == Op.Reshape:
-                                        for consumer in oper.outputs[0].consumer_list:
-                                            yield from get_rewrites(consumer)
-                                        yield oper
-
-                                rewrites.extend(get_rewrites(op))
-                                # Detect no-op reshapes by comparing their full input and output tensor shapes.
-                                inshape = op.ifm_shapes[0]
-                                compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
-                                use_NHCWB16 &= compatible_shape and all(compatible_shape)
-                            else:
-                                use_NHCWB16 = False
-                                use_fast_storage = False
-                        use_NHCWB16 &= op.run_on_npu
-                        use_fast_storage &= op.run_on_npu
-
-                    if use_fast_storage:
-                        fast_storage_tensor_rewrites[output] = rewrites
-                    if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
+                    if not output.needs_linear_format:
                         output.set_format(TensorFormat.NHCWB16, arch)
-                        for rewrite_op in rewrites:
-                            rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
-            if arch.is_spilling_enabled():
-                # Remember feature maps that can be moved to fast storage for later use
-                # in use_fast_storage_for_feature_maps
-                self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
+
+                # Create SchedulerOperations
+                op = SchedulerOperation(ps, arch, self.nng)
+                op.index = len(self.sched_ops)
+
+                # Make connections
+                if ps.ifm_tensor not in connections:
+                    connections[ps.ifm_tensor] = Connection(ps.ifm_tensor)
+                if ps.ifm2_tensor and ps.ifm2_tensor not in connections:
+                    connections[ps.ifm2_tensor] = Connection(ps.ifm2_tensor)
+                if ps.ofm_tensor not in connections:
+                    connections[ps.ofm_tensor] = Connection(ps.ofm_tensor)
+
+                op.add_ifm_connection(connections[ps.ifm_tensor])
+                if ps.ifm2_tensor:
+                    op.add_ifm2_connection(connections[ps.ifm2_tensor])
+                op.add_ofm_connection(connections[ps.ofm_tensor])
+
+                # Set requirements on the ifm/ofm buffers
+                self.sched_ops.append(op)
+                if ps.ifm_tensor in self.sg.input_tensors:
+                    # This Op consumes a subgraph input
+                    op.requires_full_ifm = True
+                if ps.ifm2_tensor and ps.ifm2_tensor in self.sg.input_tensors:
+                    # This Op consumes a subgraph input
+                    op.requires_full_ifm2 = True
+                if ps.ofm_tensor in self.sg.output_tensors:
+                    # This Op produces a subgraph output
+                    op.requires_full_ofm = True
+                if ps.ifm_tensor.needs_linear_format:
+                    op.requires_full_ifm = True
+                if ps.ifm2_tensor and ps.ifm2_tensor.needs_linear_format:
+                    op.requires_full_ifm2 = True
+                if ps.ofm_tensor.needs_linear_format or ps.primary_op.memory_function == Op.ConcatSliceWrite:
+                    op.requires_full_ofm = True
+                if len(ps.primary_op.outputs) > 1 or len(ps.primary_op.outputs[0].consumer_list) > 1:
+                    # Op has multiple outputs or consumers - requires full OFM
+                    op.requires_full_ofm = True
+
+                # Check memory requirements if this Op requires any full FeatureMaps
+                op_memory_req = 0
+                if op.requires_full_ifm:
+                    op_memory_req += op.ifm_size_in_bytes()
+                if op.requires_full_ifm2:
+                    op_memory_req += op.ifm2_size_in_bytes()
+                if op.requires_full_ofm:
+                    op_memory_req += op.ofm_size_in_bytes()
+
+                min_memory_req = max(op_memory_req, min_memory_req)
+
+        # Theoretical minimum required memory - used to guide the cascade building
+        self.min_memory_req = min_memory_req
+
+    def create_initial_schedule(self) -> Schedule:
+        """Creates an initial schedule with no cascading or buffering of any kind"""
+        schedule = Schedule(self.sg, "MAX")
+
+        for op in self.sched_ops:
+            cost = op.create_scheduler_info(self.nng, op.ofm.shape)
+            cost.cycles = self.estimate_op_performance(op, cost.block_config, op.ofm.shape.depth)
+            schedule.cost_map[op] = cost
+
+        return schedule
+
+    def update_op_memory_snapshot(self, schedule: Schedule):
+        memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
+
+        # Collect live ranges from tensors
+        lr_graph = live_range.LiveRangeGraph()
+        for mem_area, mem_type_set in memories_list:
+            live_range.extract_live_ranges_from_cascaded_passes(
+                self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+            )
+
+        # Populate time-array with memory used by live ranges
+        temporal_usage = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
+        schedule.memory_snapshot = temporal_usage
+
+        # Set the peak memory usage
+        schedule.fast_storage_peak_usage = max(temporal_usage, default=0)
+
+    def estimate_op_performance(self, op: SchedulerOperation, block_config, ofm_depth):
+        query = npu_performance.PerformanceQuery(op.op_type.npu_block_type)
+        query.ifm_shape = op.ifm.shape
+        query.ifm_memory_area = op.ifm.mem_area
+        query.ifm_bits = op.ifm.dtype.size_in_bits()
+        query.ifm_format = op.ifm.format
+        query.ifm2_shape = op.ifm2 and op.ifm2.shape
+        query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area
+        query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits()
+        query.ifm2_format = op.ifm2 and op.ifm2.format
+        query.ofm_shape = op.ofm.shape.with_depth(ofm_depth)
+        query.ofm_memory_area = op.ofm.mem_area
+        query.ofm_bits = op.ofm.dtype.size_in_bits()
+        query.ofm_format = op.ofm.format
+        if op.parent_op.bias:
+            query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
+            query.const_memory_area = self.arch.fast_storage_mem_area
+
+        query.kernel = op.kernel
+        query.config = block_config
+
+        return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query)
+
+    def propose_schedule_buffering(self, ref_schedule: Schedule):
+        """Create a buffered schedule"""
+        buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED")
+        staging_limit_bytes = self.scheduler_options.optimization_sram_limit
+
+        prev_op = None
+        for sched_op in self.sched_ops:
+            if sched_op not in ref_schedule.cost_map:
+                # sched_op is not part of this sub-schedule - skip
+                continue
+
+            self.propose_operator_buffering(sched_op, prev_op, buffered_schedule, ref_schedule, staging_limit_bytes)
+            prev_op = sched_op
+
+        return buffered_schedule
+
+    def propose_operator_buffering(
+        self,
+        sched_op: SchedulerOperation,
+        prev_op: SchedulerOperation,
+        buffered_schedule: Schedule,
+        ref_schedule: Schedule,
+        staging_limit_bytes,
+    ):
+        # Mild recursion might mean this Op has already been seen
+        if sched_op in buffered_schedule.cost_map:
+            return
+
+        # Take the reference schedule as default costings for this schedule
+        ref_cost = ref_schedule.cost_map[sched_op]
+        cost = copy.copy(ref_cost)
+        cost.slack_buffering_cycles = ref_cost.cycles.op_cycles
+        memory_snapshot = ref_schedule.memory_snapshot
+        ref_memory_usage = memory_snapshot[ref_cost.time_index] if ref_cost.time_index < len(memory_snapshot) else 0
+        cost.slack_buffering_memory = staging_limit_bytes - ref_memory_usage
+        buffered_schedule.cost_map[sched_op] = cost
+
+        # Attempt weight buffering on anything with a weights tensor
+        if sched_op.parent_op.weights:
+            self.propose_weight_buffering(
+                sched_op.parent_op.weights,
+                sched_op.parent_op.bias,
+                sched_op,
+                prev_op,
+                buffered_schedule,
+                ref_schedule,
+                cost.slack_buffering_memory,
+            )
+
+        return cost
+
+    def weights_needs_dma(self, weight_tensor):
+        if weight_tensor and weight_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
+            # Weights are in permanent storage
+            # Only when permanent storage differs from feature map storage, there is a point moving the data
+            if (
+                weight_tensor.mem_area in (MemArea.Dram, MemArea.OffChipFlash)
+                and self.arch.permanent_storage_mem_area != self.arch.fast_storage_mem_area
+            ):
+                return True
+        return False
+
+    def propose_weight_buffering(
+        self,
+        weight_tensor,
+        scale_tensor,
+        sched_op: SchedulerOperation,
+        prev_op: SchedulerOperation,
+        buffered_schedule: Schedule,
+        ref_schedule: Schedule,
+        buffer_limit_bytes,
+    ):
+        cost = buffered_schedule.cost_map[sched_op]
+        prev_cost = buffered_schedule.cost_map.get(prev_op)
+        ref_cost = ref_schedule.cost_map[sched_op]
+        assert cost and ref_cost
+
+        needs_dma = self.weights_needs_dma(weight_tensor)
+
+        ofm_full_depth_slices = [0, ref_cost.stripe.depth]
+
+        # Encode weights for the full depth
+        full_weights = weight_compressor.encode_weight_and_scale_tensor(
+            self.arch,
+            sched_op.parent_op,
+            weight_tensor,
+            scale_tensor,
+            sched_op.kernel,
+            cost.block_config,
+            ofm_full_depth_slices,
+        )
+        full_weights_bytes = len(full_weights.buffer)
+        cost.ofm_depth_slices = ofm_full_depth_slices
+
+        # No buffering required - take all the weights from permanent storage
+        if sched_op.op_type == Op.FullyConnected or not needs_dma:
+            cost.npu_weights_tensor = full_weights
+            return
+
+        encoded_weights = full_weights
+
+        # How many NPU cycles are available under the previously executing
+        # operator and SRAM unused for performing buffered DMA transfers
+        slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0
+        slack_memory = prev_cost.slack_buffering_memory if prev_cost else 0
+
+        # Force full depth for cascaded Ops
+        if ref_cost.cascade != 0:
+            weight_tensor_purpose = TensorSubPurpose.Standard
+            weight_buffer_size = full_weights_bytes
+            # Update the memory snapshot to reflect the added size of the weights
+            ref_schedule.memory_snapshot[ref_cost.time_index] += weight_buffer_size
+        else:
+            # Estimate the buffering cycle time for the full set of weights
+            full_transfer_cycles = npu_performance.measure_mem2mem_cycles(
+                self.arch, weight_tensor.mem_area, self.arch.fast_storage_mem_area, full_weights_bytes
+            )
+            cost.full_weight_transfer_cycles = full_transfer_cycles
+
+            # Calculate the amount of prebuffering necessary (or what is possible with limited
+            # double buffer buffer size)
+            half_buffer_limit = buffer_limit_bytes // 2
+            if full_transfer_cycles > slack_cycles:
+                prebuffer_ratio = slack_cycles / full_transfer_cycles
+                prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit)
+            else:
+                prebuffer_bytes = min(full_weights_bytes, half_buffer_limit)
+                prebuffer_ratio = prebuffer_bytes / full_weights_bytes
+
+            # Have to split the weights if the initial buffering can't store
+            # all of the compressed weights
+            if prebuffer_bytes < full_weights_bytes:
+                prebuffer_depth = int(ref_cost.stripe.depth * prebuffer_ratio)
+
+                # Round prebuffering down to nearest valid split depth
+                prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
+
+                while True:
+                    buffering_depth = max(cost.block_config.ofm_block.depth, prebuffer_depth)
+
+                    # Clamp buffering to the double buffering limit
+                    buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes
+                    if buffering_bytes > half_buffer_limit:
+                        buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth
+                        buffering_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth)))
+
+                    # Create list of depth slices
+                    depth_slices = [0]
+                    if prebuffer_depth < ref_cost.stripe.depth:
+                        depth_slices += list(range(prebuffer_depth, ref_cost.stripe.depth, buffering_depth))
+                    depth_slices.append(ref_cost.stripe.depth)
+
+                    # Encode weights based depth slices
+                    cost.ofm_depth_slices = depth_slices
+                    encoded_weights = weight_compressor.encode_weight_and_scale_tensor(
+                        self.arch,
+                        sched_op.parent_op,
+                        weight_tensor,
+                        scale_tensor,
+                        sched_op.kernel,
+                        cost.block_config,
+                        cost.ofm_depth_slices,
+                    )
+
+                    # Chosen buffering might not fit at all, iterate until it does
+                    # or until the minimum usable slice size is reached
+                    if (
+                        encoded_weights.max_range_bytes <= half_buffer_limit
+                        or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
+                    ):
+                        break
+
+                    prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth)
+
+                # Calculate cycles required to run the last op for use as future slack
+                tail_cycles = self.estimate_op_performance(
+                    sched_op, cost.block_config, depth_slices[-1] - depth_slices[-2]
+                )
+                cost.slack_buffering_cycles = tail_cycles.op_cycles
+
+        # Determine whether the weights need to be double buffered
+        weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes)
+
+        # Only buffer weights if there's still space left for the buffer
+        if weight_buffer_size <= buffer_limit_bytes:
+            assert weight_buffer_size % 16 == 0
+            # Determine whether to double buffer or single buffer
+            if (weight_buffer_size * 2 <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
+                weight_buffer_size = weight_buffer_size * 2
+                weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
+            else:
+                weight_tensor_purpose = TensorSubPurpose.Standard
+
+            cost.buffered_weight_tensor = Tensor(
+                [1, 1, 1, weight_buffer_size], DataType.uint8, weight_tensor.name + "_buffer"
+            )
+            cost.buffered_weight_tensor.src_tensor = encoded_weights
+            cost.buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
+            cost.buffered_weight_tensor.mem_type = MemType.Scratch_fast
+            cost.buffered_weight_tensor.purpose = TensorPurpose.Weights
+            cost.buffered_weight_tensor.sub_purpose = weight_tensor_purpose
+            if ref_cost.cascade == 0:
+                # Determine if the lifetime can be extended and pre-buffer weights under the previous operation
+                cost.buffered_weight_tensor.pre_buffer = weight_buffer_size < slack_memory
+
+            cost.slack_buffering_memory -= weight_buffer_size
+        else:
+            # Don't slice or buffer - use the whole depth from persistent storage
+            cost.ofm_depth_slices = ofm_full_depth_slices
+            encoded_weights = full_weights
+
+        cost.npu_weights_tensor = encoded_weights
+
+    def propose_minimal_schedule(self) -> Schedule:
+        """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the
+        next operators stride"""
+        min_schedule = Schedule(self.sg, "MIN")
+        cost_map = min_schedule.cost_map
+
+        # Keep track of the previous Op - which consumes the current Op's OFM
+        prev_op = None
+        for sched_op in reversed(self.sched_ops):
+            min_stripe_height = prev_op.kernel.stride.y if prev_op else 1
+            min_stripe = sched_op.ofm.shape.with_height(min_stripe_height)
+
+            cost = sched_op.create_scheduler_info(self.nng, min_stripe)
+            cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
+            cost_map[sched_op] = cost
+
+            prev_op = sched_op
+
+        return min_schedule
+
+    def propose_schedule_striping(self, final_stripe: Shape4D, label: str, ref_schedule: Schedule) -> Schedule:
+        """Proposes new striping for a schedule. The stripe is derived from the ifm requirements of the next Op down"""
+        ref_cost = ref_schedule.cost_map
+
+        striped_schedule = Schedule(self.sg, label)
+        stripe = final_stripe
+        for sched_op in reversed(self.sched_ops):
+            if sched_op not in ref_cost:
+                # sched_op is not part of the sub-schedule - skip
+                continue
+
+            # Create a cost entry with the new stripe
+            cost = sched_op.create_scheduler_info(self.nng, stripe)
+
+            # Copy the weight buffering from the reference schedule
+            cost.buffered_weight_tensor = ref_cost[sched_op].buffered_weight_tensor
+
+            # Estimate performance
+            cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth)
+            striped_schedule.cost_map[sched_op] = cost
+
+            # Calculate the preceeding Op's stripe
+            stripe = sched_op.ifm.shape.with_height(stripe.height * sched_op.kernel.stride.y)
+
+        return striped_schedule
+
+    def estimate_schedule_memory_usage(self, schedule: Schedule, non_local_mem_usage: dict):
+        """Estimates the memory usage of a schedule"""
+        cost = schedule.cost_map
+        cascades = schedule.cascades
+        peak_mem_usage = 0
+        for sched_op in self.sched_ops:
+            if sched_op not in cost:
+                # sched_op is not part of the sub-schedule - skip
+                continue
+
+            if cost[sched_op].cascade:
+                # This Op is part of a cascade - use the cascade's memory usage
+                cascade_info = cascades[cost[sched_op].cascade]
+                # Non-local memory usage is already included in the cascade_info
+                peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
+            else:
+                # This Op is not part of a cascade - calculate the memory usage
+                op_weight_buffer = 0
+                if cost[sched_op].buffered_weight_tensor:
+                    op_weight_buffer = cost[sched_op].buffered_weight_tensor.storage_size()
+
+                op_mem_usage = (
+                    sched_op.ifm_size_in_bytes()
+                    + sched_op.ofm_size_in_bytes()
+                    + op_weight_buffer
+                    + non_local_mem_usage.get(sched_op, 0)
+                )
+                peak_mem_usage = max(op_mem_usage, peak_mem_usage)
+
+        return peak_mem_usage
+
+    def optimize_sub_schedule(
+        self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int
+    ) -> Schedule:
+        """Extracts the Ops covered by the given cascade and creates a sub-schedule. The sub-schedule is optimized by
+        proposing weight buffering and then continously proposing new stripe sizes"""
+        ref_cost = ref_schedule.cost_map
+        # Extract the ops that are part of this sub-schedule
+        start = cascade_info.start
+        end = cascade_info.end
+        sub_schedule_ops = self.sched_ops[start : end + 1]
+        # Create a sub-schedule that contains only the costs for the Ops that are part of the sub-schedule
+        sub_schedule = Schedule(self.sg, f"SUB_{start}_{end}")
+        for sched_op in sub_schedule_ops:
+            sub_schedule.cost_map[sched_op] = ref_cost[sched_op]
+
+        sub_schedule.cascades[end] = cascade_info
+        # Use the memory snapshot from the reference schedule
+        sub_schedule.memory_snapshot = ref_schedule.memory_snapshot
+
+        # Calculate memory usage that is live during the sub-schedule but not part of it
+        time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
+        mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
+        # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
+        # included in a cascade or not
+        persistent_initial_ifm = (
+            sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0
+        )
+        # Calculate non-local-mem-usage per Operator
+        non_local_mem_usage = {}
+        for idx, sched_op in enumerate(sub_schedule_ops):
+            non_local_mem_usage[sched_op] = mem_usage_parallel_to_sub_schedule
+            if idx != 0:
+                non_local_mem_usage[sched_op] += persistent_initial_ifm
+
+        cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
+
+        # Start by adding buffering
+        buffered_sub_schedule = self.propose_schedule_buffering(sub_schedule)
+        # Copy the cascades over from the unbuffered-schedule
+        buffered_sub_schedule.cascades = sub_schedule.cascades
+
+        # Generate the possible stripings for the final Op in the sub-schedule
+        final_ofm_shape = sub_schedule_ops[-1].ofm.shape
+        possible_stripes = [
+            final_ofm_shape.with_height(stripe_h) for stripe_h in range(1, final_ofm_shape.height // 2 + 1)
+        ]
+
+        # Propose different striping - the possible stripes are proposed similarly to a binary search
+        best_schedule = buffered_sub_schedule
+        iteration = 0
+        while len(possible_stripes) > 1:
+            proposed_stripe = possible_stripes[len(possible_stripes) // 2]
+            proposed_schedule = self.propose_schedule_striping(
+                proposed_stripe, f"OPTIMIZED_{iteration}", buffered_sub_schedule
+            )
+
+            cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit)
+
+            # Check if proposal fits
+            proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage)
+            if (proposed_schedule_mem_usage) <= memory_limit:
+                # Remove all possible stripes smaller than this
+                possible_stripes = possible_stripes[len(possible_stripes) // 2 :]
+                best_schedule = proposed_schedule
+                if not proposed_schedule.cascades:
+                    # No cascading required - early exit
+                    break
+            else:
+                # Proposal doesn't fit within the limit - remove all possible stripes larger than this
+                possible_stripes = possible_stripes[: len(possible_stripes) // 2]
+
+            iteration += 1
+
+        return best_schedule
+
+    def optimize_schedule(
+        self, schedule: Schedule, max_sched: Schedule, max_template: Schedule, options: SchedulerOptions,
+    ) -> Schedule:
+        """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule"""
+        sram_limit = options.optimization_sram_limit
+        if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled():
+            # Maximum performance schedule fits within the SRAM target
+            return max_sched
+
+        # Extract the cascades
+        cascades = [cascade for cascade in schedule.cascades.values()]
+        for cascade_info in cascades:
+            # Remove existing cascade from schedule
+            del schedule.cascades[cascade_info.end]
+            # Optimize the sub-schedule in this cascade
+            opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
+            # Update the sub-schedule Op and cascade costs to the full schedule
+            schedule.cost_map.update(opt_sub_schedule.cost_map)
+            schedule.cascades.update(opt_sub_schedule.cascades)
+
+        # Update memory snapshot
+        self.sg.schedule = schedule
+        self.update_op_memory_snapshot(schedule)
+        # Propose schedule buffering to the optimized schedule
+        optimized_sched = self.propose_schedule_buffering(schedule)
+        # Copy the cascade's metadata from the unbuffered schedule
+        optimized_sched.cascades = schedule.cascades
+        return optimized_sched
+
+    def apply_schedule(self, sched: Schedule):
+        """Applies the given schedule as a final solution"""
+        for sched_op in self.sched_ops:
+            op_info = sched.cost_map[sched_op]
+            cascade_info = sched.cascades.get(op_info.cascade, None)
+            if cascade_info and sched_op in cascade_info.buffers:
+                buffer_tens = sched_op.ifm.connection.parent_tens
+                # Apply memory area and type
+                buffer_tens.mem_area = self.arch.fast_storage_mem_area
+                buffer_tens.mem_type = MemType.Scratch_fast
+                # Apply Rolling buffer
+                buffer_tens.set_format(TensorFormat.NHCWB16, self.arch)
+                buffer_tens.set_new_sub_purpose(TensorSubPurpose.RollingBufferY, cascade_info.buffers[sched_op].height)
+
+            sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
+
+            # Ensure that the src_tensor reference is set correctly
+            if op_info.buffered_weight_tensor:
+                op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
+
+    def use_fast_storage_for_feature_maps(self, schedule: Schedule, memory_limit: int):
+        if self.arch.fast_storage_mem_area == self.arch.feature_map_storage_mem_area:
+            return
+
+        # Force all OFMs to fast-storage
+        for sched_op in self.sched_ops:
+            cost = schedule.cost_map[sched_op]
+            if cost.cascade == 0:
+                if sched_op.get_dependants():
+                    ofm_tens = sched_op.ofm.connection.parent_tens
+                    if not any(cons is None for cons in ofm_tens.consumer_list):
+                        ofm_tens.mem_area = self.arch.fast_storage_mem_area
+                        ofm_tens.mem_type = MemType.Scratch_fast
+
+        # Collect live ranges from tensors
+        memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))]
+        lr_graph = live_range.LiveRangeGraph()
+        for mem_area, mem_type_set in memories_list:
+            live_range.extract_live_ranges_from_cascaded_passes(
+                self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+            )
+
+        # Iterate over live ranges and evict tensors that doesn't fit
+        fast_storage_snapshot = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area)
+        for lr in lr_graph.lrs:
+            if (
+                lr.mem_area == self.arch.fast_storage_mem_area
+                and max(fast_storage_snapshot[lr.start_time : lr.end_time + 1]) > memory_limit
+            ):
+                # Evict tensor to DRAM
+                for tens in lr.tensors:
+                    if tens.purpose == TensorPurpose.FeatureMap and tens.sub_purpose == TensorSubPurpose.Standard:
+                        # Can only evict unbuffered FeatureMaps
+                        tens.mem_area = self.arch.feature_map_storage_mem_area
+                        tens.mem_type = MemType.Scratch
+                        # Adjust the snapshot
+                        fast_storage_snapshot[lr.start_time : lr.end_time + 1] -= lr.size
+
+    def move_constant_data(self):
+        """Determine if  data, can be moved from permanent storage to another memory area. A move
+        will generate a DMA command in the high-level command stream"""
+        for sched_op in self.sched_ops:
+            parent_op = sched_op.parent_op
+            is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in parent_op.inputs)
+            max_ifm_shram_avail = (
+                (self.arch.available_shram_banks(is_lut_used) - self.arch.shram_reserved_output_banks)
+                * self.arch.shram_bank_size
+                // 2
+            )
+
+            for idx, tens in enumerate(parent_op.inputs):
+                if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
+                    # Tensor is in permanent storage
+                    # Only when permanent storage differs from feature map storage, there is a point moving the data
+                    if (
+                        tens.mem_area in self.arch.permanent_storage_mem_area
+                        and self.arch.permanent_storage_mem_area != self.arch.feature_map_storage_mem_area
+                    ) or tens.purpose == TensorPurpose.LUT:
+                        if tens.purpose == TensorPurpose.LUT or (
+                            tens.purpose == TensorPurpose.FeatureMap
+                            and sched_op.op_type.is_binary_elementwise_op()
+                            and tens.shape != []
+                            and sched_op.ifm.shape != sched_op.ofm.shape
+                            and tens.storage_size() > max_ifm_shram_avail
+                        ):
+                            only_vector_product_consumers = all(
+                                oper and oper.type.npu_block_type == NpuBlockType.VectorProduct
+                                for oper in tens.consumers()
+                            )
+
+                            if (not only_vector_product_consumers) or tens.purpose == TensorPurpose.LUT:
+                                new_tens = tens.clone_into_fast_storage(self.arch)
+                                if tens.purpose == TensorPurpose.LUT:
+                                    new_tens.mem_area = MemArea.Shram
+
+                                new_tens.consumer_list.append(parent_op)
+                                parent_op.inputs[idx] = new_tens
+                                sched_op.parent_ps.inputs[idx] = new_tens
+
+    def print_schedule(self, schedule: Schedule):
+        print(f"Schedule: '{schedule.name}'")
+        for sched_op in self.sched_ops:
+            if sched_op not in schedule.cost_map:
+                # Sub-schedule printing
+                continue
+
+            op_info = schedule.cost_map[sched_op]
+            print(f"\t{sched_op.index}: Operation {sched_op.name}  - OFM {sched_op.ofm.shape}")
+            print(f"\t\tType: {sched_op.op_type}")
+            print(f"\t\tKernel: {sched_op.kernel}")
+            print(f"{op_info}")
+            mem_usage = (
+                schedule.memory_snapshot[op_info.time_index]
+                if op_info.time_index < len(schedule.memory_snapshot)
+                else 0
+            )
+            print(f"\t\tSRAM Used: {mem_usage} bytes")
+
+        print(f"\tCascades:")
+        for i, cascade in enumerate(schedule.cascades.values()):
+            print(f"\t\t{i}: {cascade.start} -> {cascade.end}, size: {cascade.mem_usage}")
 
 
-def move_scales_to_fast_storage(nng, arch):
+def _update_tensor_allocation(nng: Graph, arch: ArchitectureFeatures, options):
+    """
+    Creates live ranges and runs tensor allocator for the current schedule
+    (i.e. sg.schedule for all subgraphs), returns the maximum memory usage
+    and updates SchedulerOpInfo.mem_usage for all operations in the schedule.
+    """
+    root_sg = nng.get_root_subgraph()
+
+    alloc_list = []
+    if arch.is_spilling_enabled():
+        mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
+        # Order is important
+        alloc_list.append(mem_alloc_scratch_fast)
+        alloc_list.append(mem_alloc_scratch)
+    else:
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
+        alloc_list.append(mem_alloc_scratch)
+
+    for mem_area, mem_type_set in alloc_list:
+        tensor_allocation.allocate_tensors(
+            nng,
+            root_sg,
+            arch,
+            mem_area,
+            mem_type_set,
+            tensor_allocator=options.tensor_allocator,
+            verbose_allocation=options.verbose_allocation,
+            cpu_tensor_alignment=options.cpu_tensor_alignment,
+        )
+
+
+def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions):
+    """Entry point for the Scheduler"""
+    # Initialize CPU subgraphs
+    schedulers = dict()
+    # Initialize schedulers with max schedule. Only schedule NPU subgraphs
     for sg in nng.subgraphs:
-        # IFM streamed ops reads bias tensors several times, move these to fast storage
-        for cp in sg.cascaded_passes:
-            if cp.strategy == SchedulingStrategy.IfmStream:
-                # Calculate SRAM usage
-                new_size = 0
-                all_tens = []
-                for ps in cp.passes:
-                    pass_tens = np.array([ps.ifm_tensor, ps.ifm2_tensor, ps.ofm_tensor, ps.weight_tensor])
-                    pass_tens = np.append(pass_tens, ps.intermediates)
-                    for tens in pass_tens:
-                        if tens and tens.mem_area == MemArea.Sram and tens not in all_tens:
-                            all_tens.append(tens)
-                            new_size += tens.storage_size()
+        if sg.placement != PassPlacement.Npu:
+            # Create cascaded passes for CPU Ops
+            cascaded_passes = []
+            for idx, ps in enumerate(sg.passes):
+                cps = CascadedPass(
+                    ps.name, SchedulingStrategy.WeightStream, ps.inputs, [], ps.outputs, [ps], ps.placement, False,
+                )
 
-                cp.sram_used = new_size
+                cps.time = idx
+                ps.cascade = cps
+                cascaded_passes.append(cps)
 
-                for ps in cp.passes:
-                    if ps.scale_tensor:
-                        tens = ps.scale_tensor
+            sg.cascaded_passes = cascaded_passes
+        else:
+            # Npu subgraph - create schedule
+            scheduler = Scheduler(nng, sg, arch, scheduler_options)
+            schedulers[sg] = scheduler
 
-                        # Find op using scale tensor
-                        op = next((op for op in ps.ops if tens in op.inputs), None)
-                        assert op
+            scheduler.create_scheduler_representation(arch)
+            sg.sched_ops = scheduler.sched_ops
+            scheduler.move_constant_data()
 
-                        # Create fast storage tensor
-                        new_tens = tens.clone_into_fast_storage(arch)
-                        new_tens.consumer_list = tens.consumer_list.copy()
-                        new_tens.purpose = TensorPurpose.FSBias
-                        new_tens_size = new_tens.storage_size()
+            # Create the Max schedule template
+            max_schedule_template = scheduler.create_initial_schedule()
+            scheduler.max_schedule = max_schedule_template
 
-                        if (cp.sram_used + new_tens_size) <= arch.sram_size:
-                            # Create DMA cmd
-                            dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
-                            dma_cmd.inputs = [tens]
-                            dma_cmd.set_output_tensor(new_tens)
-                            dma_cmd.attrs["source"] = tens.mem_area
-                            dma_cmd.attrs["destination"] = new_tens.mem_area
-                            dma_cmd.run_on_npu = True
+            # Create the optimimised Max schedule
+            sg.schedule = max_schedule_template
+            scheduler.update_op_memory_snapshot(max_schedule_template)
+            opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template)
+            sg.schedule = opt_max_schedule
+            scheduler.update_op_memory_snapshot(opt_max_schedule)
 
-                            tens.consumer_list.clear()
-                            tens.consumer_list.append(dma_cmd)
+            # Create Min schedule
+            min_schedule = scheduler.propose_minimal_schedule()
+            initial_sram_limit = scheduler_options.optimization_sram_limit
+            if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
+                initial_sram_limit = scheduler.min_memory_req
 
-                            # Replace tensor and op
-                            idx = op.inputs.index(tens)
-                            op.inputs[idx] = new_tens
+            cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled())
+            cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit)
+            sg.schedule = min_schedule
+            scheduler.update_op_memory_snapshot(min_schedule)
 
-                            ps.ops.insert(0, dma_cmd)
-                            ps.scale_tensor = new_tens
-                            ps.intermediates.append(new_tens)
-                            ps.cascade.intermediates.append(new_tens)
+            if scheduler_options.optimization_strategy == OptimizationStrategy.Performance:
+                # Create an optimized schedule
+                sg.schedule = scheduler.optimize_schedule(
+                    min_schedule, opt_max_schedule, max_schedule_template, scheduler_options
+                )
+                scheduler.update_op_memory_snapshot(sg.schedule)
 
-                            cp.sram_used += new_tens_size
+            scheduler.apply_schedule(sg.schedule)
+            scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit)
 
+            if scheduler_options.verbose_schedule:
+                scheduler.print_schedule(sg.schedule)
 
-def schedule_passes(nng, arch, options: SchedulerOptions):
-
-    for sg in nng.subgraphs:
-        sg.base_sram_used = 0
-
-    for sg in nng.subgraphs:
-        # re-entering the same nodes from different contexts requires us to
-        # build a simplified directed acyclic (DAG) version of the graph to
-        # use for traversal, rather than using a visit dictionary. this avoids
-        # recursing infinitely due to loops.
-        sg.build_pass_dag_predecessors()
-
-        dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
-
-        strat_set = dps.search()
-
-        dps.apply_result(strat_set, arch)
-
-        if options.verbose_schedule:
-            sg.print_cascaded_passes()
-
-
-def _calc_tens_to_cps(sg, tensor_rewrites):
-    # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
-    # Returns dictionary tensor -> list of cascaded passes
-    # Note: if cascaded passes are A, B, C, D, and a tensor is output
-    # of A and input to D, then it also consumes SRAM in passes B and C.
-    if "tens_to_cps" in sg.scheduling_info:
-        return sg.scheduling_info["tens_to_cps"]
-    # Determine life-time of tensors
-    min_index = {}
-    max_index = {}
-    index = 0
-    cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
-    for cps in cps_list:
-        for tens in cps.inputs + cps.outputs:
-            if tens in tensor_rewrites:
-                min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
-                max_index[tens] = index
-        index += 1
-    # Convert to affected cps-es
-    tens_to_cps = {}
-    for tens in min_index:
-        tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
-    sg.scheduling_info["tens_to_cps"] = tens_to_cps
-    return tens_to_cps
-
-
-def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
-    # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
-    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
-    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
-    # Sort tensors first on life-time (smallest first), then on size (biggest first)
-    tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
-    for _, _, _, tens in tens_list:
-        cps_list = tens_to_cps[tens]
-        if len(cps_list) < 1:
-            continue
-        sz = tens.storage_size()
-        fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
-        if fits_in_fast_storage:
-            tens.mem_area = arch.fast_storage_mem_area
-            tens.mem_type = MemType.Scratch_fast
-            tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
-            assert tens in tensor_rewrites
-            # Also rewrite reshapes
-            for rewrite_op in tensor_rewrites[tens]:
-                tens2 = rewrite_op.outputs[0]
-                tens2.mem_area = arch.fast_storage_mem_area
-                tens2.mem_type = MemType.Scratch_fast
-                tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
-            for cps in cps_list:
-                cps.sram_used += sz
-
-
-def undo_use_fast_storage(sg, arch):
-    # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
-    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
-    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
-    mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
-    for tens, cps_list in tens_to_cps.items():
-        if tens.mem_type == MemType.Scratch_fast:
-            sz = tens.storage_size()
-            tens.mem_area = mem_area
-            tens.mem_type = MemType.Scratch
-            # Also undo reshapes
-            for rewrite_op in tensor_rewrites[tens]:
-                tens2 = rewrite_op.outputs[0]
-                tens2.mem_area = mem_area
-                tens2.mem_type = MemType.Scratch
-            for cps in cps_list:
-                cps.sram_used -= sz
+    # Evaluate schedule
+    _update_tensor_allocation(nng, arch, options)